cuda-mode / ring-attention

ring-attention experiments
Apache License 2.0
89 stars 10 forks source link

Extend educational naive flash-attn impl to allow partial kv-block processing (create naive ring-attn) #4

Closed andreaskoepf closed 6 months ago

andreaskoepf commented 7 months ago

Extend the naive flash-attn notebook to allow block-wise processing of only a fraction of the blocks at a time, i.e. pass in and out state required to continue updating the outputs for the current queries (e.g. store block max, current sum etc).

With new function create a little test that shows that all values of splitted processing are "allclose()" to the same computation as classic dot product attention (see naive_attn() in the notebook linked above).

Store the generated ipynb file in the notebooks folder of this repo.

lancerts commented 7 months ago

Did a more basic implementation of flash attn forward pass based on the paper with annotations. The notebook for educational purpose.

lancerts commented 7 months ago

Added an incremental version on cell 5 and below along k and v block.