If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug
Folding activation scaling factor at the end of SAE training is implemented incorrectly and does not give the same reconstruction ability when loaded after training as it does during training.
Link to code. This is caused by the fact that it is not normalizing the decoder bias.
Fix
This can be fixed by adding the following line to the folding code:
def fold_activation_norm_scaling_factor(
self, activation_norm_scaling_factor: float
):
self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor
# previously weren't doing this.
self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor
#This is my new line
self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor
# once we normalize, we shouldn't need to scale activations.
self.cfg.normalize_activations = "none"
Additional context
I have checked the math and this is the right way to do the folding, whether we're subtracting decoder bias from the input or not. Doing this division post-hoc also fixed my issues with loading trained SAEs.
Checklist
[x ] I have checked that there is no similar issue in the repo (required)
If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug Folding activation scaling factor at the end of SAE training is implemented incorrectly and does not give the same reconstruction ability when loaded after training as it does during training. Link to code. This is caused by the fact that it is not normalizing the decoder bias.
Fix This can be fixed by adding the following line to the folding code:
Additional context I have checked the math and this is the right way to do the folding, whether we're subtracting decoder bias from the input or not. Doing this division post-hoc also fixed my issues with loading trained SAEs.
Checklist