huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.08k stars 26.81k forks source link

BART can only generate a maximum of 20 tokens #16622

Closed ayaka14732 closed 2 years ago

ayaka14732 commented 2 years ago

Environment info

Who can help

@patil-suraj @patrickvonplaten

Information

Model I am using: BART

To reproduce

Steps to reproduce the behavior:

from transformers import BartTokenizer, BartForConditionalGeneration

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

sentences = ['At the launch of the latest report by the Intergovernmental Panel on Climate Change, on the mitigation of climate change, the UN Secretary-General called for an urgent shift of investments and subsidies from fossil fuels to renewable energy, warning that investing in new fossil fuels infrastructure is moral and economic madness.']

inputs = tokenizer(sentences, return_tensors='pt')
print('Input shape:', inputs.input_ids.shape)

generate_ids = model.generate(inputs.input_ids, num_beams=5, min_length=50)
print('Generated shape:', generate_ids.shape)

print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

Output:

Input shape: torch.Size([1, 60])
Generated shape: torch.Size([1, 20])
At the launch of the latest report by the Intergovernmental Panel on Climate Change, on

Expected behavior

The output should not be truncated.

Actual behavior

The output is truncated.

Note that the output is truncated even if min_length=50 is specified.

gante commented 2 years ago

Hi @ayaka14732 👋 That happens because the stopping conditions take precedence over anything else. The default for max_length is 20, so that's why you see 20 generated tokens. In your example, if you rewrite the generate line into generate_ids = model.generate(inputs.input_ids, num_beams=5, min_length=50, max_length=100), you'll get the results you expect.

@patrickvonplaten @patil-suraj should we raise an exception in this case? (min_length > max_length)

patrickvonplaten commented 2 years ago

@gante, yes this would work for me! Let's maybe do this in generate() before we just into the sub-generation methods

gante commented 2 years ago

@ayaka14732 if you pull from master (or install transformers==4.19.0.dev0), you shall see an informative Exception if you try to run your original script.

Thank you for reporting this issue :D