huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.29k stars 1.17k forks source link

How to Instruction Tune with SFTTrainer? #426

Closed jenkspt closed 1 year ago

jenkspt commented 1 year ago

With the SFTTrainer it's unclear to me how to instruction tune. I might be missing relevant details - but I the examples I've seen look like they are fine-tuning on the prompt and response rather than just the response.

specifically looking at: https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py

meanwhile alpaca code explicitly creates a supervised dataset to train on responses https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py

Are there any examples for instruction tuning with SFTTrainer or am I just missing something?

younesbelkada commented 1 year ago

Hi @jenkspt Thanks for the issue, for the SFTTrainer you might be interested in first creating a instruction dataset, or use an existing one. Then use that dataset and pass it to the trainer out of the box. Please see an example below on how we used SFT Trainer to fine-tune Falcon 7B/40B on Guanaco dataset: https://gist.github.com/pacman100/1731b41f7a90a87b457e8c5415ff1c14 let me know if anything else is unclear

jenkspt commented 1 year ago

For example - the dataset from the falcon script 'timdettmers/openassistant-guanaco'. There are responses prefixed with ### Human and ### Assistant. Does the SFTTrainer split on these to optimize only on the responses after ### Assistant? Or does the SFTTrainer optimize on the entire 'text' field?

younesbelkada commented 1 year ago

@jenkspt I see now, per my understanding the SFTTrainer does not do that and optimizes on the entire text chunk and from what I know (but maybe I am wrong) that is also how it is done in all instruction fine-tuned models

jenkspt commented 1 year ago

Dolly does completion only: https://github.com/databrickslabs/dolly/blob/master/training/trainer.py#L48-L77 and I'm pretty sure this is what the stanford alpaca is doing as well: https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py#L127-L153

younesbelkada commented 1 year ago

I see that makes sense, thanks a lot for the pointers! it looks like it is a matter of adding a new datacollator to SFTTrainer, let me know if you want to give it a try and contribute in TRL! Otherwise happy to do it

PhilDakin commented 1 year ago

Alpaca indicates they are including input in about ~40% of their training data here.

younesbelkada commented 1 year ago

Hi everyone,

Thanks all for your pointers, I made https://github.com/lvwerra/trl/pull/445 that hopefully will be merged soon

vwxyzjn commented 1 year ago

Hey @jenkspt, just saying hi :) It was great learning from your gpt jax implementation https://github.com/jenkspt/gpt-jax/issues/2. Glad our paths crossed again.

jenkspt commented 1 year ago

@vwxyzjn congrats on HuggingFace!