jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.56k stars 2.81k forks source link

segmentation fault from a simple test on GCP vm #24548

Open jiyuuchc opened 3 weeks ago

jiyuuchc commented 3 weeks ago

Description

Running a simple JAX code resulted in segmentation fault on GCP VM.

To reproduce:

#!/bin/bash

create_and_activate_venv() {
  if [ -z "$1" ]; then
    echo "Usage: create_and_activate_venv <venv_name>"
    return 1
  fi

  venv_name="$1"

  if [ -d "$venv_name" ]; then
    echo "Virtual environment '$venv_name' already exists."
  else
    python3 -m venv "$venv_name"
    echo "Virtual environment '$venv_name' created."
  fi

  source "$venv_name/bin/activate"
  echo "Virtual environment '$venv_name' activated."
}

create_and_activate_venv jax_test

pip install --upgrade pip
pip install jax[cuda12]

python -c 'import jax; k=jax.random.normal(jax.random.PRNGKey(1),[4096,4096]); jax.numpy.ones([1,4096]) @ k'

Additional information:

I tested with a T4 GPU -- so maybe it won't be a problem with a modern GPU.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.10.15 | packaged by conda-forge | (main, Sep 20 2024, 16:37:05) [GCC 13.3.0]
device info: Tesla T4-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='instance-20241027-204625', release='5.10.0-32-cloud-amd64', version='#1 SMP Debian 5.10.223-1 (2024-08-10)', machine='x86_64')

$ nvidia-smi
Sun Oct 27 22:15:38 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla T4                       On  |   00000000:00:04.0 Off |                    0 |
| N/A   41C    P0             26W /   70W |     105MiB /  15360MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2799      C   python                                        102MiB |
+-----------------------------------------------------------------------------------------+
lockwo commented 3 weeks ago

What is the full stack of the error message you see?

jiyuuchc commented 3 weeks ago

There's no trace stack at all.

The error is almost certainly due to JAX not reserving enough VRAM - see nvidia-smi output. But I have no idea what's causing it.