TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.41k stars 270 forks source link

[Bug Report] hook_resid_pre doesn't match hidden_states #346

Closed loganriggs closed 1 year ago

loganriggs commented 1 year ago

Describe the bug cache[f"blocks.{x}.hook_resid_pre"] doesn't match hidden states (or only up to a set decimal place).

Hidden states is from transformer's model(tokens, output_hidden_states=True). I've checked if it matches post layer norm, but it does not. Plus it's odd that they match within a set number of decimal place for the first few layers, but not the last few.

I do expect both of these to show the state of the residual stream, but could be misunderstanding either of them.

Code example

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "EleutherAI/pythia-70m-deduped"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tl_model = HookedTransformer.from_pretrained(model_name, device=device)
text = "The quick brown fox jumps over the lazy dog"
tokens = tokenizer(text, return_tensors="pt").input_ids.to(device)
output = model(tokens, output_hidden_states=True)
hidden_states = output.hidden_states
_, cache = tl_model.run_with_cache(tokens)

for x in range(tl_model.cfg.n_layers):
    for y in range(5, -1, -1):
        if(torch.allclose(cache[f"blocks.{x}.hook_resid_pre"], hidden_states[x], atol=10**-y)):
            print(f"blocks.{x}.hook_resid_pre matches hidden_states[{x}] to {y} decimal places")
            break

System Info Python 3.10.11

pip install -r requirements.txt w/ (transformer-lens @ git+https://github.com/neelnanda-io/TransformerLens@ae32fa54ad40cb2c3f3a60f1837d0b4899c8daae) as a line

Checklist

Hzfinfdu commented 1 year ago

I've been facing the same issue with Pythia. I suspect there is a tiny mismatch of GPTNeoX config and HookedTransformerConfig. I'm trying to figure it out and also looking forward to your solution!

loganriggs commented 1 year ago

Thanks for pointing out the config. So I noticed two things:

  1. It messes up at the very beginning cache["hook_embed"], hidden_states[0], model.gpt_neox.embed_in(tokens) The cache is different than the other two.

  2. The initializer range is wrong for transformer lens it's 0.035355... for pythia-70m-deduped it's 0.02

I changed the config to match the initializer range, & re-ran it, but it didn't work. Maybe required on init or something else is wrong.

Links: Pythia config

ArthurConmy commented 1 year ago

One line fix: use HookedTransformer.from_pretrained_no_processing(model_name, device=device) . I get errors of at most 3 decimal places which I think is within the range of normal floating point errors in float32 computation (?)

The library makes changes to weight matrices (such as removing the Layer Norm bias and gain from the forward pass altogether!) to make things much easier for interpretability.

You should read https://github.com/neelnanda-io/TransformerLens/blob/main/further_comments.md#weight-processing

loganriggs commented 1 year ago

Thanks for the link & code! I am able to replicate & also get an exact match for the embedding if I subtract the embedding-weights by the mean as suggested in your link.

For my task, I'm creating reconstructions of the residual stream to find interpretable directions for model steering. I'm unsure if the original HookedTransformer.from_pretrained() is functionally equivalent, but the no_processing one does seem so since it's at most 3 decimal places different.

ArthurConmy commented 1 year ago

I believe that there will be an extra term parallel to the (1, 1, ... , 1) direction in the non-preprocessed model because TL removes this (as do LayerNorms). So theoretically I expect TL to perform very slightly better but neural nets are weird, I'm not confident