TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.17k stars 241 forks source link

Match Huggingface MLP implementation exactly. #641

Closed joelburget closed 2 weeks ago

joelburget commented 2 weeks ago

Description

570 has shown that there are small differences in outputs between the Huggingface implementation of models, which uses a fused add-multiply (this presumably matches the way the models were trained as well) and TransformerLens, which doesn't. It seems likely that with bigger models (e.g. Mixtral), these errors accumulate enough to be significant. This is my attempt to begin to fix the problem.

Note that we take a slight readability hit, but I believe this is worth it for correctness.

For now I've only applied this to the implemenation in mlp.py. There are several other places the same fix should probably be applied.

Type of change

Bug fix (non-breaking change which fixes an issue)

Checklist:

bryce13950 commented 2 weeks ago

Very nice. A couple things... I am not a huge fan of the test name. It's not super clear what exactly it is testing at a glance. A slightly more explicit name will make that clear, maybe something like test_compare_huggingface_logits_match_local_implementation. After debugging this mixtral issue, I imagine this is going to become a heavily used test to make sure our implementation matches. Also, do you have time to put together a unit test for the MLP? It is put together in a much better way now, and it is a lot easier to unit test after your change.

bryce13950 commented 2 weeks ago

I am going to mess around with this now. I think it will probably be ready to merge shortly. For future reference, there is a make command (make format) that takes care of all formatting, which will then pass the CI