jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
193 stars 67 forks source link

Adding Mistral SAEs #178

Closed JoshEngels closed 2 weeks ago

JoshEngels commented 2 weeks ago

Description

This is a brief PR that adds support for loading the Mistral 7B SAEs we trained for https://arxiv.org/abs/2405.14860. As an important note, we normalized the activations before the SAE forward pass and unnormalized them afterwards like the Anthropic February update. I tried adding this to the library but it was messy because encode and decode are sometimes called individually during evaluation, so we can't just normalize and un-normalize in the forward pass, so for now I was just thinking it would be on the user to do correctly. Let me know if there is a better way I should do this!

I also evaluated these SAEs with code here: https://github.com/JoshEngels/MultiDimensionalFeatures/blob/main/sae_multid_feature_discovery/upload_to_huggingface.py Happy to upload this code to the tutorial repo as well since it's useful to see how to do the normalization, but didn't want to make a big mess in this PR.

SAE statistics on first few hundred documents of the pile from this evaluation:

Layer Variance Explained L0 % CE Loss Recovered
8 74% 81.8 99.6%
16 85% 73.7 99.2%
24 72.13% 75.2 98.1%

The % CE loss recovered is pretty high even though the variance explained is low; I'm a bit skeptical of this, but I don't see a bug in my evaluation code (it's mostly copied from the existing metrics), and I've also seen similar things when I trained SAEs before.

Type of change

Checklist:

You have tested formatting, typing and unit tests (acceptance tests not currently in use)

If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:

Please links to wandb dashboards with a control and test group.

nix-apollo commented 2 weeks ago

I think (someone correct me if I'm wrong) that SAE config has a normalize_activations flag. This doesn't change the forward pass of the SAE itself, but it is checked by e.g. the evaluation code to know how to preprocess the activations.

You'd have to make that change in the huggingface repo.

jbloomAus commented 2 weeks ago

This is a different normalization procedure than the one Anthropic suggests in the April update which normalizes all activations by some constant for the whole dataset. The procedure in this PR let's re-estimates the normalization factor per example. With the procedure we currently have, you can fold the normalization constant into the SAE encoder post training but with this method you can't. On the other hand, with the procedure you've used, we don't need to worry about estimating the constant all the time. I probably want to support both but we should have very clear flags and a condition that says you can't do both at once.

Not sure what to call each. Does the difference make sense to you @JoshEngels ?

codecov[bot] commented 2 weeks ago

Codecov Report

Attention: Patch coverage is 66.66667% with 15 lines in your changes missing coverage. Please review.

Project coverage is 59.59%. Comparing base (4d92975) to head (4049496). Report is 2 commits behind head on main.

Files Patch % Lines
sae_lens/evals.py 0.00% 2 Missing and 5 partials :warning:
sae_lens/toolkit/pretrained_sae_loaders.py 28.57% 4 Missing and 1 partial :warning:
sae_lens/config.py 50.00% 1 Missing and 1 partial :warning:
sae_lens/training/sae_trainer.py 0.00% 0 Missing and 1 partial :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #178 +/- ## ========================================== + Coverage 59.25% 59.59% +0.34% ========================================== Files 25 25 Lines 2604 2636 +32 Branches 440 445 +5 ========================================== + Hits 1543 1571 +28 - Misses 984 987 +3 - Partials 77 78 +1 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

ianand commented 2 weeks ago

@JoshEngels @jbloomAus I noticed that in this PR the tutorial using_an_sae_as_a_steering_vector.ipynb has changed from using gpt2-small to gemma-2b as the model. Any particular reason? Should I expect it's fine to continue using with gpt2-small?

jbloomAus commented 2 weeks ago

That was a mistake! I'll fix it.