AlignmentResearch / tuned-lens

Tools for understanding how transformer predictions are built layer-by-layer
https://tuned-lens.readthedocs.io/en/latest/
MIT License
432 stars 47 forks source link

get_final_layer_norm should throw if the LN can't be found #56

Closed norabelrose closed 1 year ago

norabelrose commented 1 year ago

Right now get_final_layer_norm returns None if its heuristics can't find a final normalization module, and TunedLens allows for the layer_norm property to be nn.Identity. I think we should be more opinionated and require the user to supply an explicit key path (e.g. model.final_ln or whatever) when we can't find the layer norm with heuristics. Right now, whenever a new model architecture is added to HF (e.g. LLaMA), we may not pick up on its final LN module and we end up silently training a tuned lens without the LN even though the LN does exist.

We're already committed to basically only supporting pre-norm transformers, since post-norm ones are less obviously iterative, and ~all pre-norm transformers have a norm module before the unembedding. So I think we don't even need to support the case where there is no final normalization.

Note: we should probably use more neutral language in the code and be agnostic about which type of normalization is used. SOTA models often don't use LayerNorm specifically; LLaMA uses RMSNorm for example. I think that's why our heuristics failed in this case.