huggingface / transformers

πŸ€— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.78k stars 26.75k forks source link

InvalidArgumentError: Incompatible shapes: [5,20] vs. [5,18] [Op:Less] #4088

Closed zirlman closed 4 years ago

zirlman commented 4 years ago

πŸ› Bug

Information

Model I am using (Bert, XLNet ...): T5 - TFT5ForConditionalGeneration

Language I am using the model on (English, Chinese ...): English

The problem arises when using:

The tasks I am working on is:

To reproduce

Steps to reproduce the behavior:

  1. Download pre-trained T5 model & T5 tokenizer
  2. Encode this sentence: question: What is coronavirus? context: Coronavirus disease 2019 (COVID-19) is an infectious disease caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2). The disease was first identified in December 2019 in Wuhan, the capital of China's Hubei province, and has since spread globally, resulting in the ongoing 2019–20 coronavirus pandemic. As of 30 April 2020,[update] more than 3.19 million cases have been reported across 185 countries and territories, resulting in more than 227,000 deaths. More than 972,000 people have recovered.
  3. Generate N answer (the number doesn't matter, in my case it was 5/7/10)

Code:

from transformers import T5Tokenizer, TFT5ForConditionalGeneration

model_str = "t5-base"
hyperparams = dict(
        top_k = 50,
        top_p = 0.95,
        max_length = None,
        temperature = 0.7,
        num_return_sequences = 5,
        do_sample=True,
        use_cache=False)

tokenizer = T5Tokenizer.from_pretrained(model_str)
model = TFT5ForConditionalGeneration.from_pretrained(model_str)

def generate(input_ids):
  outputs = model.generate(input_ids, **hyperparams)
  all_outputs = []
  if outputs is not None and outputs.shape[0] == 1:
      outputs = tokenizer.decode(tf.squeeze(outputs), skip_special_tokens=True)
      all_outputs.append(outputs) 
  elif outputs is not None:
      all_outputs.extend([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])
  return all_outputs

sentence = """question: What is coronavirus? context: Coronavirus disease 2019 (COVID-19) is an infectious disease caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2). The disease was first identified in December 2019 in Wuhan, the capital of China's Hubei province, and has since spread globally, resulting in the ongoing 2019–20 coronavirus pandemic. As of 30 April 2020,[update] more than 3.19 million cases have been reported across 185 countries and territories, resulting in more than 227,000 deaths. More than 972,000 people have recovered.
""".replace("\n"," ")

input_ids = tokenizer.encode(sentence,return_tensors="tf")
generate(input_ids)

Expected behavior

This error happens for some questions only. If you remove the question mark from the question you'll get an output. First I've thought that the question mark is the problem, but on other examples both with and without question mark resulted in the same error.

Environment info

patrickvonplaten commented 4 years ago

Hey @zirlman,

Thanks a lot for catching the error and the detailed error description. The PR that will fix the error is linked to the issue :-)

zirlman commented 4 years ago

@patrickvonplaten that's great. Thank you 😁