When intervening attention heads (i.e. set the intervention component as head_attention_value_output, head_query_output, head_key_output, or head_value_output), pyvene will crash. Only GPT-2 works, all the other LMs have the bug.
The stack log will be like:
Traceback (most recent call last):
File "/data1/wxz/test_pyvene.py", line 21, in <module>
_, counterfactual_outputs = intervenable(
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 1465, in forward
raise e
File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 1425, in forward
self._wait_for_forward_with_parallel_intervention(
File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 1072, in _wait_for_forward_with_parallel_intervention
_ = self.model(**sources[group_id])
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1192, in forward
outputs = self.model(
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1019, in forward
layer_outputs = decoder_layer(
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 740, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 682, in forward
attn_output = self.o_proj(attn_output)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/wxz/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1547, in _call_impl
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 747, in hook_callback
selected_output = self._gather_intervention_output(
File "/home/wxz/workspace/pyvene/pyvene/models/intervenable_base.py", line 658, in _gather_intervention_output
selected_output = gather_neurons(
File "/home/wxz/workspace/pyvene/pyvene/models/modeling_utils.py", line 276, in gather_neurons
pos_tensor_input = bhsd_to_bs_hd(head_tensor_output)
File "/home/wxz/workspace/pyvene/pyvene/models/modeling_utils.py", line 204, in bhsd_to_bs_hd
b, h, s, d = tensor.shape
ValueError: not enough values to unpack (expected 4, got 3)
I've basically figured out how to fix this bug and will submit a PR later.
Code to produce this issue.
#simplified from the nested intervention tutorial
import pyvene as pv
_, tokenizer, gpt = pv.create_llama("/data3/MODELS/llama2-hf/llama-2-7b") # every LM except for GPT-2 has the bug
pv_config = pv.IntervenableConfig(
model_type=type(gpt),
representations=[
pv.RepresentationConfig(
1,
"head_attention_value_output", # can also be changed to head_query_output, head_key_output, head_value_output, all don't work
"h.pos",
1,
),
],
intervention_types=pv.VanillaIntervention,
)
base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [tokenizer("The capital of Italy is", return_tensors="pt")]
intervenable = pv.IntervenableModel(pv_config, gpt)
_, counterfactual_outputs = intervenable(
base,
sources,
{
"sources->base": (
[[[[0]], [[4]]]],
[[[[0]], [[4]]]],
)
},
)
Contact Details
wangxz098@gmail.com
What happened?
When intervening attention heads (i.e. set the intervention component as
head_attention_value_output
,head_query_output
,head_key_output
, orhead_value_output
), pyvene will crash. Only GPT-2 works, all the other LMs have the bug.The stack log will be like:
I've basically figured out how to fix this bug and will submit a PR later.
Code to produce this issue.