huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.1k stars 27.03k forks source link

invalid multinomial distribution (with replacement=False, not enough non-negative category to sample) #11305

Closed Muennighoff closed 3 years ago

Muennighoff commented 3 years ago

When using "sshleifer/distilbart-cnn-6-6" & do_sample the below code errors out, meanwhile the same code works for "sshleifer/distilbart-xsum-6-6". Am I missing something really obvious here? Thanks for any help!

Tranformers: 4.5.1

from transformers import (
    AutoModelForSeq2SeqLM, 
    AutoTokenizer
  )

model_name = "sshleifer/distilbart-cnn-6-6"
#model_name = "sshleifer/distilbart-xsum-6-6"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

text = "New York City (NYC), often simply called New York, is the most populous city in the United States"
input_ids = tokenizer.encode(text, return_tensors='pt')

sample_outputs = model.generate(input_ids, 
                                num_beams=3,
                                do_sample=True
                                )
sample_outputs
github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

zeke-john commented 2 years ago

i have the same exact problem when i use do_sample=True can you re-open this issue?

LysandreJik commented 2 years ago

Maybe @gante has an idea!

gante commented 2 years ago

Hi there @Muennighoff @zeke-john 👋

I've run the script above for both models on v4.5.1 (and on v4.22.dev0) and it works with no problems -- you can see a colab here.

A potential cause for errors may be GPU memory -- generation with num_beams is memory intensive. Let me know if you have more details about your problem :)