There is a memory-efficient implementation of SDPA in PyTorch, but it uses CUDA fused kernels that do not support arbitrary attention masks yet. Hence, we cannot currently use it to optimise memory usage of T5 attention. Indeed, T5 attention requires Relative Positional Embeddings, typically as an (arbitrary) component of the attention mask (see Issue 96099).
There are plans to release a memory-efficient PyTorch SDPA working with arbitrary attention masks. Then, we could use it to enable memory-efficient attention for T5.
Scaled dot-product attention (SDPA) for T5 does not currently use memory optimisations, such as FlashAttention or Memory-Efficient Attention.
There is a memory-efficient implementation of SDPA in PyTorch, but it uses CUDA fused kernels that do not support arbitrary attention masks yet. Hence, we cannot currently use it to optimise memory usage of T5 attention. Indeed, T5 attention requires Relative Positional Embeddings, typically as an (arbitrary) component of the attention mask (see Issue 96099).
There are plans to release a memory-efficient PyTorch SDPA working with arbitrary attention masks. Then, we could use it to enable memory-efficient attention for T5.