jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.35k stars 2.78k forks source link

segmentation fault from a simple test on GCP vm #24548

Open jiyuuchc opened 1 day ago

jiyuuchc commented 1 day ago

Description

Running a simple JAX code resulted in segmentation fault on GCP VM.

To reproduce:

#!/bin/bash

create_and_activate_venv() {
  if [ -z "$1" ]; then
    echo "Usage: create_and_activate_venv <venv_name>"
    return 1
  fi

  venv_name="$1"

  if [ -d "$venv_name" ]; then
    echo "Virtual environment '$venv_name' already exists."
  else
    python3 -m venv "$venv_name"
    echo "Virtual environment '$venv_name' created."
  fi

  source "$venv_name/bin/activate"
  echo "Virtual environment '$venv_name' activated."
}

create_and_activate_venv jax_test

pip install --upgrade pip
pip install jax[cuda12]

python -c 'import jax; k=jax.random.normal(jax.random.PRNGKey(1),[4096,4096]); jax.numpy.ones([1,4096]) @ k'

Additional information:

I tested with a T4 GPU -- so maybe it won't be a problem with a modern GPU.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.10.15 | packaged by conda-forge | (main, Sep 20 2024, 16:37:05) [GCC 13.3.0]
device info: Tesla T4-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='instance-20241027-204625', release='5.10.0-32-cloud-amd64', version='#1 SMP Debian 5.10.223-1 (2024-08-10)', machine='x86_64')

$ nvidia-smi
Sun Oct 27 22:15:38 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       On  |   00000000:00:04.0 Off |                    0 |
| N/A   41C    P0             26W /   70W |     105MiB /  15360MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2799      C   python                                        102MiB |
+-----------------------------------------------------------------------------------------+
lockwo commented 23 hours ago

What is the full stack of the error message you see?

jiyuuchc commented 16 hours ago

There's no trace stack at all.

The error is almost certainly due to JAX not reserving enough VRAM - see nvidia-smi output. But I have no idea what's causing it.