huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.12k stars 1.14k forks source link

Does the default messages collator for SFTTrainer handle multi-turn? #1282

Closed andysalerno closed 2 weeks ago

andysalerno commented 7 months ago

I see that you can now pass in a dataset to SFTTrainer that is formatted like messages:

{ "messages": [{"role": "user", "content": "hi..."}, {"role": assistant", "content": "hi, how can I help?"}] }

And the docs make it sound like this will be treated as a "completion only" training, where the assistant's response is what we will train on as a completion, while the user's message will not.

But what happens if the chat in "messages" is multi-turn? Will all "assistant" messages in the history be trained on, or only the final one?

In my case, I actually just want it to be the final one. I am using the Nectar dataset, which has some multi-turn chats, where a weaker model was used for the first few turns of the chat, and the true completion is only on the final assistant message, which was selected as the highest-scoring model of [gpt4, gpt35-turbo, llama-70b, etc.].

younesbelkada commented 7 months ago

Hi @andysalerno Thanks for the issue! per my understanding it correctly trains on multi-turn if you pass multi-turn chats - @philschmid can confirm as I am not 100% sure about it

philschmid commented 7 months ago

And the docs make it sound like this will be treated as a "completion only" training,

By default the SFTTrainer is not training on completions only. For this you need to use the DataCollatorForCompletionOnlyLM, which is currently not supported for multi-turn conversations.

This means if you have a dataset with multiple turns, it will regularly train on them on all tokens as regular next token predictions, no labels are masked.

younesbelkada commented 7 months ago

Thanks for the clarification @philschmid !

andysalerno commented 7 months ago

Thanks for the reply!

But what about the "data set format support" that the docs mention? Here: https://huggingface.co/docs/trl/sft_trainer#dataset-format-support

Instead of using DataCollatorForCompletionOnlyLM explicitly, it seems like SFTTrainer will detect from those common formats, and then magically understand how to interpret them. And I'm wondering, for the "conversational format", what's the magic? From the discussion so far, I assume it just trains on all messages with "role": "assistant"?

andysalerno commented 7 months ago

Oh, I see, I think you have answered that already. Sounds like even if you use the "conversation format" as shown in the docs, it will not perform any masking, and will train on all tokens. (Except "role": "user" I presume?)

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

janphilippfranken commented 2 months ago

Oh, I see, I think you have answered that already. Sounds like even if you use the "conversation format" as shown in the docs, it will not perform any masking, and will train on all tokens. (Except "role": "user" I presume?)

does anyone know if this is true? ie if i feed it a list of messages like:

{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}

will it make sure that the user is masked from the loss (i.e. -100 for these tokens) while the assistant completions are the target for sft?

thanks!

RonanKMcGovern commented 1 month ago

By default the SFTTrainer is not training on completions only. For this you need to use the DataCollatorForCompletionOnlyLM, which is currently not supported for multi-turn conversations.

This means if you have a dataset with multiple turns, it will regularly train on them on all tokens as regular next token predictions, no labels are masked.

Hi @philschmid , just to understand this comment in more detail. If I train using DataCollatorForCompletionOnlyLM on multi-turn data, are you saying it will detect the data is multi-turn and disable the mask?

OR are you saying that it will just detect the first occurrence of response_template_ids and leave everything after that first occurence unmasked (which is like training on almost all of the messages [and could be useful]!).

I'm trying to understand what's happening if I try this. Thanks

edbeeching commented 1 month ago

Hi @RonanKMcGovern, the related issues in #1550 may help you. At the moment there is no robust implementation for multi-turn masking. Particularly when packing is used.

RonanKMcGovern commented 1 month ago

Thanks.

Your other issue highlights that the masking is not even robust when there is just a single turn (which yeah, I have seen that too).

I was just curious, on the current issue, what actually happens when you use completions only with multi-turn. Does it detect the first instance of the tokens and mask everything afterwards? Or the second instance? I guess I could dig deep into the code to find out.

On Thu, Jul 11, 2024 at 2:25 PM Edward Beeching @.***> wrote:

Hi @RonanKMcGovern https://github.com/RonanKMcGovern, the related issues in #1550 https://github.com/huggingface/trl/issues/1550 may help you. At the moment there is no robust implementation for multi-turn masking. Particularly when packing is used.

— Reply to this email directly, view it on GitHub https://github.com/huggingface/trl/issues/1282#issuecomment-2222936943, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASVG6CVK5RBJUFNCZZMOQYTZL2BVDAVCNFSM6AAAAABCNFL2XGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDEMRSHEZTMOJUGM . You are receiving this because you were mentioned.Message ID: @.***>

github-actions[bot] commented 3 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.