Closed epwalsh closed 3 years ago
This issue is fixed on interpret
branch
In the current version (allennlp=1.3.0, allennlp-models=1.3.0), the error log is this:
Traceback (most recent call last):
File "gpt2_bug.py", line 22, in <module>
interpreter.saliency_interpret_from_json({"sentence": "Hi there"})
File "/Users/leoliu/My/proj/allennlp/allennlp/interpret/saliency_interpreters/simple_gradient.py", line 34, in saliency_interpret_from_json
grads = self.predictor.get_gradients([instance])[0]
File "/Users/leoliu/My/proj/allennlp/allennlp/predictors/predictor.py", line 113, in get_gradients
self._model.forward(**dataset_tensor_dict) # type: ignore
File "/Users/leoliu/My/proj/allennlp-models/allennlp_models/lm/models/next_token_lm.py", line 120, in forward
targets = util.get_token_ids_from_text_field_tensors(target_ids).view(batch_size)
RuntimeError: shape '[1]' is invalid for input of size 5
This is caused by GPT2 decoding with beam search (beam_size=5). The bug is the forward() of next_token_lm (in allennlp-models), where target_ids may be a 2-dim tensor (batch_size, x), where x > 1; that’s why xxx.view(batch_size)
fails. NextTokenLMPredictor
's output will always be the top indices in the first beam. This is fixed by making the assumption that the target_ids should have shape of (batch_size, 1); otherwise (x>1), we assume, the first token in each batch will be the user's most desirable choice (e.g. highest probability).
The subsequent problem is that in this decoding process, we use GPT2 to embed the decoded sequence multiple times; and the SimpleGradient
registers a forward hook to record the input's embedding in a list. Since both input's embedding and its gradient are used to calculate the final score for each (sub)tokens, the current design makes it hard to distinguish which element in the embedding list we should use. Currently, I just change the registered hook to distinguish between the decoded sequence and the initial input embedding. This is a raw fix because we can't make distinguishment when beam_size=1
.
Steps to reproduce: