NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.2k stars 1.36k forks source link

DP-independent checkpoint format for distributed Adam optimizer #1704

Closed timmoon10 closed 11 months ago

timmoon10 commented 11 months ago

A major limitation of the distributed optimizer is that checkpointing assumes that the saving and loading models have exactly the same parallel configuration (e.g. the data-parallel size) and optimizer options (e.g. the bucket size). This is because we blindly save and load the model's state, which is sensitive to the bucketing scheme. This PR modifies the format by repacking the optimizer state into per-parameter CPU buffers and discarding the bucket-dependent information . When loading the checkpoint, the optimizer can simply copy from these buffers using its own bucketing scheme.

We do some checks so that checkpoints in the old format can still be loaded (assuming that the parallel config and optimizer options are the same). A simple way to convert an old checkpoint is to load it and immediately save a new one.

I have added some unit tests and I also confirm that I get consistent results when loading NeMo-Megatron GPT with different data-parallel sizes.

This is heavily influenced by https://github.com/NVIDIA/NeMo/pull/7140. Some work may be needed to make sure that these two PRs are compatible.

timmoon10 commented 11 months ago

This is not quite compatible with https://github.com/NVIDIA/NeMo/pull/7140, but that is still preliminary and could be simplified with this checkpoint format. Pinging @mikolajblaz.