In the current code, if Flash Attention is not available, models.layers.blocks.Attention falls back to naive attention which directly materializes the attention matrix, OOM'ing immediately even for modest generation lengths.
Can we fall back to torch.nn.functional.scaled_dot_product_attention instead so that pre-Ampere users can run inference as well?
In the current code, if Flash Attention is not available, models.layers.blocks.Attention falls back to naive attention which directly materializes the attention matrix, OOM'ing immediately even for modest generation lengths.
Can we fall back to torch.nn.functional.scaled_dot_product_attention instead so that pre-Ampere users can run inference as well?