huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.87k stars 26.98k forks source link

Mistral with FlashAttention2 #28771

Closed khalil-Hennara closed 9 months ago

khalil-Hennara commented 9 months ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2") The code line has taken from the official website Mistral TypeError: MistralForCausalLM.init() got an unexpected keyword argument 'attn_implementation'

when using use_flash_attention_2=True it's work fine

Expected behavior

The model should be loaded without error, using flash attention2 in the background.

khalil-Hennara commented 9 months ago

I think the problem related to @ArthurZucker and @stevhliu

IYoreI commented 9 months ago

It looks like ‘attn_implementation’ is supported in version 4.36. Maybe you need to try it after upgrading the transfromers library version image

ArthurZucker commented 9 months ago

Yes, as @IYoreI mentions, feel free to upgrade the transformers version!

khalil-Hennara commented 9 months ago

Thanks @IYoreI , @ArthurZucker for your time

ArthurZucker commented 9 months ago

Closing as it's resolved!