Closed JoshEngels closed 2 weeks ago
I think (someone correct me if I'm wrong) that SAE config has a normalize_activations
flag. This doesn't change the forward pass of the SAE itself, but it is checked by e.g. the evaluation code to know how to preprocess the activations.
You'd have to make that change in the huggingface repo.
This is a different normalization procedure than the one Anthropic suggests in the April update which normalizes all activations by some constant for the whole dataset. The procedure in this PR let's re-estimates the normalization factor per example. With the procedure we currently have, you can fold the normalization constant into the SAE encoder post training but with this method you can't. On the other hand, with the procedure you've used, we don't need to worry about estimating the constant all the time. I probably want to support both but we should have very clear flags and a condition that says you can't do both at once.
Not sure what to call each. Does the difference make sense to you @JoshEngels ?
Attention: Patch coverage is 66.66667%
with 15 lines
in your changes missing coverage. Please review.
Project coverage is 59.59%. Comparing base (
4d92975
) to head (4049496
). Report is 2 commits behind head on main.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@JoshEngels @jbloomAus I noticed that in this PR the tutorial using_an_sae_as_a_steering_vector.ipynb has changed from using gpt2-small to gemma-2b as the model. Any particular reason? Should I expect it's fine to continue using with gpt2-small?
That was a mistake! I'll fix it.
Description
This is a brief PR that adds support for loading the Mistral 7B SAEs we trained for https://arxiv.org/abs/2405.14860. As an important note, we normalized the activations before the SAE forward pass and unnormalized them afterwards like the Anthropic February update. I tried adding this to the library but it was messy because encode and decode are sometimes called individually during evaluation, so we can't just normalize and un-normalize in the forward pass, so for now I was just thinking it would be on the user to do correctly. Let me know if there is a better way I should do this!
I also evaluated these SAEs with code here: https://github.com/JoshEngels/MultiDimensionalFeatures/blob/main/sae_multid_feature_discovery/upload_to_huggingface.py Happy to upload this code to the tutorial repo as well since it's useful to see how to do the normalization, but didn't want to make a big mess in this PR.
SAE statistics on first few hundred documents of the pile from this evaluation:
The % CE loss recovered is pretty high even though the variance explained is low; I'm a bit skeptical of this, but I don't see a bug in my evaluation code (it's mostly copied from the existing metrics), and I've also seen similar things when I trained SAEs before.
Type of change
Checklist:
You have tested formatting, typing and unit tests (acceptance tests not currently in use)
make check-ci
to check format and linting. (you can runmake format
to format code if needed.) (formatting passed but ran into some problems with the other tests, probably unrelated to my PR)Performance Check.
If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:
Please links to wandb dashboards with a control and test group.