Open zzhhjjj opened 5 months ago
We can disable flash attention automatically for old hardware:
import torch
def supports_flash_attention(device_id):
"""Check if a GPU supports FlashAttention."""
major, minor = torch.cuda.get_device_capability(device_id)
# Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
return is_sm8x or is_sm90
Add end-to-end test for llama.