huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.19k stars 27.06k forks source link

Seq2SeqTrainer.evaluation_loop requires `labels` due to DataCollatorForSeq2Seq #14833

Closed kleinay closed 2 years ago

kleinay commented 2 years ago

Environment info

Who can help

@affjljoo3581 @patrickvonplaten

Information

I'm trying to run inference with a fine-tuned T5 model. I'm using the run_summarization script with some editions, and the problem occurs when the predict_dataset doesn't have labels (prediction time). the __call__ function on the DataCollatorForSeq2Seq object fails ("KeyError") because it expects the datasets to have a labels key:

        # prepare decoder_input_ids
        if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
            decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
            features["decoder_input_ids"] = decoder_input_ids

Model I am using (Bert, XLNet ...): T5 and BART

The problem arises when using:

The tasks I am working on is:

Expected behavior

I should be able to run the script on prediction (--do_predict) without providing labels in the dataset.

patil-suraj commented 2 years ago

Good catch! The DataCollatorForSeq2Seq should check for None labels before computing decoder_input_ids, Would you like to open a PR to fix this? Happy to help with it, thanks !

kleinay commented 2 years ago

I would have opened a PR, but there seem to have more in it. Modifying DataCollatorForSeq2Seq solved the issue for a BART model, but not for a T5 model. For T5, when I now try to use trainer.predict (as in the run_summarization.py script) over a dataset that only includes input_ids and attention_mask features but no labels, it fails to prepare the decoder:

Caught ValueError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1612, in forward
    decoder_outputs = self.decoder(
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 902, in forward
    raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds

File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/torch/_utils.py", line 425, in reraise
    raise self.exc_type(msg)
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/transformers/trainer_seq2seq.py", line 179, in prediction_step
    outputs = model(**inputs)
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/transformers/trainer.py", line 2323, in evaluation_loop
    loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/transformers/trainer.py", line 2223, in predict
    output = eval_loop(
  File "/home/nlp/kleinay/miniconda3/envs/seq2seq-qasrl/lib/python3.8/site-packages/transformers/trainer_seq2seq.py", line 117, in predict
    return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
  File "/home/nlp/kleinay/Parsing/Seq2Seq_QASRL_Parsing/qasrl_bart/run_summarization.py", line 936, in main
    predict_results = trainer.predict(

Looking at the full stack trace it seems that something in the logic of Seq2SeqTrainer.predict is problematic - it calls Trainer.evaluation_loop, which is promised in the docstring to work "both with or without labels", but it in turn calls Seq2SeqTrainer.prediction_step which seems to expect labels in the inputs dict, at least for T5 model. So I still couldn't make trainer.predict to work for T5.

kleinay commented 2 years ago

O.K, I've caught what I was doing wrong - as the docs say,

Note that T5 uses the pad_token_id as the decoder_start_token_id, so when doing generation without using generate(), make sure you start it with the pad_token_id.

So for T5 models, I need to have some dummy labels feature in the predict-dataset initialized with just [tokenizer.pad_token_id]. Still, I think the logic or documentation issues that I pointed out in the previous comment stand - it was hard to understand where is the problem when Trainer.evaluation_loop is promised to work without labels.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.