Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.51k stars 1.24k forks source link

Support for PyTorch 2.0 #88

Open netw0rkf10w opened 1 year ago

netw0rkf10w commented 1 year ago

PyTorch 2.0 has introduced torch.compile for accelerating training and inference. I have tried it on top of flash attention but unfortunately torch seems to unable to compile flash attention:

[2022-12-09 15:37:59,048] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
   function: '<graph break in forward>' (/home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:57)
   reasons:  ['___guarded_code.valid']

Hopefully you could make flash attention compatible with PyTorch 2.0 in the near future. (FYI flash attention is still faster than non-flash attention + torch.compile for ViT).

tridao commented 1 year ago

Thanks! Is there some documentation on what's required to make things compatible with torch.compile?

netw0rkf10w commented 1 year ago

@tridao That's a great question! Could you please ask the PyTorch developers directly? They are very actively looking for feedback. This issue on their repo could be useful: https://github.com/pytorch/pytorch/issues/90550.