OpenMOSS / Language-Model-SAEs

For OpenMOSS Mechanistic Interpretability Team's Sparse Autoencoder (SAE) research.
32 stars 6 forks source link

Accelerate Inference in TransformerLens #26

Closed StarConnor closed 2 months ago

StarConnor commented 3 months ago
  1. Add use_flash_attn option when loading a HookedTransformer model.
  2. Add FlashAttentionV2 support in TransformerLens/transformer_lens/components/abstract_attention.py: https://github.com/OpenMOSS/Language-Model-SAEs/blob/a12cc220ad0e3e8afef158dd01809ee454c73e68/TransformerLens/transformer_lens/components/abstract_attention.py#L218-L244
  3. Add tests of flash attention correctness in TransformerLens/tests/integration/test_flash_attn.py. To explain how to pass the test, I will give some definitions: a = activation(tl-w/flash_attn), b= activation(tl-wo/flash_attn) a'= activation(hf-w/flash_attn), b'= activation(hf-wo/flash_attn) error_tl=max(|a-b|) and error_hf=max(|a'-b'|) for attention, MLP and residual stream activations in every layer. If error_tl < error_hf * 5, then the test is passed. Actually, error_tl is sometimes smaller than error_hf, so I think "5" is not that big.
StarConnor commented 2 months ago

Move it to TransformerLens/tests/integration/test_flash_attn.py and test with toy attention model