google-research / vision_transformer

Apache License 2.0
10.55k stars 1.3k forks source link

Run vit_jax.ipynb on AWS EC2, XlaRuntimeError, Original error: UNIMPLEMENTED: DNN library is not found. #213

Open Julia90 opened 2 years ago

Julia90 commented 2 years ago

Virtual Machine: AWS EC2 g5.2xlarge, NVIDIA A10G GPU, 24G GPU memory When running vit_jax.ipynb, get_accuracy(params_repl) in the Evaluation section, I got the following error:

INFO:absl:Load dataset info from ~/tensorflow_datasets/cifar10/3.0.2
  0%|          | 0/156 [00:00<?, ?it/s]2022-07-14 03:04:01.191992: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2022-07-14 03:04:01.192095: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:398] Possibly insufficient driver version: 510.47.3
2022-07-14 03:04:02.629491: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:232] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2022-07-14 03:04:02.629520: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:234] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
2022-07-14 03:04:02.632353: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2022-07-14 03:04:02.632420: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:398] Possibly insufficient driver version: 510.47.3
2022-07-14 03:04:02.634584: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:389] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2022-07-14 03:04:02.656034: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
  0%|          | 0/156 [00:04<?, ?it/s]
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
/home/ubuntu/vit/vit_jax.ipynb Cell 34 in <cell line: 2>()
      [1](vscode-notebook-cell://ssh-remote%2Bsso_julia_g5_1/home/ubuntu/vit/vit_jax.ipynb#ch0000033vscode-remote?line=0) # Random performance without fine-tuning.
----> [2](vscode-notebook-cell://ssh-remote%2Bsso_julia_g5_1/home/ubuntu/vit/vit_jax.ipynb#ch0000033vscode-remote?line=1) get_accuracy(params_repl)

/home/ubuntu/vit/vit_jax.ipynb Cell 34 in get_accuracy(params_repl)
      [4](vscode-notebook-cell://ssh-remote%2Bsso_julia_g5_1/home/ubuntu/vit/vit_jax.ipynb#ch0000033vscode-remote?line=3) steps = input_pipeline.get_dataset_info(dataset, 'test')['num_examples'] // batch_size
      [5](vscode-notebook-cell://ssh-remote%2Bsso_julia_g5_1/home/ubuntu/vit/vit_jax.ipynb#ch0000033vscode-remote?line=4) for _, batch in zip(tqdm.trange(steps), ds_test.as_numpy_iterator()):
----> [6](vscode-notebook-cell://ssh-remote%2Bsso_julia_g5_1/home/ubuntu/vit/vit_jax.ipynb#ch0000033vscode-remote?line=5)   predicted = vit_apply_repl(params_repl, batch['image'])
      [7](vscode-notebook-cell://ssh-remote%2Bsso_julia_g5_1/home/ubuntu/vit/vit_jax.ipynb#ch0000033vscode-remote?line=6)   is_same = predicted.argmax(axis=-1) == batch['label'].argmax(axis=-1)
      [8](vscode-notebook-cell://ssh-remote%2Bsso_julia_g5_1/home/ubuntu/vit/vit_jax.ipynb#ch0000033vscode-remote?line=7)   good += is_same.sum()

    [... skipping hidden 15 frame]

File ~/vit/vit_env/lib/python3.8/site-packages/jax/_src/dispatch.py:818, in backend_compile(backend, built_c, options)
    814 @profiler.annotate_function
    815 def backend_compile(backend, built_c, options):
    816   # we use a separate function call to ensure that XLA compilation appears
    817   # separately in Python profiling results
--> 818   return backend.compile(built_c, compile_options=options)

XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bias-activation.2 = (f32[64,7,7,768]{3,2,1,0}, u8[0]{0}) custom-call(f32[64,224,224,3]{3,2,1,0} %get-tuple-element.202, f32[32,32,3,768]{2,1,0,3} %copy, f32[768]{0} %get-tuple-element.198), window={size=32x32 stride=32x32}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_name="pmap(<lambda>)/jit(main)/VisionTransformer/embedding/conv_general_dilated[window_strides=(32, 32) padding=((0, 0), (0, 0)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(64, 224, 224, 3) rhs_shape=(32, 32, 3, 768) precision=None preferred_element_type=None]" source_file="/home/ubuntu/vit/vit_env/lib/python3.8/site-packages/flax/linen/linear.py" source_line=425}, backend_config="{\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

I noticed that the GPU memory is almost out by nvidia-smi: Every 0.1s: nvidia-smi ip-10-116-53-149: Thu Jul 14 03:53:53 2022

Thu Jul 14 03:53:53 2022 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 | | 0% 32C P0 69W / 300W | 22727MiB / 23028MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 1721 C ...tu/vit/vit_env/bin/python 22723MiB | +-----------------------------------------------------------------------------+

Here are my packages:

pip freeze
absl-py==1.1.0
aqtp==0.0.9
asttokens==2.0.5
astunparse==1.6.3
backcall==0.2.0
cached-property==1.5.2
cachetools==5.2.0
certifi==2022.6.15
charset-normalizer==2.1.0
chex==0.1.3
cloudpickle==2.1.0
clu==0.0.7
colorama==0.4.5
commonmark==0.9.1
contextlib2==21.6.0
cycler==0.11.0
dacite==1.6.0
debugpy==1.6.2
decorator==5.1.1
dill==0.3.5.1
dm-tree==0.1.7
einops==0.4.1
entrypoints==0.4
etils==0.6.0
executing==0.8.3
flatbuffers==1.12
flax==0.5.2
flaxformer @ git+https://github.com/google/flaxformer@9712a16a807ec21ad7cbf816e9f6a9c174ea795d
fonttools==4.34.4
gast==0.4.0
google-auth==2.9.1
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
googleapis-common-protos==1.56.4
grpcio==1.47.0
h5py==3.7.0
idna==3.3
importlib-metadata==4.12.0
importlib-resources==5.8.0
ipykernel==6.15.1
ipython==8.4.0
jax==0.3.14
jaxlib==0.3.14+cuda11.cudnn82
jedi==0.18.1
jupyter-client==7.3.4
jupyter-core==4.11.1
keras==2.9.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.3
libclang==14.0.1
Markdown==3.3.7
matplotlib==3.5.2
matplotlib-inline==0.1.3
ml-collections==0.1.1
msgpack==1.0.4
nest-asyncio==1.5.5
numpy==1.23.1
oauthlib==3.2.0
opt-einsum==3.3.0
optax==0.1.3
packaging==21.3
pandas==1.4.3
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.2.0
promise==2.3
prompt-toolkit==3.0.30
protobuf==3.19.4
psutil==5.9.1
ptyprocess==0.7.0
pure-eval==0.2.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.12.0
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2022.1
PyYAML==6.0
pyzmq==23.2.0
requests==2.28.1
requests-oauthlib==1.3.1
rich==11.2.0
rsa==4.8
scipy==1.8.1
six==1.16.0
stack-data==0.3.0
tensorboard==2.9.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.9.1
tensorflow-cpu==2.9.1
tensorflow-datasets==4.6.0
tensorflow-estimator==2.9.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.26.0
tensorflow-metadata==1.9.0
tensorflow-probability==0.17.0
tensorflow-text==2.9.0
termcolor==1.1.0
toml==0.10.2
toolz==0.12.0
tornado==6.2
tqdm==4.64.0
traitlets==5.3.0
typing_extensions==4.3.0
urllib3==1.26.10
wcwidth==0.2.5
Werkzeug==2.1.2
wrapt==1.14.1
zipp==3.8.1

I've spent almost a week on this issue, not finding a solution. Could anyone help?

TahaRazzaq commented 1 year ago

Hey @Julia90 Were you able to solve the issue? I am facing a similar one.

Julia90 commented 1 year ago

Hi @TahaRazzaq, I don't think I have found the solution. But I manually reduced the size of the batch or the dataset to be sent for calculation, and then it worked. I didn't check into the exact reason for this issue. But it seems to be the limit of memory hardware. I feel like using this transformer encoding, it's so easy to explode the memory. Just a piece of side information. Many researchers working on Transformer are using mix-precision or half-precision during training. I guess it's related to this issue.

TahaRazzaq commented 1 year ago

I did try reducing the batch size and that didn't really help. Will check if the memory is exploding at some point by tweaking the dataset. Thank you!

zamalali commented 6 months ago

I've encountered similar issues in the past, and I'd love to share some steps that have worked for me when dealing with the CUDA and cuDNN initialization errors, especially on AWS EC2 instances like the NVIDIA A10G. Here’s some set of solutions which you can try to potentially resolve these errors :)

1.Update NVIDIA Drivers and CUDA/cuDNN First, make sure your NVIDIA drivers and your CUDA/cuDNN installations are up to date and compatible with the TensorFlow/JAX versions you’re using. Here’s how I usually update my NVIDIA drivers on an Ubuntu system:

# Check your current driver version
nvidia-smi

# Update to the recommended driver version
sudo apt-get update
sudo apt-get install --no-install-recommends nvidia-driver-510

# Reboot to load the new driver
sudo reboot

# After rebooting, check the updated driver version
nvidia-smi

2. Verify CUDA and cuDNN Setup It's crucial that your CUDA and cuDNN paths are correctly set so that TensorFlow or JAX can locate them. I often ensure they’re set like this:

# Add CUDA and cuDNN to your path
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/include
export CUDA_HOME=/usr/local/cuda

3. Optimize GPU Memory Usage GPU memory can get filled up quickly, which might be part of the problem. Here’s how I manage TensorFlow's memory usage to prevent it from hogging all available GPU memory:

import tensorflow as tf
# Check available GPU devices
print("Available GPU devices:", tf.config.list_physical_devices('GPU'))

# Enable memory growth for each GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

4. Implement Mixed Precision Training Using mixed precision training has been a game-changer for running larger models or batches. It can significantly reduce memory consumption:

from tensorflow.keras.mixed_precision import experimental as mixed_precision
# Set up mixed precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

# Your model will now automatically utilize mixed precision

5. Dynamically Adjust Batch Size If you're still hitting memory limits, you might want to dynamically adjust your batch sizes based on the available memory:

def adjust_batch_size(available_memory, model_memory_usage, base_batch_size):
    """
    Adjust the batch size based on available GPU memory and the model's memory usage
    """
    max_batch_size = int(available_memory / model_memory_usage)
    return min(max_batch_size, base_batch_size)

# Example usage
current_batch_size = adjust_batch_size(23028, 22723, 64)  # Example values
print("Adjusted Batch Size:", current_batch_size)

Hope this helps!