Open MostDeadDeveloper opened 2 months ago
Additional reading with sample code for "standard" SAE implementation: https://www.lesswrong.com/posts/CJPqwXoFtgkKPRay8/an-intuitive-explanation-of-sparse-autoencoders-for#fnref-XbKGsdnWC7Q7h6MzC-3
Basically:
Code snippet:
import torch
import torch.nn as nn
# D = d_model, F = dictionary_size
# e.g. if d_model = 12288 and dictionary_size = 49152
# then model_activations_D.shape = (12288,)
# encoder_DF.weight.shape = (12288, 49152)
class SparseAutoEncoder(nn.Module):
"""
A one-layer autoencoder.
"""
def __init__(self, activation_dim: int, dict_size: int):
super().__init__()
self.activation_dim = activation_dim
self.dict_size = dict_size
self.encoder_DF = nn.Linear(activation_dim, dict_size, bias=True)
self.decoder_FD = nn.Linear(dict_size, activation_dim, bias=True)
def encode(self, model_activations_D: torch.Tensor) -> torch.Tensor:
return nn.ReLU()(self.encoder_DF(model_activations_D))
def decode(self, encoded_representation_F: torch.Tensor) -> torch.Tensor:
return self.decoder_FD(encoded_representation_F)
def forward_pass(self, model_activations_D: torch.Tensor):
encoded_representation_F = self.encode(model_activations_D)
reconstructed_model_activations_D = self.decode(encoded_representation_F)
return reconstructed_model_activations_D, encoded_representation_F
Sample loss:
# B = batch size, D = d_model, F = dictionary_size
def calculate_loss(autoencoder: SparseAutoEncoder, model_activations_BD: torch.Tensor, l1_coeffient: float) -> torch.Tensor:
reconstructed_model_activations_BD, encoded_representation_BF = autoencoder.forward_pass(model_activations_BD)
reconstruction_error_BD = (reconstructed_model_activations_BD - model_activations_BD).pow(2)
reconstruction_error_B = einops.reduce(reconstruction_error_BD, 'B D -> B', 'sum')
l2_loss = reconstruction_error_B.mean()
l1_loss = l1_coefficient * encoded_representation_BF.sum()
loss = l2_loss + l1_loss
return loss
Training example from SAELens: https://github.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb
Some very rough/eyeball runs on Colab (esp for T4/L4):
Using a pre-trained SAE & analysing a specific feature (and using it to steer LLM responses) can be done here: https://github.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb
Moving this excursion to done
What's the quickest way to answer this question?