Closed gaetansnl closed 1 year ago
I have the same issue on T5-3B (OOM because of the clone_input
function) and eager optimizer.
Happens on beam = 1, only if we apply dynamo to the decoder part (encoder no pb).
It has been reported by one of our user here: https://github.com/ELS-RD/kernl/issues/188
@williamwen42 any idea of a (dirty?) workaround if a clean fix takes time to come?
Seems to be related to https://github.com/pytorch/torchdynamo/issues/1950 CC @ezyang @voznesenskym
If it is really #1950, I can give you a dirty workaround for it.
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 843e50687a..97a78f8638 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -717,7 +717,6 @@ def wrap_fx_proxy_cls(target_cls, tx, proxy, example_value=None, **options):
# TODO(voz): Find all the callsites and burn this down.
# Flipping it to an assert fails dozens of tests.
if not isinstance(example_value, torch._subclasses.FakeTensor):
- proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value)
fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx)
example_value = fake_wrapper(example_value)
Correct me if I'm wrong but I think the dirty fix has been applied here https://github.com/pytorch/torchdynamo/issues/1950 and I tried testing it and am still running out of memory
Yeah, then this is a different problem, we will need to investigate
@TheExGenesis are you trying on Whisper? where does it crash? On last nighties, the issue seems to be elsewhere.
@ezyang is it possible that eager mode of dynamo has a higher (even slightly, like from 10.4 Gb to 10.6 Gb of CUDA memory reserved) memory footprint than "real" eager mode (aka without dynamo)? Also, would it be possible that the garbage collector is not called with eager+dynamo as it would for real eager mode? (later or never)
Dynamo eager can use more memory, but we found in our benchmark suite that typically memory usage improved, because our min cut graph partitioner can make better choices about what to save for backwards. The other known and obvious culprits for memory usage is cuda graphs (but this is turned off by default) and fake tensor falling back to real operations to fallback for meta usage (but this is a very slight amount of extra memory usage, only as much as is necessary to allocate the inputs/outputs for a particular operation.) @eellison, do we have an easy log level to test for the latter?
I'm going to bump the priority to make sure we have someone look into this.
@ezyang we don't atm, I can add. the one off-ops culprit was actually a red herring for other things (cudagraphs), when I landed the change for running ops inductor with fake tensor instead of regular tensors memory compression didn't decrease at all. I think it would be worth adding a debug mode that prints out the additional memory overhead for some of the following when it's significant.
I think the remaining sources of memory overhead in order of likeliness:
The issue happens at inference time with dynamo+eager mode (no CUDA graph, no Triton involved). The fix of #1950 helps but it seems something else is not working as expected.
Code to make it raise OOM is the following:
import torch
import torch._dynamo as torchdynamo
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, WhisperProcessor
torch.cuda.memory._record_memory_history(True)
torchdynamo.config.cache_size_limit = 512
audio_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").to("cuda")
def optimize_model(original_model) -> None:
original_model.forward2 = original_model.forward
@torchdynamo.optimize("eager")
def run(*args, **kwargs):
return original_model.forward2(*args, **kwargs)
original_model.forward = run
optimize_model(model.model.decoder)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
speech_data = audio_dataset[0]["audio"]["array"]
inputs = processor(speech_data, return_tensors="pt", sampling_rate=16_000).input_features.to("cuda")
with torch.no_grad(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"):
predicted_ids = model.generate(inputs, min_length=25, max_length=25, num_beams=5, do_sample=False)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
assert (
transcription == "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
), transcription
print(transcription)
print("torch.cuda.memory_allocated: %fGB" % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024))
print("torch.cuda.memory_reserved: %fGB" % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024))
print("torch.cuda.max_memory_reserved: %fGB" % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024))
Traceback (most recent call last):
File "/mnt/workspace/kernl/crash.py", line 31, in <module>
predicted_ids = model.generate(inputs, min_length=25, max_length=25, num_beams=5, do_sample=False)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 34, in decorate_context
return func(*args, **kwargs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/generation/utils.py", line 1608, in generate
return self.beam_search(
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/generation/utils.py", line 2872, in beam_search
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py", line 1251, in _reorder_cache
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py", line 1251, in <genexpr>
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.69 GiB total capacity; 20.52 GiB already allocated; 59.19 MiB free; 21.93 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
1/ it works without dynamo (memory reserved < 12Gb), aka if you comment optimize_model(model.model.decoder)
2/ it OOM with torch dynamo on a 3090 (24Gb DDR)
The error is due to this line in Whisper model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L1252
This function is called by the beam decoder.
I know it because:
CUDA_LAUNCH_BLOCKING=1
Without dynamo, it prints:
mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
torch.cuda.memory_allocated: 5.880043GB
torch.cuda.memory_reserved: 11.724609GB
torch.cuda.max_memory_reserved: 11.724609GB
Moreover new Pytorch memory profiler (torch.cuda.memory._snapshot()
) reports that's where most of memory is allocated.
I am under the impression that with torch dynamo the garbage collector can't delete these tensors, and then the CUDA memory can't be freed. The tensors of this function will be output by the model (in the cache of the transformer model) and then reused as input to generate the next token. One possible issue is that, for some reason IDK, reference to those tensors are captured by dynamo and they can't be garbage collected anymore. Makes sense to you?
Not related, but still sharing, minifier doesn't seem to catch those OOM issues, at least it's the second time it fails for me (and works for simpler case).
I'm not getting OOM anymore on an 80GB A100, but I am hitting the cache limit and getting no speed improvement (strictly 0.99x of baseline). Cache limit warnings:
function: 'run' (<ipython-input-3-d766d9f4b6a6>:5)
reasons: set(kwargs.keys()) == {'output_attentions', 'return_dict', 'input_features', 'output_hidden_states'}
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
[2022-12-22 03:54:12,536] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: 'forward' (/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:861)
reasons: ___check_obj_id(past_key_values, 94636661886208)
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
[2022-12-22 03:54:26,044] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: 'forward' (/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:396)
reasons: tensor 'past_key_value[0]' strides mismatch at index 0. expected 81920, actual 84480
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
[2022-12-22 03:54:31,065] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: 'forward' (/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:166)
reasons: tensor 'past_key_value[0]' strides mismatch at index 0. expected 84480, actual 85760
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
[2022-12-22 03:54:31,514] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '_shape' (/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:158)
reasons: ___check_obj_id(self, 140545768259696)
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.
[2022-12-22 03:54:38,637] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: 'forward' (/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py:120)
reasons: past_key_values_length == 64
to diagnose recompilation issues, see https://github.com/pytorch/torchdynamo/blob/main/TROUBLESHOOTING.md.```
You can increase the cache limit by modifying Dynamo config like code posted just above:
torchdynamo.config.cache_size_limit = 512
Moreover can you share your cuda memory footprint after running the model? (See code above on how to do it)
With "eager", I can't raise the cache_size_limit above 64 without getting OOM
With "ofi", even at cache_size_limit=64, I'm getting OOM, also a bunch of these Warnings that I haven't had time to research
/root/anaconda3/envs/whisper_kernl/lib/python3.9/site-packages/torch/jit/_check.py:181: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
@gaetansnl You are using the a no-op compiler which doesn't free the inputs to the backward when they are no longer needed. This will incur significant memory overhead. Could you try the default inductor backend ? i.e. torch.compile
. If it doesn't succeed with batch size 64, what number does it succeed ?
@TheExGenesis if you are seeing issues different from this one, please open a new issue, thank you.
Removing high priority because this is using a non standard backend which doesn't free inputs, so memory regression is expected.
@eellison do you have more details on what needs to be implemented in the backend ? I can't use inductor because I have a custom backend implementation
The problem is detailed here https://github.com/pytorch/pytorch/pull/83137/#issuecomment-1211320670.
To fix it for your backend, you want to return a compiled function that takes in a list of tensors by marking _boxed_call = True, and you also want to make sure the list is cleared and the inputs are freed when they are no longer needed.
https://github.com/pytorch/pytorch/pull/83137/ is a good example of a PR to follow.
CC @SherlockNoMad for custom backend this might be a good thing to document if it's not already.
I also have ouf of memory with inductor
torchdynamo.config.cache_size_limit = 512
def optimize_model(original_model) -> None:
original_model.forward2 = original_model.forward
@torchdynamo.optimize("inductor")
def run(*args, **kwargs):
return original_model.forward2(*args, **kwargs)
original_model.forward = run
audio_dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").to("cuda")
optimize_model(model.model.encoder)
optimize_model(model.model.decoder)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
speech_data = audio_dataset[0]["audio"]["array"]
inputs = processor(speech_data, return_tensors="pt", sampling_rate=16_000).input_features.to("cuda")
with torch.inference_mode(), torch.autocast(dtype=torch.float16, cache_enabled=True, device_type="cuda"):
predicted_ids = model.generate(inputs, min_length=25, max_length=25, num_beams=2, do_sample=False)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[0]
assert transcription == "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
I just reran the code with eager compiler + today nightly... and no more OOM! But 2.0.0.dev20230104+cu117 raises OOM. So basically something has been fixed since my last post.
Inductor compiler raises OOM but on CUDA graph, it's not surprising as CG copy input tensors and this model has a huge encoder output (appear in cache) if duplicated for each seq len of the decoder, it s not surprising it OOM.
Traceback (most recent call last):
File "/home/geantvert/workspace/kernl/toto.py", line 30, in <module>
predicted_ids = model.generate(inputs, min_length=25, max_length=25, num_beams=2, do_sample=False)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/generation/utils.py", line 1608, in generate
return self.beam_search(
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/generation/utils.py", line 2799, in beam_search
outputs = self(
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
return forward_call(*args, **kwargs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py", line 1194, in forward
outputs = self.model(
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
return forward_call(*args, **kwargs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py", line 1062, in forward
decoder_outputs = self.decoder(
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
return forward_call(*args, **kwargs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
return fn(*args, **kwargs)
File "/home/geantvert/workspace/kernl/toto.py", line 14, in run
return original_model.forward2(*args, **kwargs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/transformers/models/whisper/modeling_whisper.py", line 767, in forward
def forward(
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
return fn(*args, **kwargs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2467, in forward
return compiled_fn(full_args)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1066, in new_fn
fw_outs = call_func_with_args(compiled_fw, args, disable_amp=disable_amp)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1022, in call_func_with_args
out = normalize_as_list(f(args))
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 216, in run
return model(new_inputs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 233, in run
compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 272, in cudagraphify_impl
static_inputs = [
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 273, in <listcomp>
static_input(x) if idx not in static_input_idxs else x.detach()
File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 268, in static_input
buffer = torch.zeros(needed_size, dtype=x.dtype, device=x.device)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 23.69 GiB total capacity; 14.60 GiB already allocated; 31.06 MiB free; 21.29 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
@gaetansnl can we close the issue?
thanks a lot everyone !
🐛 Describe the bug
Hello ! I have an out of memory error when I try to run Whisper through torchdynamo.
openai/whisper-medium
num_beams=1
optimize_model(model.model.decoder)
it worksWhen I set
use_cache
tofalse
ingenerate
it segfault instead of OOM. And I don't think the minifier is working for this case.Pytorch:
1.14.0.dev20221130+cu117
(nightly)Minimal reproduction
Error logs
Minified repro
No response