Closed josephrocca closed 3 years ago
Hi there,
Does FlaxCLIPModel need more than ~7GB of GPU memory for some reason? No, it does not, it takes ~600M in fp32.
Also, this does not look like a memory error. This most probably is related to JAX GPU installation, as you can find here, JAX needs the right version of CUDA and CuDNN installed. So maybe there is a version mismatch between the docker image and the required version by JAX. Could you please verify if the right version of CUDA and CuDNN is available?
(Edit: Solved - please skip this and see follow-up comments)
@patil-suraj Thanks for your fast reply! Here are the exact reproduction steps I just took to confirm that the right versions of CUDA and CuDNN are available:
docker run --rm --gpus all -it --ipc=host ufoym/deepo:all-py36-cu111
nvcc --version
Output:
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0
For CUDA v11.1, CuDNN must be version 8 as specified in the instructions you linked:
cat /usr/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
Output:
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 0
#define CUDNN_PATCHLEVEL 5
Confirm that /usr/local/cuda-11.1
exists per instructions. ā
Install jax and jaxlib:
pip install --upgrade pip
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
Install transformers
and flax
:
pip install --upgrade transformers flax
Run python3
and then paste this:
import jax
from transformers import CLIPProcessor, FlaxCLIPModel
model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
And I get the following error (pasting again in case there are any slight but important differences):
Downloading: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 3.98k/3.98k [00:00<00:00, 4.07MB/s]
Downloading: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 605M/605M [00:11<00:00, 53.6MB/s]
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
2021-09-04 08:46:47.901780: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:691] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for %custom-call = (f32[1,7,7,768]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,224,224,3]{2,1,3,0} %copy.3, f32[32,32,3,768]{1,0,2,3} %copy.4), window={size=32x32 stride=32x32}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 224, 224, 3)\n padding=((0, 0), (0, 0))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(32, 32, 3, 768)\n window_strides=(32, 32) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.
Convolution performance may be suboptimal.
2021-09-04 08:46:48.011984: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 0 failed: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3956): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 727, in __init__
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 105, in __init__
random_params = self.init_weights(self.key, input_shape)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 740, in init_weights
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 1000, in init
method=method, mutable=mutable, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 969, in init_with_output
{}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 939, in apply
)(variables, *args, **kwargs, rngs=rngs)
File "/usr/local/lib/python3.6/dist-packages/flax/core/scope.py", line 687, in wrapper
y = fn(root, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 1178, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 1064, in __call__
return_dict=return_dict,
File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 563, in __call__
hidden_states = self.embeddings(pixel_values)
File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/clip/modeling_flax_clip.py", line 217, in __call__
patch_embeds = self.patch_embedding(pixel_values)
File "/usr/local/lib/python3.6/dist-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/flax/linen/linear.py", line 279, in __call__
precision=self.precision)
File "/usr/local/lib/python3.6/dist-packages/jax/_src/lax/lax.py", line 633, in conv_general_dilated
preferred_element_type=preferred_element_type)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 264, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 603, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 249, in apply_primitive
return compiled_fun(*args)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 365, in _execute_compiled_primitive
out_bufs = compiled.execute(input_bufs)
RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3956): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
I also just tried distilbert-base-uncased
using that exact same environment (the same docker container instance, I mean) and got a RuntimeError: CUDA operation failed: out of memory
despite having around 7GB of memory free according to nvidia-smi
:
from transformers import DistilBertTokenizer, FlaxDistilBertForMaskedLM
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = FlaxDistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
Output:
2021-09-04 08:51:12.465635: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
Downloading: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 232k/232k [00:00<00:00, 337kB/s]
Downloading: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 28.0/28.0 [00:00<00:00, 13.8kB/s]
Downloading: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 466k/466k [00:01<00:00, 397kB/s]
Downloading: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 483/483 [00:00<00:00, 536kB/s]
Downloading: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 268M/268M [00:05<00:00, 52.6MB/s]
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/distilbert/modeling_flax_distilbert.py", line 438, in __init__
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 105, in __init__
random_params = self.init_weights(self.key, input_shape)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/distilbert/modeling_flax_distilbert.py", line 445, in init_weights
params_rng, dropout_rng = jax.random.split(rng)
File "/usr/local/lib/python3.6/dist-packages/jax/_src/random.py", line 262, in split
return _split(key, int(num)) # type: ignore
File "/usr/local/lib/python3.6/dist-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/jax/_src/api.py", line 427, in cache_miss
donated_invars=donated_invars, inline=inline)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1560, in bind
return call_bind(self, fun, *args, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1563, in process
return trace.process_call(self, fun, tracers, params)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 595, in _xla_call_impl
return compiled_fun(*args)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 893, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: CUDA operation failed: out of memory
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/distilbert/modeling_flax_distilbert.py", line 438, in __init__
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_flax_utils.py", line 105, in __init__
random_params = self.init_weights(self.key, input_shape)
File "/usr/local/lib/python3.6/dist-packages/transformers/models/distilbert/modeling_flax_distilbert.py", line 445, in init_weights
params_rng, dropout_rng = jax.random.split(rng)
File "/usr/local/lib/python3.6/dist-packages/jax/_src/random.py", line 262, in split
return _split(key, int(num)) # type: ignore
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 893, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: CUDA operation failed: out of memory
I'm not sure if they're related, but there are mentions of the CUDNN_STATUS_EXECUTION_FAILED
error here:
In the latter issue hawkinsp mentions that 2GB og GPU memory is too little:
The issue is that your GPU doesn't have very much memory. Both CuDNN and JAX need some memory to work, and by default JAX allocates too much. See: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for more details. We might be able to tweak the defaults to make things work a little better on low-memory configurations, but it's a niche use case (2GB is pretty small for a current GPU).
So I wonder if ~6.5 GB is also too little? Seems unlikely, but maybe @hawkinsp could comment?
Edit: Oh, setting XLA_PYTHON_CLIENT_MEM_FRACTION
to something like 0.7 solves it! By default JAX pre-allocates 90% of memory.
$ export XLA_PYTHON_CLIENT_MEM_FRACTION=.7
$ python3
>>> import jax
>>> from transformers import CLIPProcessor, FlaxCLIPModel
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
(no errors!)
Unless this error can be displayed (or averted) in a more user-friendly/helpful way, I think this issue can be closed. I guess it's probably something that would need to be done in JAX rather than transformers anyway?
Environment info
(See reproduction steps for the docker image to get exact environment)
transformers
version: 4.10.0Who can help
@patil-suraj
Information
Model I am using: FlaxCLIPModel
The problem arises when using the official example script.
To reproduce
Steps to reproduce the behavior:
EDIT: Please use the more rigorous reproduction instructions in my comment below.
Start with a docker image like this one:
Install
transformers
and jax/flax:Run this code:
It produces the following error:
Other notes
Here's my
nvidia-smi
:Does
FlaxCLIPModel
need more than ~7GB of GPU memory for some reason? I wouldn't have expected it to need any more than a GB or two at the most, given the CLIP model's parameter count.Also worth noting that the model works fine when using the CPU on my machine, and it works fine with both TPU and GPU when running in a Google Colab notebook. I've also tested with the
ufoym/deepo:all-py36-cu111
docker image, but I get the same error.