Open jhavukainen opened 2 months ago
@jhavukainen do you know how much memory one would need to reproduce it? I've tried to repro it machine with 32Gb of memory and all I got was:
% python3 bug-135598.py
/Users/malfet/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
warnings.warn(
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 4/4 [00:00<00:00, 28.22it/s]
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Traceback (most recent call last):
File "/Users/malfet/test/bug-135598.py", line 36, in <module>
result = pipeline(
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/text_generation.py", line 267, in __call__
return super().__call__(Chat(text_inputs), **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/base.py", line 1302, in __call__
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/base.py", line 1309, in run_single
model_outputs = self.forward(model_inputs, **forward_params)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/base.py", line 1209, in forward
model_outputs = self._forward(model_inputs, **forward_params)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/text_generation.py", line 370, in _forward
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/generation/utils.py", line 2215, in generate
result = self._sample(
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/generation/utils.py", line 3206, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py", line 1190, in forward
outputs = self.model(
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py", line 945, in forward
layer_outputs = decoder_layer(
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py", line 676, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py", line 602, in forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: Invalid buffer size: 23.99 GB
Hi @malfet !
This was reproduced this on a 128GB machine but since its a 8B model it should run on a 64GB at least. TBH I was kinda assuming it would have worked on a 32GB machine as well but evidently that's not the case here.
π Describe the bug
Error:
failed assertion
[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: total bytes of NDArray > 2**32'`Requires similar tiling approach to BinaryOp that was done for the batch matmul op here: PR #133430
Reproduces with running LLaMA v3.1 on the current nightly PyTorch.
Script to repro
Versions
Collecting environment information... PyTorch version: 2.5.0.dev20240909 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A
OS: macOS 15.0 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.3.9.4) CMake version: version 3.30.2 Libc version: N/A
Python version: 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] (64-bit runtime) Python platform: macOS-14.6-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Apple M1 Ultra
Versions of relevant libraries: [pip3] numpy==2.0.1 [pip3] torch==2.5.0.dev20240909 [pip3] torchaudio==2.5.0.dev20240909 [pip3] torchvision==0.20.0.dev20240909 [conda] numpy 2.0.1 pypi_0 pypi [conda] torch 2.5.0.dev20240909 pypi_0 pypi [conda] torchaudio 2.5.0.dev20240909 pypi_0 pypi [conda] torchvision 0.20.0.dev20240909 pypi_0 papi
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97