The code and data used for reproducing results of Scaling Relationship on Learning Mathematical Reasoning with Large Language Models and Query and Response Augmentation Cannot Help Out-of-domain Math Reasoning Generalization.
Setting | 7B | 7B-2 | 13B | 13B-2 | 33B | 65B | 70B-2 |
---|---|---|---|---|---|---|---|
ICL-8shot | 11.0/18.1 | 14.6/- | 17.8/29.3 | 28.7/- | 35.6/53.1 | 50.9/69.7 | 56.8/- |
SFT | 35.9/48.7 | 41.6/55.4 | 43.0/55.2 | 50.0/61.7 | 54.6/- | 59.3/- | 63.2/- |
RFT k=100 | 41.7/52.7 | 47.5/58.7 | 49.1/59.9 | 54.8/65.4 | 54.5/- | - | - |
RFT-U13B | 49.3/61.8 | 50.3/65.6 | 52.1/66.2 | 55.4/69.1 | 56.5/- | 59.0/- | 62.3/- |
RFT-U33B | 49.1/61.6 | 51.2/64.1 | 51.4/66.3 | 55.3/69.1 | 57.9/- | 59.7/- | 64.8/- |
Metrics are maj1@1 and maj1@100.
If you cannot reproduce our results, please try using Transformers <= 4.29 and test with batch size=1.
Use train_xb.sh for fine-tuning LLaMA and LLaMA-2.
bash train_xb.sh ./data/train_use.jsonl SAVE_PATH 3
LLaMA 7B / 13B
bash group_sample_7b_13b.sh SAVE_PATH
LLaMA 30B
bash group_sample_30b.sh SAVE_PATH
python collect_rejection_sampling.py
For RFT using LLaMA-7B/7B-2/13B/13B-2/33B generated samples with k=100.
bash train_xb.sh ./data/rft/llama_yb.jsonl SAVE_PATH 3
For RFT using U13B.
bash train_xb.sh ./data/rft/u13b.jsonl SAVE_PATH 3
For RFT using U33B.
bash train_xb.sh ./data/rft/u33b.jsonl SAVE_PATH 3
We use greedy decode for the test set.
For evaluate 7B/13B models:
bash test_7b_13b.sh SAVE_PATH
For evaluate 30B models:
bash single_test_30b.sh SAVE_PATH 0 ./data/test_jsonl.sh
For evaluate 65B / 70B models:
bash single_test_65b.sh SAVE_PATH 0,1 ./data/test_jsonl.sh
Use eval.py to obtain the scores, and it also supports maj1@K.
7B / 13B | 33B | 65B / 70B | |
---|---|---|---|
SFT / RFT | 8 | 16 | 32 |
Minimal Inference | 1 | 1 | 2 |
Group Inference | 8 | 8 | 8 |
7B | 7B2 | 13B | 13B2 | 33B | |
---|---|---|---|---|---|
RFT k = 100 | OFA-Sys/gsm8k-rft-llama7b-sample100 | ||||
RFT U13B | OFA-Sys/gsm8k-rft-llama7b-u13b | OFA-Sys/gsm8k-rft-llama7b2-u13b | OFA-Sys/gsm8k-rft-llama13b-u13b | OFA-Sys/gsm8k-rft-llama13b2-u13b | |
RFT U33B | OFA-Sys/gsm8k-rft-llama33b-u33b |
MuggleMATH is fully fine-tuned on the AugGSM8K and AugMATH datasets(https://github.com/OFA-Sys/gsm8k-ScRel/tree/main/data/MuggleMATH) and based on the LLaMA-2 Models.
prompting template: ''' "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:" ''' We recommend using vllm to accelerate inference.
Model | GSM8K | MATH |
---|---|---|
MuggleMATH-7B | 69.8 | 25.8 |
MuggleMATH-13B | 74.3 | 30.7 |
MuggleMATH-70B | 82.5 | 35.6 |
Model | Checkpoints |
---|---|
MuggleMATH-7B | https://huggingface.co/OFA-Sys/MuggleMath_7B |
MuggleMATH-13B | https://huggingface.co/OFA-Sys/MuggleMath_13B |
MuggleMATH-70B | https://huggingface.co/OFA-Sys/MuggleMath_70B |
@misc{yuan2023scaling,
title={Scaling Relationship on Learning Mathematical Reasoning with Large Language Models},
author={Zheng Yuan and Hongyi Yuan and Chengpeng Li and Guanting Dong and Keming Lu and Chuanqi Tan and Chang Zhou and Jingren Zhou},
year={2023},
eprint={2308.01825},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@article{li2023query,
title={Query and response augmentation cannot help out-of-domain math reasoning generalization},
author={Li, Chengpeng and Yuan, Zheng and Dong, Guanting and Lu, Keming and Wu, Jiancan and Tan, Chuanqi and Wang, Xiang and Zhou, Chang},
journal={arXiv preprint arXiv:2310.05506},
year={2023}
}