hbin0701 / Self-Explore

[EMNLP Findings 2024 & ACL 2024 NLRSE Oral] Enhancing Mathematical Reasoning in Language Models with Fine-grained Rewards
https://arxiv.org/abs/2404.10346
44 stars 2 forks source link

Difficulty reproducing performance on GSM-8k #1

Closed ars22 closed 6 months ago

ars22 commented 6 months ago

Hello,

Thank you for making your code public!

I ran through the steps suggested exactly for deepseek-math-7b model on gsm8k without modifying any hyperparameters. I observed SFT performance of 67%, RFT of 70.5% and Self-explore performance of 73.6%. As suggested in the paper, I trained the SFT model for 2 epochs, RFT for 1 and DPO for 3 epochs. The file size for gpair for me is just about 3.5k. But, in the paper this number in the Appendix is much larger. I noticed some logic in the code to reduce datasize for RFT and Self-Explore data generation (e.g., this logic). I imagine the file size is smaller because of this.

Can you please help me to correct any errors, so that I can reproduce the 78% number in the paper. Thanks!

Amrith

hbin0701 commented 6 months ago

Hi! thanks for the comment. I deeply apologize for the error in lines 81-83. We have quickly updated several parts of the code, including this one.

To explain the logic behind RFT generation, for parallel generation of each GPU, we have made each GPU take care of some portion of questions in the dataset.

    if task == "GSM8K":
        if id == 0:
            start, end = 0, 1870 
        elif id == 1:
            start, end = 1870, 3740
        elif id == 2:
            start, end = 3740, 5610
        elif id == 3:
            start, end = 5610, 7473

Because we observed sometimes, the generation process crashed in some occasions, we initially set only generate for 1000 problems then save it, and run this process again to complete the RFT generation.

    try:
        already_done = [json.loads(x) for x in open(result_file)]
    except:
        already_done = []

    if len(already_done) != 0:
        inputs = inputs[len(already_done):]
        answers = answers[len(already_done):]

    if len(inputs) == 0 and len(answers) == 0:
        print("Already completed. Exiting.")
        return 

After refactoring the code, this part is no longer used, because we run the whole portion at once. We resolved the crashing issue by increasing the swap_space of VLLM. (But if it occurs again, we might have to roll-back to the previous setting..)

However, even after taking this into consideration, gpair should be much larger than that. Please make sure that your "..result_gen.jsonl" file has length of 7473. Then, we expect gpair to be at least around 30~40k. Please let us know if there are any other difficulties. Thanks and sorry for the inconvenience!

ars22 commented 6 months ago

Thank you for your quick reply. After pulling the new changes, I have a rft file of size 52k. For the self-explore generation, there is a similar logic in the code: here and here). I am assuming that these need to be removed as well?

Finally, please let me know if I am doing this correctly:

  1. Run SFT for 2 epochs.
  2. Generate RFT and DPO samples from SFT checkpoint at end of training.
  3. Run RFT on pretrained model for 1 epoch using the RFT samples from previous step (which does not include SFT data).
  4. Generate self-explore data using DPO file from step 2 and the RFT checkpoint at the end of RFT training from step 3.
  5. Run DPO with beta=0.01 for 3 epochs using the self-explore data, and the initialization/reference model as RFT checkpoint from the end of RFT training in step 3.
hbin0701 commented 6 months ago

Hello! For that logic, it is better to leave it like that, because exploration would take a lot to complete and it is safer to iterate N rounds of doing 1,000, i.e. this will take care of it, even though you might have to accordingly change N (number of iterations). (for instance, if your DPO file contains 40k lines, it would be 1000 x 4 x 10 = 40K, so 10 times).

and indeed, your steps are correct. Please let me know if there are any other difficulties. Thanks :)

ars22 commented 6 months ago

Got it, thanks! Unfortunately, the performance after running DPO on gpair file tanks, it drops to 60% which is worse than the SFT model.

With the updated code for RFT generation, I generated RFT data and trained an RFT model. Then, I generated the following gpair file with about 44k entries: samples_gpair.jsonl.zip.

I used the following arguments to train DPO on this gpair file.


model_name_or_path: /mnt/shared/syth-checkpoints/deepseek_GSM8K_RFT_EP1 # put model_name_to_train_on_here
run_name: deepseek_GSM8K_SxpDPO  # put_wandb_run_nmae_here
dataset_mixer:
  HuggingFaceH4/ultrafeedback_binarized: 1.0
dataset_splits:
- train_prefs
- test_prefs
preprocessing_num_workers: 12
train_data_file: /mnt/shared/syth-checkpoints/deepseek_GSM8K_RFT_EP1/samples_gpair.jsonl # for both training and test_data, put training data. 
test_data_file: /mnt/shared/syth-checkpoints/deepseek_GSM8K_RFT_EP1/samples_gpair.jsonl # testing data (i.e. eval set) is a dummy data(subset of training), because we don't use eval data.
bf16: true
beta: 0.01
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 1
gradient_checkpointing: true
hub_model_id: zephyr-7b-dpo-full
learning_rate: 1.0e-6 # use 1.0e-7 for Mistral, 1.0e-6 for others.
log_level: info
logging_steps: 10
lr_scheduler_type: linear
max_length: 512
max_prompt_length: 384
num_train_epochs: 3
optim: rmsprop
output_dir: /mnt/shared/syth-checkpoints/deepseek_GSM8K_SxpDPO 
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
push_to_hub: false
save_strategy: "epoch"
save_only_model: true
save_total_limit: 3
seed: 42
warmup_ratio: 0.1```
hbin0701 commented 6 months ago

Hi, can you share us the performance of RFT model? Meanwhile, I'll be running the experiment from the scratch again to check if there's any error. i'll let you know before Sunday. Sorry for the inconvenience.

ars22 commented 6 months ago

I observe Final Acc: 0.6929492039423806 for the RFT model that is trained for 1 epoch. Please find attached the RFT data and the eval file for RFT. eval.jsonl.zip samples_rft.jsonl.zip

hbin0701 commented 6 months ago

Hm.. I see. For now, it seems the files are okay, as the RFT sample numbers also almost match. I'll run the experiment again and see if similar phenomena is observed.

ars22 commented 6 months ago

It would be great if you could run the DPO on the gpair file I shared with you and confirm. Thanks a lot for your help. I really appreciate it. Also, does the config for the DPO training above look OK?

ars22 commented 6 months ago

sxp_samples.jsonl.zip If it helps, this is the file that self-explore code outputs, from which the gpair is generated.

ars22 commented 6 months ago

Also, when you run SFT for 2 epochs/RFT for 1 epoch, are you training for 5 epochs and taking the model at the first/second checkpoint respectively, or do you only train for that many epochs. I imagine this to make a difference because of the linear learning rate schedule.

This is the command I am using to train the RFT model:

model_name_or_path=deepseek-ai/deepseek-math-7b-base # model path to train on.
save_generator_id=deepseek_GSM8K_RFT_EP1 # model name to be saved.

save_dir=/mnt/shared/syth-checkpoints/${save_generator_id}/
export WANDB_NAME=${save_generator_id}    

# lr: 1e-6 for Mistral and 1e-5 for Others.

CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
  --config_file scripts/gsm8k/sft/config.yaml \
  --main_process_port=40993 \
  sft/train_generator.py \
  --model_name_or_path ${model_name_or_path} \
  --data_dir /mnt/shared/syth-checkpoints/deepseek_GSM8K_FT_EP2/samples_rft.jsonl \
  --target_set train \
  --save_dir ${save_dir} \
  --num_train_epoches 1 \
  --save_strategy epoch \
  --max_length 384 \
  --per_device_train_batch_size 8 \
  --per_device_eval_batch_size 2 \
  --gradient_accumulation_steps 2 \
  --gradient_checkpointing True \
  --learning_rate 1e-5 \
  --weight_decay 0 \
  --lr_scheduler_type "linear" \
  --warmup_steps 0 \
  --save_best False \
  --save_total_limit 5 \
  --logging_dir ./wandb \
  --logging_steps 8 \
  --seed 42 \
  --save_model_only True \
  --mode "rft_GSM8K" # mode is one of "ft_GSM8K", "ft_MATH", "rft_GSM8K", "rft_MATH"
hbin0701 commented 6 months ago

Yes, as far as I remember this made performance difference for Mistral, but for other models, I believe the difference was trivial. I ran as the latter for Mistral and for other models, did the former. The RFT config looks fine so does the DPO config. Could you also share your RFT model on huggingface so that I could run DPO on that model with the gpair file you shared? or otherwise, I'll just run gpair on my RFT model.

ars22 commented 6 months ago

Got it, thanks! I am uploading the my RFT model here: https://huggingface.co/ars22/deepseek_RFT_model/. The upload speed is slow. But, it should be ready in the next few hours. Meanwhile, if you could confirm results on your model that would be very helpful too. Thanks again!

ars22 commented 6 months ago

Finally, if possible, would you be willing to share any of the following: rft, dpo, self_explore samples, and gpair files for the deepseek model you trained?

hbin0701 commented 6 months ago

deepseek_gsm8k.zip Here it is :) I am able to run experiments tomorrow, so I will try to produce the results before Sunday!

ars22 commented 6 months ago

Thanks a lot @hbin0701

Using your RFT file, I trained a model for 5 epochs. The checkpoint at the end of first epoch had a performance of 68.91, but the best performing model was checkpoint 3 (71.41, which matches the result in the paper). Now, I took the first RFT checkpoint and ran DPO for 3 epochs using the above config and the gpair file that you shared with me.

The best checkpoint for DPO on gpair is second, with performance of 71.43. The performance of first and third checkpoint is poor. Here is my wandb for the DPO run:https://api.wandb.ai/links/ars22/738uk11d

hbin0701 commented 6 months ago

Hi @ars22. Upon re-running the code, I have confirmed multiple issues with the code.

In fact, at the time I ran the experiment (which was February), 71.41 was achieved at the 1st epoch RFT checkpoint, so it's kind of strange... (This is the model for it: https://huggingface.co/hbin0701/deepseek_brft) I think as your best performing model achieves similar result, i believe there is an error with the current training config or either I used some different library version. I'll try to figure these out as soon as possible. Meanwhile, looking back at some code artifacts I found that for GSM8K, I used per_device_train_batch_size=16 and gradient_accumulation_steps=1 for GSM8K and also transformers==4.34.1. Changing these and applying my RFT file yield 1st checkpoint's performance a little less than 71.

Besides that for all DPO, I used beta: 0.1. using beta: 0.01 is a typo. I have accordingly modified these errors. I'll be also running additional experiments to find any additional errors. Again, sorry for the inconvinience and thank you for pointing out the difficulties.

ars22 commented 6 months ago

Hi @hbin0701

Really appreciate all your help with this, and for the super quick replies as well :)

As it turns out, yesterday I left some DPO runs running at different values of beta. I was able to recover 77.5% performance at beta=0.2, and beta=0.1 was about 77% as well.

Thanks, for the information on the RFT checkpoint, and also for uploading the RFT model to huggingface (that was very helpful). I am currently using transformers==4.38.1. I am using per_device_batch_size=8, with grad accumulation=2 and total of 4 GPUs for SFT and RFT training. I have pulled in your changes now.

I am rerunning with these changes, and am hoping that I would be able to reproduce this time. Seems like the biggest issue was with the value of beta. Will let you know if I run into more issues. And if things turn out as expected, I will close this github issue.

Thanks again for everything!

hbin0701 commented 6 months ago

Hi @ars22 ! Were you able to successfuly replicate the results? Please let me know if there were additional issues. Thanks!

ars22 commented 6 months ago

Hello @hbin0701,

With $\beta=0.1$, I got much better results, close to what is reported in the paper. So, thanks a lot for the clarification, and apologies for not replying sooner. Closing the issue now, and once again thanks for the help and super prompt response.

One clarification: In the "generate self explore" step where we you generate gpair file by first identifying the "first pit" in the "rejected" response from DPO file, it seems the chosen response for gpair is the same as the "chosen" in DPO file, while the "rejected" in gpair is the first pit. At least this is my understanding. Was there a particular reason to pick that value of "chosen"? Did you also consider looking at sampling alternate correct responses from the "rejected" in DPO file, maybe from a step before the first pit and using that as the new "chosen" in the gpair file?

hbin0701 commented 6 months ago

Yes! We also tried various alternatives including that one but we did not get better results for that scenario.

Assume the rejected sample comprises of 4 steps, s1 -> s2 -> s3 -> s4. and the first pit occurs at s3. Our hypothesis was that if we want model to prevent from making logical mistakes from s3, we should sample s1 -> s2 -> correct (s3') -> correct (s4'). just as you suggested. However, we did not observe performance increase in that case, which we believe the chosen sample, should have been at least once trained (SFT) by the reference model to stabilize the DPO learning. Maybe if we sampled such correct instance, then we did RFT again before training with DPO objective (so the overall process would have been RFT -> Self-Explore to gather step-level labels -> RFT with the newly collected chosen -> run DPO training), maybe the result could have been different, but we did not try that far.

Besides that, we've also tried tree based sampling ( which we had difficulty merging duplicate states ), iterative training (which back then we observed marginal increase, but could be improved with DPO + NLL which was suggested by Iterative Reasoning Preference Optimization, https://arxiv.org/abs/2404.19733 ), dynamic step formulation (i.e. heuristic or probability based step selection), but in all these cases we have not observed performance gains. I think there are concurrent works that utilize similar ideas with MCTS-based methods. (https://arxiv.org/pdf/2404.12253).

If you want to know more about our experimental attempts or discuss about this topic, feel free to leave an email to me via hbin0701@kaist.ac.kr. 😄 We could probably schedule a zoom meeting and talk about this, and I could share some of more insights that we have gathered. Thanks!