Open Gandalf098 opened 1 year ago
cc @gante
Hey @Gandalf098 (the white, I hope ;) )
By default, the scores are not initialized and are kept as None
(see here). To enable score-keeping, you must pass return_dict_in_generate=True, output_scores=True
to your .generate()
call.
import torch
from transformers import StoppingCriteriaList, BartForConditionalGeneration, BartTokenizer
def custom_stopping_criteria(input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
print("Scores:", scores)
return False
stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
tok = BartTokenizer.from_pretrained("facebook/bart-large")
example_english_phrase = "UN Chief Says There Is No <mask> in Syria"
batch = tok(example_english_phrase, return_tensors="pt")
model.generate(batch["input_ids"], stopping_criteria=stopping_criteria, return_dict_in_generate=True, output_scores=True)
Hi @gante and @Gandalf098,
According to the StoppingCriteria.__call__
signature and to its docstring, scores
is supposed to be a torch.FloatTensor
.
scores (torch.FloatTensor of shape (batch_size, config.vocab_size)) — Prediction scores of a language modeling head.
It makes sense to think of it as the last prediction scores of the language modeling head, meaning that the score-keeping here refers not to score
(optional history of the prediction scores) but to next_token_scores
(always available last prediction scores - at least for greedy decoding, we should verify for other decoding strategies).
In that sense, I do think we should correct this point. What do you think @gante?
We might want to build some stopping criteria based on a sequence of tokens/sequence of scores, so this API is more general 🤗
We do need better docs and/or input validation, though, to detect these issues in advance. It is my priority for this month (and I'm keeping this issue open so I don't forget to address this case)
System Info
transformers
version: 4.29.2Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Reproduction Steps:
Scores argument is always None
Example code:
The above code uses a stopping critriea that just prints the scores value when called (which prints None)
Expected behavior
The expected behavior should be to have Scores logits populated with values instead of being None (values before or after softmax don't matter)