Closed eee4017 closed 14 hours ago
We have two operations that has branching issue:
FP8 Weight Casting in Linear Layers
Update Scale Inverse in Amax and Scale Update
amax_and_scale_update_inplace
kernel has a data dependency with previous kernels, making it unsuitable for reordering. Reordering would move many kernels out of the CUDAGraph, reducing the graph's coverage.update_weight_scale_inv
boolean parameter, which can be determined using the step ID. We only need to update update_weight_scale_inv
. This can be addressed with the set parameter mechanism, where we reset this boolean each time the graph is launched./te-ci paddle
The CI is failing while building since it can't find glog:
In file included from /opt/transformerengine/transformer_engine/paddle/csrc/custom_ops.cu:13:
/usr/local/lib/python3.10/dist-packages/paddle/include/paddle/phi/backends/gpu/cuda/cuda_graph.h:30:10: fatal error: glog/logging.h: No such file or directory
30 | #include "glog/logging.h"
/te-ci paddle
/te-ci paddle
/te-ci paddle
Description
In this PR, we introduce support for CUDAGraph in TE-PaddlePaddle. The primary issue with CUDAGraph is managing branching. Such as when enabling
weight_cache
, specific operations are required only in the first microbatch, but branching within CUDAGraph is undesirable.Solutions to the Branching Problem in CUDAGraph
Solution 1: Utilizing Multiple Graphs (TE-PyTorch Solution)
TE-PyTorch addresses branching by recording separate graphs: one for the true branch and another for the false branch. This method necessitates maintaining distinct CUDA graphs for each microbatch. However, this solution can become complex, especially with the Pipeline Parallelism mechanism, as the number of required graphs doubles with each branching point. Managing 2^N graphs for N branchings can lead to significant challenges.
Solution 2: Reordering Kernel Sequences
To simplify the process, we reorder the kernel sequences outside the computational graph, keeping the branching outside the graph's scope. This approach avoids the complexity of managing multiple graphs and ensures efficient execution within CUDAGraph.
Changes Introduced in this PR
set_rng_state
) using the set parameter mechanism in Paddle.amax_and_scale_update_inplace
with the set parameter mechanism in Paddle. This kernel is manually issued, preserving the legacy kernel for now while we explore other solutions.Type of Change
Checklist
Testing
Although we have not yet implemented unit tests, integration tests have been validated. The following parallelism configurations are supported:
Performance Testing on GPT-3 1.3B on 4 H100 GPUs
Performance Testing on GPT-3 175B with 64 H100 GPUs
We achieved convergence within the first 1024 steps on an 8-node cluster.