jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.4k stars 2.79k forks source link

PRNGKey problems with WSL2 CUDA #22762

Open pcgm-team opened 3 months ago

pcgm-team commented 3 months ago

Description

easily reproducable:

import jax.random as jrandom
import os
def set_global_seed(seed):
    key = jrandom.PRNGKey(seed)
    return key

set_global_seed(0)

gives ValueError: std::bad_cast

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.30 jaxlib: 0.4.30 numpy: 1.26.4 python: 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0] jax.devices (1 total, 1 local): [cuda(id=0)] process_count: 1 platform: uname_result(system='Linux', node='Passage-17', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')

$ nvidia-smi Tue Jul 30 14:27:13 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 560.27 Driver Version: 560.70 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GeForce RTX 3070 On | 00000000:01:00.0 On | N/A | | 33% 29C P2 18W / 220W | 7719MiB / 8192MiB | 12% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+

WSL2,

ayaka14732 commented 3 months ago

Side note: In recent JAX versions, you need to use jrandom.key(0) instead of jrandom.PRNGKey(0). And setting a global seed will not work, because in JAX, you need to split and pass the key every time you call a random function. Read the tutorial https://jax.readthedocs.io/en/latest/random-numbers.html

So I guess the repro could be:

import jax.random as jrandom
key = jrandom.key(0)

The question is that will this code crash? If so, will this crash on CPU?

rajasekharporeddy commented 3 months ago

Hi @pcgm-team, @ayaka14732

I was unable to reproduce the reported issue on WSL2 CUDA using an NVIDIA GeForce RTX 2060, JAX 0.4.30, and Python 3.9. The mentioned code executed without any error. Please see the attached screenshot.

image

Thank you.