from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
def encode_with_prompt_completion_format(example, tokenizer, max_seq_length):
'''
Here we assume each example has 'prompt' and 'completion' fields.
We concatenate prompt and completion and tokenize them together because otherwise prompt will be padded/trancated
and it doesn't make sense to follow directly with the completion.
'''
# if prompt doesn't end with space and completion doesn't start with space, add space
if not example['prompt'].endswith((' ', '\n', '\t')) and not example['completion'].startswith((' ', '\n', '\t')):
example_text = example['prompt'] + ' ' + example['completion']
else:
example_text = example['prompt'] + example['completion']
example_text = example_text + tokenizer.eos_token
tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
input_ids = tokenized_example.input_ids
labels = input_ids.clone()
tokenized_prompt = tokenizer(example['prompt'], return_tensors='pt', max_length=max_seq_length, truncation=True)
# mask the prompt part for avoiding loss
labels[:, :tokenized_prompt.input_ids.shape[1]] = -100
attention_mask = torch.ones_like(input_ids)
return {
'input_ids': input_ids.flatten(),
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten(),
}
example = {
"prompt":"What does someone need when they're feeling hunger?\n\nA. starvation\nB. eat hamburger\nC. eating\nD. pizza\nE. discomfort\n\nA: ",
"completion":"C"
}
print(example['prompt'])
encode_with_prompt_completion_format(example,tokenizer,512)
And this is the output:
This is how huggingface would handle input with labels:
It seems that we are just using the answer C to predict eos token. Do I understand this correctly?
Hi, It seems that the code here would cause incorrect mapping between input_ids and labels. https://github.com/allenai/open-instruct/blob/9ebcb582cfc243a6dab75b4302fa432784db26c2/open_instruct/finetune.py#L238
This is my code:
And this is the output:![image](https://github.com/allenai/open-instruct/assets/38466901/286570c4-a17b-4700-978e-0a79c8a61060)
This is how huggingface would handle input with labels:![image](https://github.com/allenai/open-instruct/assets/38466901/63b4798a-a77b-491b-b9ba-26a7cd8d5b17)
It seems that we are just using the answer
C
to predict eos token. Do I understand this correctly?Thank you!