huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.2k stars 27.06k forks source link

INF encountered when using sampling with temperature. #19509

Open ElliottYan opened 2 years ago

ElliottYan commented 2 years ago

System Info

latest transformers version == 4.24.0 When generating samples with mBART, I encounter this problem: image

Looking deeply into the codes, I find the problem roots from the beam score added to next_token_scores here: https://github.com/huggingface/transformers/blob/bc21aaca789f1a366c05e8b5e111632944886393/src/transformers/generation_utils.py#L2566

The original value of beam_scores is 0, but when using temperature like 0.5, the score is also divided the temperature value in logit_warper and gets larger and larger. And finally it causes the overflow of next_token_scores.

Who can help?

@patrickvonplaten @Narsil @gante

Information

Tasks

Reproduction

I provide a simple code that can reproduce this issue.

import transformers from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

model = model.cuda()

src = 'In einem Notruf erzählte Professor Shannon Lamb mit einer etwas zittrigen Stimme der Polizei, dass er seine Freundin erschossen habe und dass die Beamten zu seinem Haus kommen müssten.'

encoded_hi = tokenizer(src, return_tensors="pt", padding=True).to('cuda') # do_sample=True generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id['en_XX'], temperature=0.5, do_sample=True, num_beams=10, num_return_sequences=10)

tgt_txt = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

Expected behavior

I think this should be solved but I'm not sure about the effect of the beam_scores.

gante commented 2 years ago

Hi @ElliottYan 👋 Thank you for pointing it out, it seems like a bug indeed. I will look into it.

ElliottYan commented 2 years ago

Great! Looking forward to your solution. For now, I just swap these two lines (L2566 && 2567) and the error disappears. But I'm not sure what I do is correct.

patrickvonplaten commented 2 years ago

Are you using half or full precision here? Also inf values are not necessarily the reason for a bug, it might also be that mBart has some default logit processor settings that 0 out values which the lead to inf (cc @gante)