Closed 4gatepylon closed 4 months ago
This appears to fix it but I have not trained for many epochs.
Happy to PR this but I'm also not 100% sure if I'm just misusing some new release or something and if not whether there is a cleaner/better way to fix this with a more top-down refactor. I don't think tacking on a band-aid like this is the right approach. There should be a standard interface for all SAEs.
Hotfix idea below:
evals.py
# HOTFIX
def shape_safe_enc_dec(activations: torch.Tensor, sae: SAE) -> torch.Tensor:
"""
Enable the usage of shape agnostic and non-shape agnostic SAEs by handling shape transformations on their behalf.
"""
_in_acts = activations
_in_shape = activations.shape
if _in_shape[-1] != sae.cfg.d_in:
assert len(_in_acts.shape) >= 3
_in_acts = activations.flatten(-2, -1)
_out_acts = sae.decode(sae.encode(_in_acts)).to(activations.dtype)
if _out_acts.shape != _in_shape:
# TODO(Someone) this should have some asserts over the shape going back
# in the right way
_out_acts = _out_acts.reshape(_in_shape)
assert _out_acts.shape == _in_shape
return _out_acts
if activation_store.normalize_activations:
original_act = activation_store.apply_norm_scaling_factor(original_act)
# send the (maybe normalised) activations into the SAE
# HOTFIX
sae_out = shape_safe_enc_dec(original_act, sae) # HOTFIX
del cache
l2_norm_in = torch.norm(original_act, dim=-1)
l2_norm_out = torch.norm(sae_out, dim=-1)
l2_norm_in_for_div = l2_norm_in.clone()
l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1
l2_norm_ratio = l2_norm_out / l2_norm_in_for_div
metrics = {
...
}
return metrics
def standard_replacement_hook(activations: torch.Tensor, hook: Any):
# Handle rescaling if SAE expects it
if activation_store.normalize_activations:
activations = activation_store.apply_norm_scaling_factor(activations)
# SAE class agnost forward forward pass.
activations = shape_safe_enc_dec(activations, sae).to(activations.dtype) # HOTFIX
# Unscale if activations were scaled prior to going into the SAE
if activation_store.normalize_activations:
activations = activation_store.unscale(activations)
return activations
Ahhh I know what's happened. Thanks for the detailed report. 3.0 involved special casing hook z for analysis so reshaping by heads is handled by the SAE class and not the store, however the activation store used in training handles the reshaping as well so that buffers don't have variable numbers of dimensions. This breaks run_evals in a way not covered by testing and which you only see when it runs part way into a training run.
I will do two - three things here:
@4gatepylon Let me know if you still get any issues. I didn't run a training run but I added new tests and I expect it will work now.
Description
Also, in case anyone reads this, I keep getting this issue that is really annoying to track down. When I switch from hook point
hook_mlp_out
toattn.hook_z
after a fixed number of training epochs (or whatever you'd call these; for me its 9 with batch size 128 as you'll see below). It fails with the error (first error, then config on the bottom as a JSON which is passed as kwargs)Has anyone run into this error and found a fix? What is going on?
Error
Configuration
Code that Passes in Configuration
Solution (Not Acceptable)
Turn off wandb logging or comment out the some of the code in
site-packages
I can also just decode to never use z values and use MLP out. This works:
More Notes
n_ctx x 1 x d_model
(which gets indexed on that 2nd index) whereas when you eval the size isn_batch x n_ctx x n_heads x d_head
wheren_heads x d_head = d_model
=> Maybe things are not getting properly reshaped? You can look at the difference betweenself.activation_store.next_batch()
insae_trainer.py
andeval_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
inevals.py
. Note that the latter is inputting into the MODEL and flowing through (latter) whereas the former was somehow getting from an activation cache (I have not yet fully understood the internal mechanism used by this.I wonder if a disagreement between the caller and this in
training_sae.py
is responsible. Look:THIS (
training_sae.py
, constructor)VERSUS (
evals.py
,get_recons_loss
)AND
The printout is
ACTIVATIOSN IN STANDARD REPLACEMENT HOOK torch.Size([16, 128, 16, 64])
(with the typo and all).Note all these print probes I added in
venv
site packages, so you can ignore as part of SAE Lens.System Info
g5.xlarge
instanceami-05a1fdc12eccdf405
)Checklist