huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

`gpt2` with `output_attentions=True` has different attentions shape between flash and eager #33417

Closed lapp0 closed 2 weeks ago

lapp0 commented 2 months ago

System Info

Who can help?

@ArthurZucker @gante

Information

Tasks

Reproduction

>>> model_flash = transformers.AutoModelForCausalLM.from_pretrained("gpt2", device_map="cuda", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
>>> model_eager = transformers.AutoModelForCausalLM.from_pretrained("gpt2", device_map="cuda", attn_implementation="eager")

>>> input_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]).to("cuda")

>>> eager_attns = model_eager(input_ids, output_attentions=True).attentions
>>> flash_attns = model_flash(input_ids, output_attentions=True).attentions

>>> len(eager_attns)
12
>>> len(flash_attns)
12

>>> eager_attns[0].shape
torch.Size([3, 12, 4, 4])
>>> flash_attns[0].shape
torch.Size([3, 4, 768])

Expected behavior

output_attentions=True should be result in an error for GPT2FlashAttention2

Additionally I'd like to understand what's being returned here.

ArthurZucker commented 1 month ago

When you use the flash_attention_2, the model cannot output the partial attention weights. I think there were warnings!

github-actions[bot] commented 3 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.