hpcaitech / Open-Sora

Open-Sora: Democratizing Efficient Video Production for All
https://hpcaitech.github.io/Open-Sora/
Apache License 2.0
20.1k stars 1.91k forks source link

Use PyTorch scaled_dot_product_attention when Flash Attention is not available #559

Open bayley opened 4 days ago

bayley commented 4 days ago

In the current code, if Flash Attention is not available, models.layers.blocks.Attention falls back to naive attention which directly materializes the attention matrix, OOM'ing immediately even for modest generation lengths.

Can we fall back to torch.nn.functional.scaled_dot_product_attention instead so that pre-Ampere users can run inference as well?