Open jiyuuchc opened 3 weeks ago
Running a simple JAX code resulted in segmentation fault on GCP VM.
To reproduce:
Create an instance with GCP's cuda12 image:
Run the following script
#!/bin/bash create_and_activate_venv() { if [ -z "$1" ]; then echo "Usage: create_and_activate_venv <venv_name>" return 1 fi venv_name="$1" if [ -d "$venv_name" ]; then echo "Virtual environment '$venv_name' already exists." else python3 -m venv "$venv_name" echo "Virtual environment '$venv_name' created." fi source "$venv_name/bin/activate" echo "Virtual environment '$venv_name' activated." } create_and_activate_venv jax_test pip install --upgrade pip pip install jax[cuda12] python -c 'import jax; k=jax.random.normal(jax.random.PRNGKey(1),[4096,4096]); jax.numpy.ones([1,4096]) @ k'
Additional information:
I tested with a T4 GPU -- so maybe it won't be a problem with a modern GPU.
jax: 0.4.35 jaxlib: 0.4.34 numpy: 2.1.2 python: 3.10.15 | packaged by conda-forge | (main, Sep 20 2024, 16:37:05) [GCC 13.3.0] device info: Tesla T4-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='instance-20241027-204625', release='5.10.0-32-cloud-amd64', version='#1 SMP Debian 5.10.223-1 (2024-08-10)', machine='x86_64') $ nvidia-smi Sun Oct 27 22:15:38 2024 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 Tesla T4 On | 00000000:00:04.0 Off | 0 | | N/A 41C P0 26W / 70W | 105MiB / 15360MiB | 1% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ +-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 2799 C python 102MiB | +-----------------------------------------------------------------------------------------+
What is the full stack of the error message you see?
There's no trace stack at all.
The error is almost certainly due to JAX not reserving enough VRAM - see nvidia-smi output. But I have no idea what's causing it.
Description
Running a simple JAX code resulted in segmentation fault on GCP VM.
To reproduce:
Create an instance with GCP's cuda12 image:
Run the following script
Additional information:
I tested with a T4 GPU -- so maybe it won't be a problem with a modern GPU.
System info (python version, jaxlib version, accelerator, etc.)