TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.44k stars 279 forks source link

[Proposal] Allow tied embeddings #671

Open neelnanda-io opened 2 months ago

neelnanda-io commented 2 months ago

Proposal

TransformerLens assumes all models have untied embeddings (ie W_U =/= W_E.T). This is good to assume in general, and needs to be true if LN is folded. But, it is more memory expensive.

This is particularly bad for Gemma models, which have tied embeddings and a very large vocab size, eg 25% of Gemma 2 2.6B's params is W_E, and 10% of Gemma 2 9B is W_E. I think it would be great to load the tied models by default with tied embeddings (so W_U.data = W_E.data.T), but a helper function to clone the matrix and make this untied if need be. This would involve adding a field for tied_embeddings to the Config which defaults to False, but can be set to True for select models like GPT-2 and Gemma and Gemma 2, but which gets set back to False if fold_layernorm is run.

I'd love people to be able to work with the Gemma 2 models with a bunch of SAEs in memory, so memory efficiency is important (and folding LayerNorm isn't that important)

bryce13950 commented 2 months ago

I did a quick little experiment in the specific architecture weight conversions to see if it was sufficient for tying the weights when needed https://github.com/TransformerLensOrg/TransformerLens/tree/experiment-gemma-weight-tying. This is something that needs to be tested though. I am not sure if what I did here is sufficient to solve the issue, and this is the sort of change that I am a bit weary about, since it can probably ripple out if done incorrectly. If you have time to mess with my branch, that would be super helpful. I am pretty full on time for the next couple weeks wrapping some other things up, but once I do have time I would be happy to experiment with this a bit. If it seems to work well, then I will probably revise the weight conversions to share a bit of code, so that these sorts of system wide changes can be made in a central place without too many issues.