Add tests of flash attention correctness in TransformerLens/tests/integration/test_flash_attn.py. To explain how to pass the test, I will give some definitions:
a = activation(tl-w/flash_attn), b= activation(tl-wo/flash_attn)
a'= activation(hf-w/flash_attn), b'= activation(hf-wo/flash_attn)
error_tl=max(|a-b|) and error_hf=max(|a'-b'|) for attention, MLP and residual stream activations in every layer.
If error_tl < error_hf * 5, then the test is passed. Actually, error_tl is sometimes smaller than error_hf, so I think "5" is not that big.
use_flash_attn
option when loading aHookedTransformer
model.TransformerLens/transformer_lens/components/abstract_attention.py
: https://github.com/OpenMOSS/Language-Model-SAEs/blob/a12cc220ad0e3e8afef158dd01809ee454c73e68/TransformerLens/transformer_lens/components/abstract_attention.py#L218-L244TransformerLens/tests/integration/test_flash_attn.py
. To explain how to pass the test, I will give some definitions: a = activation(tl-w/flash_attn), b= activation(tl-wo/flash_attn) a'= activation(hf-w/flash_attn), b'= activation(hf-wo/flash_attn) error_tl=max(|a-b|) and error_hf=max(|a'-b'|) for attention, MLP and residual stream activations in every layer. If error_tl < error_hf * 5, then the test is passed. Actually, error_tl is sometimes smaller than error_hf, so I think "5" is not that big.