Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.88k stars 3.34k forks source link

enable loading `universal checkpointing` checkpoint in `DeepSpeedStrategy` #20065

Open zhoubay opened 1 month ago

zhoubay commented 1 month ago

Description & Motivation

After I trained a model in some numbers of gpus, say, 8 gpus for a while, It's difficult to load the checkpoint to 16 gpus with optimizer and model states unchanged. The deepspeed has developed the universal checkpointing strategy to solve this problem, but I didn't see the pytorch-lightning has this feature.

Pitch

I want the pytorch-lightning could support this feature

Alternatives

try to add universal_checkpoint as a param of DeepSpeedStrategy and modify the class refering to https://www.deepspeed.ai/tutorials/universal-checkpointing/

Additional context

No response

cc @borda @awaelchli

zhoubay commented 1 month ago

I've checked that, simply adding self.config["checkpoint"] = {"load_universal": True} after self._create_default_config might work.

So, the solution might be adding some config in self._create_default_config function.