huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.01k stars 26.54k forks source link

`model_kwargs` is None when `generation_config` is passed as a dict instead of `generation.GenerationConfig` #31328

Closed AADeLucia closed 2 months ago

AADeLucia commented 4 months ago

There is a small bug when a dictionary with generation config arguments is passed to .generate() instead of a generation.GenerationConfig object. The bug happens in this line in _prepare_generation_config: https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/generation/utils.py#L1382

model_kwargs is set correctly when generation_config is a generation.GenerationConfig object but when generation_config is a dictionary, model_kwargs is set to None and an error is thrown in .generate() from

https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/generation/utils.py#L1633

Small example:

gen_config = {
  "decoder_start_token_id": 2,
  "output_scores": true,
  "return_dict_in_generate": true,
  "num_return_sequences": 1,
  "max_new_tokens": 128,
  "do_sample": false,
  "num_beams": 1,
  "renormalize_logits": true,
  "output_logits": true
}

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

inputs = tokenizer("New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx.")

with torch.inference_mode():
  outputs = model.generate(
      **inputs,
      generation_config=gen_config
  )

Possible fixes could be:

hariomhpc commented 3 months ago

Hey I want to resolve this by raising a PR. I am new to it, do let me how can I proceed.

zucchini-nlp commented 3 months ago

Hey!

I don't think it's a bug. The documentation says that a GenerationConfig object is expected as input. In case you have a dict type, you can still pass it directly to the forward like below. See docs for more 🤗

model.generate(**inputs, **gen_config_dict)

cc @gante to confirm

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

gante commented 2 months ago

I confirm @zucchini-nlp's comment, working as expected 🤗