huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.36k stars 26.87k forks source link

Raise ValueError if given max_new_tokens to `Seq2SeqTrainer.predict()` #18785

Closed kumapo closed 2 years ago

kumapo commented 2 years ago

System Info

Who can help?

@sgugger

Information

Tasks

Reproduction

model = transformers.VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k",
    "bert-base-uncased"
)
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
eval_ds = datasets.load_dataset(
    "kumapo/stair_captions_dataset_script", "2014",
    data_dir="../input/coco-2014-val", split="validation", streaming=True
)
# do some preprocessing eval_ds with map() .. 
training_args = transformers.Seq2SeqTrainingArguments(
    predict_with_generate=True,
    fp16=False,
    output_dir="output/",
    report_to="none",
)
trainer = transformers.Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=transformers.default_data_collator
)
_ = trainer.predict(eval_ds, max_new_tokens=16)

then, ValueError: Both max_new_tokens and max_length have been set but they serve the same purpose raised:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_23/2318841552.py in <module>
     61     data_collator=transformers.default_data_collator,
     62 )
---> 63 _ = trainer.predict(eval_ds, max_new_tokens=16)

/opt/conda/lib/python3.7/site-packages/transformers/trainer_seq2seq.py in predict(self, test_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
    135         self._gen_kwargs = gen_kwargs
    136 
--> 137         return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
    138 
    139     def prediction_step(

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in predict(self, test_dataset, ignore_keys, metric_key_prefix)
   2844         eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
   2845         output = eval_loop(
-> 2846             test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
   2847         )
   2848         total_batch_size = self.args.eval_batch_size * self.args.world_size

/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   2947 
   2948             # Prediction step
-> 2949             loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   2950             inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
   2951 

/opt/conda/lib/python3.7/site-packages/transformers/trainer_seq2seq.py in prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
    201         generated_tokens = self.model.generate(
    202             generation_inputs,
--> 203             **gen_kwargs,
    204         )
    205         # in case the batch is shorter than max length, the output should be padded

/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/opt/conda/lib/python3.7/site-packages/transformers/generation_utils.py in generate(self, inputs, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, typical_p, repetition_penalty, bad_words_ids, force_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, logits_processor, renormalize_logits, stopping_criteria, constraints, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, exponential_decay_length_penalty, **model_kwargs)
   1237         elif max_length is not None and max_new_tokens is not None:
   1238             raise ValueError(
-> 1239                 "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
   1240                 " limit to the generated output length. Remove one of those arguments. Please refer to the"
   1241                 " documentation for more information. "

ValueError: Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a limit to the generated output length. Remove one of those arguments. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)

Expected behavior

nothing raised.

kumapo commented 2 years ago

thank you all for kind supports!