alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.06k stars 353 forks source link

[PERF] Handling cases like (2, 1024, ...) when chunks = 4 in local allgather #350

Closed zhisbug closed 2 years ago

zhisbug commented 2 years ago

In cross mesh resharding, instead of rewriting the sharding spec to let the split happen solely on second dim 1024, which would be very slow for communication, we should let the split happen first on both the first dim 2 and the rest on the second dim.

ZYHowell commented 2 years ago

The local allgather problem can be formalized as:

Given a ShardingSpec(s1) whose mesh_mapping has some Replicated dimension, we:

  1. rewrite the sharding spec as s2 to remove all replicated dimension by chunks;
  2. add minimum allgather to make the tensor's sharding back to s1.

The first sub-problem (finding such an s2) is easy. We just add some chunk values in s1's sharding and add corresponding sharded axes in its mesh_mapping. (Fortunately, in XLA a tensor dimension can have multiple chunks) Although they are very rare, the below cases should be considered:

  1. a replicated mesh dimension is split and mapped to two tensor dimensions (as the case you mentioned). Even worse, there may still some replicas remained. This changes the logical mesh shape;
  2. a tensor dimension has multiple chunk value, some but not all are from s1, while the others are created along with s2.

The two cases make the second sub-problem (inserting allgather to make tensor from s2 to s1) harder. Though it is solved in GShard, currently we manually do it because of some performance reasons (we allocate the whole buffer instead of a shard to reduce a memcpy). Although still WIP, I think we can solve the problem by doing allgather in each rewritten tensor dimension.

If the rewritten logical mesh has N dimensions(N!=2 because of case 1), a device's indices is represented as (x0, x1, ... xN). If a tensor dimension is rewritten and mapped to mesh dimension D=(i0, i1, ... in) (consider case 2), we create an allgather among devices with the same indices in N-D and different indices in D.

ZYHowell commented 2 years ago

https://github.com/alpa-projects/alpa/pull/356 Implements the algo discussed above. For further enhancement, see https://github.com/alpa-projects/alpa/issues/416.