Open jbloomAus opened 1 month ago
@chanind @curt-tigges let me know if that doesn't make sense.
I think the seqpos solution will only work for cases where every sequence is the same length, which is the case for OthelloGPT but isn't the case for most LLM datasets. For instance, when batching, we might end up with a BOS in the middle of the input rather than only at the front. I think we need to look for BOS tokens and mask them somehow, since there's also no guarantee that there's the same number of BOS tokens in a given batch. This is also somewhat problematic for pre-cached datasets since we don't currently save the special token indices. We also may want to exclude other tokens depending on how the training is set up, e.g. if the dataset includes EOS tokens or SEP tokens we probably want to exclude those too (or at least want the option to)
IMO the most robust solution is to return both activations AND a mask from the activations store / cached activations and have the SAE training exclude masked activations from loss calculation. In the mask we could define an int value for each type of special character (e.g. 1 = BOS, 2 = SEP, 3 = EOS) and then the config can specify which of these to exclude from training / eval. Sometimes tokenizers have the same token for multiple purposes (e.g. BOS = EOS = SEP) - we could have a scheme for this where, e.g. 12 = BOS + SEP, 123 = BOS + EOS + SEP, 13 = BOS + EOS, etc... Maybe I'm overthinking this, but I don't see an easy way to handle this without keep track of a token mask of some sort.
Proposal
Ensure that the
activation_store.from_sae
method creates a store that will remove BOS activations before adding them to the buffer. Setting default seqpos slice to ignore the first token on all SAEs seems like a reasonable way to do this.Motivation
Some SAEs aren't trained on BOS tokens (as is the case with GemmaScope) and it's probably the case that this should be the default in SAE training (and may increase performance)
Checklist