huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.64k stars 27.16k forks source link

ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds while trying to finetune a MT5 for Multitasking #34614

Open Pascalsmdt opened 3 weeks ago

Pascalsmdt commented 3 weeks ago

Hi @patrickvonplaten @ArthurZucker @muellerz

I am currently trying to finetune a MT5 Model for Multitasking (classification, report filling and question generation) As Iam trying and debugging for a couple days now, maybe you or some of youre Colleagues can have a look on my Code and hopefully find the mistake.

Here is the relevant Code and some outputs for clarification:

`import torch
import transformers
from tqdm import tqdm  # For progress bar
from datasets import DatasetDict

# Initialize the tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("google/mt5-base")

def convert_to_features(example_batch, task_name):
    features = {}

    if task_name == "event_classification":
        # Tokenisierung für Text
        features = tokenizer(example_batch['text'], max_length=1000, padding="max_length", truncation=True, return_tensors="pt")

        numeric_labels = [label_to_id[example_batch['label']]]
        features["labels"] = torch.tensor(numeric_labels, dtype=torch.long).unsqueeze(1)  # Adjusting dimension
        features["decoder_input_ids"] = torch.tensor(numeric_labels, dtype=torch.long).unsqueeze(1)

    elif task_name == "report_filling":
        features = tokenizer(example_batch['input_text'], max_length=1000, padding="max_length", truncation=True, return_tensors="pt")

        labels = tokenizer(example_batch['report'], max_length=1000, padding="max_length", truncation=True, return_tensors="pt")["input_ids"]
        features["labels"] = labels
        features["decoder_input_ids"] = labels.clone()  # Clone for decoder_input_ids

    elif task_name == "question_generation":
        input_texts = []
        labels = []

        for text, missing_info, questions in zip(
            example_batch['text'], example_batch['missing_info'], example_batch['generated_questions']
        ):
            missing_info_str = ", ".join(missing_info)
            input_text = f"{text} Missing info: {missing_info_str}"
            input_texts.append(input_text)
            questions_str = " [SEP] ".join(questions)
            labels.append(questions_str)

        # Tokenisierung für input_texts
        features = tokenizer(input_texts, max_length=1000, padding="max_length", truncation=True, return_tensors="pt")

        # Tokenisierung für die Labels
        labels_encodings = tokenizer(labels, max_length=1000, padding="max_length", truncation=True, return_tensors="pt")["input_ids"]
        features["labels"] = labels_encodings
        features["decoder_input_ids"] = labels_encodings.clone()  # Clone for decoder_input_ids

    return features

def tokenize_dataset(dataset, task_name, max_length=1000):
    num_examples = len(dataset)

    input_ids = []
    attention_masks = []
    labels = []
    decoder_input_ids = []

    for i in tqdm(range(num_examples), total=num_examples):
        example = dataset[i]
        features = convert_to_features(example, task_name)

        # Instead of appending, convert to tensor and squeeze to remove extra dimensions
        input_ids.append(features['input_ids'].squeeze(0))  # Removes dimension of size 1
        attention_masks.append(features['attention_mask'].squeeze(0))
        labels.append(features['labels'].squeeze(0))  # Squeeze if necessary
        decoder_input_ids.append(features['decoder_input_ids'].squeeze(0))

    # Create a dictionary to hold the dataset
    if task_name == "question_generation":
        # Concatenate along the first dimension to flatten
        input_ids = torch.cat(input_ids, dim=0)  # shape: [total_questions, 1000]
        attention_masks = torch.cat(attention_masks, dim=0)  # shape: [total_questions, 1000]
        labels = torch.cat(labels, dim=0)  # shape: [total_questions, 1000]
        decoder_input_ids = torch.cat(decoder_input_ids, dim=0)  # shape: [total_questions, 1000]

    else:
        # For the other tasks, just stack the sequences
        input_ids = torch.stack(input_ids)  # shape: [num_examples, 1000]
        attention_masks = torch.stack(attention_masks)  # shape: [num_examples, 1000]
        labels = torch.stack(labels)  # shape: [num_examples, 1] or [num_examples, 1000]
        decoder_input_ids = torch.stack(decoder_input_ids)  # shape: [num_examples, 1] or [num_examples, 1000]

    # Create a dataset dictionary with potentially different shapes
    features = {
        'input_ids': input_ids,
        'attention_mask': attention_masks,
        'labels': labels,
        'decoder_input_ids': decoder_input_ids,
    }

    return features

# Process each task in the DatasetDict
features_dict = {}
for task_name, dataset in dataset_dict.items():
    print(f"Tokenizing dataset for task: {task_name}")
    features_dict[task_name] = tokenize_dataset(dataset, task_name)
import transformers
import torch.nn as nn

class MultiTaskModel(transformers.PreTrainedModel):
    def __init__(self, encoder, task_models):
        super().__init__(transformers.PretrainedConfig())
        self.encoder = encoder
        self.task_models = nn.ModuleDict(task_models)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict):
        encoder = transformers.AutoModel.from_pretrained(model_name)
        task_models = {
            task: model_type_dict[task].from_pretrained(model_name, config=model_config_dict[task])
            for task in model_type_dict
        }
        return cls(encoder, task_models)

    def forward(self, input_ids, attention_mask, task_name, labels=None, decoder_input_ids=None):
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = encoder_outputs.last_hidden_state

        if task_name in self.task_models:
            task_model = self.task_models[task_name]

            if isinstance(task_model, transformers.MT5ForSequenceClassification):
                if decoder_input_ids is None:
                    raise ValueError("decoder_input_ids must be provided for classification tasks")
                return task_model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    inputs_embeds=last_hidden_state, 
                    labels=labels, 
                    decoder_input_ids=decoder_input_ids  # Ensure this is passed
                )

            elif isinstance(task_model, transformers.MT5ForConditionalGeneration):
                if decoder_input_ids is None:
                    raise ValueError("decoder_input_ids must be provided for generation tasks")
                return task_model(
                    input_ids=input_ids, 
                    attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids,
                    encoder_outputs=encoder_outputs,
                    labels=labels
                )

        raise ValueError(f"Unknown task: {task_name}")

# Model configuration
model_name = "google/mt5-base"
task_model_types = {
    "event_classification": transformers.MT5ForSequenceClassification,
    "report_filling": transformers.MT5ForConditionalGeneration,
    "question_generation": transformers.MT5ForConditionalGeneration,
}
model_configurations = {
    "event_classification": transformers.MT5Config.from_pretrained(model_name, num_labels=10),  # Adjust number of labels
    "report_filling": transformers.MT5Config.from_pretrained(model_name),
    "question_generation": transformers.MT5Config.from_pretrained(model_name),
}

# Create the multi-task model
multi_task_model = MultiTaskModel.create(model_name, task_model_types, model_configurations)
from torch.utils.data import DataLoader, Dataset

# Custom Dataset für Multi-Task
class MultiTaskDataset(Dataset):
    def __init__(self, features):
        self.features = features

    def __len__(self):
        return self.features['input_ids'].shape[0]

    def __getitem__(self, idx):
        return {
            'input_ids': self.features['input_ids'][idx],
            'attention_mask': self.features['attention_mask'][idx],
            'labels': self.features['labels'][idx],
            'decoder_input_ids': self.features['input_ids'][idx]
        }

# DataLoader Creation
def create_data_loader(features, batch_size):
    dataset = MultiTaskDataset(features)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Batch-Größe
batch_size = 32
data_loaders = {task: create_data_loader(features, batch_size) for task, features in features_dict.items()}
class MultiTaskDataLoader:
    def __init__(self, data_loaders):
        self.data_loaders = data_loaders
        self.iterators = {task: iter(loader) for task, loader in data_loaders.items()}

    def __iter__(self):
        while True:
            for task, iterator in self.iterators.items():
                try:
                    batch = next(iterator)
                    print(f"\nProcessing Task: {task}")
                    print(batch)
                    # Bereite Tensoren vor
                    input_ids = batch['input_ids'].to(torch.int64)
                    attention_mask = batch['attention_mask'].to(torch.int64)

                    # Debugging-Ausgabe
                    print(f"Type of input_ids: {type(input_ids)}, Shape: {input_ids.shape}")
                    print(f"Type of attention_mask: {type(attention_mask)}, Shape: {attention_mask.shape}")

                    # Überprüfen und Umformen der Labels und decoder_input_ids
                    labels = batch['labels'].clone().detach().to(torch.int64)
                    decoder_input_ids = batch['decoder_input_ids'].clone().detach().to(torch.int64)

                    # Debugging-Ausgabe für Labels und decoder_input_ids
                    print(f"Type of labels: {type(labels)}, Shape: {labels.shape}")
                    print(f"Type of decoder_input_ids: {type(decoder_input_ids)}, Shape: {decoder_input_ids.shape}")

                    yield {
                        'input_ids': input_ids,
                        'attention_mask': attention_mask,
                        'labels': labels,
                        'decoder_input_ids': decoder_input_ids,
                        'task_name': task
                    }

                except StopIteration:
                    # Iterator zurücksetzen, wenn einer erschöpft ist
                    self.iterators[task] = iter(self.data_loaders[task])
                    continue
                except Exception as e:
                    print(f"Fehler: {e}")
                    print(f"Batch details: {batch}")

    def __len__(self):
        return min(len(loader) for loader in self.data_loaders.values())

class MultiTaskTrainer(transformers.Trainer):
    def __init__(self, *args, data_loaders=None, **kwargs):
        super().__init__(*args, **kwargs)
        if data_loaders is None:
            raise ValueError("data_loaders darf nicht None sein")
        self.data_loaders = data_loaders  # saving custom DataLoader

    def get_train_dataloader(self):
        # Verwendung des benutzerdefinierten MultiTaskDataLoader
        return MultiTaskDataLoader(self.data_loaders)

    def training_step(self, model, inputs, num_items_in_batch):
        task_name = inputs['task_name']

        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        labels = inputs['labels']
        decoder_input_ids = inputs['decoder_input_ids']

        # Forward Pass with task_name
        loss = model(input_ids=input_ids, 
                        attention_mask=attention_mask, 
                        labels=labels,
                        decoder_input_ids=decoder_input_ids,
                        task_name=task_name)  

        return loss
Output:
Type of input_ids: <class 'torch.Tensor'>, Shape: torch.Size([10, 1000])
Type of attention_mask: <class 'torch.Tensor'>, Shape: torch.Size([10, 1000])
Type of labels: <class 'torch.Tensor'>, Shape: torch.Size([10, 1])
Type of decoder_input_ids: <class 'torch.Tensor'>, Shape: torch.Size([10, 1])
# Pass MultiTaskDataLoader to Trainer
trainer = MultiTaskTrainer(
    model=multi_task_model,
    args=transformers.TrainingArguments(
        output_dir="./models/multitask_model",
        overwrite_output_dir=True,
        learning_rate=5e-5,
        per_device_train_batch_size=batch_size,
        num_train_epochs=3,
        logging_dir='./logs',
    ),
    data_loaders=data_loaders  # Pass the DataLoader instead of the dataset
)

# Start training
trainer.train()

And here the full Error msg:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[258], [line 16](vscode-notebook-cell:?execution_count=258&line=16)
      [2](vscode-notebook-cell:?execution_count=258&line=2) trainer = MultiTaskTrainer(
      [3](vscode-notebook-cell:?execution_count=258&line=3)     model=multi_task_model,
      [4](vscode-notebook-cell:?execution_count=258&line=4)     args=transformers.TrainingArguments(
   (...)
     [15](vscode-notebook-cell:?execution_count=258&line=15) # Start training
---> [16](vscode-notebook-cell:?execution_count=258&line=16) trainer.train()

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\transformers\trainer.py:2122, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   [2120](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2120)         hf_hub_utils.enable_progress_bars()
   [2121](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2121) else:
-> [2122](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2122)     return inner_training_loop(

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\transformers\trainer.py:2474, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   [2471](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2471)     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   [2473](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2473) with self.accelerator.accumulate(model):
-> [2474](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2474)     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   [2476](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2476) if (
   [2477](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2477)     args.logging_nan_inf_filter

Cell In[257], [line 72](vscode-notebook-cell:?execution_count=257&line=72)
     [69](vscode-notebook-cell:?execution_count=257&line=69) decoder_input_ids =decoder_input_ids = inputs['decoder_input_ids'].expand(-1, 1000)
     [71](vscode-notebook-cell:?execution_count=257&line=71) # Forward Pass mit task_name
---> [72](vscode-notebook-cell:?execution_count=257&line=72) loss = model(input_ids=input_ids, 
     [73](vscode-notebook-cell:?execution_count=257&line=73)                 attention_mask=attention_mask, 
     [74](vscode-notebook-cell:?execution_count=257&line=74)                 labels=labels,
     [75](vscode-notebook-cell:?execution_count=257&line=75)                 decoder_input_ids=decoder_input_ids,
     [76](vscode-notebook-cell:?execution_count=257&line=76)                 task_name=task_name)  # Fügen Sie hier task_name hinzu
     [78](vscode-notebook-cell:?execution_count=257&line=78) return loss

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1734](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1734)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1735](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1736)     return self._call_impl(*args, **kwargs)

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
  [1745](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1745)         or _global_backward_pre_hooks or _global_backward_hooks
   [1746](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1746)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1747](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1747)     return forward_call(*args, **kwargs)
   [1749](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1749) result = None
   [1750](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1750) called_always_called_hooks = set()

Cell In[244], [line 20](vscode-notebook-cell:?execution_count=244&line=20)
     [19](vscode-notebook-cell:?execution_count=244&line=19) def forward(self, input_ids, attention_mask, task_name, labels=None, decoder_input_ids=None):
---> [20](vscode-notebook-cell:?execution_count=244&line=20)     encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
     [21](vscode-notebook-cell:?execution_count=244&line=21)     last_hidden_state = encoder_outputs.last_hidden_state
     [23](vscode-notebook-cell:?execution_count=244&line=23)     if task_name in self.task_models:

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1734](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1734)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1735](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1736)     return self._call_impl(*args, **kwargs)

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\transformers\models\mt5\modeling_mt5.py:1669, in MT5Model.forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   [1666](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1666)         decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
   [1668](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1668) # Decode
-> [1669](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1669) decoder_outputs = self.decoder(
   [1670](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1670)     input_ids=decoder_input_ids,
   [1671](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1671)     attention_mask=decoder_attention_mask,

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1734](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1734)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1735](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1736)     return self._call_impl(*args, **kwargs)

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
    [1745](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1745)         or _global_backward_pre_hooks or _global_backward_hooks
   [1746](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1746)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1747](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1747)     return forward_call(*args, **kwargs)
   [1749](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1749) result = None
   [1750](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1750) called_always_called_hooks = set()

File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\transformers\models\mt5\modeling_mt5.py:977, in MT5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    [975](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:975) else:
    [976](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:976)     err_msg_prefix = "decoder_" if self.is_decoder else ""
--> [977](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:977)     raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
    [979](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:979) if self.gradient_checkpointing and self.training:
    [980](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:980)     if use_cache:

ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds
ArthurZucker commented 3 days ago

You should set use_cache=False (from the look of it) when you are training