sail-sg / sdft

[ACL 2024] The official codebase for the paper "Self-Distillation Bridges Distribution Gap in Language Model Fine-tuning".
https://arxiv.org/abs/2402.13669
58 stars 4 forks source link

Issue Replicating Paper Results with scripts/gsm8k/sdft.sh #2

Closed pldlgb closed 3 months ago

pldlgb commented 3 months ago

Hello,

I've been attempting to replicate the results presented in your paper by using the provided script located at scripts/gsm8k/sdft.sh. Despite following all the instructions and ensuring that my setup matches the recommended configuration, I'm unable to achieve the results as reported in the paper.

Fine-tuning using sdft

Evaluation on gsm8k:
Accuracy for math: 387 / 1319 = 29.34%

Evaluation on multiarith:
Accuracy for math: 146 / 180 = 81.11%

Evaluation on OpenFunctions:
Accuracy for openfunction: 25 / 112 = 22.32%

Could you please provide any insights or suggestions that might help in correctly replicating the results? Am I missing an update or a crucial step in the process?

Thank you for your assistance.

Best regards

rickyang1114 commented 3 months ago

Thanks for your attention!

As stated in the Acknowledge part of the README, "The main branch has undergone refactoring. To accurately replicate the results presented in the paper, switching to the reproduce branch is recommended."

In the reproduce branch, the versions of the modules used are precisely specified, which may help in achieving results as reported in our paper

rickyang1114 commented 3 months ago

Please also ensure to execute the command pip install -e bigcode-evaluation-harness.

pldlgb commented 3 months ago

Thank you, Do you know what could have caused such a significant discrepancy in the results, leading to a decrease from 34.4 to 29.3?

rickyang1114 commented 3 months ago

Thank you for bringing up this question. I will need to further investigate this matter before I can provide a definitive answer.

pldlgb commented 3 months ago

Thank you, looking forward to your reply!

rickyang1114 commented 3 months ago

Upon investigation, I discovered the root cause: the original response was inadvertently omitted during the distillation process due to a bug. This issue has been addressed in the latest commit. I kindly suggest pulling the updated version and rerunning the experiment. You should expect results similar to those reported in the paper, although minor differences may still arise from variations in dependencies.

The results you obtained can be interpreted as an ablation study, illustrating the efficacy of models in learning the downstream task with SDFT.

Thank you again for bringing up this question.

rickyang1114 commented 3 months ago

I conducted the experiments utilizing the dependencies outlined in the main branch following the most recent commit (7ec7ab2), yielding results that closely correspond with the findings reported in the manuscript.

For the scripts/test_seed_LM.sh, here are the results:

Evaluation on seed LM.

Evaluation on gsm8k:
Accuracy for math: 373 / 1319 = 28.28%

Evaluation on multiarith:
Accuracy for math: 127 / 180 = 70.56%

Evaluation on OpenFunctions:
Accuracy for openfunction: 17 / 112 = 15.18%

Evaluation on HumanEval:
Accuracy for HumanEval: 14.02%

After vanilla FT (scripts/gsm8k/sft.sh), the results are:

Fine-tuning using sft

Evaluation on gsm8k:
Accuracy for math: 445 / 1319 = 33.74%

Evaluation on multiarith:
Accuracy for math: 149 / 180 = 82.78%

Evaluation on OpenFunctions:
Accuracy for openfunction: 19 / 112 = 16.96%

Evaluation on HumanEval:
Accuracy for HumanEval: 10.98%

After our SDFT (scripts/gsm8k/sdft.sh), the results are:

Fine-tuning using sdft

Evaluation on gsm8k:
Accuracy for math: 449 / 1319 = 34.04%

Evaluation on multiarith:
Accuracy for math: 140 / 180 = 77.78%

Evaluation on OpenFunctions:
Accuracy for openfunction: 26 / 112 = 23.21%

Evaluation on HumanEval:
Accuracy for HumanEval: 14.63%

The results achieved closely align with those presented in the paper, exhibiting a nominal variance of within 1%. I believe this minor deviation originates from disparities in the versions of dependencies, including LLaMA-Factory and modules such as transformers.