Open Sruinard opened 1 year ago
Hi @Sruinard, I'm facing the same issue. Have you by any chance come up with a solution? :(
Hi @dbrovml,
Downgrading seemed to be working for me.
Perhaps you can try the following:
conda.yml
name: deeprl
channels:
- defaults
- conda-forge
- nvidia/label/cuda-11.4.1
dependencies:
- python=3.9
- pip
- cuda-nvcc
- cuda-toolkit=11.4
- cudnn=8.4.1.50
- pip:
- --requirement requirements.txt
requirements.txt
clu==0.0.6
flax==0.6.3
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jax==0.3.25
jaxlib==0.3.25+cuda11.cudnn82 # Make sure CUDA version matches the base image.
ml-collections==0.1.0
numpy==1.22.0
optax==0.1.4
Any update on this? @Sruinard did you find a workaround without downgrading? It becomes increasingly more disruptive to depend on a version < 0.4
Hi @Sruinard
I tested provided code on GCP Ubuntu 20.04LTS with Tesla P100 and Tesla T4 GPUs with JAX 0.4.28 and Flax 0.8.4. I could not reproduce the error mentioned and works fine. Please find the below screenshots for reference.
Could you please verify in your setup with latest JAX and Flax versions and let us know if the issue still persists?
Thank you.
Description
I'm facing many challenges running jax/flax on a gpu. I have created a virtual machine with ubuntu20.04 (azure compute instance).
after installing the required toolkits, libraries etc, I'm stuck with the following error when running the file
forward_pass.py
. Jax seems to recognize the GPU, but fails with the following error:jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.gemm' failed: sm_3 does not support explicit gemm algorithms..
The content of the file is as follows:
What jax/jaxlib version are you using?
jax==0.4.1, jaxlib==0.4.1+cuda11.cudnn86, flax==0.6.3
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.8.5
NVIDIA GPU info
$nvidia-smi
$nvcc -V