allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.24k stars 400 forks source link

Flash attention 2.0 support #557

Closed johnhalloran321 closed 2 months ago

johnhalloran321 commented 2 months ago

🚀 The feature, motivation and pitch

from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-1B", revision = "step30000-tokens126B", 
                                             use_flash_attention_2="flash_attention_2",
                                             torch_dtype = torch.bfloat16,
                                             trust_remote_code=True)

Results in: ValueError: OLMoForCausalLM does not support Flash Attention 2.0 yet. Please request to add support where the model is hosted, on its model hub page: https://huggingface.co/allenai/OLMo-1B/discussions/new or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new

Could support for flash attention 2.0 be added to the OLMo-1B/7B?

Alternatives

No response

Additional context

No response

dumitrac commented 2 months ago

Thank you for your request @johnhalloran321 . OLMo itself already supports flash attention - https://github.com/allenai/OLMo/blob/main/olmo/model.py#L525. However, I'm trying to figure out why this error is raised from the transformers library. I was able to reproduce it, by the way. Stay tuned.

dumitrac commented 2 months ago

@johnhalloran321, could you please remove the "use_flash_attention_2" argument, and instead add "flash_attention=True". With this, the error went away on my end. Please let me know if that resolved it.

johnhalloran321 commented 2 months ago

Hi @dumitrac,

Thanks for the quick response, that looks to have fixed things. Closing the issue.

Best