TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.51k stars 293 forks source link

[Bug Report] `load_and_process_state_dict` handles LayerNorm folding poorly #219

Open afspies opened 1 year ago

afspies commented 1 year ago

Describe the bug If one attempts to load the state_dict of a model which was saved without folded LayerNorms (i.e. without LayerNormPre) calling load_and_process_state_dict(state_dict, fold_ln=True) fails due to the strict use of load_state_dict. This can be circumvented by instead doing:

model = HookedTransformer(model_cfg)
model.load_and_process_state_dict(state_dict, fold_ln=False)
model.process_weights_(fold_ln=True)
model.setup()

Without calling model.setup() the LayerNorm hooks remain inside the model, but are not properly attached and thus suitable activations are not returned when doing run_with_cache, causing issues in ActivationCache manipulation helpers.

Additionally, if the original model was saved with folded layernorms, calling load_and_process_state_dict(state_dict, fold_ln=True) raises an error as no layernorm parameters and located in the state dict.

jbloomAus commented 1 year ago

@afspies it feels like the issue here is one of not having a nicer error/clearer behavior. Would it feel like a solution to you if when you call load_and_process_state_dict(state_dict, fold_ln=True) you get a warning + automatic fix or would you prefer an informative error?

I'd lean towards an informative error personally, but curious what you think would have been best.

neelnanda-io commented 1 year ago

Ah, so the intended behaviour here was that if model_cfg says LayerNormPre everything works as intended - you're loading an unfolded state dict into a model that expected things to be folded. If model_cfg says LayerNorm then things break. Does this match what you see?

+1 to informative errors being good

On Sat, 15 Apr 2023, 10:30 am Joseph Bloom, @.***> wrote:

@afspies https://github.com/afspies it feels like the issue here is one of not having a nicer error/clearer behavior. Would it feel like a solution to you if when you call load_and_process_state_dict(state_dict, fold_ln=True) you get a warning + automatic fix or would you prefer an informative error?

I'd lean towards an informative error personally, but curious what you think would have been best.

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/219#issuecomment-1509698676, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKNEXSF72PFAPKNHNS3XBJTETANCNFSM6AAAAAAWEOLR4E . You are receiving this because you are subscribed to this thread.Message ID: @.***>

afspies commented 1 year ago

I would have expected the behaviour of fold_ln=True to be that the result is a HookedTransformer with properly folded layernorms, regardless of whether the state_dict being loaded already had these folded (perhaps with a suitable warning). An informative error may be sufficient (perhaps we may want to warn the user that separately processing weights without calling .setup() leads to broken hooks though)

Put another way, the following example illustrates both potential issues:

model_base = HookedTransformer.from_pretrained('NeelNanda/GELU_1L512W_C4_Code', fold_ln=True)
state_dict = model_base.state_dict()

model_two = HookedTransformer(model_base.cfg)
model_two.load_and_process_state_dict(deepcopy(state_dict), fold_ln=True)

If the first model is loaded with folded layernorms, then the second model can't be loaded with folded layernorms (we try to fold layernorms again, but don't find the suitable entries in the state_dict, so folding crashes).

If the first model is loaded without folded layernorms, then the second model can't be loaded with folded layernorms (the second model gets initialized from the config without layernorms folded, but this causes a key mismatch when the layernorm folding occurs in the load_and_preprocess_state_dict function, unless the steps shown in the original issue description are taken [only an issue because we enforce strict loading from the parameter dictionary on torch's end])

alan-cooney commented 1 year ago

If anyone wants to add a PR to improve the error message here that would be great