stanfordnlp / pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
http://pyvene.ai
Apache License 2.0
545 stars 46 forks source link

[External]: Attention head intervention doesn't work for models other than GPT-2 #158

Closed Bakser closed 3 weeks ago

Bakser commented 1 month ago

Contact Details

wangxz098@gmail.com

What happened?

When intervening attention heads (i.e. set the intervention component as head_attention_value_output, head_query_output, head_key_output, or head_value_output), pyvene will crash. Only GPT-2 works, all the other LMs have the bug.

The stack log will be like:

Traceback (most recent call last):
  File "/data1/wxz/test_pyvene.py", line 21, in <module>
    _, counterfactual_outputs = intervenable(
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 1465, in forward
    raise e
  File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 1425, in forward
    self._wait_for_forward_with_parallel_intervention(
  File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 1072, in _wait_for_forward_with_parallel_intervention
    _ = self.model(**sources[group_id])
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1192, in forward
    outputs = self.model(
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1019, in forward
    layer_outputs = decoder_layer(
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 740, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 682, in forward
    attn_output = self.o_proj(attn_output)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1547, in _call_impl
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 747, in hook_callback
    selected_output = self._gather_intervention_output(
  File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 658, in _gather_intervention_output
    selected_output = gather_neurons(
  File "/home/wxz/workspace/pyvene/pyvene/models/modeling_utils.py", line 276, in gather_neurons
    pos_tensor_input = bhsd_to_bs_hd(head_tensor_output)
  File "/home/wxz/workspace/pyvene/pyvene/models/modeling_utils.py", line 204, in bhsd_to_bs_hd
    b, h, s, d = tensor.shape
ValueError: not enough values to unpack (expected 4, got 3)

I've basically figured out how to fix this bug and will submit a PR later.

Code to produce this issue.

#simplified from the nested intervention tutorial
import pyvene as pv
_, tokenizer, gpt = pv.create_llama("/data3/MODELS/llama2-hf/llama-2-7b") # every LM except for GPT-2 has the bug
pv_config = pv.IntervenableConfig(
        model_type=type(gpt),
        representations=[
            pv.RepresentationConfig(
                1,
                "head_attention_value_output",  # can also be changed to head_query_output, head_key_output, head_value_output, all don't work
                "h.pos",
                1, 
            ),
        ],
        intervention_types=pv.VanillaIntervention,
)
base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [tokenizer("The capital of Italy is", return_tensors="pt")]
intervenable = pv.IntervenableModel(pv_config, gpt)
_, counterfactual_outputs = intervenable(
    base,
    sources,
    {
        "sources->base": (
            [[[[0]], [[4]]]],
            [[[[0]], [[4]]]],
        )
    },
)