NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.15k stars 2.28k forks source link

[BUG] Loss difference when training with FP8 vs. BF16 MoE #1152

Open viclzhu opened 1 week ago

viclzhu commented 1 week ago

Describe the bug When enabling FP8 mixed precision during training of a Mixtral model (SequentialMLP expert layer), we are observing that training and validation loss differs more than expected.

To Reproduce Start with examples/mixtral/train_mixtral_8x7b_distributed.sh.

Using tokenizer.model from https://huggingface.co/mistralai/Mixtral-8x7B-v0.1.

Expected behavior Training and validation loss across BF16 and FP8 MoE should be approximately the same.

Stack trace/logs

# BF16
3:  [2024-09-20 19:40:51] iteration        1/     100 | consumed samples:          256 | elapsed time per iteration (ms): 46555.7 | throughput per GPU (TFLOP/s/GPU): 58.4 | learning rate: 2.000000E-07 | global batch size:   256 | lm loss: 1.037833E+01 | loss scale: 1.0 | grad norm: 1.664 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 19:41:04] iteration        2/     100 | consumed samples:          512 | elapsed time per iteration (ms): 13477.5 | throughput per GPU (TFLOP/s/GPU): 201.6 | learning rate: 4.000000E-07 | global batch size:   256 | lm loss: 1.037832E+01 | loss scale: 1.0 | grad norm: 1.712 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 19:41:18] iteration        3/     100 | consumed samples:          768 | elapsed time per iteration (ms): 13434
...
3:  [2024-09-20 20:04:41] iteration       98/     100 | consumed samples:        25088 | elapsed time per iteration (ms): 15304.1 | throughput per GPU (TFLOP/s/GPU): 177.6 | learning rate: 1.960000E-05 | global batch size:   256 | lm loss: 7.280738E+00 | loss scale: 1.0 | grad norm: 0.965 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 20:04:57] iteration       99/     100 | consumed samples:        25344 | elapsed time per iteration (ms): 15364.6 | throughput per GPU (TFLOP/s/GPU): 176.9 | learning rate: 1.980000E-05 | global batch size:   256 | lm loss: 7.251996E+00 | loss scale: 1.0 | grad norm: 0.690 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 20:05:12] iteration      100/     100 | consumed samples:        25600 | elapsed time per iteration (ms): 15208.5 | throughput per GPU (TFLOP/s/GPU): 178.7 | learning rate: 2.000000E-05 | global batch size:   256 | lm loss: 7.260425E+00 | loss scale: 1.0 | grad norm: 0.632 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  validation loss at iteration 100 on validation set | lm loss value: 7.251174E+00 | lm loss PPL: 1.409760E+03 |
3:  validation loss at iteration 100 on test set | lm loss value: 7.255818E+00 | lm loss PPL: 1.416322E+03 |
# FP08
3:  [2024-09-20 19:41:08] iteration        1/     100 | consumed samples:          256 | elapsed time per iteration (ms): 62104.1 | throughput per GPU (TFLOP/s/GPU): 43.8 | learning rate: 2.000000E-07 | global batch size:   256 | lm loss: 1.037847E+01 | loss scale: 1.0 | grad norm: 0.534 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 19:41:21] iteration        2/     100 | consumed samples:          512 | elapsed time per iteration (ms): 13276.7 | throughput per GPU (TFLOP/s/GPU): 204.7 | learning rate: 4.000000E-07 | global batch size:   256 | lm loss: 1.037833E+01 | loss scale: 1.0 | grad norm: 0.571 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 19:41:34] iteration        3/     100 | consumed samples:          768 | elapsed time per iteration (ms): 13319.2 | throughput per GPU (TFLOP/s/GPU): 204.0 | learning rate: 6.000000E-07 | global batch size:   256 | lm loss: 1.037832E+01 | loss scale: 1.0 | grad norm: 0.568 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
...
3:  [2024-09-20 20:03:04] iteration       98/     100 | consumed samples:        25088 | elapsed time per iteration (ms): 13616.6 | throughput per GPU (TFLOP/s/GPU): 199.6 | learning rate: 1.960000E-05 | global batch size:   256 | lm loss: 7.739647E+00 | loss scale: 1.0 | grad norm: 4.661 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 20:03:18] iteration       99/     100 | consumed samples:        25344 | elapsed time per iteration (ms): 13650.1 | throughput per GPU (TFLOP/s/GPU): 199.1 | learning rate: 1.980000E-05 | global batch size:   256 | lm loss: 7.716366E+00 | loss scale: 1.0 | grad norm: 4.697 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 20:03:32] iteration      100/     100 | consumed samples:        25600 | elapsed time per iteration (ms): 13625.0 | throughput per GPU (TFLOP/s/GPU): 199.5 | learning rate: 2.000000E-05 | global batch size:   256 | lm loss: 7.721111E+00 | loss scale: 1.0 | grad norm: 4.632 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  validation loss at iteration 100 on validation set | lm loss value: 7.706149E+00 | lm loss PPL: 2.221969E+03 |
3:  validation loss at iteration 100 on test set | lm loss value: 7.708397E+00 | lm loss PPL: 2.226969E+03 |

moe_megatron_bf16_22455.log moe_megatron_fp8_22454.log

Environment (please complete the following information):

Additional context We also experimented with enabling FP8 with the TEGroupedMLP module (padding inputs for FP8), and see some loss differences there as well.

lumosity4tpj commented 1 week ago

I also encountered this problem, but I was on a 1b dense model and the 200 step difference reached 0.7

lhb8125 commented 1 week ago

Try disabling the recomputation.

lumosity4tpj commented 1 week ago

yes, I find this. It is effective for me at TE=1.10, like this issue

Try disabling the recomputation.

viclzhu commented 6 days ago

Thanks for the responses!

I re-ran with recomputation disabled, and also reduced the num_layers from 32 -> 16 (due to memory constraints) and still observe a loss difference (though the difference is much smaller!).

Is this level of loss difference expected for BF16 vs. FP08? It appears to be around 2e-2 for the steps I've run.

Changes:

recompute_granularity: None # set by not passing --recompute_granularity
--num-layers 16
# BF16
3:  [2024-09-24 21:17:13] iteration        1/     100 | consumed samples:          256 | elapsed time per iteration (ms): 38421.2 | throughput per GPU (TFLOP/s/GPU): 35.7 | learning rate: 2.000000E-07 | global batch size:   256 | lm loss: 1.037923E+01 | loss scale: 1.0 | grad norm: 1.192 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-24 21:17:18] iteration        2/     100 | consumed samples:          512 | elapsed time per iteration (ms): 5103.0 | throughput per GPU (TFLOP/s/GPU): 268.8 | learning rate: 4.000000E-07 | global batch size:   256 | lm loss: 1.037914E+01 | loss scale: 1.0 | grad norm: 1.224 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-24 21:17:23] iteration        3/     100 | consumed samples:          768 | elapsed time per iteration (ms): 5034.2 | throughput per GPU (TFLOP/s/GPU): 272.5 | learning rate: 6.000000E-07 | global batch size:   256 | lm loss: 1.037856E+01 | loss scale: 1.0 | grad norm: 1.224 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |

...
3:  [2024-09-24 21:26:49] iteration       98/     100 | consumed samples:        25088 | elapsed time per iteration (ms): 5645.3 | throughput per GPU (TFLOP/s/GPU): 243.0 | learning rate: 1.960000E-05 | global batch size:   256 | lm loss: 7.265516E+00 | loss scale: 1.0 | grad norm: 1.910 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-24 21:26:55] iteration       99/     100 | consumed samples:        25344 | elapsed time per iteration (ms): 5967.9 | throughput per GPU (TFLOP/s/GPU): 229.9 | learning rate: 1.980000E-05 | global batch size:   256 | lm loss: 7.227739E+00 | loss scale: 1.0 | grad norm: 0.961 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-24 21:27:01] iteration      100/     100 | consumed samples:        25600 | elapsed time per iteration (ms): 5498.0 | throughput per GPU (TFLOP/s/GPU): 249.5 | learning rate: 2.000000E-05 | global batch size:   256 | lm loss: 7.227658E+00 | loss scale: 1.0 | grad norm: 0.812 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |

3:  validation loss at iteration 100 on validation set | lm loss value: 7.274964E+00 | lm loss PPL: 1.443699E+03 | 
3:  validation loss at iteration 100 on test set | lm loss value: 7.276310E+00 | lm loss PPL: 1.445644E+03 | 

# FP08
3:  [2024-09-24 20:41:24] iteration        1/     100 | consumed samples:          256 | elapsed time per iteration (ms): 52856.1 | throughput per GPU (TFLOP/s/GPU): 26.0 | learning rate: 2.000000E-07 | global batch size:   256 | lm loss: 1.037921E+01 | loss scale: 1.0 | grad norm: 0.940 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-24 20:41:29] iteration        2/     100 | consumed samples:          512 | elapsed time per iteration (ms): 4894.2 | throughput per GPU (TFLOP/s/GPU): 280.3 | learning rate: 4.000000E-07 | global batch size:   256 | lm loss: 1.037913E+01 | loss scale: 1.0 | grad norm: 1.203 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-24 20:41:33] iteration        3/     100 | consumed samples:          768 | elapsed time per iteration (ms): 4860.1 | throughput per GPU (TFLOP/s/GPU): 282.3 | learning rate: 6.000000E-07 | global batch size:   256 | lm loss: 1.037867E+01 | loss scale: 1.0 | grad norm: 1.216 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
...
3:  [2024-09-24 20:50:27] iteration       98/     100 | consumed samples:        25088 | elapsed time per iteration (ms): 5348.4 | throughput per GPU (TFLOP/s/GPU): 256.5 | learning rate: 1.960000E-05 | global batch size:   256 | lm loss: 7.236306E+00 | loss scale: 1.0 | grad norm: 0.881 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-24 20:50:32] iteration       99/     100 | consumed samples:        25344 | elapsed time per iteration (ms): 5404.4 | throughput per GPU (TFLOP/s/GPU): 253.8 | learning rate: 1.980000E-05 | global batch size:   256 | lm loss: 7.206497E+00 | loss scale: 1.0 | grad norm: 1.128 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-24 20:50:38] iteration      100/     100 | consumed samples:        25600 | elapsed time per iteration (ms): 5355.2 | throughput per GPU (TFLOP/s/GPU): 256.2 | learning rate: 2.000000E-05 | global batch size:   256 | lm loss: 7.221344E+00 | loss scale: 1.0 | grad norm: 1.499 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |

3:  validation loss at iteration 100 on validation set | lm loss value: 7.259491E+00 | lm loss PPL: 1.421533E+03 | 
3:  validation loss at iteration 100 on test set | lm loss value: 7.258996E+00 | lm loss PPL: 1.420830E+03 | 

moe_megatron_bf16_no_recompute_l16_22863.log moe_megatron_fp8_no_recompute_l16_22836.log

lhb8125 commented 6 days ago

I am not sure if disabling recomputation does work in your case, maybe you can double-check it by enabling recomputation and using 16 layers. I think the loss diff is acceptable at an early training stage, we expect the loss diff to decrease to <1e-2 after several billion tokens.

viclzhu commented 5 days ago

I see, sounds good will try it!

I also ran the bf16/fp08 no recompute jobs for a bit longer and observe the following:

# FP08
iteration     1000/   20000 | consumed samples:       256000 | elapsed time per iteration (ms): 5484.6 | throughput per GPU (TFLOP/s/GPU): 250.1 | learning rate: 9.999946E-05 | global batch size:   256 | lm loss: 3.147666E+00 | loss scale: 1.0 | grad norm: 0.514 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |

validation loss at iteration 200 | lm loss value: 6.496762E+00 | lm loss PPL: 6.629913E+02 |
validation loss at iteration 400 | lm loss value: 5.496616E+00 | lm loss PPL: 2.438654E+02 |
validation loss at iteration 600 | lm loss value: 4.583405E+00 | lm loss PPL: 9.784695E+01 |
validation loss at iteration 800 | lm loss value: 3.997547E+00 | lm loss PPL: 5.446438E+01 |
validation loss at iteration 1000 | lm loss value: 3.846950E+00 | lm loss PPL: 4.684997E+01 |

# BF16
iteration     1000/   20000 | consumed samples:       256000 | elapsed time per iteration (ms): 5770.8 | throughput per GPU (TFLOP/s/GPU): 237.7 | learning rate: 9.999946E-05 | global batch size:   256 | lm loss: 3.054385E+00 | loss scale: 1.0 | grad norm: 0.968 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |

validation loss at iteration 200 | lm loss value: 6.399312E+00 | lm loss PPL: 6.014314E+02 |
validation loss at iteration 400 | lm loss value: 5.297484E+00 | lm loss PPL: 1.998334E+02 |
validation loss at iteration 600 | lm loss value: 4.482264E+00 | lm loss PPL: 8.843467E+01 |
validation loss at iteration 800 | lm loss value: 3.942853E+00 | lm loss PPL: 5.156553E+01 |
validation loss at iteration 1000 | lm loss value: 3.890827E+00 | lm loss PPL: 4.895137E+01 |

My understanding is that this difference in train and validation loss is then acceptable due to the low number of tokens processed, and that with longer training time, the curves will be expected to converge on each other?