I am wondering what's the best way to use efficient implementations of attention. PyTorch provides the experimental torch.nn.functional.scaled_dot_product_attention (SDPA) which supports three different implementations, including flash attention. Unfortunately, we cannot use flash attention because it doesn't support arbitrary attention masks yet (something which is critical for Chronos). It's not clear when attention mask support will be added to flash attention (see https://github.com/Dao-AILab/flash-attention/issues/840). Meanwhile, SDPA falls back to another efficient implementation when a mask is provided.
I monkey patched the T5Attention implementation in transformers and here are the results (script below).
Results
TL;DR: SDPA is clearly faster than the implementation in transformers that we're currently using, even without flash attention.
V100 (float32)
Note: V100 doesn't support bfloat16, so SDPA won't work with bf16 because the custom kernels won't exist.
I am wondering what's the best way to use efficient implementations of attention. PyTorch provides the experimental
torch.nn.functional.scaled_dot_product_attention
(SDPA) which supports three different implementations, including flash attention. Unfortunately, we cannot use flash attention because it doesn't support arbitrary attention masks yet (something which is critical for Chronos). It's not clear when attention mask support will be added to flash attention (see https://github.com/Dao-AILab/flash-attention/issues/840). Meanwhile, SDPA falls back to another efficient implementation when a mask is provided.I monkey patched the
T5Attention
implementation in transformers and here are the results (script below).Results
TL;DR: SDPA is clearly faster than the implementation in transformers that we're currently using, even without flash attention.
V100 (float32)
Note: V100 doesn't support
bfloat16
, so SDPA won't work with bf16 because the custom kernels won't exist.Using transformers (current version):
Using SDPA:
A100 (float32)
Using transformers (current version):
Using SDPA:
A100 (bfloat16)
Using transformers (current version):
Using SDPA:
Script