pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.41k stars 22.73k forks source link

LLaMA v3.1 on MPS backend breaks in BinaryOp mps::add_sub_lerp_template #135598

Open jhavukainen opened 2 months ago

jhavukainen commented 2 months ago

πŸ› Describe the bug

Error: failed assertion[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: total bytes of NDArray > 2**32'`

Requires similar tiling approach to BinaryOp that was done for the batch matmul op here: PR #133430

Reproduces with running LLaMA v3.1 on the current nightly PyTorch.

Script to repro

import torch
import transformers

if not torch.backends.mps.is_available():
    raise RuntimeError(
        "Please enable MPS on your machine. See https://pytorch.org/docs/stable/backends.html#torch-multiprocessing-mp"
    )
_TORCH_DEVICE = "mps"

pipeline = transformers.pipeline(
    "text-generation",
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    device=_TORCH_DEVICE,
    torch_dtype=torch.bfloat16
)

needle = "needle"
hay = "hay"  # 3 + 1 chars per element

target_num_tokens_approximate = 20_000
haystack = [hay] * target_num_tokens_approximate
test_index = 42
haystack[test_index] = needle
haystack_list = list(haystack)
prompt = f"""
Please return the index of the needle in the haystack as an integer.
Return only the integer number index.
<BEGIN HAYSTACK>
{" ".join(haystack_list)}
<END HAYSTACK>
"""

messages = [
    {"role": "system", "content": prompt},
]
result = pipeline(
    messages,
    max_new_tokens=200,
)
messages = result[0]["generated_text"]
completion = messages[-1]["content"]

print("successful inference")
print(f"completion: {completion}")
print(f"test_index: {test_index}")

Versions

Collecting environment information... PyTorch version: 2.5.0.dev20240909 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 15.0 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.3.9.4) CMake version: version 3.30.2 Libc version: N/A

Python version: 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ] (64-bit runtime) Python platform: macOS-14.6-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M1 Ultra

Versions of relevant libraries: [pip3] numpy==2.0.1 [pip3] torch==2.5.0.dev20240909 [pip3] torchaudio==2.5.0.dev20240909 [pip3] torchvision==0.20.0.dev20240909 [conda] numpy 2.0.1 pypi_0 pypi [conda] torch 2.5.0.dev20240909 pypi_0 pypi [conda] torchaudio 2.5.0.dev20240909 pypi_0 pypi [conda] torchvision 0.20.0.dev20240909 pypi_0 papi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97

malfet commented 1 week ago

@jhavukainen do you know how much memory one would need to reproduce it? I've tried to repro it machine with 32Gb of memory and all I got was:

% python3 bug-135598.py 
/Users/malfet/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020
  warnings.warn(
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:00<00:00, 28.22it/s]
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Traceback (most recent call last):
  File "/Users/malfet/test/bug-135598.py", line 36, in <module>
    result = pipeline(
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/text_generation.py", line 267, in __call__
    return super().__call__(Chat(text_inputs), **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/base.py", line 1302, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/base.py", line 1309, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/base.py", line 1209, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/pipelines/text_generation.py", line 370, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/generation/utils.py", line 2215, in generate
    result = self._sample(
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/generation/utils.py", line 3206, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py", line 1190, in forward
    outputs = self.model(
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py", line 945, in forward
    layer_outputs = decoder_layer(
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py", line 676, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/malfet/Library/Python/3.9/lib/python/site-packages/transformers/models/llama/modeling_llama.py", line 602, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: Invalid buffer size: 23.99 GB
jhavukainen commented 1 week ago

Hi @malfet !

This was reproduced this on a 128GB machine but since its a 8B model it should run on a 64GB at least. TBH I was kinda assuming it would have worked on a 32GB machine as well but evidently that's not the case here.