Open bentherien opened 1 year ago
Can you try with a jaxlib
built from head? There was at least one recent XLA change (https://github.com/openxla/xla/commit/0ef9d092689e767a431c01b332d94b76d66866c9) that should be included in a head jaxlib that may help.
(We'll probably make a new release this week, also.)
Just built jax & jaxlib from source and installed them, within the docker image mentioned above, using the following steps:
git clone https://github.com/google/jax
cd jax
git checkout HEAD #puts me on main
python build/build.py --enable_cuda --cuda_path /usr/local/cuda --cudnn_path /usr/lib/x86_64-linux-gnu
pip install -e .
pip install dist/*.whl
After the above steps I get
pip list | grep jax
jax 0.4.16.dev20230912 /btherien/github/jax
jaxlib 0.4.16.dev20230912
However, the same error appears again (with some small differences in the message due to batch size):
2023-09-12 17:57:33.259257: W external/xla/xla/service/gpu/conv_algorithm_picker.cc:809] None of the algorithms provided by cuDNN heuristics worked; trying fallback algorithms.
2023-09-12 17:57:33.259283: W external/xla/xla/service/gpu/conv_algorithm_picker.cc:812] Conv: (f32[3,512,3,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,2048,33,33]{3,2,1,0}, f32[512,128,31,31]{3,2,1,0}, f32[512]{0}), window={size=31x31}, dim_labels
=bf01_oi01->bf01, feature_group_count=16, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"conv_result_scale":-0.5,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}
@hawkinsp should I checkout to a particular dev branch before building or specify an XLA version (e.g., following https://jax.readthedocs.io/en/latest/developer.html#building-jaxlib-from-source-with-a-modified-xla-repository)?
It looks like what you did was correct.
In that case, we'd need a reproduction of the problem. One way you could give us that is to share an XLA HLO dump. Set the environment variable XLA_FLAGS=--xla_dump_to=/somewhere
and zip the contents of /somewhere
and attach them. Note this in effect shares your model, which you may or may not want to do. Up to you.
Any chance you can grab that HLO dump, as requested above?
@hawkinsp Here are the logs and a screenshot of the warning output when they were produced. This was run on an RTX 3090 GPU. Let me know if you need any more information.
Hmm. I couldn't reproduce on A100. Does this still reproduce for you with jax
and jaxlib
0.4.20?
I get an identical error with 0.4.20
Description
I'm training learned optimizers using Jax and a custom version of https://github.com/google/learned_optimization. I get the following warnings when training on Ampere GPUs (tested for RTX 3090 and A6000), however, no warning message appears when using an RTX2080ti or an RTX8000 GPU. I'm listing this as an issue since the training is much slower (2x or more) on ampere GPUs than their predecessors, which should not be the case.
Unfortunately, I did not manage to extract a minimal reproducing example within two hours, so I have gone ahead and posted the issue anyway.
Here is the complete stack trace for reference:
What jax/jaxlib version are you using?
jax v0.4.13; jaxlib 0.4.13+cuda12.cudnn89
Which accelerator(s) are you using?
GPU
Additional system info
Docker container "benjamintherien/dev:cu12.1.1-py3.9-jax"; (ubuntu 22.04;cuda 12.1.1;cudnn 8.9)
NVIDIA GPU info
RTX 3090
A6000
RTX 8000
RTX 2080 ti