jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

JAX doesn't work with cuda/gpu #15268

Closed tychovdo closed 1 year ago

tychovdo commented 1 year ago

Description

Hi all,

For a new project, I am trying to install JAX with cuda/gpu support. I installed cuda/cudnn using conda cudatoolkit==11.7. The same conda environment contains working pytorch and tensorflow installs, which seem to work on gpu.

$ nvidia-smi

| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A30          Off  | 00000000:01:00.0 Off |                    0 |
| N/A   35C    P0    30W / 165W |      0MiB / 24576MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

$ nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

pytorch does seem to work:

>>> import torch
>>> torch.ones(4).cuda() + 2
tensor([3., 3., 3., 3.], device='cuda:0')

If I try jax, I can find the GPU:

>>> import jax
>>> jax.devices()
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

But it crashes when I try to do something with it:

>>> import jax.numpy as np
>>> np.ones(5)
2023-03-28 20:24:31.407336: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2092, in ones
    return lax.full(shape, 1, _jnp_dtype(dtype))
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1190, in full
    return broadcast(fill_value, shape)
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 756, in broadcast
    return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 784, in broadcast_in_dim
    return broadcast_in_dim_p.bind(
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/core.py", line 359, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/core.py", line 362, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/core.py", line 816, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/my_home_dir/anaconda3/envs/gpu2/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

I tried by jax/jaxlib 0.4.6 and 0.4.7, and also tried the [cuda11_cudnn82] tag in the pip install.

Many thanks in advance.

What jax/jaxlib version are you using?

jax 0.4.6, jaxlib 0.4.7

Which accelerator(s) are you using?

GPU

Additional system info

Linux

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A30          Off  | 00000000:01:00.0 Off |                    0 |
| N/A   35C    P0    30W / 165W |      0MiB / 24576MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
hawkinsp commented 1 year ago
2023-03-28 20:24:31.407336: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.

This is the key message. Note: we found a CuDNN v8.5, but JAX was built against CuDNN v8.6. You need to update CuDNN.

Is there an older CuDNN somewhere on your system?

tychovdo commented 1 year ago

Thanks for the fast response.

Instead, I also tried installing JAX built with a lower CuDNN version from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html using the command:

pip install https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.7+cuda11.cudnn82-cp39-cp39-manylinux2014_x86_64.whl
pip install jax==0.4.7

If I understand correctly, this build should work with cudnn >= 8.2 (judging from the last error, my system has CuDNN v8.5). I am still getting an error, although it's a slightly different one:

>>> import jax
>>> jax.devices()
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
>>> import jax.numpy as np
>>> np.ones(5)
2023-03-28 22:27:30.612627: W external/xla/xla/stream_executor/cuda/cuda_dnn.cc:397] There was an error before creating cudnn handle: cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.
2023-03-28 22:27:30.612751: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2092, in ones
    return lax.full(shape, 1, _jnp_dtype(dtype))
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1190, in full
    return broadcast(fill_value, shape)
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 756, in broadcast
    return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 784, in broadcast_in_dim
    return broadcast_in_dim_p.bind(
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/core.py", line 359, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/core.py", line 362, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/core.py", line 816, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/my_home_dir/anaconda3/envs/gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Again, this is in the same conda env where pytorch and tensorflow seem to have no issue with finding cudnn.

Thanks a lot!

nouiz commented 1 year ago

In JAX 0.4.7, there is a new way to install jax that use pip packages for cuda stuff (except the driver). Can you use pip instead of conda to install it?

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

It works for us, but wasn't widely tested. Do this in a fresh env, as it will install cuda 12 and you don't want both cuda to be installed in your env.

tychovdo commented 1 year ago

This seems to work. Thanks a lot!

mehdiataei commented 1 year ago

I don't believe this issue should be closed, as I have confirmed that I encounter the same problem when performing either cuda12_local or cuda11_local installations. So far, I have not been able to find a solution. It appears that the local installations are not functioning properly and may require attention.

mjsML commented 1 year ago

@mehdiataei can you post a repro with nvidia-smi output ... etc?

mehdiataei commented 1 year ago
Fri Mar 31 17:55:11 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.142.00   Driver Version: 450.142.00   CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  A100-SXM4-40GB      On   | 00000000:07:00.0 Off |                    0 |
| N/A   32C    P0    51W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  A100-SXM4-40GB      On   | 00000000:0F:00.0 Off |                    0 |
| N/A   30C    P0    51W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  A100-SXM4-40GB      On   | 00000000:47:00.0 Off |                    0 |
| N/A   31C    P0    52W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  A100-SXM4-40GB      On   | 00000000:4E:00.0 Off |                    0 |
| N/A   32C    P0    56W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  A100-SXM4-40GB      On   | 00000000:87:00.0 Off |                    0 |
| N/A   36C    P0    55W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  A100-SXM4-40GB      On   | 00000000:90:00.0 Off |                    0 |
| N/A   34C    P0    53W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  A100-SXM4-40GB      On   | 00000000:B7:00.0 Off |                    0 |
| N/A   35C    P0    51W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  A100-SXM4-40GB      On   | 00000000:BD:00.0 Off |                    0 |
| N/A   36C    P0    51W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+ 

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Using nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 docker image installing JAX would result in the error. I just verified that installing CuDNN manually solves it. Could this be a linking error?

hawkinsp commented 1 year ago

@mehdiataei what version of cudnn is present in that docker image?

mehdiataei commented 1 year ago

ENV NV_CUDNN_VERSION=8.7.0.84

ENV NV_CUDNN_PACKAGE_NAME=libcudnn8

ENV NV_CUDNN_PACKAGE=libcudnn8=8.7.0.84-1+cuda11.8

ENV NV_CUDNN_PACKAGE_DEV=libcudnn8-dev=8.7.0.84-1+cuda11.8

/usr/include/x86_64-linux-gnu/cudnn_v8.h
/usr/include/x86_64-linux-gnu/cudnn_version_v8.h
/usr/include/x86_64-linux-gnu/cudnn_adv_train_v8.h
/usr/include/x86_64-linux-gnu/cudnn_cnn_train_v8.h
/usr/include/x86_64-linux-gnu/cudnn_adv_infer_v8.h
/usr/include/x86_64-linux-gnu/cudnn_backend_v8.h
/usr/include/x86_64-linux-gnu/cudnn_ops_train_v8.h
/usr/include/x86_64-linux-gnu/cudnn_cnn_infer_v8.h
/usr/include/x86_64-linux-gnu/cudnn_ops_infer_v8.h
/usr/include/cudnn_ops_train.h
/usr/include/cudnn_version.h
/usr/include/cudnn_backend.h
/usr/include/cudnn_ops_infer.h
/usr/include/cudnn_cnn_train.h
/usr/include/cudnn_adv_infer.h
/usr/include/cudnn_adv_train.h
/usr/include/cudnn_cnn_infer.h
/usr/include/cudnn.h
/usr/src/cudnn_samples_v8
/usr/share/lintian/overrides/libcudnn8
/usr/share/lintian/overrides/libcudnn8-dev
/usr/share/doc/libcudnn8
/usr/share/doc/libcudnn8-dev
/usr/local/lib/python3.8/dist-packages/jaxlib-0.4.7+cuda11.cudnn86.dist-info
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train_static.a
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer_static.a
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train_static_v8.a
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer_static.a
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer_static_v8.a
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train_static.a
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer_static.a
/usr/lib/x86_64-linux-gnu/libcudnn.so.8
/usr/lib/x86_64-linux-gnu/libcudnn.so
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer_static_v8.a
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer_static_v8.a
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train_static_v8.a
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train_static.a
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train_static_v8.a
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.7.0
/etc/alternatives/cudnn_version_h
/etc/alternatives/libcudnn_adv_infer_so
/etc/alternatives/cudnn_cnn_infer_h
/etc/alternatives/cudnn_adv_train_h
/etc/alternatives/libcudnn_ops_infer_so
/etc/alternatives/cudnn_backend_h
/etc/alternatives/cudnn_ops_train_h
/etc/alternatives/cudnn_cnn_train_h
/etc/alternatives/libcudnn_cnn_train_so
/etc/alternatives/libcudnn_so
/etc/alternatives/libcudnn_cnn_infer_so
/etc/alternatives/libcudnn_ops_train_so
/etc/alternatives/cudnn_adv_infer_h
/etc/alternatives/libcudnn
/etc/alternatives/cudnn_ops_infer_h
/etc/alternatives/libcudnn_adv_train_so
/var/lib/dpkg/info/libcudnn8-dev.md5sums
/var/lib/dpkg/info/libcudnn8-dev.postinst
/var/lib/dpkg/info/libcudnn8-dev.list
/var/lib/dpkg/info/libcudnn8-dev.prerm
/var/lib/dpkg/info/libcudnn8.md5sums
/var/lib/dpkg/info/libcudnn8.list
/var/lib/dpkg/alternatives/libcudnn
hawkinsp commented 1 year ago

@mehdiataei I cannot reproduce. On a cloud VM with an NVIDIA T4 GPU, I did this:

$ docker run -it  --gpus=all nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 bash

==========
== CUDA ==
==========

CUDA Version 11.8.0

Container image Copyright (c) 2016-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

root@78346e6f8048:/# apt update
Get:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  InRelease [1581 B]
Get:2 http://archive.ubuntu.com/ubuntu focal InRelease [265 kB]
Get:3 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]
Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  Packages [969 kB]
Get:5 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]
Get:6 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Get:7 http://security.ubuntu.com/ubuntu focal-security/universe amd64 Packages [1027 kB]
Get:8 http://archive.ubuntu.com/ubuntu focal/universe amd64 Packages [11.3 MB]
Get:9 http://security.ubuntu.com/ubuntu focal-security/main amd64 Packages [2590 kB]
Get:10 http://security.ubuntu.com/ubuntu focal-security/restricted amd64 Packages [2060 kB]
Get:11 http://security.ubuntu.com/ubuntu focal-security/multiverse amd64 Packages [28.5 kB]
Get:12 http://archive.ubuntu.com/ubuntu focal/main amd64 Packages [1275 kB]
Get:13 http://archive.ubuntu.com/ubuntu focal/multiverse amd64 Packages [177 kB]
Get:14 http://archive.ubuntu.com/ubuntu focal/restricted amd64 Packages [33.4 kB]
Get:15 http://archive.ubuntu.com/ubuntu focal-updates/universe amd64 Packages [1323 kB]
Get:16 http://archive.ubuntu.com/ubuntu focal-updates/restricted amd64 Packages [2198 kB]
Get:17 http://archive.ubuntu.com/ubuntu focal-updates/multiverse amd64 Packages [31.2 kB]
Get:18 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages [3069 kB]
Get:19 http://archive.ubuntu.com/ubuntu focal-backports/main amd64 Packages [55.2 kB]
Get:20 http://archive.ubuntu.com/ubuntu focal-backports/universe amd64 Packages [28.6 kB]
Fetched 26.8 MB in 2s (13.0 MB/s)
Reading package lists... Done
Building dependency tree
Reading state information... Done
26 packages can be upgraded. Run 'apt list --upgradable' to see them.
root@78346e6f8048:/# apt install python3-pip
Reading package lists... Done
Building dependency tree
Reading state information... Done
The following additional packages will be installed:
  file libexpat1 libexpat1-dev libmagic-mgc libmagic1 libmpdec2 libpython3-dev libpython3-stdlib libpython3.8 libpython3.8-dev libpython3.8-minimal libpython3.8-stdlib mime-support python-pip-whl python3
  python3-dev python3-distutils python3-lib2to3 python3-minimal python3-pkg-resources python3-setuptools python3-wheel python3.8 python3.8-dev python3.8-minimal zlib1g-dev
Suggested packages:
  python3-doc python3-tk python3-venv python-setuptools-doc python3.8-venv python3.8-doc binfmt-support
The following NEW packages will be installed:
  file libexpat1 libexpat1-dev libmagic-mgc libmagic1 libmpdec2 libpython3-dev libpython3-stdlib libpython3.8 libpython3.8-dev libpython3.8-minimal libpython3.8-stdlib mime-support python-pip-whl python3
  python3-dev python3-distutils python3-lib2to3 python3-minimal python3-pip python3-pkg-resources python3-setuptools python3-wheel python3.8 python3.8-dev python3.8-minimal zlib1g-dev
0 upgraded, 27 newly installed, 0 to remove and 26 not upgraded.
Need to get 14.4 MB of archives.
After this operation, 61.7 MB of additional disk space will be used.
Do you want to continue? [Y/n]
Get:1 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 libpython3.8-minimal amd64 3.8.10-0ubuntu1~20.04.7 [717 kB]
Get:2 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 libexpat1 amd64 2.2.9-1ubuntu0.6 [74.6 kB]
Get:3 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 python3.8-minimal amd64 3.8.10-0ubuntu1~20.04.7 [1903 kB]
Get:4 http://archive.ubuntu.com/ubuntu focal/main amd64 python3-minimal amd64 3.8.2-0ubuntu2 [23.6 kB]
Get:5 http://archive.ubuntu.com/ubuntu focal/main amd64 mime-support all 3.64ubuntu1 [30.6 kB]
Get:6 http://archive.ubuntu.com/ubuntu focal/main amd64 libmpdec2 amd64 2.4.2-3 [81.1 kB]
Get:7 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 libpython3.8-stdlib amd64 3.8.10-0ubuntu1~20.04.7 [1675 kB]
Get:8 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 python3.8 amd64 3.8.10-0ubuntu1~20.04.7 [387 kB]
Get:9 http://archive.ubuntu.com/ubuntu focal/main amd64 libpython3-stdlib amd64 3.8.2-0ubuntu2 [7068 B]
Get:10 http://archive.ubuntu.com/ubuntu focal/main amd64 python3 amd64 3.8.2-0ubuntu2 [47.6 kB]
Get:11 http://archive.ubuntu.com/ubuntu focal/main amd64 libmagic-mgc amd64 1:5.38-4 [218 kB]
Get:12 http://archive.ubuntu.com/ubuntu focal/main amd64 libmagic1 amd64 1:5.38-4 [75.9 kB]
Get:13 http://archive.ubuntu.com/ubuntu focal/main amd64 file amd64 1:5.38-4 [23.3 kB]
Get:14 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 python3-pkg-resources all 45.2.0-1ubuntu0.1 [130 kB]
Get:15 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 libexpat1-dev amd64 2.2.9-1ubuntu0.6 [116 kB]
Get:16 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 libpython3.8 amd64 3.8.10-0ubuntu1~20.04.7 [1626 kB]
Get:17 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 libpython3.8-dev amd64 3.8.10-0ubuntu1~20.04.7 [3953 kB]
Get:18 http://archive.ubuntu.com/ubuntu focal/main amd64 libpython3-dev amd64 3.8.2-0ubuntu2 [7236 B]
Get:19 http://archive.ubuntu.com/ubuntu focal-updates/universe amd64 python-pip-whl all 20.0.2-5ubuntu1.8 [1805 kB]
Get:20 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 zlib1g-dev amd64 1:1.2.11.dfsg-2ubuntu1.5 [155 kB]
Get:21 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 python3.8-dev amd64 3.8.10-0ubuntu1~20.04.7 [514 kB]
Get:22 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 python3-lib2to3 all 3.8.10-0ubuntu1~20.04 [76.3 kB]
Get:23 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 python3-distutils all 3.8.10-0ubuntu1~20.04 [141 kB]
Get:24 http://archive.ubuntu.com/ubuntu focal/main amd64 python3-dev amd64 3.8.2-0ubuntu2 [1212 B]
Get:25 http://archive.ubuntu.com/ubuntu focal-updates/main amd64 python3-setuptools all 45.2.0-1ubuntu0.1 [330 kB]
Get:26 http://archive.ubuntu.com/ubuntu focal-updates/universe amd64 python3-wheel all 0.34.2-1ubuntu0.1 [23.9 kB]
Get:27 http://archive.ubuntu.com/ubuntu focal-updates/universe amd64 python3-pip all 20.0.2-5ubuntu1.8 [231 kB]
Fetched 14.4 MB in 1s (9640 kB/s)
debconf: delaying package configuration, since apt-utils is not installed
Selecting previously unselected package libpython3.8-minimal:amd64.
(Reading database ... 12894 files and directories currently installed.)
Preparing to unpack .../libpython3.8-minimal_3.8.10-0ubuntu1~20.04.7_amd64.deb ...
Unpacking libpython3.8-minimal:amd64 (3.8.10-0ubuntu1~20.04.7) ...
Selecting previously unselected package libexpat1:amd64.
Preparing to unpack .../libexpat1_2.2.9-1ubuntu0.6_amd64.deb ...
Unpacking libexpat1:amd64 (2.2.9-1ubuntu0.6) ...
Selecting previously unselected package python3.8-minimal.
Preparing to unpack .../python3.8-minimal_3.8.10-0ubuntu1~20.04.7_amd64.deb ...
Unpacking python3.8-minimal (3.8.10-0ubuntu1~20.04.7) ...
Setting up libpython3.8-minimal:amd64 (3.8.10-0ubuntu1~20.04.7) ...
Setting up libexpat1:amd64 (2.2.9-1ubuntu0.6) ...
Setting up python3.8-minimal (3.8.10-0ubuntu1~20.04.7) ...
Selecting previously unselected package python3-minimal.
(Reading database ... 13185 files and directories currently installed.)
Preparing to unpack .../0-python3-minimal_3.8.2-0ubuntu2_amd64.deb ...
Unpacking python3-minimal (3.8.2-0ubuntu2) ...
Selecting previously unselected package mime-support.
Preparing to unpack .../1-mime-support_3.64ubuntu1_all.deb ...
Unpacking mime-support (3.64ubuntu1) ...
Selecting previously unselected package libmpdec2:amd64.
Preparing to unpack .../2-libmpdec2_2.4.2-3_amd64.deb ...
Unpacking libmpdec2:amd64 (2.4.2-3) ...
Selecting previously unselected package libpython3.8-stdlib:amd64.
Preparing to unpack .../3-libpython3.8-stdlib_3.8.10-0ubuntu1~20.04.7_amd64.deb ...
Unpacking libpython3.8-stdlib:amd64 (3.8.10-0ubuntu1~20.04.7) ...
Selecting previously unselected package python3.8.
Preparing to unpack .../4-python3.8_3.8.10-0ubuntu1~20.04.7_amd64.deb ...
Unpacking python3.8 (3.8.10-0ubuntu1~20.04.7) ...
Selecting previously unselected package libpython3-stdlib:amd64.
Preparing to unpack .../5-libpython3-stdlib_3.8.2-0ubuntu2_amd64.deb ...
Unpacking libpython3-stdlib:amd64 (3.8.2-0ubuntu2) ...
Setting up python3-minimal (3.8.2-0ubuntu2) ...
Selecting previously unselected package python3.
(Reading database ... 13587 files and directories currently installed.)
Preparing to unpack .../00-python3_3.8.2-0ubuntu2_amd64.deb ...
Unpacking python3 (3.8.2-0ubuntu2) ...
Selecting previously unselected package libmagic-mgc.
Preparing to unpack .../01-libmagic-mgc_1%3a5.38-4_amd64.deb ...
Unpacking libmagic-mgc (1:5.38-4) ...
Selecting previously unselected package libmagic1:amd64.
Preparing to unpack .../02-libmagic1_1%3a5.38-4_amd64.deb ...
Unpacking libmagic1:amd64 (1:5.38-4) ...
Selecting previously unselected package file.
Preparing to unpack .../03-file_1%3a5.38-4_amd64.deb ...
Unpacking file (1:5.38-4) ...
Selecting previously unselected package python3-pkg-resources.
Preparing to unpack .../04-python3-pkg-resources_45.2.0-1ubuntu0.1_all.deb ...
Unpacking python3-pkg-resources (45.2.0-1ubuntu0.1) ...
Selecting previously unselected package libexpat1-dev:amd64.
Preparing to unpack .../05-libexpat1-dev_2.2.9-1ubuntu0.6_amd64.deb ...
Unpacking libexpat1-dev:amd64 (2.2.9-1ubuntu0.6) ...
Selecting previously unselected package libpython3.8:amd64.
Preparing to unpack .../06-libpython3.8_3.8.10-0ubuntu1~20.04.7_amd64.deb ...
Unpacking libpython3.8:amd64 (3.8.10-0ubuntu1~20.04.7) ...
Selecting previously unselected package libpython3.8-dev:amd64.
Preparing to unpack .../07-libpython3.8-dev_3.8.10-0ubuntu1~20.04.7_amd64.deb ...
Unpacking libpython3.8-dev:amd64 (3.8.10-0ubuntu1~20.04.7) ...
Selecting previously unselected package libpython3-dev:amd64.
Preparing to unpack .../08-libpython3-dev_3.8.2-0ubuntu2_amd64.deb ...
Unpacking libpython3-dev:amd64 (3.8.2-0ubuntu2) ...
Selecting previously unselected package python-pip-whl.
Preparing to unpack .../09-python-pip-whl_20.0.2-5ubuntu1.8_all.deb ...
Unpacking python-pip-whl (20.0.2-5ubuntu1.8) ...
Selecting previously unselected package zlib1g-dev:amd64.
Preparing to unpack .../10-zlib1g-dev_1%3a1.2.11.dfsg-2ubuntu1.5_amd64.deb ...
Unpacking zlib1g-dev:amd64 (1:1.2.11.dfsg-2ubuntu1.5) ...
Selecting previously unselected package python3.8-dev.
Preparing to unpack .../11-python3.8-dev_3.8.10-0ubuntu1~20.04.7_amd64.deb ...
Unpacking python3.8-dev (3.8.10-0ubuntu1~20.04.7) ...
Selecting previously unselected package python3-lib2to3.
Preparing to unpack .../12-python3-lib2to3_3.8.10-0ubuntu1~20.04_all.deb ...
Unpacking python3-lib2to3 (3.8.10-0ubuntu1~20.04) ...
Selecting previously unselected package python3-distutils.
Preparing to unpack .../13-python3-distutils_3.8.10-0ubuntu1~20.04_all.deb ...
Unpacking python3-distutils (3.8.10-0ubuntu1~20.04) ...
Selecting previously unselected package python3-dev.
Preparing to unpack .../14-python3-dev_3.8.2-0ubuntu2_amd64.deb ...
Unpacking python3-dev (3.8.2-0ubuntu2) ...
Selecting previously unselected package python3-setuptools.
Preparing to unpack .../15-python3-setuptools_45.2.0-1ubuntu0.1_all.deb ...
Unpacking python3-setuptools (45.2.0-1ubuntu0.1) ...
Selecting previously unselected package python3-wheel.
Preparing to unpack .../16-python3-wheel_0.34.2-1ubuntu0.1_all.deb ...
Unpacking python3-wheel (0.34.2-1ubuntu0.1) ...
Selecting previously unselected package python3-pip.
Preparing to unpack .../17-python3-pip_20.0.2-5ubuntu1.8_all.deb ...
Unpacking python3-pip (20.0.2-5ubuntu1.8) ...
Setting up mime-support (3.64ubuntu1) ...
Setting up libmagic-mgc (1:5.38-4) ...
Setting up libmagic1:amd64 (1:5.38-4) ...
Setting up file (1:5.38-4) ...
Setting up libexpat1-dev:amd64 (2.2.9-1ubuntu0.6) ...
Setting up zlib1g-dev:amd64 (1:1.2.11.dfsg-2ubuntu1.5) ...
Setting up python-pip-whl (20.0.2-5ubuntu1.8) ...
Setting up libmpdec2:amd64 (2.4.2-3) ...
Setting up libpython3.8-stdlib:amd64 (3.8.10-0ubuntu1~20.04.7) ...
Setting up python3.8 (3.8.10-0ubuntu1~20.04.7) ...
Setting up libpython3-stdlib:amd64 (3.8.2-0ubuntu2) ...
Setting up python3 (3.8.2-0ubuntu2) ...
running python rtupdate hooks for python3.8...
running python post-rtupdate hooks for python3.8...
Setting up python3-wheel (0.34.2-1ubuntu0.1) ...
Setting up libpython3.8:amd64 (3.8.10-0ubuntu1~20.04.7) ...
Setting up python3-lib2to3 (3.8.10-0ubuntu1~20.04) ...
Setting up python3-pkg-resources (45.2.0-1ubuntu0.1) ...
Setting up python3-distutils (3.8.10-0ubuntu1~20.04) ...
Setting up python3-setuptools (45.2.0-1ubuntu0.1) ...
Setting up libpython3.8-dev:amd64 (3.8.10-0ubuntu1~20.04.7) ...
Setting up python3-pip (20.0.2-5ubuntu1.8) ...
Setting up python3.8-dev (3.8.10-0ubuntu1~20.04.7) ...
Setting up libpython3-dev:amd64 (3.8.2-0ubuntu2) ...
Setting up python3-dev (3.8.2-0ubuntu2) ...
Processing triggers for libc-bin (2.31-0ubuntu9.9) ...
root@78346e6f8048:/# pip install -q ipython
root@78346e6f8048:/# pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda11_local]
  Downloading jax-0.4.8.tar.gz (1.2 MB)
     |████████████████████████████████| 1.2 MB 4.9 MB/s
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
    Preparing wheel metadata ... done
Collecting scipy>=1.7
  Downloading scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)
     |████████████████████████████████| 34.5 MB 58.7 MB/s
Collecting opt-einsum
  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
     |████████████████████████████████| 65 kB 7.7 MB/s
Collecting numpy>=1.21
  Downloading numpy-1.24.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
     |████████████████████████████████| 17.3 MB 47.2 MB/s
Collecting ml-dtypes>=0.0.3
  Downloading ml_dtypes-0.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (154 kB)
     |████████████████████████████████| 154 kB 58.2 MB/s
Collecting jaxlib==0.4.7+cuda11.cudnn86; extra == "cuda11_local"
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.7%2Bcuda11.cudnn86-cp38-cp38-manylinux2014_x86_64.whl (152.0 MB)
     |████████████████████████████████| 152.0 MB 77 kB/s
Building wheels for collected packages: jax
  Building wheel for jax (PEP 517) ... done
  Created wheel for jax: filename=jax-0.4.8-py3-none-any.whl size=1439678 sha256=861f86a7adba492588bccfeb8fb87177446bb07a3af15edae5905204b2bc9579
  Stored in directory: /root/.cache/pip/wheels/45/83/1e/3db22c5e1941c10e41c4f5cdf829b0a358146d4d0733d4a105
Successfully built jax
Installing collected packages: numpy, scipy, opt-einsum, ml-dtypes, jaxlib, jax
Successfully installed jax-0.4.8 jaxlib-0.4.7+cuda11.cudnn86 ml-dtypes-0.0.4 numpy-1.24.2 opt-einsum-3.3.0 scipy-1.10.1
root@78346e6f8048:/#
root@78346e6f8048:/# ipython
Python 3.8.10 (default, Mar 13 2023, 10:26:41)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.12.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax

In [2]: jax.devices()
Out[2]: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

In [3]: jax.numpy.ones((7,))
Out[3]: Array([1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [4]:
Do you really want to exit ([y]/n)? y
nvroot@78346e6f8048:/# nvidia-smi
Fri Mar 31 19:40:44 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    35W /  70W |      2MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
nouiz commented 1 year ago

@tychovdo This part of the error message:

2023-03-28 22:27:30.612627: W external/xla/xla/stream_executor/cuda/cuda_dnn.cc:397] There was an error before creating cudnn handle: cudaGetErrorName symbol not found. : cudaGetErrorString symbol not found.

Seem to indicate there is a driver issues. I did a small PR to improve the error message a little bit: https://github.com/openxla/xla/pull/2335 But are you able to use other software on the GPU? Maybe reinstalling the driver can help you.

dcbrien commented 1 year ago

Similar errors here with local install of cuda-11.8 and cudnn 8.8 on Ubuntu 20.04 WSL2 and then used:

pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

for install. As others have said, Tensorflow and the cudnn samples run without issue.

Running any JAX code gives:

E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1207] Failed to get stream capture info: operation not supported E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2432] Execution of replica 0 failed: INVALID_ARGUMENT: stream is uninitialized or in an error state

`+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 530.41.03 Driver Version: 531.41 CUDA Version: 12.1 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA TITAN Xp On | 00000000:02:00.0 On | N/A | | 23% 35C P5 16W / 250W| 1163MiB / 12288MiB | 2% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | No running processes found | +---------------------------------------------------------------------------------------+`

hawkinsp commented 1 year ago

@dcbrien WSL is community-supported. This is not a configuration we test.

nouiz commented 1 year ago

https://github.com/openxla/xla/pull/2335 is merged. So the error message should be better in the next release.

dcbrien commented 1 year ago

@dcbrien WSL is community-supported. This is not a configuration we test.

No problem. I will say that I have it running on 3 machines through WSL2 and it has generally ran perfectly. It was just this upgrade with the Titan RTX that seems to have caused issues (upgraded from 11.4 / 8.2 to 12.1 / 8.8 and the latest driver so could just be that combo on that card). Weird issue. 12.1 / 8.8 on the other 2 and they are 2060 and 3050 for GPU, respectively.

hawkinsp commented 1 year ago

I don't think there are any outstanding issues that we have instructions to reproduce in this bug. Please open a new bug with instructions to reproduce if any of these issues still apply!