jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

Unable to run jax on gpu #13758

Open Sruinard opened 1 year ago

Sruinard commented 1 year ago

Description

I'm facing many challenges running jax/flax on a gpu. I have created a virtual machine with ubuntu20.04 (azure compute instance).

after installing the required toolkits, libraries etc, I'm stuck with the following error when running the file forward_pass.py. Jax seems to recognize the GPU, but fails with the following error:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.gemm' failed: sm_3 does not support explicit gemm algorithms..

The content of the file is as follows:

import jax
from jax.lib import xla_bridge

from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

print(xla_bridge.get_backend().platform)

class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.relu(nn.Dense(feat)(x))
        x = nn.Dense(self.features[-1])(x)
        return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

What jax/jaxlib version are you using?

jax==0.4.1, jaxlib==0.4.1+cuda11.cudnn86, flax==0.6.3

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.8.5

NVIDIA GPU info

$nvidia-smi

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.141.10   Driver Version: 470.141.10   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla K80           On   | 00000001:00:00.0 Off |                    0 |
| N/A   32C    P8    33W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

$nvcc -V

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_Oct_11_21:27:02_PDT_2021
Cuda compilation tools, release 11.4, V11.4.152
Build cuda_11.4.r11.4/compiler.30521435_0
dbrovml commented 1 year ago

Hi @Sruinard, I'm facing the same issue. Have you by any chance come up with a solution? :(

Sruinard commented 1 year ago

Hi @dbrovml,

Downgrading seemed to be working for me.

Perhaps you can try the following:

conda.yml
name: deeprl
channels:
  - defaults
  - conda-forge
  - nvidia/label/cuda-11.4.1
dependencies:
  - python=3.9
  - pip
  - cuda-nvcc
  - cuda-toolkit=11.4
  - cudnn=8.4.1.50
  - pip:
      - --requirement requirements.txt
requirements.txt

clu==0.0.6
flax==0.6.3
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax==0.3.25
jaxlib==0.3.25+cuda11.cudnn82  # Make sure CUDA version matches the base image.
ml-collections==0.1.0
numpy==1.22.0
optax==0.1.4
jannisborn commented 1 year ago

Any update on this? @Sruinard did you find a workaround without downgrading? It becomes increasingly more disruptive to depend on a version < 0.4

rajasekharporeddy commented 6 months ago

Hi @Sruinard

I tested provided code on GCP Ubuntu 20.04LTS with Tesla P100 and Tesla T4 GPUs with JAX 0.4.28 and Flax 0.8.4. I could not reproduce the error mentioned and works fine. Please find the below screenshots for reference.

image image image

Could you please verify in your setup with latest JAX and Flax versions and let us know if the issue still persists?

Thank you.