Closed tkoziara closed 1 year ago
We support CUDA 11.3 and CUDA 11.4 on Linux, so my guess is this a WSL-specific issue. We don't ourselves test this configuration, so this is something the community will need to figure out...
An alternative might be to use a native Windows build? If you build jaxlib from source this should work; see the documentation.
Thanks. I'll try compilation.
I just did a fresh cuda/cudnn working installation in wsl with the same versions and jax worked out of the box.
Python 3.8.5
Cuda compilation tools, release 11.4, V11.4.100
cudnn 8_8.2.2.26
EDIT: With a NVIDIA GeForce driver.
@IlyaOrson Was it on a laptop with quadro card?
Ah no, it was on a gtx with the geforce driver.
I just did a fresh cuda/cudnn working installation in wsl with the same versions and jax worked out of the box.
Python 3.8.5 Cuda compilation tools, release 11.4, V11.4.100 cudnn 8_8.2.2.26
EDIT: With a NVIDIA GeForce driver.
Hi @IlyaOrson, can you confirm if CudNN is working as well? For me, Jax manages to find the GPU and CUDA. Training an MLP with just Linear layers works fine. However every time I try to use a Conv2D or anything using CudNN then it fails with an error message "UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED".
This error message is raised whether I use Haiku or Flax. When using Tensorflow however, it works like a charm.
Code snippet for Tensorflow:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
network = keras.Sequential([layers.Conv2D(64, 3, 2, padding="same")])
x = tf.random.uniform((8, 64, 64, 3), 0, 1)
y = network(x)
Works without issue. I can see that the GPU is being used.
The two following snippets don't run though:
Haiku
import jax
import jax.numpy as jnp
import haiku as hk
def network_fn(x):
x = hk.Conv2D(64, 3, 2, padding="SAME")(x)
return x
rng = jax.random.PRNGKey(0)
network = hk.transform(network_fn)
params = network.init(rng, jnp.ones([1, 64, 64, 3]))
x = jax.random.uniform(rng, (8, 64, 64, 3))
out = network.apply(params, rng, x)
And Flax
import jax
import jax.numpy as jnp
from flax import linen as nn
class Network(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Conv(64, (3, 3), 2, padding="SAME")(x)
rng = jax.random.PRNGKey(0)
network = Network()
params = network.init(rng, jnp.ones((1, 64, 64, 3)))["params"]
x = jax.random.uniform(rng, (8, 64, 64, 3))
y = network.apply(x)
The error is extremely verbose, but here is the gist:
Traceback (most recent call last):
File "/home/romain/quick_flax_test.py", line 12, in <module>
params = network.init(rng, jnp.ones((1, 64, 64, 3)))["params"]
File "/home/romain/quick_flax_test.py", line 8, in __call__
return nn.Conv(64, (3, 3), 2, padding="SAME")(x)
File "/home/romain/.local/lib/python3.8/site-packages/flax/linen/linear.py", line 282, in __call__
y = lax.conv_general_dilated(
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%custom-call.1 = (f32[1,32,32,64]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,65,65,3]{2,1,3,0} %pad, f32[3,3,3,64]{1,0,2,3} %copy.4), window={size=3x3 stride=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 64, 64, 3)\n padding=((0, 1), (0, 1))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 3, 64)\n window_strides=(2, 2)\n]" source_file="/home/romain/.local/lib/python3.8/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
Original error: INTERNAL: All algorithms tried for %custom-call.1 = (f32[1,32,32,64]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,65,65,3]{2,1,3,0} %pad, f32[3,3,3,64]{1,0,2,3} %copy.4), window={size=3x3 stride=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 64, 64, 3)\n padding=((0, 1), (0, 1))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 3, 64)\n window_strides=(2, 2)\n]" source_file="/home/romain/.local/lib/python3.8/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. Per-algorithm errors:
Profiling failure on cuDNN engine eng28{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
Profiling failure on cuDNN engine eng34{k5=0,k6=0,k7=0,k2=2,k19=0,k4=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
Profiling failure on cuDNN engine eng1{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
Profiling failure on cuDNN engine eng1{k2=2,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
Profiling failure on cuDNN engine eng34{k19=0,k4=0,k5=1,k6=0,k7=0,k2=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
Profiling failure on cuDNN engine eng34{k2=2,k19=0,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
@RomainSabathe Running your snippets with the latest jax I get the same error with Haiku and the following with Flax:
@IlyaOrson Thanks very much for trying, I appreciate! Apologies for the issue with Flax. I'm assuming it has to do with the kernel size that I wrote as a tuple of ints ((3, 3)
) instead of a simple int (3
).
Regardless; if I understand correctly, you're confirming that you're having issues with CuDNN as well? If I may ask, how do you run ConvNets with Jax on GPU then?
Edit: I've been doing a fresh install of WSL and yet again: Tensorflow and Pytorch ConvNets work just fine, but Jax's Conv2d operation fails. I'm going to try to build jaxlib instead of using the pip-ready version.
Edit 2: nope, it seems that even when compiling jaxlib from scratch, this problem with CuDNN remains :(
Hi, I encountered the same problem. Can I know if you fix this?
I'm guessing that this issue is stale, given all the versions are very old at this point. If someone is still having trouble under WSL2, please file a new issue.
Hi
Just installed jax cpu/gpu versions on WSL2 with python 3.8.5 using NVIDIA's WSL quadro driver and cuda-11.4 cudnn8_8.2.2.26. When running a simple code snippet from your readme (attached below), I was getting
Could not load dynamic library 'libcuda.so.1'
and after creating symbolic link from/usr/local/cuda-11.4/targets/x86_64-linux/lib/stubs/libcuda.so
to the missing library and adding the path to LD_LIBRARY_PATH I am getting now:2021-08-12 12:19:51.012484: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: UNKNOWN ERROR (34)
The question is whether this combination of libraries (versions) is supported?
I tested it also with cuda-11.3 but got the same behavior (and the same initial issue with missing library link).
Thank you, Tomek