jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
412 stars 111 forks source link

[Bug Report] scaling_factor broke sae_vis compatibility #99

Closed adamkarvonen closed 5 months ago

adamkarvonen commented 5 months ago

Describe the bug I'm using this notebook on an SAE I created: basic_loading_and_analysing.ipynb. I get an error that appears to be because the scaling_factor was added to the SparseAutoencoder class, which sae_vis is not expecting.

Code example

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[21], line 17
     14 print(type(sparse_autoencoder))
     15 print(sparse_autoencoder.state_dict().keys())
---> 17 sae_vis_data_gpt = SaeVisData.create(
     18     encoder=sparse_autoencoder,
     19     model=model,
     20     tokens=all_tokens,  # type: ignore
     21     cfg=feature_vis_config_gpt,
     22 )

File /opt/conda/lib/python3.10/site-packages/sae_vis/data_storing_fns.py:1017, in SaeVisData.create(cls, encoder, model, tokens, cfg, encoder_B)
   1014 # If encoder isn't an AutoEncoder, we wrap it in one
   1015 if not isinstance(encoder, AutoEncoder):
   1016     assert (
-> 1017         set(encoder.state_dict().keys()) == {"W_enc", "W_dec", "b_enc", "b_dec"}
   1018     ), "If encoder isn't an AutoEncoder, it should have weights 'W_enc', 'W_dec', 'b_enc', 'b_dec'"
   1019     d_in, d_hidden = encoder.W_enc.shape
   1020     device = encoder.W_enc.device

AssertionError: If encoder isn't an AutoEncoder, it should have weights 'W_enc', 'W_dec', 'b_enc', 'b_dec'

When running this cell:

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

test_feature_idx_gpt = list(range(10)) + [14057]

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_point,
    features=test_feature_idx_gpt,
    batch_size=2048,
    minibatch_size_tokens=128,
    verbose=True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder=sparse_autoencoder,
    model=model,
    tokens=all_tokens,  # type: ignore
    cfg=feature_vis_config_gpt,
)

This print statement shows a scaling_factor key that isn't in the above assert:

print(type(sparse_autoencoder))
print(sparse_autoencoder.state_dict().keys())

<class 'sae_lens.training.sparse_autoencoder.SparseAutoencoder'>
odict_keys(['W_enc', 'b_enc', 'W_dec', 'b_dec', 'scaling_factor'])

System Info Describe the characteristic of your environment:

Checklist

jbloomAus commented 5 months ago

Hey Adam,

Thanks for raising this. We're moving pretty quickly so it's easy for things to get out of sync. We did catch this and will be implementing integration tests between SAE Vis and SAE Lens shortly.

If you update sae_vis this should be fixed. (https://github.com/callummcdougall/sae_vis/blob/d759ef0237089e72cc9cad7edc4eceb4e8cfdd00/sae_vis/data_storing_fns.py#L1026)

A bit more detail: SAE Vis currently makes strong assumptions about the forward pass of the Autoencoder which won't in general be true (eg: pre-encoder subtraction of the decoder bias, relu activation etc). So I think we'll need to find a solution for this (likely creating some spec which allows SAE Vis to black box the Autoencoder).