jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Running on WSL2 with cuda-11.{3,4} #7600

Closed tkoziara closed 1 year ago

tkoziara commented 3 years ago

Hi

Just installed jax cpu/gpu versions on WSL2 with python 3.8.5 using NVIDIA's WSL quadro driver and cuda-11.4 cudnn8_8.2.2.26. When running a simple code snippet from your readme (attached below), I was getting Could not load dynamic library 'libcuda.so.1' and after creating symbolic link from /usr/local/cuda-11.4/targets/x86_64-linux/lib/stubs/libcuda.so to the missing library and adding the path to LD_LIBRARY_PATH I am getting now:

2021-08-12 12:19:51.012484: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: UNKNOWN ERROR (34)

The question is whether this combination of libraries (versions) is supported?

I tested it also with cuda-11.3 but got the same behavior (and the same initial issue with missing library link).

Thank you, Tomek

import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
hawkinsp commented 3 years ago

We support CUDA 11.3 and CUDA 11.4 on Linux, so my guess is this a WSL-specific issue. We don't ourselves test this configuration, so this is something the community will need to figure out...

An alternative might be to use a native Windows build? If you build jaxlib from source this should work; see the documentation.

tkoziara commented 3 years ago

Thanks. I'll try compilation.

IlyaOrson commented 3 years ago

I just did a fresh cuda/cudnn working installation in wsl with the same versions and jax worked out of the box.

Python 3.8.5
Cuda compilation tools, release 11.4, V11.4.100
cudnn 8_8.2.2.26

EDIT: With a NVIDIA GeForce driver.

tkoziara commented 3 years ago

@IlyaOrson Was it on a laptop with quadro card?

IlyaOrson commented 3 years ago

Ah no, it was on a gtx with the geforce driver.

RomainSabathe commented 2 years ago

I just did a fresh cuda/cudnn working installation in wsl with the same versions and jax worked out of the box.

Python 3.8.5
Cuda compilation tools, release 11.4, V11.4.100
cudnn 8_8.2.2.26

EDIT: With a NVIDIA GeForce driver.

Hi @IlyaOrson, can you confirm if CudNN is working as well? For me, Jax manages to find the GPU and CUDA. Training an MLP with just Linear layers works fine. However every time I try to use a Conv2D or anything using CudNN then it fails with an error message "UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED".

This error message is raised whether I use Haiku or Flax. When using Tensorflow however, it works like a charm.

Code snippet for Tensorflow:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

network = keras.Sequential([layers.Conv2D(64, 3, 2, padding="same")])

x = tf.random.uniform((8, 64, 64, 3), 0, 1)
y = network(x)

Works without issue. I can see that the GPU is being used.

The two following snippets don't run though:

Haiku

import jax
import jax.numpy as jnp
import haiku as hk

def network_fn(x):
    x = hk.Conv2D(64, 3, 2, padding="SAME")(x)
    return x

rng = jax.random.PRNGKey(0)
network = hk.transform(network_fn)
params = network.init(rng, jnp.ones([1, 64, 64, 3]))

x = jax.random.uniform(rng, (8, 64, 64, 3))
out = network.apply(params, rng, x)

And Flax

import jax
import jax.numpy as jnp
from flax import linen as nn

class Network(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Conv(64, (3, 3), 2, padding="SAME")(x)

rng = jax.random.PRNGKey(0)
network = Network()
params = network.init(rng, jnp.ones((1, 64, 64, 3)))["params"]

x = jax.random.uniform(rng, (8, 64, 64, 3))
y = network.apply(x)

The error is extremely verbose, but here is the gist:

Traceback (most recent call last):
  File "/home/romain/quick_flax_test.py", line 12, in <module>
    params = network.init(rng, jnp.ones((1, 64, 64, 3)))["params"]
  File "/home/romain/quick_flax_test.py", line 8, in __call__
    return nn.Conv(64, (3, 3), 2, padding="SAME")(x)
  File "/home/romain/.local/lib/python3.8/site-packages/flax/linen/linear.py", line 282, in __call__
    y = lax.conv_general_dilated(
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%custom-call.1 = (f32[1,32,32,64]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,65,65,3]{2,1,3,0} %pad, f32[3,3,3,64]{1,0,2,3} %copy.4), window={size=3x3 stride=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n  batch_group_count=1\n  dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n  feature_group_count=1\n  lhs_dilation=(1, 1)\n  lhs_shape=(1, 64, 64, 3)\n  padding=((0, 1), (0, 1))\n  precision=None\n  preferred_element_type=None\n  rhs_dilation=(1, 1)\n  rhs_shape=(3, 3, 3, 64)\n  window_strides=(2, 2)\n]" source_file="/home/romain/.local/lib/python3.8/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: INTERNAL: All algorithms tried for %custom-call.1 = (f32[1,32,32,64]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,65,65,3]{2,1,3,0} %pad, f32[3,3,3,64]{1,0,2,3} %copy.4), window={size=3x3 stride=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n  batch_group_count=1\n  dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n  feature_group_count=1\n  lhs_dilation=(1, 1)\n  lhs_shape=(1, 64, 64, 3)\n  padding=((0, 1), (0, 1))\n  precision=None\n  preferred_element_type=None\n  rhs_dilation=(1, 1)\n  rhs_shape=(3, 3, 3, 64)\n  window_strides=(2, 2)\n]" source_file="/home/romain/.local/lib/python3.8/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.  Per-algorithm errors:
  Profiling failure on cuDNN engine eng28{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
  Profiling failure on cuDNN engine eng34{k5=0,k6=0,k7=0,k2=2,k19=0,k4=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
  Profiling failure on cuDNN engine eng1{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
  Profiling failure on cuDNN engine eng1{k2=2,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
  Profiling failure on cuDNN engine eng34{k19=0,k4=0,k5=1,k6=0,k7=0,k2=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4502): 'status'
  Profiling failure on cuDNN engine eng34{k2=2,k19=0,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
IlyaOrson commented 2 years ago

@RomainSabathe Running your snippets with the latest jax I get the same error with Haiku and the following with Flax:

Traceback ```python in 10 rng = jax.random.PRNGKey(0) 11 network = Network() ---> 12 params = network.init(rng, jnp.ones((1, 64, 64, 3)))["params"] 13 14 x = jax.random.uniform(rng, (8, 64, 64, 3)) ~/.local/lib/python3.8/site-packages/flax/linen/module.py in init(self, rngs, method, mutable, *args, **kwargs) 996 The initialized variable dict. 997 """ --> 998 _, v_out = self.init_with_output( 999 rngs, *args, 1000 method=method, mutable=mutable, **kwargs) ~/.local/lib/python3.8/site-packages/flax/linen/module.py in init_with_output(self, rngs, method, mutable, *args, **kwargs) 966 f'{self.__class__.__name__}, but rngs are: {rngs}') 967 rngs = {'params': rngs} --> 968 return self.apply( 969 {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs) 970 ~/.local/lib/python3.8/site-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs) 934 method = self.__call__ 935 method = _get_unbound_fn(method) --> 936 return apply( 937 method, self, 938 mutable=mutable, capture_intermediates=capture_intermediates ~/.local/lib/python3.8/site-packages/flax/core/scope.py in wrapper(variables, rngs, *args, **kwargs) 685 **kwargs) -> Union[Any, Tuple[Any, VariableDict]]: 686 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root: --> 687 y = fn(root, *args, **kwargs) 688 if mutable is not False: 689 return y, root.mutable_variables() ~/.local/lib/python3.8/site-packages/flax/linen/module.py in scope_fn(scope, *args, **kwargs) 1176 _context.capture_stack.append(capture_intermediates) 1177 try: -> 1178 return fn(module.clone(parent=scope), *args, **kwargs) 1179 finally: 1180 _context.capture_stack.pop() ~/.local/lib/python3.8/site-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs) 273 _context.module_stack.append(self) 274 try: --> 275 y = fun(self, *args, **kwargs) 276 if _context.capture_stack: 277 filter_fn = _context.capture_stack[-1] in __call__(self, x) 6 @nn.compact 7 def __call__(self, x): ----> 8 return nn.Conv(64, (3, 3), 2, padding="SAME")(x) 9 10 rng = jax.random.PRNGKey(0) ~/.local/lib/python3.8/site-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs) 273 _context.module_stack.append(self) 274 try: --> 275 y = fun(self, *args, **kwargs) 276 if _context.capture_stack: 277 filter_fn = _context.capture_stack[-1] ~/.local/lib/python3.8/site-packages/flax/linen/linear.py in __call__(self, inputs) 268 269 dimension_numbers = _conv_dimension_numbers(inputs.shape) --> 270 y = lax.conv_general_dilated( 271 inputs, 272 kernel, ~/.local/lib/python3.8/site-packages/jax/_src/lax/convolution.py in conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count, precision, preferred_element_type) 149 dtypes.canonicalize_dtype(np.dtype(preferred_element_type))) 150 return conv_general_dilated_p.bind( --> 151 lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), 152 lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), 153 dimension_numbers=dnums, TypeError: 'int' object is not iterable ```
RomainSabathe commented 2 years ago

@IlyaOrson Thanks very much for trying, I appreciate! Apologies for the issue with Flax. I'm assuming it has to do with the kernel size that I wrote as a tuple of ints ((3, 3)) instead of a simple int (3).

Regardless; if I understand correctly, you're confirming that you're having issues with CuDNN as well? If I may ask, how do you run ConvNets with Jax on GPU then?

Edit: I've been doing a fresh install of WSL and yet again: Tensorflow and Pytorch ConvNets work just fine, but Jax's Conv2d operation fails. I'm going to try to build jaxlib instead of using the pip-ready version.

Edit 2: nope, it seems that even when compiling jaxlib from scratch, this problem with CuDNN remains :(

Theo-Wu commented 2 years ago

Hi, I encountered the same problem. Can I know if you fix this?

hawkinsp commented 1 year ago

I'm guessing that this issue is stale, given all the versions are very old at this point. If someone is still having trouble under WSL2, please file a new issue.