wasiahmad / PLBART

Official code of our work, Unified Pre-training for Program Understanding and Generation [NAACL 2021].
https://arxiv.org/abs/2103.06333
MIT License
186 stars 35 forks source link

HuggingFace Checkpoint Configurations #50

Closed nandovallec closed 1 month ago

nandovallec commented 1 year ago

Hello, I am replicating some of the experiments with the checkpoints from PLBART on HuggingFace.

On the Drive Document (), it is mentioned that the model ('uclanlp/plbart-refine-java-small') does not need a language token. However, if decoder_start_token_id=tokenizer.lang_code_to_id["java"] is not used, the model does not work properly.

Would that be an error in the documentation or am I confusing terms?

Thanks.

wasiahmad commented 1 year ago

Yes, when we finetuned PLBART on the code refinement task, we didn't use the language token. So, I am not sure why we need the token to generate refined java code.

nandovallec commented 1 year ago

The finetuned model from the GitHub repository works on these cases. This is a small script to replicate the problem that I mentioned in case it helps:

from datasets import load_dataset
dataset = load_dataset("code_x_glue_cc_code_refinement", name='small' , split='test')
from transformers import PLBartForConditionalGeneration, PLBartTokenizer, AutoTokenizer, AutoModel
tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-refine-java-small")
model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-refine-java-small")

example_1 = dataset['buggy'][0]
example_2 = dataset['buggy'][1]

inputs_1 = tokenizer(example_1, return_tensors="pt", padding=True)
inputs_2 = tokenizer(example_2, return_tensors="pt", padding=True)

translated_tokens_correct = model.generate(**inputs_1, max_length = 512, num_beams=5, do_sample=True)
translated_tokens_error = model.generate(**inputs_2, max_length = 512, num_beams=5, do_sample=True)

output_correct = tokenizer.batch_decode(translated_tokens_correct, skip_special_tokens=True)
output_error = tokenizer.batch_decode(translated_tokens_error, skip_special_tokens=True)

print(output_correct)
print(output_error)