facebookresearch / metaseq

Repo for external large-scale work
MIT License
6.52k stars 725 forks source link

Finetuned OPT-350M throws error when loaded [Huggingface Implementation] #118

Closed Leli1024 closed 2 years ago

Leli1024 commented 2 years ago

🐛 Bug

When the OPT-350M variant is fine-tuned via huggingface, the resulting model will give the following error when loaded

RuntimeError: Error(s) in loading state_dict for OPTForCausalLM:
        size mismatch for lm_head.weight: copying a param with shape torch.Size([50272, 512]) from checkpoint, the shape in current model is torch.Size([50272, 1024]).

For context, I have used this code on the 125M variant and while the model didn't perform well it didn't crash, I believe that's a parameter issue (?) as I compared them both (base, not fine-tuned) and the 350m was capable of generating coherent output.

Code to load model

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, OPTForCausalLM
import torch

def generate_text(model, tokenizer, prompt):

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids, do_sample=True, num_return_sequences=5, max_length=10)
    texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    return texts

path = "facebook/opt-350m"
path = "opt/model_ckpts"
model = OPTForCausalLM.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)

prompt = "The woman worked as a"

print(generate_text(model, tokenizer, prompt))

Training Code

import torch as th
from dataset import get_examples, GSMDataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import GPT2Config, AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from transformers import AutoModelForCausalLM, AutoTokenizer, OPTModel, OPTConfig, OPTForCausalLM
import torch

model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", use_fast=False)

try:
    model = OPTForCausalLM.from_pretrained("model_ckpts")
    print("model loaded")
except Exception as e:
    print(e)
train_examples = get_examples("train")
train_dset = GSMDataset(tokenizer, train_examples)

device = th.device("cuda")

model.to(device)
model.train()

train_loader = DataLoader(train_dset, batch_size=4, shuffle=True)
optim = AdamW(model.parameters(), lr=1e-5)

num_epochs = 10
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optim,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

pbar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    for batch in train_loader:
        optim.zero_grad()
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch, labels=batch["input_ids"])
        loss = outputs[0]
        loss.backward()
        optim.step()
        lr_scheduler.step()
        pbar.update(1)
        pbar.set_description(f"train_loss: {loss.item():.5f}")

model.save_pretrained("model_ckpts/")

Dataset module


import os
import re
import torch as th

def read_jsonl(path: str):
    with open(path) as fh:
        return [json.loads(line) for line in fh.readlines() if line]

def get_examples(split):
    path = os.path.join("data/", f"{split}.jsonl")
    examples = read_jsonl(path)

    #examples = examples[0:100]

    for ex in examples:
        ex.update(question=ex["question"] + "\n")
        ex.update(answer=ex["answer"] + "<|endoftext|>")

    print(f"{len(examples)} {split} examples")
    return examples

ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"

def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS

def is_correct(model_completion, gt_example):
    gt_answer = extract_answer(gt_example["answer"])
    assert gt_answer != INVALID_ANS
    return extract_answer(model_completion) == gt_answer

class GSMDataset(th.utils.data.Dataset):
    def __init__(self, tokenizer, examples, loss_on_prefix=True):
        self.examples = examples
        self.qns = [ex["question"] for ex in self.examples]
        self.ans = [ex["answer"] for ex in self.examples]
        self.qns = tokenizer(self.qns, padding=False)
        self.ans = tokenizer(self.ans, padding=False)
        self.loss_on_prefix = loss_on_prefix
        self.max_len = max(
            [
                len(self.qns["input_ids"][i]) + len(self.ans["input_ids"][i])
                for i in range(len(self.examples))
            ]
        )
        print(f"Max tokens: {self.max_len}")

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        qn_tokens = self.qns["input_ids"][idx]
        ans_tokens = self.ans["input_ids"][idx]
        pad_tokens = [0] * (self.max_len - len(qn_tokens) - len(ans_tokens))
        tokens = qn_tokens + ans_tokens + pad_tokens
        mask = (
            ([int(self.loss_on_prefix)] * len(qn_tokens))
            + ([1] * len(ans_tokens))
            + ([0] * len(pad_tokens))
        )
        tokens = th.tensor(tokens)
        mask = th.tensor(mask)
        return dict(input_ids=tokens, attention_mask=mask)```
suchenzang commented 2 years ago

@Leli1024 Please open an issue within https://github.com/huggingface/transformers given a separate codebase used for the above (not in metaseq).