thunlp / OpenPrompt

An Open-Source Framework for Prompt-Learning.
https://thunlp.github.io/OpenPrompt/
Apache License 2.0
4.38k stars 455 forks source link

[BUG]an error will occur when running t5-3b in prefixtuning. #171

Open BAOOOOOM opened 2 years ago

BAOOOOOM commented 2 years ago

Hello, when I use the config that "plm, tokenizer, model_config, WrapperClass = load_plm('t5', 't5-3b'), PrefixTuningTemplate", it will rise the error like below. But if I run the programe using model 't5-base' or 't5-small', it can run successful.

Traceback (most recent call last): File "prefix_classify.py", line 172, in logits = prompt_model(inputs) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/openprompt/pipeline_base.py", line 295, in forward outputs = self.prompt_model(batch) File "/data/bmk/anaconda3/envs/g/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/openprompt/pipeline_base.py", line 212, in forward outputs = self.plm(input_batch, output_hidden_states=True) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1601, in forward encoder_outputs = self.encoder( File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/data/wkh/anaconda3/envs/gkp/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1033, in forward layer_outputs = layer_module( File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 664, in forward self_attention_outputs = self.layer[0]( File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/openprompt/prompts/prefix_tuning_template.py", line 205, in modified_encoder_forward return backup_encoder_forward_functions[layer_id](*args, *kwargs) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 570, in forward attention_output = self.SelfAttention( File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, **kwargs) File "/data/wkh/anaconda3/envs/g/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 499, in forward key_states = project( File "/data/bmk/anaconda3/envs/g/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 489, in project hidden_states = torch.cat([past_key_value, hidden_states], dim=2) RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 128 for tensor number 1 in the list.

BAOOOOOM commented 2 years ago

I used the prefix-t5-3b model to classify.

ningding97 commented 2 years ago

Hi, it is an interesting issue, could you please provide more details? For example, did you use multi-GPU techniques such as Data-Parallel when for t5-3b?

BAOOOOOM commented 2 years ago

I haven't used multi-GPU techniques. This is the code, thanks: `

import argparse
import torch

parser = argparse.ArgumentParser("")
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--plm_eval_mode", action="store_true")
parser.add_argument("--model", type=str, default='t5')  # tested model are gpt2/t5
parser.add_argument("--model_name_or_path", default='t5-3b')
parser.add_argument("--train_dataset_path", default='raw_data/train.json')
parser.add_argument("--dev_dataset_path", default='raw_data/dev.json')
parser.add_argument("--save_path", default='prefix_test_output/')
args = parser.parse_args()
print(args)

from openprompt.data_utils import InputExample
import json
train_dataset=[]
with open(args.train_dataset_path) as f:
    lines=f.readlines()
    for line in lines:
        data=json.loads(line)
        question=data["question"]
        answer=data["answer"]
        id=data["id"]
        label=0
        if answer=='yes':
            label=1
        input_example = InputExample(text_a = question, label=label, guid=id)
        train_dataset.append(input_example)

dev_dataset=[]
with open(args.dev_dataset_path) as f:
    lines=f.readlines()
    for line in lines:
        data=json.loads(line)
        question=data["question"]
        answer=data["answer"]
        id=data["id"]
        label=0
        if answer=='yes':
            label=1
        input_example = InputExample(text_a = question, label=label, guid=id)
        dev_dataset.append(input_example)
print("load dataset successful!")

from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path)

from openprompt.prompts.prefix_tuning_template import PrefixTuningTemplate

mytemplate = PrefixTuningTemplate(model=plm,  tokenizer=tokenizer, text=' {"placeholder":"text_a"} {"mask"} ', using_decoder_past_key_values=False,num_token=50)

from openprompt import PromptDataLoader
train_dataloader = PromptDataLoader(dataset=train_dataset, template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=64, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

validation_dataloader = PromptDataLoader(dataset=dev_dataset, template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=64, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")

from openprompt.prompts import ManualVerbalizer

myverbalizer = ManualVerbalizer(tokenizer, num_classes=2,
                        label_words=[["no"], ["yes"]])

from openprompt import PromptForGeneration
from openprompt import PromptForClassification
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") 

prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=True, plm_eval_mode=args.plm_eval_mode)
prompt_model=  prompt_model.to(device)

from transformers import AdamW

loss_func = torch.nn.CrossEntropyLoss()
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
    "params": [p for n, p in mytemplate.named_parameters() if (not any(nd in n for nd in no_decay)) and p.requires_grad],
    "weight_decay": 0.0,
},
{
    "params": [p for n, p in mytemplate.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
    "weight_decay": 0.0,
},
]

optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=1e-8)

from transformers.optimization import get_linear_schedule_with_warmup

from openprompt.utils.metrics import generation_metric

def evaluate(prompt_model, dataloader):
    prompt_model.eval()

    allpreds = []
    alllabels = []
    for step, inputs in enumerate(dataloader):
        inputs = inputs.to(device)
        logits = prompt_model(inputs)
        labels = inputs['label']
        alllabels.extend(labels.cpu().tolist())
        allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

    acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)

    return acc

print("start to train!")

global_step = 0
tot_loss = 0
log_loss = 0

best_acc=0
for epoch in range(100):
    prompt_model.train()
    print("epoch:",epoch)
    for step, inputs in enumerate(train_dataloader):
        global_step +=1
        inputs = inputs.to(device)
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        torch.nn.utils.clip_grad_norm_(mytemplate.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()
    print("Epoch {}, average loss: {}".format(epoch, loss.item()), flush=True)
    nowacc = evaluate(prompt_model, validation_dataloader)
    print("the accuracy:",nowacc)
    if nowacc>best_acc:
        best_acc=nowacc
        torch.save(prompt_model.state_dict(), args.save_path + "7_6_t5-3b-prefix_model.ckpt")
        print("save accuracy:",best_acc)

`

BAOOOOOM commented 2 years ago

And the environment is python3.8, torch==1.11.0, transformers==4.20.1, openprompt==1.0.0

1275361989 commented 2 years ago

Hi there, I face the same problem when I use prefix-t5-3b for generation. I only use one GPU. The environment is python 3.9, torch==1.10.2, transformers==4.20.1, openprompt==1.0.1.