Closed pseudo-rnd-thoughts closed 2 years ago
I couldn't repro this on my environment, it's likely RTX 3080 specific. I'm asking our GPU experts to take a look
zhangqiaorjc@skyewm-gpu-vm2:~$ cat issue_8506.py
import jax
import jax.numpy as jnp
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
model = CNN()
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
# output = model.apply(variables, batch)
zhangqiaorjc@skyewm-gpu-vm2:~$ python3 issue_8506.py
zhangqiaorjc@skyewm-gpu-vm2:~$ python3
Python 3.8.10 (default, Sep 28 2021, 16:10:42)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
imp>>> import jaxlib
>>> jax.__version__
'0.2.25'
>>> jaxlib.__version__
'0.1.73'
>>>
zhangqiaorjc@skyewm-gpu-vm2:~$ nvidia-smi
Wed Nov 17 19:31:07 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03 Driver Version: 460.91.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla V100-SXM2... Off | 00000000:00:04.0 Off | 0 |
| N/A 38C P0 45W / 300W | 109MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 Tesla V100-SXM2... Off | 00000000:00:05.0 Off | 0 |
| N/A 37C P0 45W / 300W | 4MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 2 Tesla V100-SXM2... Off | 00000000:00:06.0 Off | 0 |
| N/A 39C P0 46W / 300W | 4MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 3 Tesla V100-SXM2... Off | 00000000:00:07.0 Off | 0 |
| N/A 37C P0 44W / 300W | 4MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 4 Tesla V100-SXM2... Off | 00000000:00:08.0 Off | 0 |
| N/A 37C P0 44W / 300W | 4MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 5 Tesla V100-SXM2... Off | 00000000:00:09.0 Off | 0 |
| N/A 37C P0 43W / 300W | 4MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 6 Tesla V100-SXM2... Off | 00000000:00:0A.0 Off | 0 |
| N/A 39C P0 43W / 300W | 4MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 7 Tesla V100-SXM2... Off | 00000000:00:0B.0 Off | 0 |
| N/A 40C P0 43W / 300W | 4MiB / 16160MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+```
Yep, this is ampere-specific, and I was able to repro on an A6000 using the previous release. Yesterday's release of jaxlib 0.1.74 fixes it on my machine: can you try that?
The latest release doesn't fix it
Traceback (most recent call last):
File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
variables = model.init(jax.random.PRNGKey(0), batch)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 884, in init
_, v_out = self.init_with_output(rngs, *args, method=method, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 862, in init_with_output
return self.apply(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 841, in apply
return apply(fn, mutable=mutable)(variables, rngs=rngs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/core/scope.py", line 608, in wrapper
y = fn(root, *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 834, in <lambda>
fn = lambda scope: method(self.clone(parent=scope), *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 277, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 269, in __call__
y = lax.conv_general_dilated(
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/_src/lax/lax.py", line 695, in conv_general_dilated
return conv_general_dilated_p.bind(
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/core.py", line 274, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/core.py", line 626, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 419, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/_src/util.py", line 201, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/_src/util.py", line 194, in cached
return f(*args, **kwargs)
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 442, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 768, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 903, in compile
self._executable = XlaCompiledComputation.from_xla_computation(
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 932, in from_xla_computation
compiled = compile_or_get_cached(backend, xla_computation, options)
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 871, in compile_or_get_cached
return backend_compile(backend, computation, compile_options)
File "/home/mark/Documents/programming/jax/jax-jaxlib-v0.1.74/jax/interpreters/xla.py", line 478, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(32, 64, 64, 10)\n padding=((1, 1), (1, 1))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 10, 32)\n window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=269}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
Original error: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(32, 64, 64, 10)\n padding=((1, 1), (1, 1))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 10, 32)\n window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=269}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. Per-algorithm errors:
Profiling failure on cuDNN engine 1#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 1: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 0#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 0: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 2#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 2: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 4#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 4: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 6#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 6: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 5#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 5: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 7#TC: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
Profiling failure on cuDNN engine 7: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4139): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'
To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.
jax.__version__ = '0.2.26'
jaxlib.__version__ = '0.1.74'
$ nvidia-smi
Thu Nov 18 22:12:41 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.00 Driver Version: 470.82.00 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:07:00.0 On | N/A |
| 0% 45C P8 32W / 340W | 344MiB / 10014MiB | 2% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 1276 G /usr/lib/xorg/Xorg 35MiB | | 0 N/A N/A 1824 G /usr/lib/xorg/Xorg 129MiB | | 0 N/A N/A 1959 G /usr/bin/gnome-shell 53MiB | | 0 N/A N/A 81757 G ...AAAAAAAAA= --shared-files 105MiB | | 0 N/A N/A 84116 G ..._82451.log --shared-files 3MiB | +-----------------------------------------------------------------------------+
What version of CuDNN do you have installed?
Sorry I hadnt seen your reply
I have the latest cuda and cudnn version, 11.5 and 8.3.1 Am happy to test any other versions that you suggest
I have tried building the jax project for my computer however that hasnt worked either
Any other suggestions to try?
@pseudo-rnd-thoughts:
Issue #8302 solved this problem for me when running the Flax ImageNet example (add environment variable TF_FORCE_GPU_ALLOW_GROWTH
before calling tf datasets)
Sadly that haven't fixed it either, same error about cudnn convolutions
This is really strange as a couple of people have had this bug but have all got their code working. Im not sure what is strange on my system. Trying the equivalent tensorflow code, it doesn't throw an error Any idea on what to try?
== Version information Jax version: 0.2.25 Jaxlib version: 0.1.73 Cuda: 11.5 Cudnn: 8.3.1 TF_FORCE_GPU_ALLOW_GROWTH - true Ubuntu 20.04.3 LTS
== Error
Traceback (most recent call last):
File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
variables = model.init(jax.random.PRNGKey(0), batch)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1122, in init
_, v_out = self.init_with_output(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1091, in init_with_output
return self.apply(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1058, in apply
return apply(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/core/scope.py", line 706, in wrapper
y = fn(root, *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1313, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/transforms.py", line 883, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 318, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 603, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/transforms.py", line 883, in wrapped_fn
return prewrapped_fn(self, *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 318, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 603, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 282, in __call__
y = lax.conv_general_dilated(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 653, in conv_general_dilated
return conv_general_dilated_p.bind(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 624, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 416, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 187, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 180, in cached
return f(*args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 439, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 759, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 892, in compile
self._executable = XlaCompiledComputation.from_xla_computation(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 921, in from_xla_computation
compiled = compile_or_get_cached(backend, xla_computation, options)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 863, in compile_or_get_cached
return backend_compile(backend, computation, compile_options)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 474, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(32, 64, 64, 10)\n padding=((1, 1), (1, 1))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 10, 32)\n window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.
Convolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
variables = model.init(jax.random.PRNGKey(0), batch)
File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 282, in __call__
y = lax.conv_general_dilated(
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(32, 64, 64, 10)\n padding=((1, 1), (1, 1))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 10, 32)\n window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=282}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.
Convolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.
Process finished with exit code 1
== Jax / Flax Code This can be found on the github.com/google/flax page
import jax
import jax.numpy as jnp
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
model = CNN()
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)
== Tensorflow code
import numpy as np
import tensorflow as tf
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3)),
tf.keras.layers.ReLU(),
tf.keras.layers.AvgPool2D(pool_size=(2, 2), strides=(2, 2)),
tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3)),
tf.keras.layers.ReLU(),
tf.keras.layers.AvgPool2D(pool_size=(2, 2), strides=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(10, activation='softplus')
])
batch = np.ones((32, 64, 64, 10))
output = model(batch)
Thanks
@pseudo-rnd-thoughts Go to the folder /usr/local and check what cuda installation you have installed (in my case it is cuda-11.3 as I work with the docker imagenvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04
)
Then under the user that you working with (e.g. su pseudo-rnd-thoughts) type:
export PATH=/usr/local/cuda-11.3/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11.3/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
Replace 11.3
above with whatever cuda version that you have installed.
Seeing the same issue on a Quadro T2000, tried the various fixes above and none worked.
== Version information Jax version: 0.2.26 Jaxlib version: 0.1.75 Cuda: 11.5 Cudnn: 8.3.0
Fix it, it is a memory allocation issue like suggested below however different export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7
I found this previous discussion that had a very similar problem to mine https://github.com/google/jax/discussions/6332
The discussion noted the way that Jax allocates memory, which by default is 90% on the first JAX operation which for us was the convolution operation. As the GPU is my display then I think there isn't enough memory available for JAX to allocate 90% of the memory
@rems75 does this fix the issue for you? If so, I think we can close the issue
@pseudo-rnd-thoughts I think the fix is that we need to have a minimum absolute amount of GPU RAM that we reserve for CuDNN. How much GPU RAM do you have? Is 0.7 the largest value that works? e.g., does, say, 0.8 work?
@hawkinsp I have Nvidia 3080 with 10Gb RAM
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05 Driver Version: 495.29.05 CUDA Version: 11.5 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:07:00.0 On | N/A |
| 0% 52C P5 43W / 340W | 287MiB / 10016MiB | 31% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1195 G /usr/lib/xorg/Xorg 35MiB |
| 0 N/A N/A 1785 G /usr/lib/xorg/Xorg 153MiB |
| 0 N/A N/A 1913 G /usr/bin/gnome-shell 42MiB |
| 0 N/A N/A 27970 G ...AAAAAAAAA= --shared-files 40MiB |
I did a bit of testing: 80% and 85% are good while 90% causes the crash. So I dont think the issue is minimum amount of GPU RAM because requiring 9GB (90%) seems too much to me However if the nvidia-smi output is correct, when testing, my system was only using ~300Mb of RAM, i.e. 3% of the available so I don't understand why 90% use is giving a problem
@hawkinsp do you have any other questions? Im happy to be a guinea pig so see if there is a larger underlying issue
@pseudo-rnd-thoughts No, that seems roughly in line with what I expect. You have 10016MiB, of which JAX claims 90% (9014MiB). Your system processes claim another 300MiB, so (9314MiB), and there's only ~700MiB left for CuDNN. This is apparently not enough. I think the way to fix this is for JAX to ensure that at least say, 1GiB is left free after its allocation for CuDNN to work. I don't know what the right value is for "1GiB", but clearly ~700MiB is too low.
@hawkinsp Thanks, I was imagining that the cudnn memory usage would be within the JAX preallocated amount. That makes a lot of sense now.
Worked for me as well. Cool.
Numbers are different in my case. This is the GPU memory with XLA_PYTHON_CLIENT_MEM_FRACTION=0.8: before 609MiB / 4096MiB during 4027MiB / 4096MiB So it seems like my CuDNN only uses ~130MB?
Note that this is through WSL2 on a Laptop running Windows 11.
Hi everyone,
I had the exact same issue described above. I am running on WSL2 on Windows 10. I installed CUDA and CuDNN and then installed jax[gpu]
via pip
. After setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.87
, my program works perfectly on a 3080, but with 0.9 the same RuntimeError
is thrown up.
Pre-updated memory fraction nvidia-smi
:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.52 Driver Version: 511.79 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:01:00.0 Off | N/A |
| N/A 42C P8 12W / N/A | 7530MiB / 8192MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 7754 C /python3.8 N/A |
+-----------------------------------------------------------------------------+`
Post-updated memory fraction nvidia-smi
:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.52 Driver Version: 511.79 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:01:00.0 Off | N/A |
| N/A 53C P0 57W / N/A | 7476MiB / 8192MiB | 48% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 7958 C /python3.8 N/A |
+-----------------------------------------------------------------------------+
I am very new to any sort of collaboration on repositories, so apologies if my etiquette is somewhat off, but I was wondering whether this had any updates? Any "best practice" ways to correct this? I am currently setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.87
in my bash ~/.profile directory and then just running Jax as is.
Also, slightly unrelated, but what would be the best way to keep up with updates to this repository? I will be using Jax pretty religiously to build SVGP models as I love its flexibility, and so would like to keep up-to-date.
Thanks for any help!
Running convolutional layers seems to cause an error that Jax does not know what cudnn optimisation algorithm to use This error appears to be Jax only as I have replicated the code with TensorFlow and no error occurs
My jax version is 0.2.24 and jaxlib version is 0.1.74+cuda11.cudnn82 with a Nvidia 3080
The example is taken from the flax readme (https://github.com/google/flax) The bug appears to be only for convolutions as the error does not occur for the MLP example
I haven't been able to replicate this error as I don't have another GPU to use I found this similar issue from someone who uses a 3080 like me (https://github.com/google/jax/issues/7953)