cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
Apache License 2.0
1.27k stars 96 forks source link

ZeroShotClassificationExplainer appears to be broken #125

Open tr-enjoyer opened 1 year ago

tr-enjoyer commented 1 year ago

To reproduce the issue, simply start an empty environment (e.g. google colab) and run

! pip install transformers-interpret
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import ZeroShotClassificationExplainer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
zero_shot_explainer = ZeroShotClassificationExplainer(model, tokenizer)
word_attributions = zero_shot_explainer(
"Today apple released the new Macbook showing off a range of new features found in the proprietary silicon chip computer. ",
labels = ["finance",  "technology",  "sports"],)

Where, if not in a notebook, simply remove ! pip install transformers-interpret.

The output is:

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers-interpret
  Downloading transformers_interpret-0.9.6-py3-none-any.whl (45 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.6/45.6 KB 1.4 MB/s eta 0:00:00
Collecting captum>=0.3.1
  Downloading captum-0.6.0-py3-none-any.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 32.8 MB/s eta 0:00:00
Collecting ipython<8.0.0,>=7.31.1
  Downloading ipython-7.34.0-py3-none-any.whl (793 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 793.8/793.8 KB 37.4 MB/s eta 0:00:00
Collecting transformers>=3.0.0
  Downloading transformers-4.26.0-py3-none-any.whl (6.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.3/6.3 MB 69.6 MB/s eta 0:00:00
Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from captum>=0.3.1->transformers-interpret) (1.13.1+cu116)
Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from captum>=0.3.1->transformers-interpret) (1.21.6)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from captum>=0.3.1->transformers-interpret) (3.2.2)
Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.8/dist-packages (from ipython<8.0.0,>=7.31.1->transformers-interpret) (5.7.1)
Requirement already satisfied: pygments in /usr/local/lib/python3.8/dist-packages (from ipython<8.0.0,>=7.31.1->transformers-interpret) (2.6.1)
Requirement already satisfied: backcall in /usr/local/lib/python3.8/dist-packages (from ipython<8.0.0,>=7.31.1->transformers-interpret) (0.2.0)
Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.8/dist-packages (from ipython<8.0.0,>=7.31.1->transformers-interpret) (4.8.0)
Collecting jedi>=0.16
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 51.4 MB/s eta 0:00:00
Collecting matplotlib-inline
  Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from ipython<8.0.0,>=7.31.1->transformers-interpret) (2.0.10)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.8/dist-packages (from ipython<8.0.0,>=7.31.1->transformers-interpret) (57.4.0)
Requirement already satisfied: decorator in /usr/local/lib/python3.8/dist-packages (from ipython<8.0.0,>=7.31.1->transformers-interpret) (4.4.2)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.8/dist-packages (from ipython<8.0.0,>=7.31.1->transformers-interpret) (0.7.5)
Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers>=3.0.0->transformers-interpret) (3.9.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers>=3.0.0->transformers-interpret) (23.0)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers>=3.0.0->transformers-interpret) (4.64.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers>=3.0.0->transformers-interpret) (6.0)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.6/7.6 MB 68.2 MB/s eta 0:00:00
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers>=3.0.0->transformers-interpret) (2022.6.2)
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.12.0-py3-none-any.whl (190 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 190.3/190.3 KB 14.2 MB/s eta 0:00:00
Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers>=3.0.0->transformers-interpret) (2.25.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers>=3.0.0->transformers-interpret) (4.4.0)
Requirement already satisfied: parso<0.9.0,>=0.8.0 in /usr/local/lib/python3.8/dist-packages (from jedi>=0.16->ipython<8.0.0,>=7.31.1->transformers-interpret) (0.8.3)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.8/dist-packages (from pexpect>4.3->ipython<8.0.0,>=7.31.1->transformers-interpret) (0.7.0)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.8/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython<8.0.0,>=7.31.1->transformers-interpret) (1.15.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.8/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython<8.0.0,>=7.31.1->transformers-interpret) (0.2.6)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum>=0.3.1->transformers-interpret) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum>=0.3.1->transformers-interpret) (1.4.4)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum>=0.3.1->transformers-interpret) (2.8.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->captum>=0.3.1->transformers-interpret) (3.0.9)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers>=3.0.0->transformers-interpret) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers>=3.0.0->transformers-interpret) (2.10)
Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers>=3.0.0->transformers-interpret) (4.0.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers>=3.0.0->transformers-interpret) (2022.12.7)
Installing collected packages: tokenizers, matplotlib-inline, jedi, ipython, huggingface-hub, transformers, captum, transformers-interpret
  Attempting uninstall: ipython
    Found existing installation: ipython 7.9.0
    Uninstalling ipython-7.9.0:
      Successfully uninstalled ipython-7.9.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires ipython~=7.9.0, but you have ipython 7.34.0 which is incompatible.
Successfully installed captum-0.6.0 huggingface-hub-0.12.0 ipython-7.34.0 jedi-0.18.2 matplotlib-inline-0.1.6 tokenizers-0.13.2 transformers-4.26.0 transformers-interpret-0.9.6
Downloading (…)okenizer_config.json: 100%
26.0/26.0 [00:00<00:00, 1.07kB/s]
Downloading (…)lve/main/config.json: 100%
1.15k/1.15k [00:00<00:00, 60.9kB/s]
Downloading (…)olve/main/vocab.json: 100%
899k/899k [00:00<00:00, 2.73MB/s]
Downloading (…)olve/main/merges.txt: 100%
456k/456k [00:00<00:00, 1.48MB/s]
Downloading (…)/main/tokenizer.json: 100%
1.36M/1.36M [00:00<00:00, 3.66MB/s]
Downloading (…)"pytorch_model.bin";: 100%
1.63G/1.63G [00:06<00:00, 252MB/s]
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-1-2538eac9d0fb>](https://localhost:8080/#) in <module>
     11 
     12 
---> 13 word_attributions = zero_shot_explainer(
     14     "Today apple released the new Macbook showing off a range of new features found in the proprietary silicon chip computer. ",
     15     labels = ["finance", "technology", "sports"],

9 frames
[/usr/local/lib/python3.8/dist-packages/transformers_interpret/explainers/text/zero_shot_classification.py](https://localhost:8080/#) in __call__(self, text, labels, embedding_type, hypothesis_template, include_hypothesis, internal_batch_size, n_steps)
    293             self.hypothesis_text = self.hypothesis_labels[i]
    294             self.predicted_label = labels[i] + " (" + self.entailment_key.lower() + ")"
--> 295             super().__call__(
    296                 text,
    297                 class_name=self.entailment_key,

[/usr/local/lib/python3.8/dist-packages/transformers_interpret/explainers/text/sequence_classification.py](https://localhost:8080/#) in __call__(self, text, index, class_name, embedding_type, internal_batch_size, n_steps)
    310         if internal_batch_size:
    311             self.internal_batch_size = internal_batch_size
--> 312         return self._run(text, index, class_name, embedding_type=embedding_type)
    313 
    314     def __str__(self):

[/usr/local/lib/python3.8/dist-packages/transformers_interpret/explainers/text/sequence_classification.py](https://localhost:8080/#) in _run(self, text, index, class_name, embedding_type)
    264         self.text = self._clean_text(text)
    265 
--> 266         self._calculate_attributions(embeddings=embeddings, index=index, class_name=class_name)
    267         return self.word_attributions  # type: ignore
    268 

[/usr/local/lib/python3.8/dist-packages/transformers_interpret/explainers/text/zero_shot_classification.py](https://localhost:8080/#) in _calculate_attributions(self, embeddings, class_name, index)
    201 
    202         reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
--> 203         lig = LIGAttributions(
    204             self._forward,
    205             embeddings,

[/usr/local/lib/python3.8/dist-packages/transformers_interpret/attributions.py](https://localhost:8080/#) in __init__(self, custom_forward, embeddings, tokens, input_ids, ref_input_ids, sep_id, attention_mask, target, token_type_ids, position_ids, ref_token_type_ids, ref_position_ids, internal_batch_size, n_steps)
     49 
     50         if self.token_type_ids is not None and self.position_ids is not None:
---> 51             self._attributions, self.delta = self.lig.attribute(
     52                 inputs=(self.input_ids, self.token_type_ids, self.position_ids),
     53                 baselines=(

[/usr/local/lib/python3.8/dist-packages/captum/log/__init__.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
     40             @wraps(func)
     41             def wrapper(*args, **kwargs):
---> 42                 return func(*args, **kwargs)
     43 
     44             return wrapper

[/usr/local/lib/python3.8/dist-packages/captum/attr/_core/layer/layer_integrated_gradients.py](https://localhost:8080/#) in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta, attribute_to_layer_input)
    369             self.device_ids = getattr(self.forward_func, "device_ids", None)
    370 
--> 371         inputs_layer = _forward_layer_eval(
    372             self.forward_func,
    373             inps,

[/usr/local/lib/python3.8/dist-packages/captum/_utils/gradient.py](https://localhost:8080/#) in _forward_layer_eval(forward_fn, inputs, layer, additional_forward_args, device_ids, attribute_to_layer_input, grad_enabled)
    180     grad_enabled: bool = False,
    181 ) -> Union[Tuple[Tensor, ...], List[Tuple[Tensor, ...]]]:
--> 182     return _forward_layer_eval_with_neuron_grads(
    183         forward_fn,
    184         inputs,

[/usr/local/lib/python3.8/dist-packages/captum/_utils/gradient.py](https://localhost:8080/#) in _forward_layer_eval_with_neuron_grads(forward_fn, inputs, layer, additional_forward_args, gradient_neuron_selector, grad_enabled, device_ids, attribute_to_layer_input)
    443 
    444     with torch.autograd.set_grad_enabled(grad_enabled):
--> 445         saved_layer = _forward_layer_distributed_eval(
    446             forward_fn,
    447             inputs,

[/usr/local/lib/python3.8/dist-packages/captum/_utils/gradient.py](https://localhost:8080/#) in _forward_layer_distributed_eval(forward_fn, inputs, layer, target_ind, additional_forward_args, attribute_to_layer_input, forward_hook_with_return, require_layer_grads)
    303 
    304     if len(saved_layer) == 0:
--> 305         raise AssertionError("Forward hook did not obtain any outputs for given layer")
    306 
    307     if forward_hook_with_return:

AssertionError: Forward hook did not obtain any outputs for given layer

This seems as minimal an example as possible, so I feel like something is afoot.