Closed zhisbug closed 2 years ago
The local allgather problem can be formalized as:
Given a
ShardingSpec
(s1) whosemesh_mapping
has someReplicated
dimension, we:
- rewrite the sharding spec as s2 to remove all replicated dimension by chunks;
- add minimum allgather to make the tensor's sharding back to s1.
The first sub-problem (finding such an s2) is easy. We just add some chunk values in s1's sharding
and add corresponding sharded axes in its mesh_mapping
. (Fortunately, in XLA a tensor dimension can have multiple chunks)
Although they are very rare, the below cases should be considered:
The two cases make the second sub-problem (inserting allgather to make tensor from s2 to s1) harder. Though it is solved in GShard, currently we manually do it because of some performance reasons (we allocate the whole buffer instead of a shard to reduce a memcpy). Although still WIP, I think we can solve the problem by doing allgather in each rewritten tensor dimension.
If the rewritten logical mesh has N dimensions(N!=2 because of case 1), a device's indices is represented as (x0, x1, ... xN). If a tensor dimension is rewritten and mapped to mesh dimension D=(i0, i1, ... in) (consider case 2), we create an allgather among devices with the same indices in N-D and different indices in D.
https://github.com/alpa-projects/alpa/pull/356 Implements the algo discussed above. For further enhancement, see https://github.com/alpa-projects/alpa/issues/416.
In cross mesh resharding, instead of rewriting the sharding spec to let the split happen solely on second dim
1024
, which would be very slow for communication, we should let the split happen first on both the first dim2
and the rest on the second dim.