jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
481 stars 127 forks source link

[Bug Report] fold_activation_norm_scaling_factor should be performed before saving all checkpoints, not only the final checkpoint #381

Open chanind opened 2 days ago

chanind commented 2 days ago

If using normalize_activations = expected_average_only_in, we scale activations by a scaling factor being training the SAE. Without calling fold_activation_norm_scaling_factor(), the SAE will not work properly with unscaled activations. We do not save the scaling factor along with the SAE, so there's no way to use an SAE that's been saved without the activation scaling also being saved or folded in.

SAELens trainer calls fold_activation_norm_scaling_factor() before saving the final checkpoint, but this is not called before intermediate checkpoints. This means the intermediate checkpoints are not usable SAEs.

We should either call fold_activation_norm_scaling_factor() before saving a checkpoint, or include the norm scaling factor somehow when saving the checkpoint. A complication with the fold_activation_norm_scaling_factor() approach is that we need to do it only on the version of the SAE weights being saved, not on the in-progress SAE being trained.

Checklist