TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.17k stars 241 forks source link

[Proposal] Add support for Baichuan1 and Baichuan2 #622

Open StarrySeas1 opened 1 month ago

StarrySeas1 commented 1 month ago

Hello, does this analysis tool support Baichuan1 and Baichuan2?

Hello, I want to use this tool to analyze the Baichuan1 and Baichuan2 models. I don’t know if it is supported.

bryce13950 commented 1 month ago

TransformerLens does not currently support Baichuan. Looking at their page on HuggingFace, it appears to be pretty similar to LLaMA, which means that it should be relatively easy to add, barring any surprises. @StarrySeas1 if you would like to use this model, I can walk you through how to do it, and you can give it a shot. It should be a matter of adding an alias to llama, and then adding a configuration block to loading_from_pretrained to make sure it matches the config on hugging face https://huggingface.co/baichuan-inc/Baichuan-7B/blob/main/config.json.

StarrySeas1 commented 1 month ago

Here are the modifications I made:

  1. Added "Baichuan-13B-Chat" to the "OFFICIAL_MODEL_NAMES".

  2. Added the configuration for Baichuan in the "convert_hf_model_config" function.

    elif "Baichuan-13B" in official_model_name:
    cfg_dict = {
    "d_model": hf_config.hidden_size,
    "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
    "n_heads": hf_config.num_attention_heads,
    "d_mlp": hf_config.intermediate_size,
    "n_layers": hf_config.num_hidden_layers,
    "n_ctx": 2048,  # Capped due to HF Tokenizer Constraints
    "d_vocab": hf_config.vocab_size,
    "eps": hf_config.rms_norm_eps,
    "act_fn": hf_config.hidden_act,
    "initializer_range": hf_config.initializer_range,
    "normalization_type": "RMS",
    "positional_embedding_type": "alibi",
    "post_embedding_ln": True,
    "positional_embedding_type": "alibi",
    }
  3. Added a function "convert_baichuan_weights"

    
    def convert_baichuan_weights(baichuan, cfg: HookedTransformerConfig):
    state_dict = {}
state_dict["embed.W_E"] = baichuan.model.embed_tokens.weight

assert cfg.d_mlp is not None  # keep mypy happy

for l in range(cfg.n_layers):
    state_dict[f"blocks.{l}.ln1.w"] = baichuan.model.layers[l].input_layernorm.weight

    W = baichuan.model.layers[l].self_attn.W_pack.weight

    W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)

    W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :]
    W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads)
    W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads)
    W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads)
    state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
    state_dict[f"blocks.{l}.attn.W_K"] = W_K
    state_dict[f"blocks.{l}.attn.W_V"] = W_V

    state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
        cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=W_Q.device
    )
    state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(
        cfg.n_heads,
        cfg.d_head,
        dtype=cfg.dtype,
        device=W_Q.device,
    )
    state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(
        cfg.n_heads,
        cfg.d_head,
        dtype=cfg.dtype,
        device=W_Q.device,
    )

    W_O = baichuan.model.layers[l].self_attn.o_proj.weight
    W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
    state_dict[f"blocks.{l}.attn.W_O"] = W_O
    state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
        cfg.d_model, dtype=cfg.dtype, device=W_O.device
    )

    state_dict[f"blocks.{l}.ln2.w"] = baichuan.model.layers[l].post_attention_layernorm.weight

    state_dict[f"blocks.{l}.mlp.W_in"] = baichuan.model.layers[l].mlp.up_proj.weight.T
    state_dict[f"blocks.{l}.mlp.W_gate"] = baichuan.model.layers[l].mlp.gate_proj.weight.T
    state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=W_O.dtype)

    state_dict[f"blocks.{l}.mlp.W_out"] = baichuan.model.layers[l].mlp.down_proj.weight.T
    state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=W_O.dtype)

state_dict["ln_final.w"] = baichuan.model.norm.weight
state_dict["unembed.W_U"] = baichuan.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=W_O.dtype)

return state_dict

The following method calls can be used to load without error
```python
model = transformer_lens.HookedTransformer.from_pretrained(ckpt_path,fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,device="cuda:6",trust_remote_code=True)

Question:

  1. Are the above additions correct?
  2. I compared the Attention Patterns of Baichuan-13B-Chat and the Baichuan-13B-Chat after SFT on the same sample, and they are basically not much different, with the values on the first few Tokens being very large. I'm not sure if this is a problem with the model training or if my addition of support for Baichuan is incorrect.
bryce13950 commented 4 weeks ago

Thank you very much for doing the work here to get it to work in TransformerLens! Your implementation from a glance seems correct. However, it's hard to say if the discrepancies are a result of something being slightly off, or if it is something of no concern without playing with the code directly. Can you setup a PR with these changes in it, so that we can double check everything together?