openai / grade-school-math

992 stars 145 forks source link

How can I replicate these results with the GPT-3 API? #3

Closed ofirpress closed 2 years ago

ofirpress commented 2 years ago

Hi! I'm trying to replicate your results with the GPT-3 API. This is how I've preprocessed the train file: {"prompt": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n\n###\n\n", "completion": " 72"} And here's the first line of the valid file: {"prompt": "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\n\n###\n\n", "completion": " 18"}

I'm trying to replicate the result from the paper which doesn't use the intermediate steps:

If we instead finetune a 6B model to directly output the final answer without any intermediate steps, performance drops drastically from 20.6% to 5.2%.

I run this command: openai api fine_tunes.create -t train.jsonl -v valid.jsonl -m curie --n_epochs 20

The training loss and accuracy looks great, but the validation accuracy just goes crazy throughout all of training:

13488,10571584,107904,0.006697522398264077,1.0,1.0,0.2868101323364219,0.0,0.0
13567,10633656,108536,0.004670373035388367,1.0,1.0,0.40770788124678037,0.0,0.0
13646,10695280,109168,0.006482512512579658,1.0,1.0,0.23502722387660283,0.125,0.125
13722,10755664,109776,0.007619243074462266,1.0,1.0,0.2898588183223495,0.125,0.125
13800,10817728,110400,0.007864901396759683,1.0,1.0,0.3764918240812498,0.0,0.0
13878,10877360,111024,0.007001778227947788,1.0,1.0,0.34287615764819,0.0,0.0
13955,10938520,111640,0.005789751697015372,1.0,1.0,0.343928185654087,0.125,0.125
14033,11000392,112264,0.005374998155119428,1.0,1.0,0.2626242895447271,0.0,0.0
14110,11061552,112880,0.007522114686767156,1.0,1.0,0.303700584687286,0.0,0.0
14189,11122920,113512,0.007042250184718232,1.0,1.0,0.25505893241963656,0.25,0.25
14267,11184344,114136,0.005404598385505561,1.0,1.0,0.4671648049487169,0.0,0.0
14344,11245504,114752,0.00575376992666942,1.0,1.0,0.3507942277975161,0.125,0.125
14424,11306624,115392,0.00572749277984494,1.0,1.0,0.23907356193368778,0.25,0.25
14502,11368880,116016,0.006820225469565155,1.0,1.0,0.321503115529883,0.25,0.25
14579,11428760,116632,0.006588759870975103,1.0,1.0,0.33556188792293684,0.0,0.0
14657,11489160,117256,0.008451912472405503,1.0,1.0,0.26277955851476575,0.125,0.125
14735,11550648,117880,0.005669759222337177,1.0,1.0,0.29857896823688596,0.0,0.0
14813,11613672,118504,0.0038765392963223673,1.0,1.0,0.2643208104990911,0.0,0.0
14893,11674920,119144,0.005361676043338958,1.0,1.0,0.3359258272212552,0.125,0.125
14971,11735512,119768,0.005775294550614619,1.0,1.0,0.3953682584280405,0.0,0.1111111111111111
15049,11796040,120392,0.006509018494322124,1.0,1.0,0.38131386485963575,0.0,0.0
15128,11857408,121024,0.006104460320729444,1.0,1.0,0.36498191312040756,0.125,0.125
15205,11918056,121640,0.008038989705175963,1.0,1.0,0.34087099317847736,0.0,0.0
15282,11980816,122256,0.0049923646901917085,1.0,1.0,0.3232976954536787,0.125,0.125
15361,12042376,122888,0.004565948148458752,1.0,1.0,0.43752457097146086,0.0,0.0
15437,12103336,123496,0.008198144564371599,1.0,1.0,0.3357829815657358,0.125,0.125

I'm also not sure how in some iterations validation_sequence_accuracy and validation_token_accuracy even though all validation completions are of length 1 (I pruned all the prompts where the answer were longer).

I'd be grateful for any advice here. I've tried different LRs but that didn't seem to help.

Thanks!

kcobbe commented 2 years ago

At first glance these results don't look unreasonable. This slice of metrics is clearly from late in training when you're well past the point of overfitting. Presumably the train/validation losses are much closer during the first epoch.

Note that, from the API docs, "we periodically calculate metrics on batches of validation data during training time". Since this is done at the batch level, the high variance in these validation metrics is not surprising -- they don't look crazy to me. In fact, based on this batch data, it looks like the overall validation accuracy would come out to something close to the ~5% reported in the paper. I recommend evaluating the accuracy across the whole validation set to see if this is the case.

I'm not sure what would cause the occasional discrepancy between validation sequence/token accuracy. You're correct that if you pruned to single token completions, they should agree.

ofirpress commented 2 years ago

Thanks!