TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.45k stars 283 forks source link

Match Huggingface GPT2 implementation *exactly* #645

Closed joelburget closed 3 months ago

joelburget commented 3 months ago

Description

I'm opening this to follow-up on #570 / #641, discuss a bit of work I did, and discuss next steps.

I think there's a decent amount of evidence that TL's implementation of (at least) Mixtral differs from the "official" Hugginface implementation. If this is a crux, I think it might be worthwhile to try to generate stronger evidence than the plots in #570.

My best guess at what's happening is that we're seeing these small differences in outputs, which may compound across layers (and tokens). Again, if this is a crux, I might try to collect evidence for or against.

einops and F.linear are almost, but not quite in agreement

Here's a snippet of code:

batch = 2
pos = 5
d_model = 768
d_head = 64
n_heads = 12

w = torch.normal(0, 0.02, size=(n_heads, d_model, d_head))
b = torch.normal(0, 0.1, size=(n_heads, d_head))
input = torch.normal(0, 1, size=(batch, pos, d_model))

w_ = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model")
b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)")
result1 = F.linear(input, w_, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1])

result2 = einops.einsum(
    input,
    w,
    "batch pos d_model, head_index d_model d_head -> batch pos head_index d_head",
) + b

And a plot of the difference between result1 and result2:

Screenshot 2024-06-25 at 6 16 41 PM

Ideally, there would be no difference, but it turns out that there is, probably due to floating-point operations happening in different orders.

This PR

This is intended as a demonstration that by avoiding einsum it is possible to match the canonical implementation exactly. The added test failed before but passes now.

You'll note that simple_attn_linear is implemented using F.linear but I haven't worked out how to do this for complex_attn_linear yet, or if this is even possible.

Behind the scenes, the Huggingface implementation of GPT2 uses Conv1D (which calls addmm). Mixtral uses torch.nn.Linear, which calls F.linear.

Tests and Questions

Getting the new test to pass for just GPT2 was a lot of effort and it raises some questions:

  1. Is it even worth it? I think this is somewhat of an empirical question and may require some more research to answer well. The answer might differ depending on the size and type of the model. For example, it may matter for Mixtral but not GPT2.
  2. Does this imply a ton of bespoke work for each model? Do some models in their canonical implementation use addmm / F.linear while some use einsum? My guess is that addmm / F.linear are more common.
  3. I haven't added a test for Mixtral yet, because it would either (a) require a beefy machine to run or (b) require a way to load a single layer (for both the TL and HF models), which will be a bit of work.

Overall my take is that using addmm / F.linear is slightly less readable, but likely to more closely match the implementation of most models so is probably worth using as a default. I also like the idea of testing that the TL implementation exactly matches HF, but I'm not sure if this type of test would be flaky.

Type of change

Discussion, possibly bug fix.

Checklist:

bryce13950 commented 3 months ago

OK, so first off I think getting these tests to pass is really worth it. We need to accept a certain amount of inaccuracy, but at the moment it is unacceptable. My hypothesis is that every model supported in TransformerLens is slightly inaccurate from the implementation in HuggingFace. This inaccuracy, which may actually have been bigger on GPT2 than it was on Mixtral has not been obvious on smaller models, since the accumulative error is not significant enough to impact smaller models on a limited amount of passes. To test this, my idea is to run a small model on a massive amount of generation loops. I believe that it will eventually begin deteriorating due to the accumulation of error.

Going forward, all model implementations introduced to TransformerLens should match huggingface. I have spent quite a bit of time in the last couple of weeks reading through the transformers code base, and so far I have not seen a use of einsum. Not to say that they don't exist, but there's probably a reason why HuggingFace are not relying on it. In our case, accuracy is 100 times more important than readability, so removing einsum may be a very important task across the whole codebase. Improving this accuracy is going to require a ton of work for each model, and probably a community initiative to help identify where it is most needed. It is absolutely vital though to be able to solve this inaccuracy as quickly as possible, and then put systems in place that will track accuracy, and identify when we need to work on improving it. This work will lay the groundwork for supporting more complicated models into the future, and without finding a solution for it now, our ability to support larger models will be limited. It is not feasible that we have test suites for the full model of every model we support, but when we can have that, we should. As for slicing a layer, and using that, this idea has been discussed in great detail, and it will happen at some point. The most difficult thing at this moment is identifying what we can do, which will have the most impact on solving the underlying issue. When I was debugging mixtral in depth last week, there were quite a few moments when all of a sudden I changed something and accuracy improved. A lot of those changes were really small, with the biggest being when I synced the MoE component with HuggingFace. I think we are going to find a lot of little problems across the board, and a few larger problems.

What I think we should do is focus on improving accuracy in smaller models. I believe that if we do that, the improved accuracy will ripple out into the more complicated models. What you have done here is a really good move into that direction. I would like to wrap up a few tasks, and put up a release within the next couple days. I would really like to include this in that release. Is there anything I can do to help that?

joelburget commented 3 months ago

Is there anything I can do to help that?

A few ideas:

  1. General code review

  2. Help fix tests/acceptance/test_activation_cache.py::test_logit_attrs_matches_reference_code which is failing CI (though it's passing for me locally)

  3. complex_attn_linear is kind of implemented as a placeholder at the moment. You could help figure out the best way to implement this without einsum. (Though it's probably fine if we merge this as-is for now).

  4. Implement tests for more models.

I can work on all of these tasks as well, though with limited bandwidth in the next few days.