databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

[BUG] Optimizer Weights Not Reloaded When Training with bf16 Pretrained Weights #80

Open RookieHong opened 6 months ago

RookieHong commented 6 months ago

While working with the load_checkpoint function in the file third_party/Megatron-LM/megatron/checkpointing.py, I noticed that the condition on line 585:

if args.fp16 and optimizer is not None:

should be modified to:

if (args.fp16 or args.bf16) and optimizer is not None:

Without it, when using bf16 and attempting to load pretrained model weights to continue training, the weights in the optimizer will not be reset to the pretrained model weights. As a result, the training process becomes virtually no different from starting the training from scratch.

I am aware that the proper course of action would be to submit a Pull Request to the Megatron-LM repository linked herein. However, I wanted to raise this issue here as well to alert others of the current problem in the codebase.

tgale96 commented 6 months ago

Hi! Sorry for the delay!

Wow! Good catch and thanks for posting here. We'd welcome a PR on the Megatron-LM fork to fix this!