FlashAttention is an approximation-free fast and memory efficient implementation of multi-head attention. In particular, it reduces the memory requirement from $\mathcal{O}(n^2)$ to $\mathcal{O}(n)$ and would therefore be highly beneficial for practitioners with limited GPU memory.
Are there any plans to implement or integrate it for the ESM models?
FlashAttention is an approximation-free fast and memory efficient implementation of multi-head attention. In particular, it reduces the memory requirement from $\mathcal{O}(n^2)$ to $\mathcal{O}(n)$ and would therefore be highly beneficial for practitioners with limited GPU memory.
Are there any plans to implement or integrate it for the ESM models?