uclaml / SPIN

The official implementation of Self-Play Fine-Tuning (SPIN)
https://uclaml.github.io/SPIN/
Apache License 2.0
1.05k stars 92 forks source link

Question about which datasets are used for each iteration #11

Closed lewtun closed 9 months ago

lewtun commented 9 months ago

Hello, thank you for open sourcing the code behind SPIN - it's very clean!

I'm currently working on porting this to trl in https://github.com/huggingface/trl/pull/1344 and am validating everything works on a small Qwen-1.5-0.5b model.

On p.9 of your paper, you state that you combine datasets across each iteration:

In multiple iterations, we leverage the synthetic data from the most recent iteration and add to the newly generated synthetic data, therefore resulting in a synthetic dataset size of 50k at iteration 0 and 100k at iteration 1, 2 and 3. At each iteration, we train our model for 2 epochs.

My question concerns which combination of datasets you used for each SPIN iteration:

  1. Was zephyr-7b-sft-full-SPIN-iter0 trained on UCLA-AGI/SPIN_iter0 (50k samples)?
  2. Was zephyr-7b-sft-full-SPIN-iter1 trained on UCLA-AGI/SPIN_iter0 and UCLA-AGI/SPIN_iter1 (100k samples)?
  3. Was zephyr-7b-sft-full-SPIN-iter2 trained on UCLA-AGI/SPIN_iter1 and UCLA-AGI/SPIN_iter2 (100k samples)?
  4. Etc

In other words, do you combine the generations from the model trained on iteration t with those from t-1?

A related question is whether you always run generation on the same 50k prompts at each iteration or do you generate over 100k prompts for iterations 1-3?

Thanks!

yihedeng9 commented 9 months ago

Hi, thank you very much for your interests! Yes, your understanding is correct. zephyr-7b-sft-full-SPIN-iter0 is trained on UCLA-AGI/SPIN_iter0 (50k samples). The following iteration 1 is trained on data from iteration 0 and 1 (100k samples), and iteration 2 is similarly trained with data from iteration 1 and 2 (100k samples).

We always run generations on the same 50k prompts at each iteration. Since the 100k prompts would simply contain two duplicates for each prompt, we just leverage the same 50k prompt.

I hope the above clarifies the questions!

lewtun commented 9 months ago

Thank you @yihedeng9 this is very helpful!

We always run generations on the same 50k prompts at each iteration. Since the 100k prompts would simply contain two duplicates for each prompt, we just leverage the same 50k prompt.

For my understanding, is the reason for this to ensure the model at iteration t doesn't drift too far away from the model at iteration t-1?

yihedeng9 commented 9 months ago

Yes, we observed in experiments that incorporating data from previous iterations helps in stabilizing the model performance in larger iterations. We consider it as a form of regularization, ensuring the model doesn't significantly deviate from its performance in the previous iteration as you said.

lewtun commented 9 months ago

Thank you! Closing the issue since everything is now clear :)