pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.25k stars 165 forks source link

Question about custom cuda operators for tensor parallelism #434

Open vermouth1992 opened 3 months ago

vermouth1992 commented 3 months ago

We are currently trying to apply torchtitan to MoE models. MoE models require using grouped_gemm https://github.com/fanshiqing/grouped_gemm. GroupedGemm ops basically follow the same rule as in ColumnLinear and RowLinear. Is there any way to make custom ops dtensor compatible? Great thanks for help!

yifuwang commented 3 months ago

I think there are two high-level approaches:

I think both approaches are equally viable if the permute input and output have the same sharding. If you are performing expert parallelism across devices, the second approach is probably easier at the moment.

fegin commented 3 months ago

For the second method mentioned by @yifuwang, both ColwiseParallel and RowwiseParalle have the options to convert the output to local tensors. Both also convert the input tensor to DTensor if the inputs are torch.Tensor. So you will be able to combine non-DTensor compatible ops with ColwiseParallel and RowwiseParalle

kwen2501 commented 3 months ago

Curious, regarding this figure in your repo, which value(s) do you want to DTensor'ize? The input? Each of the Experts? Or representing the whole group of experts as a big DTensor? (each expert being a subset of rows or columns of it). Would appreciate your comment.

vermouth1992 commented 3 months ago

We would like both the input and the expert weights to be in DTensor just like the FFN layer in dense models. Specially, the gate can be parallelized via SequenceParallel. Each expert can use the standard Column + Row Parallel. However, this is less efficient in computation as we have to loop over all the experts. By using GroupedGemm, the experts can be concatenated together and we just need to perform a single big GroupedGemm.

The weights of the whole experts.

We can use similar techniques as in ColumnParallel for w1, w3 and RowParallel for w2 (essentially a batched version). But we need GroupedGemm operators to know the TP sharding propagation rule as in torch.matmul.

Note that this is still tensor parallelism. We can further perform expert parallelism by distributing experts onto different Expert Parallel Group. This further shards the weights to

In expert parallelism, we simply distribute the tokens corresponding to that particular experts to that EP group and discards the rest. The inputs can be hard to express in DTensor as it is uneven sharding in EP group. But inside each EP group, we can use DTensor to represent the inputs.

The implementation in Megatron is a good reference with raw torch.Tensor and torch.distributed, (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/experts.py#L25).

XilunWu commented 3 months ago

You can also try DTensor local_map as how we enabled FusedRMSNorm in torchtitan: #404 , which is the second approach in @yifuwang 's comment.