alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.05k stars 353 forks source link

Pipeline Stage Composition Algorithm #113

Closed zhuohan123 closed 2 years ago

zhuohan123 commented 2 years ago

Hao's algorithm in #59 is composed of two stages, or a small loop in a big loop, where:

  1. The big loop: We first iterate enumerate different sub-mesh shapes. Then for each shape (e.g. a 1x2 mesh), we try to cover the full mesh with the sub-meshes with this specific shape. For each shape, we will find one specific covering solution.
  2. The small loop: For each covering solution, we find a way to assign different pipeline layers to different pipeline stages that minimizes the communication and computation costs.

However, the algorithm in the big loop is slightly misaligned with the goal of the project: For normal GPU clusters, we can always find a greedy way to cover the full mesh: for example, for 8-GPU nodes, we can have 4 1x2 meshes in a node and repeat this pattern for each node. In addition, for most neural networks, the cost of communication in pipeline parallelism is smaller than the cost of communication in data parallelism. So the greedy solution should perform well in most cases.

Also, this algorithm also doesn't consider that we can have multiple different types of sub-mesh shapes. This should be important for networks with non-uniform shapes (e.g. ResNet).


However, finding the optimal way to slice a 2D mesh with minimized pipeline execution latency is a very hard problem and we can't find a polynomial DP algorithm (poly wrt total mesh size) that can directly solve this. Our current proposal is to have some constraints on both the cluster mesh shape and the cluster mesh shapes. Specifically, we have:

  1. For cluster mesh shape, we assume it's of the shape n x 2^m.
  2. The possible sub-mesh shapes are 1 x 1, 1 x 2, 1 x 4, 1 x 8, ... 1 x 2^m, 2 x 2^m, 3 x 2^m, ... n x 2^m.

Then we can utilize a 1D DP algorithm to get the solution. More specifically, we transform the original problem into the following one: Assume we have in total n*2^m devices, find an optimal way to assign layers to device groups, where each device group can have 1, 2, 4, 8, ..., 2^m, 2*2^m, ..., n*2^m devices. This can be solved by defining the DP state as DP[i][k] that represents the optimal cost of putting layers 0 - (i-1) on k devices. Then we can derive

DP[i][k] = min_(j <= i, s < = k, s is a feasible device group size) {DP[i - j][k - s] + Computation cost of putting layers j to i-1 on s devices + communication cost}.

Because of our specific selection of sub-mesh shapes, we can guarantee that we can map the 1D solution back to the 2D mesh.

Cons of this method:

  1. Constraints of the cluster mesh shape. We might be able to loose this constraint by generalizing to n x m meshes and make sure the size of small sub-meshes is a factor of m.
  2. We still assume that only consecutive layers can be put on a sub-mesh. This doesn't cover the case in the updated megatron-lm paper.
  3. We assume the cost of communication in pipeline parallelism is smaller than the cost of communication in data parallelism, which might not be true for all networks.

The issue of the above algorithm is that the total runtime of a pipeline is determined by the following formula:

pipeline time = sum of all stages's time + (#microbatches - 1) * maximal stage time

Some other points:

Enumerate all possible maximum stage time:
  for i in range(n_layers):
    for j in range(n_devices):
      for k in range(i):
        for s in possible submeshes:
          if compute cost + communication cost < maximum:
            f[i][j] = min(f[i][j], f[k][j - s] + compute cost from k to i on mesh s + communication cost for layer k on mesh s)
  Cost for this stage time = f[n_layers][n_devices] + (B - 1) * current maximum stage time

How to get communication cost:

  1. Only count the receiver's receiving ability. Assume sender has infinite bandwidth.
  2. Use a greedy solution: directly use the mesh shape that optimizes for each DP subproblem.

Some other issues:

  1. Right now to profile computation cost for a 10 layer bert on a single node with 4 GPUs takes 10 mins without any parallelism.
zhuohan123 commented 2 years ago

121