intelligent-machine-learning / dlrover

DLRover: An Automatic Distributed Deep Learning System
Other
1.14k stars 144 forks source link

megatron distributed optimizer slow down the flash checkpoint #985

Closed FKyms closed 6 months ago

FKyms commented 6 months ago

hi

在 megatron 中使用分布式优化器时,优化器需要执行内存操作和通信同步,这使得flash checkpoint可能会降级到分钟级别的效率。请问,对于这个问题,我们有什么优化方案和计划吗?

when use distributed optimizer in megatron, the optimizer needs to perform memory operations and communication synchronization, this makes flash checkpoints on the second level possible degrade to the minute level.

Do we have any optimization solutions and plans for this issue?

workingloong commented 6 months ago

是的,这里的解决方案应该是每个 rank 存自己的优化器分片。DeepSpeed 是已经实现了这个方案的, Flash CKPT 也支持 DeepSpeed。我们训练LLM 是用的 FSDP,针对 FSDP 也实现了各个 rank 保存自己的模型分片。我们后续会调研下,能否支持 distributed optimizer 的分开保存。

Yes, each rank should save its shard of the optimizer. DeepSpeed has already implemented this solution. We train LLM using FSDP and have also implemented saving individual model shards for each rank with FSDP. We will look into whether the each rank saves its shard of distributed optimizer in Megatron-LM.

workingloong commented 6 months ago

Megatron-LM 从各个 rank 上获取 optimizer 的 tensor 时使用的是 gloo 作为通信后端,这个比 NCCL 会慢很多。你们可以修改下这个通信的 group,使用 nccl 来做通信后端,应该会快很多。

Megatron-LM rank 0 gathers all optimizer tensors from other ranks using gloo as the communication backend. It is much slower than NCCL. You could set the communication group with NCCL to gather all optimizer tensors wich may be mush faster than gloo.

data_parallel_group_gloo = mpu.get_data_parallel_group_gloo(with_context_parallel=True)

workingloong commented 6 months ago

We have released dlrover[torch]==0.3.4 to save and load the distributed optimizer in parallel. You can pip install dlrover[torch]==0.3.4 -U and use the APIs in megatron/training.py.

# from megatron.checkpointing import load_checkpoint
# from megatron.checkpointing import save_checkpoint

from dlrover.trainer.torch.flash_checkpoint.megatron_dist_ckpt import save_checkpoint
from dlrover.trainer.torch.flash_checkpoint.megatron_dist_ckpt import load_checkpoint