Closed mayank31398 closed 1 year ago
Hi @mayank31398 ! Thanks for creating this.
I think it'd be best to keep a single way of saving the optimizer, so let's not re-introduce the previous way of saving the distributed optimizer.
As for loading old checkpoints that were using the distributed optimizer, that can indeed be useful. However I think one probably wouldn't need to load the optimizer state (if continuing training, that would be on a different dataset most probably). Then in that case a previous checkpoint can be loaded easily by using --no_load_optim
or --finetune
, and renaming the checkpoints to model_optim_rng.pt
. I have not tested this yet, but I think this could work. Thus we'd avoid adding more complexity to checkpointing.py
.
What do you think?
So, my only concern with the current way of saving the optimizer is that it requires a lot of CPU memory for saving. Ideally, this will start breaking at around 80-90B params when the system memory starts to get around 1.2 TB (I think this is what most DGX servers have today). I leave the decision upto you though :)
Current rebase against NVIDIA has broken the saving and loading of old checkpoints. This PR will fix that