thunlp / OpenPrompt

An Open-Source Framework for Prompt-Learning.
https://thunlp.github.io/OpenPrompt/
Apache License 2.0
4.38k stars 455 forks source link

use transformers Trainer when training #292

Open CaptWake opened 1 year ago

CaptWake commented 1 year ago

Hi, I'm using the ag_news dataset available from huggingface. I was trying to train the classifier with the Trainer class of the transformers library using the following code:

training_args = TrainingArguments(
    output_dir='training_with_es',
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy='steps',
    report_to=None, # don't report to wandb on default
    # required for early stopping
    load_best_model_at_end = True,
    eval_steps = 100,
    metric_for_best_model = 'f1',
)

trainer = Trainer(
    model=prompt_model,
    args=training_args,
    train_dataset=train_dataloader.dataloader.dataset,
    eval_dataset=valid_dataloader.dataloader.dataset,
    tokenizer=None,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, None),
)

The train_dataloader and valid_dataloader are instances of PromptDataLoader. prompt_model instead is an instance of PromptForClassification with plm a pretrained BertForMaskedLM.

But when I run trainer.train() I got the following error:

TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: PromptForClassification.forward() got an unexpected keyword argument 'labels'

I don't know whether I have made some mistakes or it's normal that doesn't work. Thanks in advance. Here is the full code I used:

from tqdm import tqdm
from datasets import load_dataset
from openprompt.data_utils import InputExample
from openprompt.plms.mlm import MLMTokenizerWrapper
from transformers import BertForMaskedLM, BertTokenizer, BertConfig
from openprompt.plms import load_plm
from openprompt.prompts import ManualVerbalizer, ManualTemplate
from openprompt import PromptDataLoader
from openprompt import PromptForClassification
import torch
from transformers import Trainer, TrainingArguments
from torch.optim import AdamW
from transformers import EarlyStoppingCallback
import evaluate

# function used for classification evaluation
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    acc_score = accuracy.compute(predictions=predictions, 
                                 references=labels)["accuracy"]
    f1_score = f1.compute(predictions=predictions, 
                          references=labels, 
                          average="micro")["f1"]
    return {"accuracy": acc_score, "F1 score": f1_score}

# load the dataset 
raw_dataset = load_dataset('ag_news', 
                           cache_dir="../datasets/.cache/huggingface_datasets")
raw_dataset['train'][0]

dataset = {}
for split in ['train', 'test']:
    dataset[split] = []
    for idx, data in tqdm(enumerate(raw_dataset[split])):
        input_example = InputExample(text_a = data['text'], 
                                     label=int(data['label']), 
                                     guid=idx)
        dataset[split].append(input_example)
print(dataset['train'][0])

# load the model
model, tokenizer, model_config, WrapperClass = (
    BertForMaskedLM.from_pretrained('nlpaueb/legal-bert-base-uncased'),
    BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased'),
    BertConfig.from_pretrained('nlpaueb/legal-bert-base-uncased'),
    MLMTokenizerWrapper
)

# define a Verbalizer
verbalizer = ManualVerbalizer(tokenizer=tokenizer, 
                              num_classes=4, 
                              label_words=[['World'], 
                                           ['Sports'], 
                                           ['Business'], 
                                           ['Sci/Tech']])

# define a Template
template = ManualTemplate(tokenizer=tokenizer, 
                          text='{"placeholder":"text_a"}. What topic is that? {"mask"}')

# view wrapped example
wrapped_example = template.wrap_one_example(dataset['train'][0])
print(wrapped_example)

train_dataloader = PromptDataLoader(dataset['train'], 
                                    template, 
                                    tokenizer=tokenizer, 
                                    tokenizer_wrapper_class=WrapperClass, 
                                    batch_size=64,
                                    decoder_max_length=384,
                                    max_seq_length=384, 
                                    shuffle=False, 
                                    teacher_forcing=False)

valid_dataloader = PromptDataLoader(dataset['test'], 
                                    template, 
                                    tokenizer=tokenizer, 
                                    tokenizer_wrapper_class=WrapperClass, 
                                    batch_size=64,
                                    decoder_max_length=384,
                                    max_seq_length=384, 
                                    shuffle=False, 
                                    teacher_forcing=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

prompt_model = PromptForClassification(plm=model, 
                                       template=template, 
                                       verbalizer=verbalizer, 
                                       freeze_plm=False)
prompt_model=prompt_model.to(device)

no_decay = ['bias', 'LayerNorm.weight']

# ===========================
# training / testing section
# ===========================

# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters = [
    {'params': [p for n, p in prompt_model.named_parameters() 
                if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.named_parameters() 
                if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)

accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

training_args = TrainingArguments(
    output_dir='training_with_es',
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy='steps',
    report_to=None, # don't report to wandb on default
    # required for early stopping
    load_best_model_at_end = True,
    eval_steps = 100,
    metric_for_best_model = 'f1',
)

trainer = Trainer(
    model=prompt_model,
    args=training_args,
    train_dataset=train_dataloader.dataloader.dataset,
    eval_dataset=valid_dataloader.dataloader.dataset,
    tokenizer=None,
    compute_metrics=compute_metrics,
    optimizers=(optimizer, None),
)

trainer.train()