whitebox-research / excursions

1 stars 1 forks source link

How do you train an SAE and use it to get a neuron out of superposition? #5

Open MostDeadDeveloper opened 2 months ago

zrkrlc commented 2 months ago

What's the quickest way to answer this question?

jddantes commented 1 month ago

Additional reading with sample code for "standard" SAE implementation: https://www.lesswrong.com/posts/CJPqwXoFtgkKPRay8/an-intuitive-explanation-of-sparse-autoencoders-for#fnref-XbKGsdnWC7Q7h6MzC-3

Basically:

Code snippet:

import torch
import torch.nn as nn

# D = d_model, F = dictionary_size
# e.g. if d_model = 12288 and dictionary_size = 49152
# then model_activations_D.shape = (12288,)
# encoder_DF.weight.shape = (12288, 49152)

class SparseAutoEncoder(nn.Module):
    """
    A one-layer autoencoder.
    """
    def __init__(self, activation_dim: int, dict_size: int):
        super().__init__()
        self.activation_dim = activation_dim
        self.dict_size = dict_size

        self.encoder_DF = nn.Linear(activation_dim, dict_size, bias=True)
        self.decoder_FD = nn.Linear(dict_size, activation_dim, bias=True)

    def encode(self, model_activations_D: torch.Tensor) -> torch.Tensor:
        return nn.ReLU()(self.encoder_DF(model_activations_D))

    def decode(self, encoded_representation_F: torch.Tensor) -> torch.Tensor:
        return self.decoder_FD(encoded_representation_F)

    def forward_pass(self, model_activations_D: torch.Tensor):
        encoded_representation_F = self.encode(model_activations_D)
        reconstructed_model_activations_D = self.decode(encoded_representation_F)
        return reconstructed_model_activations_D, encoded_representation_F

Sample loss:

# B = batch size, D = d_model, F = dictionary_size

def calculate_loss(autoencoder: SparseAutoEncoder, model_activations_BD: torch.Tensor, l1_coeffient: float) -> torch.Tensor:
    reconstructed_model_activations_BD, encoded_representation_BF = autoencoder.forward_pass(model_activations_BD)
    reconstruction_error_BD = (reconstructed_model_activations_BD - model_activations_BD).pow(2)
    reconstruction_error_B = einops.reduce(reconstruction_error_BD, 'B D -> B', 'sum')
    l2_loss = reconstruction_error_B.mean()

    l1_loss = l1_coefficient * encoded_representation_BF.sum()
    loss = l2_loss + l1_loss
    return loss
jddantes commented 3 weeks ago

Training example from SAELens: https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb

Some very rough/eyeball runs on Colab (esp for T4/L4):

Using a pre-trained SAE & analysing a specific feature (and using it to steer LLM responses) can be done here: https://github.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb

Moving this excursion to done