lee-ny / teaching_arithmetic

MIT License
71 stars 19 forks source link

add plain3 method for \n start testing #3

Open james016 opened 1 year ago

james016 commented 1 year ago

Introduction

This pull request aims to introduce a new method, plain3, which employs a newline (\n) as the starting symbol for 3-digit addition baseline tests. Below are the summarized results comparing plain3 against the baseline plain method.

Summary of Results

As seen, the plain3 method outperforms the baseline in terms of test accuracy. However, it still falls short of achieving a 97% accuracy rate.

Conclusion

The introduction of the plain3 method offers an improvement over the baseline plain method in terms of test accuracy. Further analysis is needed to understand the remaining gap from a 97% accuracy rate.

Your review and feedback on this pull request would be highly appreciated.

appendix

plain3_results with the command python train.py config2/addition/plain/train_addition_bal.py \ --ckpt_path_name="ckpt_10000.pt" \ --out_dir="out/addition_plain3/$model_root" \ --data_type='text' --data_format='plain3' \ --dataset='bal' --train_data_path="train_3digit_10000.txt" \ --eval_addition=True --start='FILE:data/bal/test_10000.txt' \ --wandb_log=False \ --exp_name="$model_root"

iter,train_loss,val_loss,val_ppl,test_acc,train_acc,test_acc_ar,test_acc_other
0,4.558202743530273,4.5570502281188965,,0.0,0.0,,
250,1.7185370922088623,1.708314061164856,,0.1111111111111111,0.18,,
500,1.301753044128418,1.2898467779159546,,46.939393939393945,46.77,,
750,1.1975243091583252,1.1963804960250854,,82.46464646464646,77.44,,
1000,1.1228467226028442,1.2068736553192139,,90.93939393939394,88.61,,
1250,0.6004476547241211,1.566043734550476,,93.22222222222221,93.45,,
1500,0.0927983745932579,2.7950005531311035,,93.73737373737374,96.22,,
1750,0.0727611780166626,3.4491138458251953,,93.61616161616162,96.66,,
2000,0.06619961559772491,3.7007741928100586,,94.44444444444444,97.53,,
2250,0.06289065629243851,3.888205051422119,,94.54545454545455,97.89,,
2500,0.06084860861301422,4.060454368591309,,94.56565656565657,98.07000000000001,,
2750,0.057863060384988785,4.177984714508057,,94.83838383838383,98.52,,
3000,0.055862732231616974,4.30571174621582,,95.2020202020202,98.76,,
3250,0.05437888577580452,4.407271385192871,,95.24242424242424,98.9,,
3500,0.053212303668260574,4.461126804351807,,95.77777777777777,98.96000000000001,,
3750,0.05198404937982559,4.569876670837402,,95.41414141414141,99.27,,
4000,0.051169995218515396,4.614325046539307,,95.60606060606061,99.38,,
4250,0.05025502294301987,4.683193683624268,,95.36363636363636,99.41,,
4500,0.04967493191361427,4.768005847930908,,95.67676767676768,99.53999999999999,,
4750,0.04910096526145935,4.813632488250732,,95.83838383838383,99.7,,
5000,0.048811234533786774,4.795279026031494,,95.61616161616162,99.7,,

baseline_plain_results with the command python train.py config2/addition/plain/train_addition_bal.py \ --ckpt_path_name="ckpt_10000.pt" \ --out_dir="out/addition_plain/$model_root" \ --data_type='text' --data_format='plain' \ --dataset='bal' --train_data_path="train_3digit_10000.txt" \ --eval_addition=True --start='FILE:data/bal/test_10000.txt' \ --wandb_log=False \ --exp_name="$model_root"

iter,train_loss,val_loss,val_ppl,test_acc,train_acc,test_acc_ar,test_acc_other
0,4.558202743530273,4.5570502281188965,,0.0,0.0,,
250,1.7185370922088623,1.708314061164856,,0.050505050505050504,0.13999999999999999,,
500,1.301753044128418,1.2898467779159546,,40.07070707070707,40.739999999999995,,
750,1.1975243091583252,1.1963804960250854,,74.70707070707071,69.31,,
1000,1.1228467226028442,1.2068736553192139,,78.76767676767676,74.45,,
1250,0.6004476547241211,1.566043734550476,,82.37373737373737,79.46,,
1500,0.0927983745932579,2.7950005531311035,,83.53535353535354,81.95,,
1750,0.0727611780166626,3.4491138458251953,,83.07070707070707,82.28,,
2000,0.06619961559772491,3.7007741928100586,,84.58585858585859,83.48,,
2250,0.06289065629243851,3.888205051422119,,84.65656565656565,83.73,,
2500,0.06084860861301422,4.060454368591309,,85.14141414141415,83.74000000000001,,
2750,0.057863060384988785,4.177984714508057,,85.67676767676767,83.81,,
3000,0.055862732231616974,4.30571174621582,,85.97979797979798,85.45,,
3250,0.05437888577580452,4.407271385192871,,86.1919191919192,85.04,,
3500,0.053212303668260574,4.461126804351807,,86.64646464646465,85.35000000000001,,
3750,0.05198404937982559,4.569876670837402,,86.57575757575758,85.21,,
4000,0.051169995218515396,4.614325046539307,,86.60606060606061,86.05000000000001,,
4250,0.05025502294301987,4.683193683624268,,86.55555555555556,85.63,,
4500,0.04967493191361427,4.768005847930908,,86.91919191919192,86.28,,
4750,0.04910096526145935,4.813632488250732,,86.79797979797979,86.66,,
5000,0.048811234533786774,4.795279026031494,,86.92929292929293,87.19,,