jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
481 stars 127 forks source link

[Proposal] BOS should be ignored by default in the activation store for SAEs not trained on the bos (and future SAEs trained in SAE Lens) #333

Open jbloomAus opened 1 month ago

jbloomAus commented 1 month ago

Proposal

Ensure that the activation_store.from_sae method creates a store that will remove BOS activations before adding them to the buffer. Setting default seqpos slice to ignore the first token on all SAEs seems like a reasonable way to do this.

Motivation

Some SAEs aren't trained on BOS tokens (as is the case with GemmaScope) and it's probably the case that this should be the default in SAE training (and may increase performance)

Checklist

jbloomAus commented 1 month ago

@chanind @curt-tigges let me know if that doesn't make sense.

chanind commented 1 month ago

I think the seqpos solution will only work for cases where every sequence is the same length, which is the case for OthelloGPT but isn't the case for most LLM datasets. For instance, when batching, we might end up with a BOS in the middle of the input rather than only at the front. I think we need to look for BOS tokens and mask them somehow, since there's also no guarantee that there's the same number of BOS tokens in a given batch. This is also somewhat problematic for pre-cached datasets since we don't currently save the special token indices. We also may want to exclude other tokens depending on how the training is set up, e.g. if the dataset includes EOS tokens or SEP tokens we probably want to exclude those too (or at least want the option to)

IMO the most robust solution is to return both activations AND a mask from the activations store / cached activations and have the SAE training exclude masked activations from loss calculation. In the mask we could define an int value for each type of special character (e.g. 1 = BOS, 2 = SEP, 3 = EOS) and then the config can specify which of these to exclude from training / eval. Sometimes tokenizers have the same token for multiple purposes (e.g. BOS = EOS = SEP) - we could have a scheme for this where, e.g. 12 = BOS + SEP, 123 = BOS + EOS + SEP, 13 = BOS + EOS, etc... Maybe I'm overthinking this, but I don't see an easy way to handle this without keep track of a token mask of some sort.