huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.97k stars 1.26k forks source link

DataCollatorForCompletionOnlyLM instruction token masking fails if first occurence of instruction is marked differently #1184

Closed mgerstgrasser closed 10 months ago

mgerstgrasser commented 10 months ago

The current code in DataCollatorForCompletionOnlyLM assumes that the first deteced occurence of instruction_template comes before the first detected occurence of response_template. This is reasonable, since in current applications conversations are initiated by the user, not the assistant. However, this can fail if the first instruction is marked differently from all the other instructions, which can if a context-sensitive tokenizer such as Llama-2 tokenizes the instruction_template differently at the start of a string than in the middle.

In particular this happens in practice with TinyLlama: <|user|> gets tokenized as 529, 29989, 1792, 29989, 29958 at the start of a conversation, but as 29966, 29989, 1792, 29989, 29958 in later messages.

Reproduction snippet:

from transformers import AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM

tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

chat = [
    {"role": "user", "content": "Which is bigger, the moon or the sun?"},
    {"role": "assistant", "content": "The sun."},
    {"role": "user", "content": "Really?"},
    {"role": "assistant", "content": "Yes."},
    {"role": "user", "content": "I don't believe you."},
    {"role": "assistant", "content": "That's okay."},
]

collator = DataCollatorForCompletionOnlyLM(response_template=[29966, 29989, 465, 22137, 29989, 29958, 13], instruction_template=[29871, 13, 29966, 29989, 1792, 29989, 29958], tokenizer=tokenizer)

collator([tokenizer.apply_chat_template(chat)])
# This doesn't mask user messages 2 & 3.

PR #1185 fixes this, and makes the above snippet mask out all the user messages correctly.

mgerstgrasser commented 10 months ago

@younesbelkada While we're talking about this anyway: Would it potentially make sense to add a more flexible system as an alternative to matching strings here? I have in mind something like passing a function to the data collator that maps example to a list of (first_token_idx, last_token_idx)-pairs that mark each part of the example that is assistant-generated, i.e. should be unmasked.

I'm asking for two reasons:

  1. There's still many edge cases remaining, given the wide variety of chat templates. (Case in point: Even the code snipped above is still not 100% correct - the chat template inserts a newline after the final EOS token, which IMHO should be masked if the final turn is from the assistant, but isn't, and I don't see a way of doing that with the template system.)
  2. Doing that would allow re-use of that function in other parts of a project, e.g. for reward or RL training, unit-testing, etc.

I will probably implement something like this for a project I'm working on - I'd be happy to open a PR once it's done, if you think that would be useful to have in the library.

(Asking this here, as it's broadly related to the issue, but not to the fix in the PR.)