Closed dinhanhx closed 1 year ago
Hi @dinhanhx can you check if you are running the latest version of Flax? Optax should be installed with Flax via pip.
Reproduced the code in Kaggle with P100 GPU with no errors. You may need to click on Run -> Restate and clear cell outputs after you upgrade with !pip install -U flax jax jaxlib
jax 0.4.16 jaxlib 0.4.16 flax 0.7.4 optax 0.1.7
!pip show flax jax jaxlib optax
Name: flax
Version: 0.7.4
Required-by: clu
Name: jax
Version: 0.4.16
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, optax, orbax-checkpoint, tensorflow
Name: jaxlib
Version: 0.4.16
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint
Name: optax
Version: 0.1.7
Requires: absl-py, chex, jax, jaxlib, numpy
Required-by: flax
@8bitmp3 I just have upgraded everything and tried "Run -> Restate and clear cell outputs after you upgrade with". However new weird stuff happens,
the new error happens at that RNG line,
why XLA and TPU stuff appear on solely-GPU enviroment?
2023-09-30 11:37:56.723066: E external/xla/xla/stream_executor/cuda/] failed to create cublas handle: the resource allocation failed
2023-09-30 11:37:56.723102: E external/xla/xla/stream_executor/cuda/] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
XlaRuntimeError Traceback (most recent call last)
Cell In[20], line 1
----> 1 state = create_train_state(cnn, init_rng, learning_rate, momentum)
2 del init_rng # Must not be used anymore.
Cell In[13], line 6, in create_train_state(module, rng, learning_rate, momentum)
4 def create_train_state(module, rng, learning_rate, momentum):
5 """Creates an initial `TrainState`."""
----> 6 params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
7 tx = optax.sgd(learning_rate, momentum)
8 return TrainState.create(
9 apply_fn=module.apply, params=params, tx=tx,
10 metrics=Metrics.empty())
[... skipping hidden 9 frame]
Cell In[11], line 7, in CNN.__call__(self, x)
5 @nn.compact
6 def __call__(self, x):
----> 7 x = nn.Conv(features=32, kernel_size=(3, 3))(x)
8 x = nn.relu(x)
9 x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
[... skipping hidden 2 frame]
File /opt/conda/lib/python3.10/site-packages/flax/linen/, in _Conv.__call__(self, inputs)
522 else:
523 conv_general_dilated = self.conv_general_dilated
--> 524 y = conv_general_dilated(
525 inputs,
526 kernel,
527 strides,
528 padding_lax,
529 lhs_dilation=input_dilation,
530 rhs_dilation=kernel_dilation,
531 dimension_numbers=dimension_numbers,
532 feature_group_count=self.feature_group_count,
533 precision=self.precision,
534 )
535 else:
536 y = lax.conv_general_dilated_local(
537 lhs=inputs,
538 rhs=kernel,
545 precision=self.precision,
546 )
[... skipping hidden 13 frame]
File /opt/conda/lib/python3.10/site-packages/jax/_src/, in backend_compile(backend, module, options, host_callbacks)
246 return backend.compile(built_c, compile_options=options,
247 host_callbacks=host_callbacks)
248 # Some backends don't have `host_callbacks` option yet
249 # TODO(sharadmv): remove this fallback when all backends allow `compile`
250 # to take in `host_callbacks`
--> 251 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv.1 = (f32[1,32,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0} %bitcast.3, f32[32,1,3,3]{3,2,1,0} %transpose), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_263/" source_line=7}, backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}
Original error: INTERNAL: All algorithms tried for (f32[1,32,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0}, f32[32,1,3,3]{3,2,1,0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0} failed. Falling back to default algorithm. Per-algorithm errors:
Profiling failure on cuDNN engine eng28{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=2,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng1{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=1,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng2{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng2{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=1,k4=3,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng1{k2=2,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=2,k4=2,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng28{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=0,k4=2,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng42{k2=1,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng4{}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng28{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=2,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng1{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=1,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng2{k2=4,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng2{k2=3,k3=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=1,k4=3,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng1{k2=2,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=2,k4=2,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng28{k2=1,k3=0}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng34{k2=0,k4=2,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng42{k2=1,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng4{}: UNKNOWN: CUDNN_STATUS_INTERNAL_ERROR
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng1{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng28{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng0{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng42{k2=2,k4=1,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng42{k2=2,k4=2,k5=0,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng42{k2=0,k4=1,k5=1,k6=0,k7=0}: UNKNOWN: CUDNN_STATUS_EXECUTION_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
Profiling failure on cuDNN engine eng3{k11=2}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'status'
To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.
This is what happened when I use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false
2023-09-30 11:48:53.530029: E external/xla/xla/stream_executor/cuda/] failed to create cublas handle: the resource allocation failed
2023-09-30 11:48:53.530067: E external/xla/xla/stream_executor/cuda/] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
2023-09-30 11:48:53.988144: E external/xla/xla/pjrt/] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.conv.forward' failed: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'; current tracing scope: cudnn-conv.1; current profiling annotation: XlaModule:#hlo_module=jit_conv_general_dilated,program_id=11#.
XlaRuntimeError Traceback (most recent call last)
Cell In[18], line 1
----> 1 state = create_train_state(cnn, init_rng, learning_rate, momentum)
2 del init_rng # Must not be used anymore.
Cell In[11], line 6, in create_train_state(module, rng, learning_rate, momentum)
4 def create_train_state(module, rng, learning_rate, momentum):
5 """Creates an initial `TrainState`."""
----> 6 params = module.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
7 tx = optax.sgd(learning_rate, momentum)
8 return TrainState.create(
9 apply_fn=module.apply, params=params, tx=tx,
10 metrics=Metrics.empty())
[... skipping hidden 9 frame]
Cell In[9], line 7, in CNN.__call__(self, x)
5 @nn.compact
6 def __call__(self, x):
----> 7 x = nn.Conv(features=32, kernel_size=(3, 3))(x)
8 x = nn.relu(x)
9 x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
[... skipping hidden 2 frame]
File /opt/conda/lib/python3.10/site-packages/flax/linen/, in _Conv.__call__(self, inputs)
522 else:
523 conv_general_dilated = self.conv_general_dilated
--> 524 y = conv_general_dilated(
525 inputs,
526 kernel,
527 strides,
528 padding_lax,
529 lhs_dilation=input_dilation,
530 rhs_dilation=kernel_dilation,
531 dimension_numbers=dimension_numbers,
532 feature_group_count=self.feature_group_count,
533 precision=self.precision,
534 )
535 else:
536 y = lax.conv_general_dilated_local(
537 lhs=inputs,
538 rhs=kernel,
545 precision=self.precision,
546 )
[... skipping hidden 8 frame]
File /opt/conda/lib/python3.10/site-packages/jax/_src/interpreters/, in ExecuteReplicated.__call__(self, *args)
1147 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1148 else:
-> 1149 results = self.xla_executable.execute_sharded(input_bufs)
1150 if dispatch.needs_check_special():
1151 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.conv.forward' failed: CUDNN_STATUS_ALLOC_FAILED
in external/xla/xla/stream_executor/cuda/ 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd_.handle(), input_data.opaque(), filter_.handle(), filter_data.opaque(), conv_.handle(), ToConvForwardAlgo(algo), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd_.handle(), output_data.opaque())'; current tracing scope: cudnn-conv.1; current profiling annotation: XlaModule:#hlo_module=jit_conv_general_dilated,program_id=11#.
@dinhanhx Able to start training the model in Kaggle with a P100 GPU here. Can you try opening a new Kaggle notebook with P100. Then, start with !pip install -U flax jax jaxlib
, followed by the code cells copied from the Try running them line by line and let us know if you are still experiencing any issues.
@8bitmp3 the quickstart code is executing now, but it's slow and doesn't use GPU,
Do I have to move something to GPU like PyTorch?
Source code:
Okay I'm dumb, so I need to install jax with cuda then restart the runtime to use GPU
pip install -U flax jax jaxlib
pip install -q --upgrade "jax[cuda11_pip]" -f
Thanks for your help,
Name: flax
Version: 0.7.4
Summary: Flax: A neural network library for JAX designed for flexibility
Author-email: Flax team <>
Location: /opt/conda/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: clu
Name: jax
Version: 0.4.18
Summary: Differentiate, compile, and transform Numpy code.
Author: JAX team
License: Apache-2.0
Location: /opt/conda/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, optax, orbax-checkpoint, tensorflow
Name: jaxlib
Version: 0.4.18+cuda11.cudnn86
Summary: XLA library for JAX
Author: JAX team
License: Apache-2.0
Location: /opt/conda/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint
Name: optax
Version: 0.1.7
Summary: A gradient processing and optimisation library in JAX.
Author-email: DeepMind <>
Location: /opt/conda/lib/python3.10/site-packages
Requires: absl-py, chex, jax, jaxlib, numpy
Required-by: flax
:+1: You are welcome @dinhanhx :+1: If you have any more questions, feel free to create a new GitHub Issue, Discussion or reopen this Issue.
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
pip show flax jax jaxlib
: latest on pipProblem you have encountered:
When I simply run Quickstart code, it fails at Step 9
What you expected to happen:
It works
Logs, error messages, etc:
Steps to reproduce: