Right now get_final_layer_norm returns None if its heuristics can't find a final normalization module, and TunedLens allows for the layer_norm property to be nn.Identity. I think we should be more opinionated and require the user to supply an explicit key path (e.g. model.final_ln or whatever) when we can't find the layer norm with heuristics. Right now, whenever a new model architecture is added to HF (e.g. LLaMA), we may not pick up on its final LN module and we end up silently training a tuned lens without the LN even though the LN does exist.
We're already committed to basically only supporting pre-norm transformers, since post-norm ones are less obviously iterative, and ~all pre-norm transformers have a norm module before the unembedding. So I think we don't even need to support the case where there is no final normalization.
Note: we should probably use more neutral language in the code and be agnostic about which type of normalization is used. SOTA models often don't use LayerNorm specifically; LLaMA uses RMSNorm for example. I think that's why our heuristics failed in this case.
Right now
get_final_layer_norm
returnsNone
if its heuristics can't find a final normalization module, andTunedLens
allows for thelayer_norm
property to benn.Identity
. I think we should be more opinionated and require the user to supply an explicit key path (e.g.model.final_ln
or whatever) when we can't find the layer norm with heuristics. Right now, whenever a new model architecture is added to HF (e.g. LLaMA), we may not pick up on its final LN module and we end up silently training a tuned lens without the LN even though the LN does exist.We're already committed to basically only supporting pre-norm transformers, since post-norm ones are less obviously iterative, and ~all pre-norm transformers have a norm module before the unembedding. So I think we don't even need to support the case where there is no final normalization.
Note: we should probably use more neutral language in the code and be agnostic about which type of normalization is used. SOTA models often don't use LayerNorm specifically; LLaMA uses RMSNorm for example. I think that's why our heuristics failed in this case.