pytorch / text

Models, data loaders and abstractions for language processing, powered by PyTorch
https://pytorch.org/text
BSD 3-Clause "New" or "Revised" License
3.49k stars 815 forks source link

Make use of SDPA in TorchText T5 Models #2135

Open joecummings opened 1 year ago

yohann-benchetrit commented 1 year ago

Scaled dot-product attention (SDPA) for T5 does not currently use memory optimisations, such as FlashAttention or Memory-Efficient Attention.

There is a memory-efficient implementation of SDPA in PyTorch, but it uses CUDA fused kernels that do not support arbitrary attention masks yet. Hence, we cannot currently use it to optimise memory usage of T5 attention. Indeed, T5 attention requires Relative Positional Embeddings, typically as an (arbitrary) component of the attention mask (see Issue 96099).

There are plans to release a memory-efficient PyTorch SDPA working with arbitrary attention masks. Then, we could use it to enable memory-efficient attention for T5.