mlfoundations / open_lm

A repository for research on medium sized language models.
MIT License
476 stars 71 forks source link

Figure out why AdamW + gradient accumulation leads to different results for test case #126

Open achalddave opened 11 months ago

achalddave commented 11 months ago

In https://github.com/mlfoundations/open_lm/pull/125, we had to switch our gradient accumulation tests from SGD to AdamW to make gradient accumulation tests pass. It's unclear why this is the case; anecdotally, when training models with AdamW, training curves look similar with and without gradient accumulation. This could be a numerical issue, or some specific issue with AdamW that makes gradient accumulation behave differently.

sedrick-keh-tri commented 11 months ago

The issue might be with the forward pass.

Created this test to verify the forward pass in one batch vs multiple batches, and it seems there are some numerical discrepancies there: https://github.com/sedrick-keh-tri/open_lm/commit/f0aa4e72bf038efc2dd793ca29b8a5bd7727f83f

(Update: Seems like this forward test passes with FP32 until threshold=1e-5 but fails with 1e-6 onwards) (Update 2: Same pattern holds for FP16, i.e. works at 1e-5 but not 1e-6 onwards)

achalddave commented 11 months ago

hm, it's strange that the accumulation tests pass with SGD at tol=1e-7 then if the forward pass is failing checks at threshold=1e-6...

sedrick-keh-tri commented 11 months ago

I am very confused at this point. Seems like the backward is generally more precise than the forward... is this something that's even possible?

I tried taking your exact branch and just editing the train_one_epoch to return the forward outputs. Forward output checking fails with precision 1e-7, but the backward tests pass fine with 1e-7.

The trend of backward being more precise than forward seems to hold for some reason. For bf16, I tried doing single-layer and turning off the norm (this seems to affect the precision somehow). With 1 layer, forward threshold is 1e-2 but backward threshold is 1e-4.

mitchellnw commented 11 months ago

A useful test here could also just be a short training run with and without grad accum such that we'd expect the curves to be identical. If the model with grad accum is clearly worse then we know something is wrong. If they achieve very similar loss then that supports noise.

sedrick-keh-tri commented 10 months ago

Ran some 160M experiments here: Wandb report

Loss curves for accum_freq=1,2,4 look basically identical. image

Validation perplexities are as follows: accum=1 evaluation perplexity: 13.26832732791897 accum=2 evaluation perplexity: 13.302758042158114 accum=4 evaluation perplexity: 13.271497238202906

For reference, the evaluation perplexity of the other accum=4 experiment is 13.414325512942533, so there seems to be a fair bit of variation even with identical runs

achalddave commented 10 months ago

This is great, thanks @sedrick-keh-tri! I feel comfortable with our grad accum implementation for now, then. Let's leave this issue open in case anyone wants to dig into why the test fails, but continue using grad accum as implemented.