huggingface / alignment-handbook

Robust recipes to align language models with human and AI preferences
https://huggingface.co/HuggingFaceH4
Apache License 2.0
4.28k stars 367 forks source link

A question about the SFTTrainer (also a theoretical question about SFT in general) #74

Open PradeepKadubandi opened 7 months ago

PradeepKadubandi commented 7 months ago

I have a general question about Supervised Fine Tuning (SFT) for Dialogue applications.

Should the SFT process use the same LM objective (next-token prediction) that is used in pre-training a language model?

The "Dialogue" task is predicting "assistant" tokens, right? Shouldn't the objective be predicting only those tokens? Is one way to do this is to set labels for only assistant tokens and ignore the labels on others?

The SFTTrainer implementation does not set labels - as far as I understand, this leads to "labels" being cloned to "input_ids" and shifted right (within transformers code) leading to using "next-token" prediction objective.

More on a philosophical note - if using the same objective as pre-training for SFT, why shouldn't that be called "Fine Tuning" the model (On a dialogue dataset of course) rather than "Supervised Fine Tuning". What am I missing? Is there a reference paper that explains this well? The right approach to do SFT for Dialogue applications?

It is not obvious hence the question. For example, the InstructGPT paper mentions SFT but mainly redirects to the (seemingly) first attempt at SFT in this paper which talks about a "Summarization" task but not a "Dialogue" task.

In that paper, when human labelers are asked to summarize and then when the paper mentions "Behavioral Cloning" is used to finetune the LLM to adapt to this task, I'd imagine that only "Summary" section is considered label but not the entire prompt/document. Following that principle, for "Dialogue" tasks, intuitively, I'd imagine that only "assistant" turns should be part of labels.

(By the way I already asked this in trl repository as well but not sure which is the best repository to ask the question (this repository is for alignment tasks in which SFT is a step - hence posted here too).

alexvishnevskiy commented 7 months ago

I think you're right in some sense. TRL doesn't provide masking user queries by default, but you can set it using specific DataCollator. Also, authors of original implementation of DPO paper train SFT model only on assistant's completions. Here is repo: https://github.com/eric-mitchell/direct-preference-optimization

PradeepKadubandi commented 7 months ago

Thank you for the pointer and the acknowledgement!

qinchuanhui commented 6 months ago

Documentation on using the SFTTrainer and training only on the responses using DataCollator: https://huggingface.co/docs/trl/main/en/sft_trainer#train-on-completions-only