salesforce / CodeT5

Home of CodeT5: Open Code LLMs for Code Understanding and Generation
https://arxiv.org/abs/2305.07922
BSD 3-Clause "New" or "Revised" License
2.68k stars 394 forks source link

Unable to match results on code generation #47

Closed sindhura97 closed 2 years ago

sindhura97 commented 2 years ago

I am unable to get a decent (close to test set performance reported in paper) performance on the validation set for code generation using your fine-tuned checkpoint. I am getting a bleu score of 29.49 and EM of 12.65. Here is my code. Am I doing something wrong here?

from datasets import load_dataset

class Example(object):
    def __init__(self, idx, source, target ):
        self.idx = idx
        self.source = source
        self.target = target

def read_examples(split):
    dataset = load_dataset('code_x_glue_tc_text_to_code')[split]
    examples = []
    for eg in dataset:
        examples.append(Example(idx = eg['id'], source=eg['nl'], target=eg['code']))
    return examples

examples = read_examples('validation')

from transformers import RobertaTokenizer, T5ForConditionalGeneration
import torch
from tqdm import tqdm
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"

tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')

model.load_state_dict(torch.load('finetuned_models_concode_codet5_base.bin'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)
model.to(device)

preds = []
for eg in tqdm(examples):
    input_ids = tokenizer(eg.source, return_tensors="pt").input_ids.to(device)
#     print (len(input_ids[0]))
    generated_ids = model.generate(input_ids, max_length=200, num_beams=5)
    preds.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

import sys
import numpy as np
from bleu import _bleu

accs = []
with open("test.output",'w') as f, open("test.gold",'w') as f1:
    for ref,gold in zip(preds,examples[:len(preds)]):
        f.write(ref+'\n')
        f1.write(gold.target+'\n')    
        accs.append(ref.strip().split()==gold.target.split())

print (np.mean(accs), _bleu('test.gold', 'test.output'))
yuewang-cuhk commented 2 years ago

Hi, before we get time to examine your code to figure out where the problem comes from, we suggest you to first employ the the run_gen.py script to reproduce the results. You can pass the do_test argument at here and pass the finetuned checkpoint to load at here.

sindhura97 commented 2 years ago

thanks, that worked. I got test BLEU 39.6 which is close to but not exactly the same reported in paper. Probably the paper used a different checkpoint?