bigcode-project / Megatron-LM

Ongoing research training transformer models at scale
Other
371 stars 48 forks source link

fix saving and loading of old checkpoints #67

Closed mayank31398 closed 1 year ago

mayank31398 commented 1 year ago

Current rebase against NVIDIA has broken the saving and loading of old checkpoints. This PR will fix that

RaymondLi0 commented 1 year ago

Hi @mayank31398 ! Thanks for creating this. I think it'd be best to keep a single way of saving the optimizer, so let's not re-introduce the previous way of saving the distributed optimizer. As for loading old checkpoints that were using the distributed optimizer, that can indeed be useful. However I think one probably wouldn't need to load the optimizer state (if continuing training, that would be on a different dataset most probably). Then in that case a previous checkpoint can be loaded easily by using --no_load_optim or --finetune, and renaming the checkpoints to model_optim_rng.pt. I have not tested this yet, but I think this could work. Thus we'd avoid adding more complexity to checkpointing.py. What do you think?

mayank31398 commented 1 year ago

So, my only concern with the current way of saving the optimizer is that it requires a lot of CPU memory for saving. Ideally, this will start breaking at around 80-90B params when the system memory starts to get around 1.2 TB (I think this is what most DGX servers have today). I leave the decision upto you though :)