Extend the naive flash-attn notebook to allow block-wise processing of only a fraction of the blocks at a time, i.e. pass in and out state required to continue updating the outputs for the current queries (e.g. store block max, current sum etc).
With new function create a little test that shows that all values of splitted processing are "allclose()" to the same computation as classic dot product attention (see naive_attn() in the notebook linked above).
Store the generated ipynb file in the notebooks folder of this repo.
Extend the naive flash-attn notebook to allow block-wise processing of only a fraction of the blocks at a time, i.e. pass in and out state required to continue updating the outputs for the current queries (e.g. store block max, current sum etc).
With new function create a little test that shows that all values of splitted processing are "allclose()" to the same computation as classic dot product attention (see
naive_attn()
in the notebook linked above).Store the generated ipynb file in the notebooks folder of this repo.