huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.3k stars 26.11k forks source link

Running a `forward` pass before `generate` with AWQ fused modules breaks it #28470

Closed IlyasMoutawwakil closed 4 months ago

IlyasMoutawwakil commented 7 months ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

from transformers import AutoModelForCausalLM, AwqConfig, AutoTokenizer

awq_config = AwqConfig(do_fuse=True, fuse_max_seq_len=512)
model = AutoModelForCausalLM.from_pretrained(
    "casperhansen/tinyllama-1b-awq",
    quantization_config=awq_config,
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained("casperhansen/tinyllama-1b-awq")
input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt").input_ids.to("cuda")

model.forward(input_ids)
model.generate(input_ids, max_new_tokens=100)

Expected behavior

code works if only generate is called but not if a forward pass precedes it. looking at the traceback:

Traceback (most recent call last):
  File "/workspace/llm-perf/test_.py", line 29, in <module>
    model.generate(input_ids, max_new_tokens=100)
  File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 1718, in generate
    return self.greedy_search(
  File "/home/user/.local/lib/python3.10/site-packages/transformers/generation/utils.py", line 2579, in greedy_search
    outputs = self(
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1181, in forward
    outputs = self.model(
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1033, in forward
    attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  File "/home/user/.local/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 372, in _prepare_4d_causal_attention_mask_for_sdpa
    expanded_4d_mask = attn_mask_converter.to_4d(
  File "/home/user/.local/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py", line 136, in to_4d
    expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
RuntimeError: The size of tensor a (9) must match the size of tensor b (25) at non-singleton dimension 3

the problems seems to be related to the sdpa integration

ArthurZucker commented 7 months ago

cc @younesbelkada and @fxmarty, if they use static cache then that is expected. I might fix it in #27931

ArthurZucker commented 6 months ago

cc @younesbelkada 🤗

VictorSanh commented 5 months ago

was this fixed? i just ran into the same error here

younesbelkada commented 5 months ago

Thanks everyone ! I managed to repro and the fix should be : https://github.com/casper-hansen/AutoAWQ/pull/401 cc @casper-hansen

Note if you run a dummy forward pass before you need to explicitly pass use_cache=False:

from transformers import AutoModelForCausalLM, AwqConfig, AutoTokenizer

awq_config = AwqConfig(do_fuse=True, fuse_max_seq_len=512)
model = AutoModelForCausalLM.from_pretrained(
    "casperhansen/tinyllama-1b-awq",
    quantization_config=awq_config,
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained("casperhansen/tinyllama-1b-awq")
input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt").input_ids.to("cuda")

model.forward(input_ids, use_cache=False)
model.generate(input_ids, max_new_tokens=100)
github-actions[bot] commented 4 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

younesbelkada commented 4 months ago

Closing as fixed on the latest autoawq release (see message above), let me know if this issue is still relevant