TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.17k stars 241 forks source link

Added support for Gemma-2 #650

Closed neelnanda-io closed 8 hours ago

neelnanda-io commented 6 days ago

Added support for Gemma-2 models.

Key differences between Gemma-1 and Gemma-2:

Type of change

Please delete options that are not relevant.

Checklist:

neelnanda-io commented 6 days ago

Looks like abstract_attention.py assumed that n_heads * d_head == d_model, which is normally true but not true for Gemma-2 27B, I fixed that in the latest commit (this should do nothing for any model where that relation is true)

neelnanda-io commented 6 days ago

Hmm, there's at least a 0.2 difference in logits for the 27B in float32, which is concerning... And quite surprising tbh, I don't see any architectural difference between the 9B and 27B, which suggests this would only be cascading errors? I did get the HF logits on CPU and TL logits across 2 GPUs though, which might cause some additional divergence?

ArthurConmy commented 6 days ago

Hmm, there's at least a 0.2 difference in logits for the 27B in float32, which is concerning... And quite surprising tbh, I don't see any architectural difference between the 9B and 27B, which suggests this would only be cascading errors? I did get the HF logits on CPU and TL logits across 2 GPUs though, which might cause some additional divergence?

I think you should move on and put a loud warning when loading 27B model. We have lots of evidence that TransformerLens' slight differences from HuggingFace have cascading errors since we see much worse numerical errors in models with lots of layers e.g. here and here. @bryce13950 has mentioned trying to improve the numerical errors, so I doubt you made the error in implementation. The only question to me is whether the 27B model should even be merged.

neelnanda-io commented 5 days ago

Fair. I think 27B should be merged, but printing a warning sounds good. I was mostly surprised at such a big jump between 9B and 27B of the error, when even cascading errors shouldn't explain that IMO - it's like 42 vs 46 layers, just a fair bit wider and with about twice as many neurons per MLP layer

bryce13950 commented 4 days ago

I am going to try and fold in some accuracy improvements, specifically in MLPs, into this when I put it up. I wouldn’t worry about adding warnings or anything for the time being. I have a list of models to try once that is done. This is second on the list now.

bryce13950 commented 4 days ago

We crossed paths a little bit with this. I have done quite a bit with MLPs recently, and we did a few things very similar. I have been thinking about how to put the two changes together, and I think I am going to wrap up my branch first, given that it affects accuracy with existing models. I am basically redoing the entire set of components, and I would bet that what I am doing is going to increase accuracy here.

bryce13950 commented 4 days ago

Just one comment from me on the code so far. I am going to wrap up my work, and come back to this afterwards to test it, and load up the code locally to play around a bit.

JThh commented 1 day ago

Hey @neelnanda-io thanks for making this pr to support Gemma2 series model. I am aware that this pr has not been ready, but I use what it has and am training a SAE on the resid_post position. The run is being recorded here. It might also be good auxiliary reference for further code adjustments to this pr.

neelnanda-io commented 1 day ago

Cool! Let me know if you run into any issues

On Thu, 4 Jul 2024, 2:53 pm Jiatong (Julius) Han, @.***> wrote:

Hey @neelnanda-io https://github.com/neelnanda-io thanks for making this pr to support Gemma2 series model. I am aware that this pr has not been ready, but I use what it has and am training a SAE on the resid_post position. The run is being recorded here https://wandb.ai/jiatongg/sae_semantic_entropy/runs/e88i5gcc?nw=nwuserjiatongg. It might also be good auxiliary reference for further code adjustments to this pr.

— Reply to this email directly, view it on GitHub https://github.com/TransformerLensOrg/TransformerLens/pull/650#issuecomment-2209056337, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKL7RALWZGEOVUCVN3LZKVHW3AVCNFSM6AAAAABKDSOZC6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMBZGA2TMMZTG4 . You are receiving this because you were mentioned.Message ID: @.***>

bryce13950 commented 8 hours ago

MLP outputs are now perfect on these models

bryce13950 commented 8 hours ago
image