allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.76k stars 2.25k forks source link

Interpret broken for GPT2 NextTokenLM #4745

Closed epwalsh closed 3 years ago

epwalsh commented 4 years ago
image

Steps to reproduce:

from allennlp.interpret.saliency_interpreters import (
    SimpleGradient,
)
from allennlp_models.pretrained import load_predictor

predictor = load_predictor("lm-next-token-lm-gpt2", overrides={
    "dataset_reader.max_tokens": 512,
    "model.beam_search_generator": {
        "type": "transformer",
        "beam_search": {
            "end_index": 50256,
            "max_steps": 5,
            "beam_size": 5,
        }
    }
})
interpreter = SimpleGradient(predictor)
interpreter.saliency_interpret_from_json({"sentence": "Hi there"})
leo-liuzy commented 3 years ago

This issue is fixed on interpret branch In the current version (allennlp=1.3.0, allennlp-models=1.3.0), the error log is this:

Traceback (most recent call last):
  File "gpt2_bug.py", line 22, in <module>
    interpreter.saliency_interpret_from_json({"sentence": "Hi there"})
  File "/Users/leoliu/My/proj/allennlp/allennlp/interpret/saliency_interpreters/simple_gradient.py", line 34, in saliency_interpret_from_json
    grads = self.predictor.get_gradients([instance])[0]
  File "/Users/leoliu/My/proj/allennlp/allennlp/predictors/predictor.py", line 113, in get_gradients
    self._model.forward(**dataset_tensor_dict)  # type: ignore
  File "/Users/leoliu/My/proj/allennlp-models/allennlp_models/lm/models/next_token_lm.py", line 120, in forward
    targets = util.get_token_ids_from_text_field_tensors(target_ids).view(batch_size)
RuntimeError: shape '[1]' is invalid for input of size 5

This is caused by GPT2 decoding with beam search (beam_size=5). The bug is the forward() of next_token_lm (in allennlp-models), where target_ids may be a 2-dim tensor (batch_size, x), where x > 1; that’s why xxx.view(batch_size) fails. NextTokenLMPredictor's output will always be the top indices in the first beam. This is fixed by making the assumption that the target_ids should have shape of (batch_size, 1); otherwise (x>1), we assume, the first token in each batch will be the user's most desirable choice (e.g. highest probability).

The subsequent problem is that in this decoding process, we use GPT2 to embed the decoded sequence multiple times; and the SimpleGradient registers a forward hook to record the input's embedding in a list. Since both input's embedding and its gradient are used to calculate the final score for each (sub)tokens, the current design makes it hard to distinguish which element in the embedding list we should use. Currently, I just change the registered hook to distinguish between the decoded sequence and the initial input embedding. This is a raw fix because we can't make distinguishment when beam_size=1.