volcengine / veScale

A PyTorch Native LLM Training Framework
http://vescale.xyz
Apache License 2.0
553 stars 26 forks source link

[DDP&DOptimizer] Open Source #12

Closed Vremold closed 5 months ago

Vremold commented 5 months ago

In this PR, we open source our DDP & DOptimizer, Yo~

veScale Distributed Data Parallel (DDP)

Distributed Data Parallel (DDP) is a distributed training strategy that partitions the input data across multiple devices, such as multiple GPUs, and replicates the model on each device. On top of this, various ZeRO features can be implemented.

veScale DDP is primarily inherited from Megatron-LM's DDP. We extend the compatibility of the DDP implementation with our DTensor.

DDP is a module wrapper that creates a flattened grad buffer to store the gradients produced by the model backwarding. This is achieved by adding a hook to the grad_fn of the parameters, which fill DTensor gradient output by PyTorch Autograd engine to the pre-allocated grad buffer. The purpose of grad buffer is to accelerate the all-reduce process for gradient updates during distributed training, as it only needs to be performed once for the entire buffer, rather than once per parameter.

On the basis of this, there are some optimizations can be achieved:

  1. Overlap gradient all-reduce with model backwarding procedure.
  2. Reduce-scatter the gradient rather than all-reduce gradient if we have a veScale DistributedOptimizer (a "ZeRO 2+" optimizer) installed.

veScale optimizers

We provide two kinds of optimizers in distributed training.

BasicOptimizer

A simple optimizer warpper plus some utilities for distributed training, such as recover flattened gradient from DDP and trigger gradient all-reduce for LayerNorm (or some other similar) blocks in Sequence Parallel. BasicOptimizer is not a ZeRO optimizer.

DistributedOptimizer

A "ZeRO 2+" optimizer. Simliar to DDP, veScale DistributedOptimizer is primarily inherited from Megatron-LM's DistributedOptimizer. We extend compatibility of its implementation with our DTensor.

In DistributedOptimizer, the model parameters and gradients are further split. Each DP rank only obtains the corresponding gradient, updates the corresponding parameters, maintaining the corresponding optimizer states. Therefore, a typical optimizer initialization and step process of DistributedOptimizer includes the following stages:

  1. At initialzation, split model parameters across all DP ranks. It is not a real split. Each DP rank actually owns a partial view of the whole model parameters. Optimizer's original param_groups will be replaced with the dp-sharded parameter.

  2. At step, copy main_grad attached at original parameter by DDP to the dp-sharded parameters.

  3. Run optimizer.step().

  4. Copy updated dp-sharded parameters to a specific param buffer which reuse the storage of grad buffer for later param all-gather. We can have an optimization here, i.e., overlap the param all-gather with the model forwarding.

Credit to veScale DDP/DOptim Team

This endeavor would not have been possible without the contribution of our team which includes but not limited to: @SerailHydra @Vremold @JsBlueCat @jc-bytedance @MingjiHan99 @lichen225 @MackZackA @leonardo0lyj.

Also thanks to the great guidance and leadership of: @liwenchangbdbz @pengyanghua @eric-haibin-lin @Meteorix