huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.9k stars 26.51k forks source link

Beam search sometimes fails this assert error #3188

Closed Laksh1997 closed 4 years ago

Laksh1997 commented 4 years ago

🐛 Bug

Information

Model I am using: GPT2 with custom config (vocab=27)

Language I am using the model on (English, Chinese ...): Molecules... (see https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system)

The problem arises when using: Using the .generate(beam=2) function

The tasks I am working on is: Generating molecules

The Problem

Essentially, every so

To reproduce

Steps to reproduce the behavior:

Just do the generation with these args:

args = { "num_beams": 3, "max_length": 50, "temperature": 1, "repetition_penalty": 1, "length_penalty": 1, "do_sample": true, "top_k": 50, "top_p": 1}

The generation runs fine for several batches, but then after like 100s of iterations, it sometimes bugs out with this error:

    assert len(next_batch_beam) == num_beams * (batch_ex + 1), f"{next_batch_beam}, {num_beams}, {batch_ex}"

I then added print statements to modeling_utils to try and see what is going on. I changed the assert line to:

assert len(next_batch_beam) == num_beams * (batch_ex + 1), f"{next_batch_beam}, {num_beams}, {batch_ex}"

And with this I got:

AssertionError: [(tensor(-19.8421), tensor(26), tensor(0)), (tensor(-20.9710), tensor(26), tensor(0)), (tensor(-30.5064), tensor(5), tensor(0)), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (tensor(-17.4236), tensor(11), tensor(9)), (tensor(-26.3645), tensor(16), tensor(9)), (tensor(-23.9410), tensor(16), tensor(9)), (0, 0, 0), (0, 0, 0), (0, 0, 0), (tensor(-58.0648), tensor(0), tensor(15))], 3, 5

If you count up the length of the list, it is length 16, which =/= 3 * (5 + 1)

Not sure what is going on here, looking into the code now to try and figure out what is going on.

patrickvonplaten commented 4 years ago

@Laksh1997 thanks a lot for reporting this error. Can you provide a code snippet and maybe a link to your data to easily reproduce this error?

In the meantime:

Laksh1997 commented 4 years ago

Right, I'll try the master branch and inform any further problems.

For generation of samples (without any context or input), there is no point of having sample set to False, as one will always generate the same sample as a result.

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Khasir commented 4 years ago

I'm running into this issue somewhat commonly as well.

Code to reproduce error:

import torch
from transformers import MarianMTModel, MarianTokenizer

torch.manual_seed(15)
phrase = "Ich verstehe nicht, was du sagen willst. Sprich doch Deutsch!"
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-de-ZH")
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-de-ZH")

# Nucleus sampling as per https://github.com/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb
input_ids = tokenizer.prepare_translation_batch([phrase])
token_ids_p = model.generate(
    **input_ids,
    do_sample=True,
    top_p=0.9,
)

translated_p = [tokenizer.decode(string, skip_special_tokens=True) for string in token_ids_p]
print(translated_p)

Error:

Traceback (most recent call last):
  File "temp.py", line 14, in <module>
    top_p=0.9,
  File "/Users/kaz/envs/venv-3.7/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/Users/kaz/envs/venv-3.7/lib/python3.7/site-packages/transformers/generation_utils.py", line 459, in generate
    model_specific_kwargs=model_specific_kwargs,
  File "/Users/kaz/envs/venv-3.7/lib/python3.7/site-packages/transformers/generation_utils.py", line 757, in _generate_beam_search
    assert len(next_sent_beam) == num_beams, "Beam should always be full"

@patrickvonplaten Is it possible to revive this issue?