jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
481 stars 127 forks source link

[Bug Report] fold_activation_norm_scaling_factor missing division of decoder bias #354

Closed tuomaso closed 3 weeks ago

tuomaso commented 3 weeks ago

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug Folding activation scaling factor at the end of SAE training is implemented incorrectly and does not give the same reconstruction ability when loaded after training as it does during training. Link to code. This is caused by the fact that it is not normalizing the decoder bias.

Fix This can be fixed by adding the following line to the folding code:

def fold_activation_norm_scaling_factor(
        self, activation_norm_scaling_factor: float
    ):
        self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor
        # previously weren't doing this.
        self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor
        #This is my new line
        self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor
        # once we normalize, we shouldn't need to scale activations.
        self.cfg.normalize_activations = "none"

Additional context I have checked the math and this is the right way to do the folding, whether we're subtracting decoder bias from the input or not. Doing this division post-hoc also fixed my issues with loading trained SAEs.

Checklist