huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.24k stars 26.34k forks source link

RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED #13416

Closed josephrocca closed 3 years ago

josephrocca commented 3 years ago

Environment info

(See reproduction steps for the docker image to get exact environment)

Who can help

@patil-suraj

Information

Model I am using: FlaxCLIPModel

The problem arises when using the official example script.

To reproduce

Steps to reproduce the behavior:

EDIT: Please use the more rigorous reproduction instructions in my comment below.

Start with a docker image like this one:

docker run --rm -it --gpus all tensorflow/tensorflow:2.4.0-gpu

Install transformers and jax/flax:

pip install --upgrade transformers jax flax jaxlib==0.1.69+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Run this code:

import jax
from transformers import CLIPProcessor, FlaxCLIPModel
model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")

It produces the following error:

Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 3.98k/3.98k [00:00<00:00, 4.33MB/s]
Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 605M/605M [00:10<00:00, 55.6MB/s]
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
2021-09-04 06:57:54.998764: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:691] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for %custom-call = (f32[1,7,7,768]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,224,224,3]{2,1,3,0} %copy.3, f32[32,32,3,768]{1,0,2,3} %copy.4), window={size=32x32 stride=32x32}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n                      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n                      feature_group_count=1\n                      lhs_dilation=(1, 1)\n                      lhs_shape=(1, 224, 224, 3)\n                      padding=((0, 0), (0, 0))\n                      precision=None\n                      preferred_element_type=None\n                      rhs_dilation=(1, 1)\n                      rhs_shape=(32, 32, 3, 768)\n                      window_strides=(32, 32) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.
2021-09-04 06:57:55.109797: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2040] Execution of replica 0 failed: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3990): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 727, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 105, in __init__
    random_params = self.init_weights(self.key, input_shape)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 740, in init_weights
    return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 1000, in init
    method=method, mutable=mutable, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 969, in init_with_output
    {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 939, in apply
    )(variables, *args, **kwargs, rngs=rngs)
  File "/usr/local/lib/python3.6/dist-packages/flax/core/scope.py", line 687, in wrapper
    y = fn(root, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 1178, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 1064, in __call__
    return_dict=return_dict,
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 563, in __call__
    hidden_states = self.embeddings(pixel_values)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 217, in __call__
    patch_embeds = self.patch_embedding(pixel_values)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/linear.py", line 279, in __call__
    precision=self.precision)
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/lax/lax.py", line 633, in conv_general_dilated
    preferred_element_type=preferred_element_type)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 603, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 249, in apply_primitive
    return compiled_fun(*args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 365, in _execute_compiled_primitive
    out_bufs = compiled.execute(input_bufs)
RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3990): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'

Other notes

Here's my nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| N/A   57C    P5    22W /  N/A |   1229MiB /  7982MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

Does FlaxCLIPModel need more than ~7GB of GPU memory for some reason? I wouldn't have expected it to need any more than a GB or two at the most, given the CLIP model's parameter count.

Also worth noting that the model works fine when using the CPU on my machine, and it works fine with both TPU and GPU when running in a Google Colab notebook. I've also tested with the ufoym/deepo:all-py36-cu111 docker image, but I get the same error.

patil-suraj commented 3 years ago

Hi there,

Does FlaxCLIPModel need more than ~7GB of GPU memory for some reason? No, it does not, it takes ~600M in fp32.

Also, this does not look like a memory error. This most probably is related to JAX GPU installation, as you can find here, JAX needs the right version of CUDA and CuDNN installed. So maybe there is a version mismatch between the docker image and the required version by JAX. Could you please verify if the right version of CUDA and CuDNN is available?

josephrocca commented 3 years ago

(Edit: Solved - please skip this and see follow-up comments)

@patil-suraj Thanks for your fast reply! Here are the exact reproduction steps I just took to confirm that the right versions of CUDA and CuDNN are available:

docker run --rm --gpus all -it --ipc=host ufoym/deepo:all-py36-cu111
nvcc --version

Output:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0

For CUDA v11.1, CuDNN must be version 8 as specified in the instructions you linked:

cat /usr/include/cudnn_version.h | grep CUDNN_MAJOR -A 2

Output:

#define CUDNN_MAJOR 8
#define CUDNN_MINOR 0
#define CUDNN_PATCHLEVEL 5

Confirm that /usr/local/cuda-11.1 exists per instructions. āœ…

Install jax and jaxlib:

pip install --upgrade pip
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Install transformers and flax:

pip install --upgrade transformers flax

Run python3 and then paste this:

import jax
from transformers import CLIPProcessor, FlaxCLIPModel
model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")

And I get the following error (pasting again in case there are any slight but important differences):

Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 3.98k/3.98k [00:00<00:00, 4.07MB/s]
Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 605M/605M [00:11<00:00, 53.6MB/s]
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
2021-09-04 08:46:47.901780: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:691] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for %custom-call = (f32[1,7,7,768]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,224,224,3]{2,1,3,0} %copy.3, f32[32,32,3,768]{1,0,2,3} %copy.4), window={size=32x32 stride=32x32}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n                      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n                      feature_group_count=1\n                      lhs_dilation=(1, 1)\n                      lhs_shape=(1, 224, 224, 3)\n                      padding=((0, 0), (0, 0))\n                      precision=None\n                      preferred_element_type=None\n                      rhs_dilation=(1, 1)\n                      rhs_shape=(32, 32, 3, 768)\n                      window_strides=(32, 32) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.
2021-09-04 08:46:48.011984: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 0 failed: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3956): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 727, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 105, in __init__
    random_params = self.init_weights(self.key, input_shape)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 740, in init_weights
    return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 1000, in init
    method=method, mutable=mutable, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 969, in init_with_output
    {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 939, in apply
    )(variables, *args, **kwargs, rngs=rngs)
  File "/usr/local/lib/python3.6/dist-packages/flax/core/scope.py", line 687, in wrapper
    y = fn(root, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 1178, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 1064, in __call__
    return_dict=return_dict,
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 563, in __call__
    hidden_states = self.embeddings(pixel_values)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 217, in __call__
    patch_embeds = self.patch_embedding(pixel_values)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/flax/linen/linear.py", line 279, in __call__
    precision=self.precision)
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/lax/lax.py", line 633, in conv_general_dilated
    preferred_element_type=preferred_element_type)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 264, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 603, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 249, in apply_primitive
    return compiled_fun(*args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 365, in _execute_compiled_primitive
    out_bufs = compiled.execute(input_bufs)
RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3956): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
josephrocca commented 3 years ago

I also just tried distilbert-base-uncased using that exact same environment (the same docker container instance, I mean) and got a RuntimeError: CUDA operation failed: out of memory despite having around 7GB of memory free according to nvidia-smi:

from transformers import DistilBertTokenizer, FlaxDistilBertForMaskedLM
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = FlaxDistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')

Output:

2021-09-04 08:51:12.465635: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 232k/232k [00:00<00:00, 337kB/s]
Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 28.0/28.0 [00:00<00:00, 13.8kB/s]
Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 466k/466k [00:01<00:00, 397kB/s]
Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 483/483 [00:00<00:00, 536kB/s]
Downloading: 100%|ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 268M/268M [00:05<00:00, 52.6MB/s]
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/distilbert/modeling_flax_distilbert.py", line 438, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 105, in __init__
    random_params = self.init_weights(self.key, input_shape)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/distilbert/modeling_flax_distilbert.py", line 445, in init_weights
    params_rng, dropout_rng = jax.random.split(rng)
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/random.py", line 262, in split
    return _split(key, int(num))  # type: ignore
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 595, in _xla_call_impl
    return compiled_fun(*args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 893, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: CUDA operation failed: out of memory

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/distilbert/modeling_flax_distilbert.py", line 438, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 105, in __init__
    random_params = self.init_weights(self.key, input_shape)
  File "/usr/local/lib/python3.6/dist-packages/transformers/models/distilbert/modeling_flax_distilbert.py", line 445, in init_weights
    params_rng, dropout_rng = jax.random.split(rng)
  File "/usr/local/lib/python3.6/dist-packages/jax/_src/random.py", line 262, in split
    return _split(key, int(num))  # type: ignore
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 893, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: CUDA operation failed: out of memory
josephrocca commented 3 years ago

I'm not sure if they're related, but there are mentions of the CUDNN_STATUS_EXECUTION_FAILED error here:

In the latter issue hawkinsp mentions that 2GB og GPU memory is too little:

The issue is that your GPU doesn't have very much memory. Both CuDNN and JAX need some memory to work, and by default JAX allocates too much. See: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for more details. We might be able to tweak the defaults to make things work a little better on low-memory configurations, but it's a niche use case (2GB is pretty small for a current GPU).

So I wonder if ~6.5 GB is also too little? Seems unlikely, but maybe @hawkinsp could comment?

Edit: Oh, setting XLA_PYTHON_CLIENT_MEM_FRACTION to something like 0.7 solves it! By default JAX pre-allocates 90% of memory.

$ export XLA_PYTHON_CLIENT_MEM_FRACTION=.7
$ python3
>>> import jax
>>> from transformers import CLIPProcessor, FlaxCLIPModel
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
(no errors!)

Unless this error can be displayed (or averted) in a more user-friendly/helpful way, I think this issue can be closed. I guess it's probably something that would need to be done in JAX rather than transformers anyway?