jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.78k forks source link

Failed to determine best cudnn convolution algorithm/No GPU/TPU found #8746

Open iwldzt3011 opened 2 years ago

iwldzt3011 commented 2 years ago

RTX3080 / cuda11.1/cudnn 8.2.1/ubuntu16.04

This problem occurs in jaxlib-0.1.72+cuda111. When I update to 0.1.74, it will disappear. However, in 0.1.74, Jax cannot detect the existence of GPU, and tensorflow can

Therefore, whether I use 0.1.72 or 0.1.74, there is always a problem with me

`RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %custom-call.1 = (f32[1,112,112,64]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,229,229,3]{2,1,3,0} %pad, f32[7,7,3,64]{1,0,2,3} %copy.4), window={size=7x7 stride=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 224, 224, 3)\n padding=((2, 3), (2, 3))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(7, 7, 3, 64)\n window_strides=(2, 2)\n]" source_file="/media/node/Materials/anaconda3/envs/xmcgan/lib/python3.9/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.

Convolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning. `

hawkinsp commented 2 years ago

What version of the jaxlib 0.1.74 wheel did you install, and how did you install it? Try removing jaxlib and reinstalling it following the instructions here: https://github.com/google/jax#pip-installation-gpu-cuda ?

iwldzt3011 commented 2 years ago

What version of the jaxlib 0.1.74 wheel did you install, and how did you install it? Try removing jaxlib and reinstalling it following the instructions here: https://github.com/google/jax#pip-installation-gpu-cuda ?

I use a stand-alone version of jaxlib = 0.1.74, that is: pip install jaxlib

Because the latest version of jaxlib combined with cuda111 in this link [https://storage.googleapis.com/jax-releases/jax_releases.html] is still 0.1.72, that is, jaxlib 0.1.72 + cuda111, I can't get jaxlib 0.1.74 + cud111 from it

However, there is jaxlib 0.1.74 + cud11 in the above link, so I also try to use jaxlib 0.1.74 + cud11, but unfortunately, this version has the same error as jaxlib 0.1.72 + cuda111

ross-Hr commented 2 years ago

Do you fix the error ?

dljjqy commented 2 years ago

I have faced this issue recently,I run my the jax:How to think in Jax documation, and Jupyter notebook report this error when i try to do convolve.Besides, this error somtimes disappear and i do not know why.

half-potato commented 2 years ago

I have this exact same issue when trying to run https://github.com/google/mipnerf. I get a failed to determine best cudnn convolution algorithm when running jax.scipy.signal.convolve2d. I only get the error when running their code base and not when trying to run the convolve operation itself. It seems related to running vmap on convolve2d and is related to the version of cuda + cudnn being used.

cuda 11.5 cudnn 8.3.2 jax 0.3.2

half-potato commented 2 years ago

Turns out it was an OOM error, just a bad error message. Solution is in #8506. use the environment flag XLA_PYTHON_CLIENT_MEM_FRACTION=0.87. It appears that there is some kind of issue with how jax.scipy.signal.convolve2d handles preallocated memory. I believe it would be nice to have a better error message for this.

luweizheng commented 2 years ago

I have the same error on my Titan RTX which is based on Turing architecture. After some trail and errors, I find the error may be related with cudnn version. If I export the LD_LIBRARY_PATH with cudnn 8.2.1, it works. cudnn 8.2.4 could not work.

sudhakarsingh27 commented 2 years ago

Was this issue resolved? @iwldzt3011

sudhakarsingh27 commented 2 years ago

closing since no activity/no add. info provided.

hawkinsp commented 2 years ago

(I should add: if someone can provide instructions to reproduce the problem, e.g., on a cloud GPU VM or similar, we would love to look into it further!)

hcwinsemius commented 2 years ago

Hello all.

I don't have a GPU VM, but can confirm I have the same problem with a EVGA 3070ti XC3. What may help to pin the problem is that I installed the conda recipe using:

conda install jax cuda-nvcc -c conda-forge -c nvidia

The nvcc version info in the conda environment reads as follows:

$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0

I found the CuDNN version in the include folder in the virtual env:

#define CUDNN_MAJOR 8
#define CUDNN_MINOR 4
#define CUDNN_PATCHLEVEL 1

Any chance this helps to reproduce the problem? If you have a temporary work around, I'd love to try that.

amughrabi commented 1 year ago

I did the following and it works

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false --xla_gpu_force_compilation_parallelism=1"
BeaverInGreenland commented 9 months ago

I did the following and it works

export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false --xla_gpu_force_compilation_parallelism=1"

Thanks, this works!