jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.79k forks source link

"None of the algorithms provided by cuDNN heuristics worked" for Ampere NV GPUs #17523

Open bentherien opened 1 year ago

bentherien commented 1 year ago

Description

I'm training learned optimizers using Jax and a custom version of https://github.com/google/learned_optimization. I get the following warnings when training on Ampere GPUs (tested for RTX 3090 and A6000), however, no warning message appears when using an RTX2080ti or an RTX8000 GPU. I'm listing this as an issue since the training is much slower (2x or more) on ampere GPUs than their predecessors, which should not be the case.

2023-09-07 03:07:57.541957: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:779] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.                                     
2023-09-07 03:07:57.541994: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:782] Conv: (f32[3,2048,3,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,8192,33,33]{3,2,1,0}, f32[2048,128,31,31]{3,2,1,0}, f32[2048]{0
}), window={size=31x31}, dim_labels=bf01_oi01->bf01, feature_group_count=64, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"conv_result_scale":-0.5,"activation_mode":"kNone","side_input_scale
":0}

Unfortunately, I did not manage to extract a minimal reproducing example within two hours, so I have gone ahead and posted the issue anyway.

Here is the complete stack trace for reference:

2023-09-07 03:07:24.380550: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT                                                                                              
ERROR:absl:Oryx not found! This library will still work but no summarywill be logged.                       
gpu                                                                   
/btherien/github/learned_optimization/learned_optimization/outer_trainers/truncation_schedule.py:95: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in asarray is not available, and will be
 truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  length = jnp.asarray(jnp.exp(log_length), dtype=jnp.int64)                                                
ERROR:wandb.sdk.lib.gitlib:git root error: Cmd('git') failed due to: exit code(128)                         
  cmdline: git rev-parse --show-toplevel              
  stderr: 'fatal: detected dubious ownership in repository at         

Outer Loop:   0%|                                                                                                                                                                   | 0/5000 [00:00<?, ?it/s]/btherien/gi
thub/learned_optimization/learned_optimization/outer_trainers/truncation_schedule.py:95: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in asarray is not available, and will be truncated t
o dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  length = jnp.asarray(jnp.exp(log_length), dtype=jnp.int64)                                                
2023-09-07 03:07:57.541957: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:779] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.                                     
2023-09-07 03:07:57.541994: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:782] Conv: (f32[3,2048,3,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,8192,33,33]{3,2,1,0}, f32[2048,128,31,31]{3,2,1,0}, f32[2048]{0
}), window={size=31x31}, dim_labels=bf01_oi01->bf01, feature_group_count=64, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"conv_result_scale":-0.5,"activation_mode":"kNone","side_input_scale
":0}                                                  
2023-09-07 03:07:58.447426: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:779] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.                                     
2023-09-07 03:07:58.447452: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:782] Conv: (f32[32,4096,3,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[32,8192,16,16]{3,2,1,0}, f32[4096,128,16,16]{3,2,1,0}, f32[4096]
{0}), window={size=16x16 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, feature_group_count=64, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"conv_result_scale":-0.5,"activation_mode":"kNone","si
de_input_scale":0}                                    
2023-09-07 03:08:01.502173: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:779] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.                                     
2023-09-07 03:08:01.502206: W external/xla/xla/service/gpu/gpu_conv_algorithm_picker.cc:782] Conv: (f32[64,4096,3,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[64,8192,16,16]{3,2,1,0}, f32[4096,128,16,16]{3,2,1,0}, f32[4096]
{0}), window={size=16x16 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, feature_group_count=64, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"conv_result_scale":-0.5,"activation_mode":"kNone","si
de_input_scale":0}                                    
Outer Loop:   0%|1                                                                                                                                                     | 6/5000 [25:10<339:43:11, 244.89s/it]  

What jax/jaxlib version are you using?

jax v0.4.13; jaxlib 0.4.13+cuda12.cudnn89

Which accelerator(s) are you using?

GPU

Additional system info

Docker container "benjamintherien/dev:cu12.1.1-py3.9-jax"; (ubuntu 22.04;cuda 12.1.1;cudnn 8.9)

NVIDIA GPU info

RTX 3090

Screen Shot 2023-09-08 at 3 41 26 PM

A6000

Screen Shot 2023-09-08 at 3 40 54 PM

RTX 8000

Capture d’écran, le 2023-09-07 à 15 58 43

RTX 2080 ti

Screen Shot 2023-09-08 at 3 40 36 PM
hawkinsp commented 1 year ago

Can you try with a jaxlib built from head? There was at least one recent XLA change (https://github.com/openxla/xla/commit/0ef9d092689e767a431c01b332d94b76d66866c9) that should be included in a head jaxlib that may help.

(We'll probably make a new release this week, also.)

bentherien commented 1 year ago

Just built jax & jaxlib from source and installed them, within the docker image mentioned above, using the following steps:

git clone https://github.com/google/jax
cd jax
git checkout HEAD #puts me on main
python build/build.py --enable_cuda --cuda_path /usr/local/cuda --cudnn_path /usr/lib/x86_64-linux-gnu
pip install -e .
pip install dist/*.whl 

After the above steps I get

pip list | grep jax
jax                          0.4.16.dev20230912 /btherien/github/jax
jaxlib                       0.4.16.dev20230912

However, the same error appears again (with some small differences in the message due to batch size):

2023-09-12 17:57:33.259257: W external/xla/xla/service/gpu/conv_algorithm_picker.cc:809] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2023-09-12 17:57:33.259283: W external/xla/xla/service/gpu/conv_algorithm_picker.cc:812] Conv: (f32[3,512,3,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,2048,33,33]{3,2,1,0}, f32[512,128,31,31]{3,2,1,0}, f32[512]{0}), window={size=31x31}, dim_labels
=bf01_oi01->bf01, feature_group_count=16, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"conv_result_scale":-0.5,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}

@hawkinsp should I checkout to a particular dev branch before building or specify an XLA version (e.g., following https://jax.readthedocs.io/en/latest/developer.html#building-jaxlib-from-source-with-a-modified-xla-repository)?

hawkinsp commented 1 year ago

It looks like what you did was correct.

In that case, we'd need a reproduction of the problem. One way you could give us that is to share an XLA HLO dump. Set the environment variable XLA_FLAGS=--xla_dump_to=/somewhere and zip the contents of /somewhere and attach them. Note this in effect shares your model, which you may or may not want to do. Up to you.

hawkinsp commented 11 months ago

Any chance you can grab that HLO dump, as requested above?

bentherien commented 11 months ago

@hawkinsp Here are the logs and a screenshot of the warning output when they were produced. This was run on an RTX 3090 GPU. Let me know if you need any more information.

xla_logs.zip

Screenshot 2023-11-20 at 6 11 16 PM
hawkinsp commented 11 months ago

Hmm. I couldn't reproduce on A100. Does this still reproduce for you with jax and jaxlib 0.4.20?

bentherien commented 11 months ago

I get an identical error with 0.4.20

Screenshot 2023-11-20 at 11 27 08 PM Screenshot 2023-11-20 at 11 27 34 PM