TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.48k stars 289 forks source link

[Proposal] Optionally use flash attention. #378

Open tbenthompson opened 1 year ago

tbenthompson commented 1 year ago

It would be nice to have a flag to enable flash attention in models where that would make sense. This is helpful for performance and memory usage in larger models. In my case working with Pythia 12B, I get ~50% better performance and ~4x larger batch sizes when using flash attention. I also find numerical stability in float16 to be better using flash attention, probably because the model was trained using flash attention.

The downside of using flash attention in TransformerLens is that we would not have access to intermediate quantities in the attention calculation like the attention matrix itself. This is why I would suggest having a default-off flag so that users can choose whether they need those intermediate values to be available. In addition, when only a small subset of attention intermediates are needed, it's much faster to just cache the input to the attention layer (or the residual stream) and then recompute those intermediates when needed.

Thanks!

neelnanda-io commented 1 year ago

Seems reasonable to me, I'd be happy for someone to add this

On Fri, 8 Sept 2023 at 18:44, Ben Thompson @.***> wrote:

It would be nice to have a flag to enable flash attention in models where that would make sense. This is helpful for performance and memory usage in larger models. In my case working with Pythia 12B, I get ~50% better performance and ~4x larger batch sizes when using flash attention. I also find numerical stability in float16 to be better using flash attention, probably because the model was trained using flash attention.

The downside of using flash attention in TransformerLens is that the we would not have access to intermediate quantities in the attention calculation like the attention matrix itself. This is why I would suggest having a default-off flag so that users can choose whether they need those intermediate values to be available. In addition, when only a small subset of attention intermediates are needed, it's much faster to just cache the input to the attention layer and then recompute those intermediates when needed.

Thanks!

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/378, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKOCL65GY7N52PH2PQDXZNKRXANCNFSM6AAAAAA4QWO7HE . You are receiving this because you are subscribed to this thread.Message ID: @.***>

alan-cooney commented 11 months ago

Seems v. useful for sparse autoencoder training.

Docs here - https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#conclusion - in case anyone wants to take this (I'll pick it up at some point if no-one does).

cmathw commented 8 months ago

I'd be quite keen to make a start on this soon, @alan-cooney have you made a start already?

alan-cooney commented 8 months ago

I'd be quite keen to make a start on this soon, @alan-cooney have you made a start already?

I haven't yet so please feel free to!