Open pcgm-team opened 3 months ago
Side note: In recent JAX versions, you need to use jrandom.key(0)
instead of jrandom.PRNGKey(0)
. And setting a global seed will not work, because in JAX, you need to split and pass the key every time you call a random function. Read the tutorial https://jax.readthedocs.io/en/latest/random-numbers.html
So I guess the repro could be:
import jax.random as jrandom
key = jrandom.key(0)
The question is that will this code crash? If so, will this crash on CPU?
Hi @pcgm-team, @ayaka14732
I was unable to reproduce the reported issue on WSL2 CUDA using an NVIDIA GeForce RTX 2060, JAX 0.4.30, and Python 3.9. The mentioned code executed without any error. Please see the attached screenshot.
Thank you.
Description
easily reproducable:
gives ValueError: std::bad_cast
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30 jaxlib: 0.4.30 numpy: 1.26.4 python: 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1 platform: uname_result(system='Linux', node='Passage-17', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')
$ nvidia-smi Tue Jul 30 14:27:13 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 560.27 Driver Version: 560.70 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GeForce RTX 3070 On | 00000000:01:00.0 On | N/A | | 33% 29C P2 18W / 220W | 7719MiB / 8192MiB | 12% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+
WSL2,