jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
481 stars 127 forks source link

[Proposal] Allow excluding special tokens when applying SAEs in HookedSAETransformer #350

Open chanind opened 4 weeks ago

chanind commented 4 weeks ago

Proposal

We should add an option to exclude special tokens when adding a SAE into HookedSAETransformer. This could take the form of an exclude_special_tokens param for add_sae() / run_with_cache_with_saes() / run_with_saes(). This would exclude running the SAE on BOS, EOS, and SEP tokens as specified by the model tokenizer. The user could pass True to avoid these standard tokens, or pass a list (or tensor) of token_id values to exclude to further customize this behavior.

Motivation

It's often not useful to apply a SAE on special tokens since the SAEs are often not trained on special tokens, and it's not particularly interesting to see SAE latents that fire on BOS. Given this is a common use-case, we should make it easy to just skip special tokens when running with a SAE using HookedSAETransformer as this class only exists to make common use-cases for SAEs easy.

Alternatives

We could alternatively allow users to specify certain token indices to avoid running the SAE on instead of token ids. This would require more work for the users but may support other use-cases where the user doesn't want to apply the SAE at certain positions. This could also be implemented separately / in-addition to adding an exclude_special_tokens param.

Checklist

NainaniJatinZ commented 4 weeks ago

Here is a temporary fix that worked for me, if someone wants to use before the PR is made.

import torch

def run_with_saes_filtered(tokens, filtered_ids, model, saes):
    # Ensure tokens are a torch.Tensor
    if not isinstance(tokens, torch.Tensor):
        tokens = torch.tensor(tokens, dtype=torch.long)

    # Create a mask where True indicates positions to modify
    mask = torch.ones_like(tokens, dtype=torch.bool)
    for token_id in filtered_ids:
        mask &= tokens != token_id

    # For each SAE, add the appropriate hook
    for sae in saes:
        hook_point = sae.cfg.hook_name

        # Define the modified hook function
        def filtered_hook(act, hook, sae=sae, mask=mask):
            # act shape: [batch_size, seq_len, hidden_size]
            # Expand mask to match the shape of act
            mask_expanded = mask.unsqueeze(-1).expand_as(act)
            # Apply sae only to positions where mask is True
            act = torch.where(mask_expanded, sae(act), act)
            return act

        # Add the hook to the model
        model.add_hook(hook_point, filtered_hook, dir='fwd')

    # Run the model with the tokens
    logits = model(tokens)

    # Reset the hooks after computation
    model.reset_hooks()
    return logits

filtered_ids = [
    model.tokenizer.bos_token_id,
    model.tokenizer.eos_token_id,
    model.tokenizer.pad_token_id
]

logits = run_with_saes_filtered(tokens, filtered_ids, model, [sae])