I am currently trying to finetune a MT5 Model for Multitasking (classification, report filling and question generation)
As Iam trying and debugging for a couple days now, maybe you or some of youre Colleagues can have a look on my Code and hopefully find the mistake.
Here is the relevant Code and some outputs for clarification:
`import torch
import transformers
from tqdm import tqdm # For progress bar
from datasets import DatasetDict
# Initialize the tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("google/mt5-base")
def convert_to_features(example_batch, task_name):
features = {}
if task_name == "event_classification":
# Tokenisierung für Text
features = tokenizer(example_batch['text'], max_length=1000, padding="max_length", truncation=True, return_tensors="pt")
numeric_labels = [label_to_id[example_batch['label']]]
features["labels"] = torch.tensor(numeric_labels, dtype=torch.long).unsqueeze(1) # Adjusting dimension
features["decoder_input_ids"] = torch.tensor(numeric_labels, dtype=torch.long).unsqueeze(1)
elif task_name == "report_filling":
features = tokenizer(example_batch['input_text'], max_length=1000, padding="max_length", truncation=True, return_tensors="pt")
labels = tokenizer(example_batch['report'], max_length=1000, padding="max_length", truncation=True, return_tensors="pt")["input_ids"]
features["labels"] = labels
features["decoder_input_ids"] = labels.clone() # Clone for decoder_input_ids
elif task_name == "question_generation":
input_texts = []
labels = []
for text, missing_info, questions in zip(
example_batch['text'], example_batch['missing_info'], example_batch['generated_questions']
):
missing_info_str = ", ".join(missing_info)
input_text = f"{text} Missing info: {missing_info_str}"
input_texts.append(input_text)
questions_str = " [SEP] ".join(questions)
labels.append(questions_str)
# Tokenisierung für input_texts
features = tokenizer(input_texts, max_length=1000, padding="max_length", truncation=True, return_tensors="pt")
# Tokenisierung für die Labels
labels_encodings = tokenizer(labels, max_length=1000, padding="max_length", truncation=True, return_tensors="pt")["input_ids"]
features["labels"] = labels_encodings
features["decoder_input_ids"] = labels_encodings.clone() # Clone for decoder_input_ids
return features
def tokenize_dataset(dataset, task_name, max_length=1000):
num_examples = len(dataset)
input_ids = []
attention_masks = []
labels = []
decoder_input_ids = []
for i in tqdm(range(num_examples), total=num_examples):
example = dataset[i]
features = convert_to_features(example, task_name)
# Instead of appending, convert to tensor and squeeze to remove extra dimensions
input_ids.append(features['input_ids'].squeeze(0)) # Removes dimension of size 1
attention_masks.append(features['attention_mask'].squeeze(0))
labels.append(features['labels'].squeeze(0)) # Squeeze if necessary
decoder_input_ids.append(features['decoder_input_ids'].squeeze(0))
# Create a dictionary to hold the dataset
if task_name == "question_generation":
# Concatenate along the first dimension to flatten
input_ids = torch.cat(input_ids, dim=0) # shape: [total_questions, 1000]
attention_masks = torch.cat(attention_masks, dim=0) # shape: [total_questions, 1000]
labels = torch.cat(labels, dim=0) # shape: [total_questions, 1000]
decoder_input_ids = torch.cat(decoder_input_ids, dim=0) # shape: [total_questions, 1000]
else:
# For the other tasks, just stack the sequences
input_ids = torch.stack(input_ids) # shape: [num_examples, 1000]
attention_masks = torch.stack(attention_masks) # shape: [num_examples, 1000]
labels = torch.stack(labels) # shape: [num_examples, 1] or [num_examples, 1000]
decoder_input_ids = torch.stack(decoder_input_ids) # shape: [num_examples, 1] or [num_examples, 1000]
# Create a dataset dictionary with potentially different shapes
features = {
'input_ids': input_ids,
'attention_mask': attention_masks,
'labels': labels,
'decoder_input_ids': decoder_input_ids,
}
return features
# Process each task in the DatasetDict
features_dict = {}
for task_name, dataset in dataset_dict.items():
print(f"Tokenizing dataset for task: {task_name}")
features_dict[task_name] = tokenize_dataset(dataset, task_name)
import transformers
import torch.nn as nn
class MultiTaskModel(transformers.PreTrainedModel):
def __init__(self, encoder, task_models):
super().__init__(transformers.PretrainedConfig())
self.encoder = encoder
self.task_models = nn.ModuleDict(task_models)
@classmethod
def create(cls, model_name, model_type_dict, model_config_dict):
encoder = transformers.AutoModel.from_pretrained(model_name)
task_models = {
task: model_type_dict[task].from_pretrained(model_name, config=model_config_dict[task])
for task in model_type_dict
}
return cls(encoder, task_models)
def forward(self, input_ids, attention_mask, task_name, labels=None, decoder_input_ids=None):
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = encoder_outputs.last_hidden_state
if task_name in self.task_models:
task_model = self.task_models[task_name]
if isinstance(task_model, transformers.MT5ForSequenceClassification):
if decoder_input_ids is None:
raise ValueError("decoder_input_ids must be provided for classification tasks")
return task_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=last_hidden_state,
labels=labels,
decoder_input_ids=decoder_input_ids # Ensure this is passed
)
elif isinstance(task_model, transformers.MT5ForConditionalGeneration):
if decoder_input_ids is None:
raise ValueError("decoder_input_ids must be provided for generation tasks")
return task_model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
labels=labels
)
raise ValueError(f"Unknown task: {task_name}")
# Model configuration
model_name = "google/mt5-base"
task_model_types = {
"event_classification": transformers.MT5ForSequenceClassification,
"report_filling": transformers.MT5ForConditionalGeneration,
"question_generation": transformers.MT5ForConditionalGeneration,
}
model_configurations = {
"event_classification": transformers.MT5Config.from_pretrained(model_name, num_labels=10), # Adjust number of labels
"report_filling": transformers.MT5Config.from_pretrained(model_name),
"question_generation": transformers.MT5Config.from_pretrained(model_name),
}
# Create the multi-task model
multi_task_model = MultiTaskModel.create(model_name, task_model_types, model_configurations)
from torch.utils.data import DataLoader, Dataset
# Custom Dataset für Multi-Task
class MultiTaskDataset(Dataset):
def __init__(self, features):
self.features = features
def __len__(self):
return self.features['input_ids'].shape[0]
def __getitem__(self, idx):
return {
'input_ids': self.features['input_ids'][idx],
'attention_mask': self.features['attention_mask'][idx],
'labels': self.features['labels'][idx],
'decoder_input_ids': self.features['input_ids'][idx]
}
# DataLoader Creation
def create_data_loader(features, batch_size):
dataset = MultiTaskDataset(features)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Batch-Größe
batch_size = 32
data_loaders = {task: create_data_loader(features, batch_size) for task, features in features_dict.items()}
class MultiTaskDataLoader:
def __init__(self, data_loaders):
self.data_loaders = data_loaders
self.iterators = {task: iter(loader) for task, loader in data_loaders.items()}
def __iter__(self):
while True:
for task, iterator in self.iterators.items():
try:
batch = next(iterator)
print(f"\nProcessing Task: {task}")
print(batch)
# Bereite Tensoren vor
input_ids = batch['input_ids'].to(torch.int64)
attention_mask = batch['attention_mask'].to(torch.int64)
# Debugging-Ausgabe
print(f"Type of input_ids: {type(input_ids)}, Shape: {input_ids.shape}")
print(f"Type of attention_mask: {type(attention_mask)}, Shape: {attention_mask.shape}")
# Überprüfen und Umformen der Labels und decoder_input_ids
labels = batch['labels'].clone().detach().to(torch.int64)
decoder_input_ids = batch['decoder_input_ids'].clone().detach().to(torch.int64)
# Debugging-Ausgabe für Labels und decoder_input_ids
print(f"Type of labels: {type(labels)}, Shape: {labels.shape}")
print(f"Type of decoder_input_ids: {type(decoder_input_ids)}, Shape: {decoder_input_ids.shape}")
yield {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'decoder_input_ids': decoder_input_ids,
'task_name': task
}
except StopIteration:
# Iterator zurücksetzen, wenn einer erschöpft ist
self.iterators[task] = iter(self.data_loaders[task])
continue
except Exception as e:
print(f"Fehler: {e}")
print(f"Batch details: {batch}")
def __len__(self):
return min(len(loader) for loader in self.data_loaders.values())
class MultiTaskTrainer(transformers.Trainer):
def __init__(self, *args, data_loaders=None, **kwargs):
super().__init__(*args, **kwargs)
if data_loaders is None:
raise ValueError("data_loaders darf nicht None sein")
self.data_loaders = data_loaders # saving custom DataLoader
def get_train_dataloader(self):
# Verwendung des benutzerdefinierten MultiTaskDataLoader
return MultiTaskDataLoader(self.data_loaders)
def training_step(self, model, inputs, num_items_in_batch):
task_name = inputs['task_name']
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
labels = inputs['labels']
decoder_input_ids = inputs['decoder_input_ids']
# Forward Pass with task_name
loss = model(input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_input_ids=decoder_input_ids,
task_name=task_name)
return loss
Output:
Type of input_ids: <class 'torch.Tensor'>, Shape: torch.Size([10, 1000])
Type of attention_mask: <class 'torch.Tensor'>, Shape: torch.Size([10, 1000])
Type of labels: <class 'torch.Tensor'>, Shape: torch.Size([10, 1])
Type of decoder_input_ids: <class 'torch.Tensor'>, Shape: torch.Size([10, 1])
# Pass MultiTaskDataLoader to Trainer
trainer = MultiTaskTrainer(
model=multi_task_model,
args=transformers.TrainingArguments(
output_dir="./models/multitask_model",
overwrite_output_dir=True,
learning_rate=5e-5,
per_device_train_batch_size=batch_size,
num_train_epochs=3,
logging_dir='./logs',
),
data_loaders=data_loaders # Pass the DataLoader instead of the dataset
)
# Start training
trainer.train()
And here the full Error msg:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[258], [line 16](vscode-notebook-cell:?execution_count=258&line=16)
[2](vscode-notebook-cell:?execution_count=258&line=2) trainer = MultiTaskTrainer(
[3](vscode-notebook-cell:?execution_count=258&line=3) model=multi_task_model,
[4](vscode-notebook-cell:?execution_count=258&line=4) args=transformers.TrainingArguments(
(...)
[15](vscode-notebook-cell:?execution_count=258&line=15) # Start training
---> [16](vscode-notebook-cell:?execution_count=258&line=16) trainer.train()
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\transformers\trainer.py:2122, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
[2120](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2120) hf_hub_utils.enable_progress_bars()
[2121](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2121) else:
-> [2122](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2122) return inner_training_loop(
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\transformers\trainer.py:2474, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
[2471](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2471) self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
[2473](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2473) with self.accelerator.accumulate(model):
-> [2474](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2474) tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[2476](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2476) if (
[2477](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/trainer.py:2477) args.logging_nan_inf_filter
Cell In[257], [line 72](vscode-notebook-cell:?execution_count=257&line=72)
[69](vscode-notebook-cell:?execution_count=257&line=69) decoder_input_ids =decoder_input_ids = inputs['decoder_input_ids'].expand(-1, 1000)
[71](vscode-notebook-cell:?execution_count=257&line=71) # Forward Pass mit task_name
---> [72](vscode-notebook-cell:?execution_count=257&line=72) loss = model(input_ids=input_ids,
[73](vscode-notebook-cell:?execution_count=257&line=73) attention_mask=attention_mask,
[74](vscode-notebook-cell:?execution_count=257&line=74) labels=labels,
[75](vscode-notebook-cell:?execution_count=257&line=75) decoder_input_ids=decoder_input_ids,
[76](vscode-notebook-cell:?execution_count=257&line=76) task_name=task_name) # Fügen Sie hier task_name hinzu
[78](vscode-notebook-cell:?execution_count=257&line=78) return loss
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
[1734](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1734) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1735](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1736) return self._call_impl(*args, **kwargs)
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
[1745](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1745) or _global_backward_pre_hooks or _global_backward_hooks
[1746](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1746) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1747](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1747) return forward_call(*args, **kwargs)
[1749](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1749) result = None
[1750](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1750) called_always_called_hooks = set()
Cell In[244], [line 20](vscode-notebook-cell:?execution_count=244&line=20)
[19](vscode-notebook-cell:?execution_count=244&line=19) def forward(self, input_ids, attention_mask, task_name, labels=None, decoder_input_ids=None):
---> [20](vscode-notebook-cell:?execution_count=244&line=20) encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
[21](vscode-notebook-cell:?execution_count=244&line=21) last_hidden_state = encoder_outputs.last_hidden_state
[23](vscode-notebook-cell:?execution_count=244&line=23) if task_name in self.task_models:
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
[1734](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1734) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1735](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1736) return self._call_impl(*args, **kwargs)
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\transformers\models\mt5\modeling_mt5.py:1669, in MT5Model.forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
[1666](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1666) decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
[1668](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1668) # Decode
-> [1669](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1669) decoder_outputs = self.decoder(
[1670](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1670) input_ids=decoder_input_ids,
[1671](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:1671) attention_mask=decoder_attention_mask,
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
[1734](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1734) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1735](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1735) else:
-> [1736](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1736) return self._call_impl(*args, **kwargs)
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
[1745](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1745) or _global_backward_pre_hooks or _global_backward_hooks
[1746](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1746) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1747](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1747) return forward_call(*args, **kwargs)
[1749](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1749) result = None
[1750](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/torch/nn/modules/module.py:1750) called_always_called_hooks = set()
File c:\workspace\databricks-notebooks-017\.venv\lib\site-packages\transformers\models\mt5\modeling_mt5.py:977, in MT5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
[975](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:975) else:
[976](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:976) err_msg_prefix = "decoder_" if self.is_decoder else ""
--> [977](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:977) raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
[979](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:979) if self.gradient_checkpointing and self.training:
[980](file:///C:/workspace/databricks-notebooks-017/.venv/lib/site-packages/transformers/models/mt5/modeling_mt5.py:980) if use_cache:
ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds
Hi @patrickvonplaten @ArthurZucker @muellerz
I am currently trying to finetune a MT5 Model for Multitasking (classification, report filling and question generation) As Iam trying and debugging for a couple days now, maybe you or some of youre Colleagues can have a look on my Code and hopefully find the mistake.
Here is the relevant Code and some outputs for clarification:
And here the full Error msg: