salesforce / CodeT5

Home of CodeT5: Open Code LLMs for Code Understanding and Generation
https://arxiv.org/abs/2305.07922
BSD 3-Clause "New" or "Revised" License
2.66k stars 391 forks source link

Transmitting the Input to the Decoder #118

Closed antonio-mastropaolo closed 1 year ago

antonio-mastropaolo commented 1 year ago

Hi all, I was looking at the code released to generate predictions with CodeT5+

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

checkpoint = "Salesforce/instructcodet5p-16b"
device = "cuda" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
                                              torch_dtype=torch.float16,
                                              low_cpu_mem_usage=True,
                                              trust_remote_code=True).to(device)

encoding = tokenizer("def print_hello_world():", return_tensors="pt").to(device)
encoding['decoder_input_ids'] = encoding['input_ids'].clone()
outputs = model.generate(**encoding, max_length=15)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

and I was wondering what would it be the difference if we factor out the following instruction encoding['decoder_input_ids'] = encoding['input_ids'].clone() ? What changes under the hood of the model?

Thanks in advance for any help you can provide on this.

yuewang-cuhk commented 1 year ago

Hi there, by setting encoding['decoder_input_ids'] = encoding['input_ids'].clone(), we also feed the text prompt to decoder to better provide the prefix contexts for the models. We find that this is very helpful for CodeT5+ models >=2B, as these models have a deep decoder initilized from frozen GPT-style LLMs, doing this can have a better compatibility with the default behaviours of GPT models. Noth that for CodeT5+ 220M and 770M, they do not need such additional prefix prompts as they are pretrained from scratch.

antonio-mastropaolo commented 1 year ago

@yuewang-cuhk Crystal clear! Many thanks :)