microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3.01k stars 202 forks source link

Swapped naive dot product attention for flash attention #24

Open usryokousha opened 1 year ago

usryokousha commented 1 year ago

This pull request adds support for the Flash Attention mechanism to the MultiheadAttention module. Flash Attention is a recently proposed alternative to the conventional multi-head attention mechanism which reduces memory usage and improves training efficiency. The implementation in this pull request follows the paper "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (https://arxiv.org/abs/2205.14135)

Changes Made:

usryokousha commented 1 year ago

@microsoft-github-policy-service agree

mranzinger commented 1 year ago

I ran into some issues using this branch as-is, and created a pull request for it here: https://github.com/usryokousha/torchscale/pull/1

Please review and pull in, if applicable.

usryokousha commented 1 year ago

Please merge with master

usryokousha commented 1 year ago

Please merge with master