jbloomAus / SAELens

Training Sparse Autoencoders on Language Models
https://jbloomaus.github.io/SAELens/
MIT License
386 stars 106 forks source link

[Bug Report] Cannot train on z: interfaces disagree on shape? #164

Closed 4gatepylon closed 4 months ago

4gatepylon commented 4 months ago

Description

Also, in case anyone reads this, I keep getting this issue that is really annoying to track down. When I switch from hook point hook_mlp_out to attn.hook_zafter a fixed number of training epochs (or whatever you'd call these; for me its 9 with batch size 128 as you'll see below). It fails with the error (first error, then config on the bottom as a JSON which is passed as kwargs)

Has anyone run into this error and found a fix? What is going on?

Error

Traceback (most recent call last):
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/saes.py", line 184, in train_saes
    sparse_autoencoder: TrainingSAE = sae_training_runner.run()
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/sae_training_runner.py", line 81, in run
    sae = self.run_trainer_with_interruption_handling(trainer)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/sae_training_runner.py", line 124, in run_trainer_with_interruption_handling
    sae = trainer.fit()
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/training/sae_trainer.py", line 172, in fit
    self._run_and_log_evals()
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/training/sae_trainer.py", line 321, in _run_and_log_evals
    eval_metrics = run_evals(
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/evals.py", line 28, in run_evals
    losses_df = recons_loss_batched(
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/evals.py", line 98, in recons_loss_batched
    score, loss, recons_loss, zero_abl_loss = get_recons_loss(
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/evals.py", line 196, in get_recons_loss
    recons_loss = model.run_with_hooks(
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/transformer_lens/hook_points.py", line 358, in run_with_hooks
    return hooked_model.forward(*model_args, **model_kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 550, in forward
    residual = block(
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 1585, in forward
    self.attn(
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 587, in forward
    z = self.calculate_z_scores(v, pattern)  # [batch, pos, head_index, d_head]
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/transformer_lens/components.py", line 764, in calculate_z_scores
    z = self.hook_z(
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1595, in _call_impl
    hook_result = hook(self, args, result)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/transformer_lens/hook_points.py", line 77, in full_hook
    return hook(module_output, hook=self)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/evals.py", line 140, in standard_replacement_hook
    activations = sae.decode(sae.encode(activations)).to(activations.dtype)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/training/training_sae.py", line 156, in encode
    feature_acts, _ = self.encode_with_hidden_pre(x)
  File "/home/ubuntu/git/IfYouDontUnderstandItDontUseIt/src/.venv/lib/python3.10/site-packages/sae_lens/training/training_sae.py", line 176, in encode_with_hidden_pre
    sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))
RuntimeError: The size of tensor a (64) must match the size of tensor b (1024) at non-singleton dimension 3

Configuration

{
    "model_name": "tiny-stories-1L-21M",
    "model_class_name": "HookedTransformer",
    "hook_name": "blocks.0.attn.hook_z",
    "hook_eval": "NOT_IN_USE",
    "hook_layer": 0,
    "hook_head_index": null,
    "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
    "streaming": true,
    "is_dataset_tokenized": true,
    "context_size": 256,
    "use_cached_activations": false,
    "cached_activations_path": null,
    "d_in": 1024,
    "d_sae": 16384,
    "b_dec_init_method": "geometric_median",
    "expansion_factor": 16,
    "activation_fn": "relu",
    "normalize_sae_decoder": true,
    "noise_scale": 0.0,
    "from_pretrained_path": null,
    "apply_b_dec_to_input": true,
    "decoder_orthogonal_init": false,
    "decoder_heuristic_init": false,
    "init_encoder_as_decoder_transpose": false,
    "n_batches_in_buffer": 64,
    "training_tokens": 50000000,
    "finetuning_tokens": 0,
    "store_batch_size_prompts": 16,
    "train_batch_size_tokens": 128,
    "normalize_activations": false,
    "device": "cuda",
    "act_store_device": "cuda",
    "seed": 42,
    "dtype": "float32",
    "prepend_bos": true,
    "autocast": false,
    "autocast_lm": false,
    "compile_llm": false,
    "llm_compilation_mode": null,
    "compile_sae": false,
    "sae_compilation_mode": null,
    "adam_beta1": 0,
    "adam_beta2": 0.999,
    "mse_loss_normalization": null,
    "l1_coefficient": 0.001,
    "lp_norm": 1.0,
    "scale_sparsity_penalty_by_decoder_norm": false,
    "l1_warm_up_steps": 0,
    "lr": 0.0004,
    "lr_scheduler_name": "constant",
    "lr_warm_up_steps": 10000,
    "lr_end": 4e-05,
    "lr_decay_steps": 0,
    "n_restart_cycles": 1,
    "finetuning_method": null,
    "use_ghost_grads": false,
    "feature_sampling_window": 1000,
    "dead_feature_window": 1000,
    "dead_feature_threshold": 0.0001,
    "n_eval_batches": 10,
    "eval_batch_size_prompts": null,
    "log_to_wandb": true,
    "log_activations_store_to_wandb": false,
    "log_optimizer_state_to_wandb": false,
    "wandb_project": "sae_lens_tutorial",
    "wandb_id": null,
    "run_name": "16384-L1-0.001-LR-0.0004-Tokens-5.000e+07",
    "wandb_entity": null,
    "wandb_log_frequency": 10,
    "eval_every_n_wandb_logs": 100,
    "resume": false,
    "n_checkpoints": 0,
    "checkpoint_path": "my folder/my folder/my older/my folder/my folder/random string SAE lens appends apparently",
    "verbose": true,
    "model_kwargs": {},
    "model_from_pretrained_kwargs": {},
    "sae_lens_version": "3.0.0",
    "sae_lens_training_version": "3.0.0",
    "tokens_per_buffer": 2097152
}

Code that Passes in Configuration

cfg = LanguageModelSAERunnerConfig(
                **config_copy,
)
import json # XXX
print(json.dumps(cfg.__dict__, indent=4)) # XXX  <------ prints out what we said above
# NOTE this is used implicitely in `language_model_sae_runner`
layer_checkpoint_path = Path(cfg.checkpoint_path)
layer_checkpoint_path.mkdir(
    parents=True, exist_ok=False
) 
layer = cfg.hook_layer
assert isinstance(layer, int)
sae_training_runner = SAETrainingRunner(cfg)
sparse_autoencoder: TrainingSAE = sae_training_runner.run() # <------ FAIL

Solution (Not Acceptable)

Turn off wandb logging or comment out the some of the code in site-packages

I can also just decode to never use z values and use MLP out. This works:

{
    ...
    "hook_name": "blocks.0.hook_mlp_out",
    ...
}

More Notes

I wonder if a disagreement between the caller and this in training_sae.py is responsible. Look:

THIS (training_sae.py, constructor)

# The training SAE will assume that the activation store handles
# reshaping.
self.reshape_fn_in = lambda x: x

VERSUS (evals.py, get_recons_loss)

def standard_replacement_hook(activations: torch.Tensor, hook: Any):
    # Handle rescaling if SAE expects it
    if activation_store.normalize_activations:
        activations = activation_store.apply_norm_scaling_factor(activations)

    # SAE class agnost forward forward pass.
    print("="*100) # XXX <------------ DOES print
    print("ACTIVATIOSN IN STANDARD REPLACEMENT HOOK", activations.shape) # XXX <----- error source
    activations = sae.decode(sae.encode(activations)).to(activations.dtype)
    print("+"*100) # XXX <---- does NOT print

    # Unscale if activations were scaled prior to going into the SAE
    if activation_store.normalize_activations:
        activations = activation_store.unscale(activations)
    return activations

AND

# we would include hook z, except that we now have base SAE's
# which will do their own reshaping for hook z.
has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v"]
if any(substring in hook_name for substring in has_head_dim_key_substrings):
    if head_index is None:
        replacement_hook = all_head_replacement_hook
    else:
        replacement_hook = single_head_replacement_hook
else:
    print("STANDARD REPLACEMENT HOOK") # XXX  <----- this probe prints when I do my config
    replacement_hook = standard_replacement_hook

The printout is ACTIVATIOSN IN STANDARD REPLACEMENT HOOK torch.Size([16, 128, 16, 64]) (with the typo and all).

Note all these print probes I added in venv site packages, so you can ignore as part of SAE Lens.

System Info

Checklist

4gatepylon commented 4 months ago

Hotfix

This appears to fix it but I have not trained for many epochs.

Happy to PR this but I'm also not 100% sure if I'm just misusing some new release or something and if not whether there is a cleaner/better way to fix this with a more top-down refactor. I don't think tacking on a band-aid like this is the right approach. There should be a standard interface for all SAEs.

Hotfix idea below:

1. Add this to the top of evals.py

# HOTFIX
def shape_safe_enc_dec(activations: torch.Tensor, sae: SAE) -> torch.Tensor:
   """
   Enable the usage of shape agnostic and non-shape agnostic SAEs by handling shape transformations on their behalf.
   """
    _in_acts = activations
    _in_shape = activations.shape
    if _in_shape[-1] != sae.cfg.d_in:
        assert len(_in_acts.shape) >= 3
        _in_acts = activations.flatten(-2, -1)
    _out_acts = sae.decode(sae.encode(_in_acts)).to(activations.dtype)
    if _out_acts.shape != _in_shape:
        # TODO(Someone) this should have some asserts over the shape going back
        # in the right way
        _out_acts = _out_acts.reshape(_in_shape)
    assert _out_acts.shape == _in_shape
    return _out_acts

2. Replace two lines in the same file

2A: run_evals function should end like this:

if activation_store.normalize_activations:
    original_act = activation_store.apply_norm_scaling_factor(original_act)

# send the (maybe normalised) activations into the SAE
# HOTFIX
sae_out = shape_safe_enc_dec(original_act, sae) # HOTFIX
del cache

l2_norm_in = torch.norm(original_act, dim=-1)
l2_norm_out = torch.norm(sae_out, dim=-1)
l2_norm_in_for_div = l2_norm_in.clone()
l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1
l2_norm_ratio = l2_norm_out / l2_norm_in_for_div

metrics = {
   ...
}

return metrics

2B: Change the standard_replacement_hook value in get_recons_loss

def standard_replacement_hook(activations: torch.Tensor, hook: Any):
    # Handle rescaling if SAE expects it
    if activation_store.normalize_activations:
        activations = activation_store.apply_norm_scaling_factor(activations)

    # SAE class agnost forward forward pass.
    activations = shape_safe_enc_dec(activations, sae).to(activations.dtype) # HOTFIX

    # Unscale if activations were scaled prior to going into the SAE
    if activation_store.normalize_activations:
        activations = activation_store.unscale(activations)
    return activations
jbloomAus commented 4 months ago

Ahhh I know what's happened. Thanks for the detailed report. 3.0 involved special casing hook z for analysis so reshaping by heads is handled by the SAE class and not the store, however the activation store used in training handles the reshaping as well so that buffers don't have variable numbers of dimensions. This breaks run_evals in a way not covered by testing and which you only see when it runs part way into a training run.

I will do two - three things here:

  1. Fix this (not sure yet what I'll do, maybe one of the things you suggested).
  2. Add a test that would have caught this.
  3. (Maybe) Run evals once at the start of the script so if there's an issue you see it immediately.
jbloomAus commented 4 months ago

@4gatepylon Let me know if you still get any issues. I didn't run a training run but I added new tests and I expect it will work now.