volcengine / veScale

A PyTorch Native LLM Training Framework
http://vescale.xyz
Apache License 2.0
575 stars 28 forks source link

[RFC] Single-Device-Abstract DDP #52

Open lllukehuang opened 2 weeks ago

lllukehuang commented 2 weeks ago

Single-Device-Abstract DDP

Motivation

In current PyTorch DDP, when training a model with Dropout operations, the final results obtained from distributed training will not be consistent with those obtained from single-machine training. This is mainly because the RNG state offset is copied across DP workers, and the Dropout mask calculation results are the same for all DP workers. In contrast, on a single device, the Dropout mask results for sequentially input micro-batches do not share this dependency, leading to a misalignment between DP dropout operations and single-machine computation results. We resolve the issue in veScale via deep understanding of how GPUs generate random numbers parallelly and torch cuda random generation implementation patch.

image

We have validated the prototype on several open-source models, including Llama2, Llama3, GPT2, and Mixtual, and it successfully ensures that the loss curve remains consistent with single-device training when DP is enabled.

We welcome any and all feedback on this effort!

Design

To ensure consistency in random number generation between distributed and single-machine scenarios, veScale proposed the ThreadBasedRNGTracer, which regulates the thread ID used during CUDA random number generation. This ensures that the thread ID used are identical in both single-machine and parallel scenarios. However, the existing Random Op processing only considers thread ID adjustments in Tensor Parallel scenarios and overlooks the consistency issues in random number generation caused by Data Parallelism.

Following veScale's former approach to handling TP Random operations, we can inject additional DP-related information into the torch CUDA RNG state. During the CUDA random generation process, this DP information can be retrieved from the RNG state, enabling the generation of correct local random results.

image

For example, consider a scenario with 4 GPUs and a parallel setup of DP=2 and TP=2. In this case, veScale manually adjusts the thread ID on each GPU based on the corresponding parallel configuration to ensure consistency with the single-machine state.

image

RNGTracker API

To ensure the correct generation of single-device-abstracted random numbers when DP is enabled, users need to wrap the random number generation code with the _distribute_region context manager. The DTensorSpec provided to this context manager includes information related to TP, while the dp_size and dp_rank parameters specify the DP-related information.

import vescale.dtensor.random as random

with random._rng_tracker._distribute_region(DTensorSpec, dp_size, dp_rank):
    # here to process random gen ops

Compared with the current master branch, the_distribute_region input now includes dp_size and dp_rank. When these two values are not provided, dp_size defaults to 1 and dp_rank defaults to 0, meaning that the DP-based random number generation adjustment is not enabled by default.


cc @leonardo0lyj @MackZackA @JsBlueCat

leonardo0lyj commented 1 week ago

@lllukehuang Great work, indeed!