pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.93k stars 499 forks source link

fix LLMAttribution for old pytorch/python versions #1353

Closed DianjingLiu closed 1 month ago

DianjingLiu commented 1 month ago

Summary: When setting use_cached_outputs=False, the LLMAttribution failed to run on some old versions of pytorch/python.

Error message

======================================================================
ERROR: test_llm_attr_hf_compatibility_0 (tests.attr.test_llm_attr_hf_compatibility.TestLLMAttrHFCompatibility_1_cpu)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/parameterized/parameterized.py", line 620, in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
  File "/data/users/liudj/captum/tests/attr/test_llm_attr_hf_compatibility.py", line 80, in test_llm_attr_hf_compatibility
    res = llm_attr.attribute(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 461, in attribute
    cur_attr = self.attr_method.attribute(
  File "/data/users/liudj/captum/captum/log/__init__.py", line 52, in wrapper
    return func(*args, **kwargs)
  File "/data/users/liudj/captum/captum/attr/_core/feature_ablation.py", line 292, in attribute
    initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
  File "/data/users/liudj/captum/captum/_utils/common.py", line 599, in _run_forward
    output = forward_func(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 335, in _forward_func
    outputs = self.model.forward(model_inp, **model_kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 428, in forward
    attn_weights = attn_weights + causal_mask
RuntimeError: The size of tensor a (8) must match the size of tensor b (7) at non-singleton dimension 3

Root cause

The attention_mask was not updated to adapt to the growth of input size. Error message see test plan.

Impacted versions

{F1876426564}

Differential Revision: D63016032

facebook-github-bot commented 1 month ago

This pull request was exported from Phabricator. Differential Revision: D63016032

facebook-github-bot commented 1 month ago

This pull request was exported from Phabricator. Differential Revision: D63016032

facebook-github-bot commented 1 month ago

This pull request was exported from Phabricator. Differential Revision: D63016032

facebook-github-bot commented 1 month ago

This pull request was exported from Phabricator. Differential Revision: D63016032

facebook-github-bot commented 1 month ago

This pull request has been merged in pytorch/captum@fc910e5e0289ffd856d40503d5504d73e8b28b95.