huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.99k stars 26.53k forks source link

custom stopping_critriea function doesn't receive logits scores (receives None instead) #23674

Open Gandalf098 opened 1 year ago

Gandalf098 commented 1 year ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

Reproduction Steps:

  1. Initialize a BART model & its tokenizer (in my case it is facebook/bart-large)
  2. Create a custom stopping_criteria function and add it to StoppingCriteriaList object
  3. Run model.generate() with the your stopping criteria list as argument

Scores argument is always None

Example code:

import torch
from transformers import StoppingCriteriaList, BartForConditionalGeneration, BartTokenizer

def custom_stopping_criteria(input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
    print ("Scores:", scores)
    return False

stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
tok = BartTokenizer.from_pretrained("facebook/bart-large")

example_english_phrase = "UN Chief Says There Is No <mask> in Syria"
batch = tok(example_english_phrase, return_tensors="pt")

model.generate(batch["input_ids"], stopping_criteria=stopping_criteria)

The above code uses a stopping critriea that just prints the scores value when called (which prints None)

Expected behavior

The expected behavior should be to have Scores logits populated with values instead of being None (values before or after softmax don't matter)

sgugger commented 1 year ago

cc @gante

gante commented 1 year ago

Hey @Gandalf098 (the white, I hope ;) )

By default, the scores are not initialized and are kept as None (see here). To enable score-keeping, you must pass return_dict_in_generate=True, output_scores=True to your .generate() call.


import torch
from transformers import StoppingCriteriaList, BartForConditionalGeneration, BartTokenizer

def custom_stopping_criteria(input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
    print("Scores:", scores)
    return False

stopping_criteria = StoppingCriteriaList([custom_stopping_criteria])

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
tok = BartTokenizer.from_pretrained("facebook/bart-large")

example_english_phrase = "UN Chief Says There Is No <mask> in Syria"
batch = tok(example_english_phrase, return_tensors="pt")

model.generate(batch["input_ids"], stopping_criteria=stopping_criteria, return_dict_in_generate=True, output_scores=True)
ylacombe commented 1 year ago

Hi @gante and @Gandalf098,

According to the StoppingCriteria.__call__ signature and to its docstring, scores is supposed to be a torch.FloatTensor.

scores (torch.FloatTensor of shape (batch_size, config.vocab_size)) — Prediction scores of a language modeling head.

It makes sense to think of it as the last prediction scores of the language modeling head, meaning that the score-keeping here refers not to score (optional history of the prediction scores) but to next_token_scores (always available last prediction scores - at least for greedy decoding, we should verify for other decoding strategies).

In that sense, I do think we should correct this point. What do you think @gante?

gante commented 1 year ago

We might want to build some stopping criteria based on a sequence of tokens/sequence of scores, so this API is more general 🤗

We do need better docs and/or input validation, though, to detect these issues in advance. It is my priority for this month (and I'm keeping this issue open so I don't forget to address this case)