huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.04k stars 27.02k forks source link

MT5 data padding not working #24567

Closed hexie1995 closed 1 year ago

hexie1995 commented 1 year ago

System Info

Hello,

I am using the latest version of transformers.

I have run into this issue recently and would like to receive some help on it. I am using the MT5 and "google/base" to finetune to my own dataset, while processing the data, I run into the issue where I keep getting error message of dimension not matching even after padding and truncation like suggested in the example:

I tried the exact same code with XLMProphetNet, XLM Roberta, XLNet, all worked. Only MT5 gives me this error message. This error almost always occur at the first step when the trainer is trying to evaluate on the validation data. I suspect this has somethign to do with the evaluation loop, but so far I have found nothing that could help me resolve this issue.

RuntimeError: output with shape [4, 12, 1, 1] doesn't match the broadcast shape [4, 12, 1, 128] @alexayalamcs tagging Alex here.

Who can help?

@sgugger

Information

Tasks

Reproduction

from transformers import AutoTokenizer, XLMProphetNetDecoder,DataCollatorWithPadding
from transformers import DataCollatorForLanguageModeling
from datasets import concatenate_datasets, load_dataset
from transformers import MT5ForConditionalGeneration, MT5Tokenizer, MT5Config, MT5Model,T5Tokenizer
import torch
from torch.utils.data import DataLoader
from transformers import Trainer
import nltk
import random
from accelerate import Accelerator
accelerator = Accelerator()
import datasets
rouge = datasets.load_metric("rouge")
import evaluate
accuracy_metric = evaluate.load("accuracy")

train = load_dataset("cnn_dailymail", "3.0.0", split = "train")
valid = load_dataset("cnn_dailymail", "3.0.0", split = "validation")
test = load_dataset("cnn_dailymail", "3.0.0", split = "test")

model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
tokenizer = T5Tokenizer.from_pretrained("google/mt5-base")

encoder_max_length=512
decoder_max_length=128

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
    inputs = tokenizer(batch["article"], padding="max_length",truncation=True, max_length=encoder_max_length)
    outputs = tokenizer(batch["highlights"],padding="max_length", truncation=True, max_length=decoder_max_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids.copy()

    return batch

train_data = train.select(range(16))
#train_data = train_init
#batch_size = 16
batch_size=4

train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

val_data = valid.select(range(8))
#val_data = valid
val_data = val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

from transformers import Seq2SeqTrainer,Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    num_train_epochs = 3, 
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=False,
    output_dir="./",
    logging_steps=2,
    #save_steps=5000,
    eval_steps=2,
    # logging_steps=1000,
    # save_steps=500,
    # eval_steps=7500,
    # warmup_steps=2000,
    # save_total_limit=3,
)

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids)
    label_str = tokenizer.batch_decode(labels_ids)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)
trainer.train()

Expected behavior

I would expect this to run through just fine like XLMPropheNet, XLM Roberta, and XLNet, but it does not.

sgugger commented 1 year ago

cc @ArthurZucker

hexie1995 commented 1 year ago

Thank you. One additional information: I tried to follow step by step the official text summrization tutorial here: https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb But the same error occurred. Thanks a lot!

ArthurZucker commented 1 year ago

Hey! Thanks for reporting could you share the entire traceback of the error? 😉

hexie1995 commented 1 year ago

Sure, here's the whole error message. Thanks a lot!

`---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], line 10
      1 data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
      2 trainer = Seq2SeqTrainer(
      3     model=model,
      4     args=training_args,
   (...)
      8     data_collator=data_collator,
      9 )
---> 10 trainer.train()
     11 output = "/output/"
     12 #trainer.save_model(output + "MT5-12-original-XLSUM-accuracy")

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\trainer.py:1645, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1640     self.model_wrapped = self.model
   1642 inner_training_loop = find_executable_batch_size(
   1643     self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
   1644 )
-> 1645 return inner_training_loop(
   1646     args=args,
   1647     resume_from_checkpoint=resume_from_checkpoint,
   1648     trial=trial,
   1649     ignore_keys_for_eval=ignore_keys_for_eval,
   1650 )

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\trainer.py:2011, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2008     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2009     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2011     self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
   2012 else:
   2013     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\trainer.py:2312, in Trainer._maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval)
   2310         metrics.update(dataset_metrics)
   2311 else:
-> 2312     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2313 self._report_to_hp_search(trial, self.state.global_step, metrics)
   2315 # Run delayed LR scheduler now that metrics are populated

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\trainer_seq2seq.py:159, in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
    154 gen_kwargs["num_beams"] = (
    155     gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
    156 )
    157 self._gen_kwargs = gen_kwargs
--> 159 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\trainer.py:3043, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3040 start_time = time.time()
   3042 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3043 output = eval_loop(
   3044     eval_dataloader,
   3045     description="Evaluation",
   3046     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3047     # self.args.prediction_loss_only
   3048     prediction_loss_only=True if self.compute_metrics is None else None,
   3049     ignore_keys=ignore_keys,
   3050     metric_key_prefix=metric_key_prefix,
   3051 )
   3053 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3054 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\trainer.py:3235, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3232         batch_size = observed_batch_size
   3234 # Prediction step
-> 3235 loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   3236 inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
   3238 if is_torch_tpu_available():

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\trainer_seq2seq.py:276, in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
    270 if (
    271     "labels" in inputs
    272     and "decoder_input_ids" in inputs
    273     and inputs["labels"].shape == inputs["decoder_input_ids"].shape
    274 ):
    275     inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
--> 276 generated_tokens = self.model.generate(**inputs, **gen_kwargs)
    278 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
    279 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
    280 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
    281 if self.model.generation_config._from_model_config:

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\generation\utils.py:1522, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1516         raise ValueError(
   1517             "num_return_sequences has to be 1 when doing greedy search, "
   1518             f"but is {generation_config.num_return_sequences}."
   1519         )
   1521     # 11. run greedy search
-> 1522     return self.greedy_search(
   1523         input_ids,
   1524         logits_processor=logits_processor,
   1525         stopping_criteria=stopping_criteria,
   1526         pad_token_id=generation_config.pad_token_id,
   1527         eos_token_id=generation_config.eos_token_id,
   1528         output_scores=generation_config.output_scores,
   1529         return_dict_in_generate=generation_config.return_dict_in_generate,
   1530         synced_gpus=synced_gpus,
   1531         streamer=streamer,
   1532         **model_kwargs,
   1533     )
   1535 elif is_contrastive_search_gen_mode:
   1536     if generation_config.num_return_sequences > 1:

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\generation\utils.py:2339, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2336 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2338 # forward pass to get next token
-> 2339 outputs = self(
   2340     **model_inputs,
   2341     return_dict=True,
   2342     output_attentions=output_attentions,
   2343     output_hidden_states=output_hidden_states,
   2344 )
   2346 if synced_gpus and this_peer_finished:
   2347     continue  # don't waste resources running the code we don't need

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\torch\nn\modules\module.py:1051, in Module._call_impl(self, *input, **kwargs)
   1047 # If we don't have any hooks, we want to skip the rest of the logic in
   1048 # this function, and just call forward.
   1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051     return forward_call(*input, **kwargs)
   1052 # Do not call functions when jit is used
   1053 full_backward_hooks, non_full_backward_hooks = [], []

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\models\mt5\modeling_mt5.py:1753, in MT5ForConditionalGeneration.forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1750         decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
   1752 # Decode
-> 1753 decoder_outputs = self.decoder(
   1754     input_ids=decoder_input_ids,
   1755     attention_mask=decoder_attention_mask,
   1756     inputs_embeds=decoder_inputs_embeds,
   1757     past_key_values=past_key_values,
   1758     encoder_hidden_states=hidden_states,
   1759     encoder_attention_mask=attention_mask,
   1760     head_mask=decoder_head_mask,
   1761     cross_attn_head_mask=cross_attn_head_mask,
   1762     use_cache=use_cache,
   1763     output_attentions=output_attentions,
   1764     output_hidden_states=output_hidden_states,
   1765     return_dict=return_dict,
   1766 )
   1768 sequence_output = decoder_outputs[0]
   1770 # Set device for model parallelism

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\torch\nn\modules\module.py:1051, in Module._call_impl(self, *input, **kwargs)
   1047 # If we don't have any hooks, we want to skip the rest of the logic in
   1048 # this function, and just call forward.
   1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051     return forward_call(*input, **kwargs)
   1052 # Do not call functions when jit is used
   1053 full_backward_hooks, non_full_backward_hooks = [], []

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\models\mt5\modeling_mt5.py:1062, in MT5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1049     layer_outputs = checkpoint(
   1050         create_custom_forward(layer_module),
   1051         hidden_states,
   (...)
   1059         None,  # past_key_value is always None with gradient checkpointing
   1060     )
   1061 else:
-> 1062     layer_outputs = layer_module(
   1063         hidden_states,
   1064         attention_mask=extended_attention_mask,
   1065         position_bias=position_bias,
   1066         encoder_hidden_states=encoder_hidden_states,
   1067         encoder_attention_mask=encoder_extended_attention_mask,
   1068         encoder_decoder_position_bias=encoder_decoder_position_bias,
   1069         layer_head_mask=layer_head_mask,
   1070         cross_attn_layer_head_mask=cross_attn_layer_head_mask,
   1071         past_key_value=past_key_value,
   1072         use_cache=use_cache,
   1073         output_attentions=output_attentions,
   1074     )
   1076 # layer_outputs is a tuple with:
   1077 # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
   1078 if use_cache is False:

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\torch\nn\modules\module.py:1051, in Module._call_impl(self, *input, **kwargs)
   1047 # If we don't have any hooks, we want to skip the rest of the logic in
   1048 # this function, and just call forward.
   1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051     return forward_call(*input, **kwargs)
   1052 # Do not call functions when jit is used
   1053 full_backward_hooks, non_full_backward_hooks = [], []

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\models\mt5\modeling_mt5.py:557, in MT5Block.forward(self, hidden_states, attention_mask, position_bias, encoder_hidden_states, encoder_attention_mask, encoder_decoder_position_bias, layer_head_mask, cross_attn_layer_head_mask, past_key_value, use_cache, output_attentions, return_dict)
    554 else:
    555     self_attn_past_key_value, cross_attn_past_key_value = None, None
--> 557 self_attention_outputs = self.layer[0](
    558     hidden_states,
    559     attention_mask=attention_mask,
    560     position_bias=position_bias,
    561     layer_head_mask=layer_head_mask,
    562     past_key_value=self_attn_past_key_value,
    563     use_cache=use_cache,
    564     output_attentions=output_attentions,
    565 )
    566 hidden_states, present_key_value_state = self_attention_outputs[:2]
    567 attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\torch\nn\modules\module.py:1051, in Module._call_impl(self, *input, **kwargs)
   1047 # If we don't have any hooks, we want to skip the rest of the logic in
   1048 # this function, and just call forward.
   1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051     return forward_call(*input, **kwargs)
   1052 # Do not call functions when jit is used
   1053 full_backward_hooks, non_full_backward_hooks = [], []

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\models\mt5\modeling_mt5.py:462, in MT5LayerSelfAttention.forward(self, hidden_states, attention_mask, position_bias, layer_head_mask, past_key_value, use_cache, output_attentions)
    451 def forward(
    452     self,
    453     hidden_states,
   (...)
    459     output_attentions=False,
    460 ):
    461     normed_hidden_states = self.layer_norm(hidden_states)
--> 462     attention_output = self.SelfAttention(
    463         normed_hidden_states,
    464         mask=attention_mask,
    465         position_bias=position_bias,
    466         layer_head_mask=layer_head_mask,
    467         past_key_value=past_key_value,
    468         use_cache=use_cache,
    469         output_attentions=output_attentions,
    470     )
    471     hidden_states = hidden_states + self.dropout(attention_output[0])
    472     outputs = (hidden_states,) + attention_output[1:]  # add attentions if we output them

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\torch\nn\modules\module.py:1051, in Module._call_impl(self, *input, **kwargs)
   1047 # If we don't have any hooks, we want to skip the rest of the logic in
   1048 # this function, and just call forward.
   1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051     return forward_call(*input, **kwargs)
   1052 # Do not call functions when jit is used
   1053 full_backward_hooks, non_full_backward_hooks = [], []

File ~\AppData\Local\anaconda3\envs\hface\lib\site-packages\transformers\models\mt5\modeling_mt5.py:420, in MT5Attention.forward(self, hidden_states, mask, key_value_states, position_bias, past_key_value, layer_head_mask, query_length, use_cache, output_attentions)
    417 else:
    418     position_bias_masked = position_bias
--> 420 scores += position_bias_masked
    421 attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
    422     scores
    423 )  # (batch_size, n_heads, seq_length, key_length)
    424 attn_weights = nn.functional.dropout(
    425     attn_weights, p=self.dropout, training=self.training
    426 )  # (batch_size, n_heads, seq_length, key_length)

RuntimeError: output with shape [4, 12, 1, 1] doesn't match the broadcast shape [4, 12, 1, 32]
`
ArthurZucker commented 1 year ago

Hey! I did not have time to check this, if you can isolate a small reproduction script (without all the training loop) would be great. Otherwise, I am investigating

hexie1995 commented 1 year ago

Hi Arthur @ArthurZucker , the code that I shared initially is a small training loop without all the samples and could reproduce the error once run (the training size is set to be 16 and the evaluation set to be 8). The run time should take about 3 minutes top, because it has to download the CNNDailyMail dataset first. Thank a lot for your help!!

ArthurZucker commented 1 year ago

Ok, low on bandwidth so pinging @Rocketknight1 in case he can have a look!

ArthurZucker commented 1 year ago

Sorry @hexie1995 did not have time to have look 😢

Rocketknight1 commented 1 year ago

I figured this one out! Making a PR.

Rocketknight1 commented 1 year ago

@hexie1995 This should now be fixed on main! You can install from main with pip install git+https://github.com/huggingface/transformers.git. It will also be included in the next release, at which point you can go back to just pip install transformers.

And thanks for the bug report - it turns out there really was an issue deep in the transformers code that was causing this!

hexie1995 commented 1 year ago

Thank you! This is wonderful news. I will install the new one now.