Open chanind opened 4 weeks ago
Here is a temporary fix that worked for me, if someone wants to use before the PR is made.
import torch
def run_with_saes_filtered(tokens, filtered_ids, model, saes):
# Ensure tokens are a torch.Tensor
if not isinstance(tokens, torch.Tensor):
tokens = torch.tensor(tokens, dtype=torch.long)
# Create a mask where True indicates positions to modify
mask = torch.ones_like(tokens, dtype=torch.bool)
for token_id in filtered_ids:
mask &= tokens != token_id
# For each SAE, add the appropriate hook
for sae in saes:
hook_point = sae.cfg.hook_name
# Define the modified hook function
def filtered_hook(act, hook, sae=sae, mask=mask):
# act shape: [batch_size, seq_len, hidden_size]
# Expand mask to match the shape of act
mask_expanded = mask.unsqueeze(-1).expand_as(act)
# Apply sae only to positions where mask is True
act = torch.where(mask_expanded, sae(act), act)
return act
# Add the hook to the model
model.add_hook(hook_point, filtered_hook, dir='fwd')
# Run the model with the tokens
logits = model(tokens)
# Reset the hooks after computation
model.reset_hooks()
return logits
filtered_ids = [
model.tokenizer.bos_token_id,
model.tokenizer.eos_token_id,
model.tokenizer.pad_token_id
]
logits = run_with_saes_filtered(tokens, filtered_ids, model, [sae])
Proposal
We should add an option to exclude special tokens when adding a SAE into
HookedSAETransformer
. This could take the form of anexclude_special_tokens
param foradd_sae()
/run_with_cache_with_saes()
/run_with_saes()
. This would exclude running the SAE on BOS, EOS, and SEP tokens as specified by the model tokenizer. The user could passTrue
to avoid these standard tokens, or pass a list (or tensor) oftoken_id
values to exclude to further customize this behavior.Motivation
It's often not useful to apply a SAE on special tokens since the SAEs are often not trained on special tokens, and it's not particularly interesting to see SAE latents that fire on BOS. Given this is a common use-case, we should make it easy to just skip special tokens when running with a SAE using
HookedSAETransformer
as this class only exists to make common use-cases for SAEs easy.Alternatives
We could alternatively allow users to specify certain token indices to avoid running the SAE on instead of token ids. This would require more work for the users but may support other use-cases where the user doesn't want to apply the SAE at certain positions. This could also be implemented separately / in-addition to adding an
exclude_special_tokens
param.Checklist