pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.51k stars 22.76k forks source link

[DTensor] use P2P for complicated transformation when redistributing tensor #134646

Open botbw opened 3 months ago

botbw commented 3 months ago

🚀 The feature, motivation and pitch

Motivation

For complicated DTensor redistribution (e.g. [S(0), S(1)] -> [S(1), S(0)]), it's likely that only GPU1 and GPU2 need to communicate (when tensor and mesh are both square) and can be achieved by P2P operations.

The current implementation only applies rule-based redistribution, for the above case, it does the following:

  1. S(1) -> R on mesh dim 1
  2. S(0) -> S(1) on mesh dim 0
  3. R -> S(0) on mesh dim 1

Instead, P2P does:

  1. rank1 sends local_tensor to rank2
  2. rank2 sends local_tensor to rank1

And they can be done concurrently since there is no data dependency. This helps to optimize both communication volume and intermediate tensor buffer size.

Experiment

As discussed in #132751, one major concern is that this method cannot utilize comm collectives and might suffer when communicating between 2 nodes. After conducting simple experiments, I believe it still benefits the communication time considering the reduced communication volume.

Setup

The above case was conducted with a 4*4 mesh (2 nodes, 8 GPUs each, NV8 fully connected and InfiniBand is used), rule refers to main implementation, and p2p refers to the above method. The 2-d square tensor size increases along the x-axis, and execution time is recorded along the y-axis.

Result

[S(0), S(1)] -> [S(1), S(0)]

case1_16_gpus_p2p

[S(0), S(0)] -> [S(1), S(1)]

case2_16_gpus_p2p

Implementation: Doing P2P and rule-based in a hybrid way

benchmark file

I do observe P2P suffers in some cases, especially when the redistribution can be done using a single collective. Thus I implemented a draft redistribute function such that it utilizes collectives whenever possible, and uses P2P to handle the rest.

I roughly tested the implementation with different mesh settings: 2*2, 4*2, 8*2, 4*4 (1 node for the first 2 settings and 2 nodes for the last 2), the microbenchmark was done with (32, 8192, 4096) tensor redistributing with any placement combination from [R, S(0), S(1), S(2)] (4 ** 4 in total). The performance is as follows (green dots indicate that this implementation reduces communication time):

torch Size( 2, 2 )_optimization torch Size( 4, 2 )_optimization torch Size( 2, 8 )_optimization torch Size( 4, 4 )_optimization

And the hybrid implementation doesn't hurt the performance we got from Experiment section:

[S(0), S(1)] -> [S(1), S(0)]

case1_16_gpus_hybrid

[S(0), S(0)] -> [S(1), S(1)]

case2_16_gpus_hybrid

Other

I used additional buffers when using P2P for easier implementation so I didn't test on memory optimization. If you guys think P2P makes sense considering the experiment above, do let me know and I'm happy to work on this.

You can find the above experiment code and draft implementation in this fork, the p2p implementation passed test_redistribute_p2p.py, which is modified from test_redistribute.py

cc: @wanchaol

Alternatives

No response

Additional context

No response

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

awgu commented 3 months ago

cc: @tianyu-l since @wanchaol is out

tianyu-l commented 2 months ago

Thank you for the exploratory effort! I think this is quite valuable.

I consulted teammates on this and here's what I got:

  1. It's good to know that p2p comm is (at least sometimes) more efficient than the default behavior right now, likely due to it's communicating data of less volume as you mentioned. When @wanchaol said in #132751 that "iirc the major bottleneck of send/recv is that its poor scalability would show up when #GPUs goes beyond a certain scale", I'm not sure at what scale we can see the slowdown. 2*8 GPUs sounds a small scale.
  2. AFAIK there are more reasons why we don't have it today. First, I was told that initializing a ProcessGroup is a heavy operation, and each time would cause a overhead scaling with the world size. E.g. when creating a channel between [0, 1] and [1, 0] (for potential p2p comm) in a world of 8 8 GPUs, it would incur an overhead scaling up with 64, and what's more you need to create many of these (maybe `m nmany given anm-by-n` grid). So the init overhead is highly non-negligible in large scale, despite actual transferring could take less time.
  3. Maybe because of lack of use case, the abstraction DeviceMesh currently only supports orthogonal ProcessGroup support, not the diagonal ones needed in p2p comm of [S(0), S(1)] -> [S(1), S(0)]. If we know that there would be a lot of use cases for the latter, we can consider adding support that to DeviceMesh.

Happy to hear what you think.

cc: @fegin @wz337

botbw commented 2 months ago

@tianyu-l

Thanks for sharing the information!

Regarding point 2 and 3, I'm not sure why ProcessGroup initialization affects p2p comm, could please further explain it? I thought that all P2P comms use the default ProcessGroup (i.e. World), and different pairs of peers will only init new communicators, which doesn't depend on DeviceMesh (please correct me if I'm wrong).

Regarding point 1, unfortunately I won't be able to do the experiment at a larger scale. The reason I think P2P comm might help is that the bandwidth between nodes will be the bottleneck of the overall performance as we increase the number of GPUs, and the advantages of less comm volume might be even greater (at least I saw this from 8GPUs to 16GPUs).

wconstab commented 2 months ago

Regarding point 2 and 3, I'm not sure why ProcessGroup initialization affects p2p comm, could please further explain it?

The way NCCL library currently works, it is necessary to create a new nccl communicator object and a new cuda stream to do P2P operations with a given destination rank if you want to overlap the P2P with other communication operations. If we only want to do one P2P operation at a time, we could likely use one nccl communicator for everything. We should more carefully examine the usage pattern here and see if we need the extra communciator in this case to get overlap and performance.

in a world of 8 * 8 GPUs, it would incur an overhead scaling up with 64, a

I think we can put guard-rails up for this feature. A simple heuristic could be to only use P2P mode if we are working on a TP dim of <=8. We still need to evaluate the cost of creating new communicators, but it isn't clear that its definitely too expensive.

Maybe because of lack of use case, the abstraction DeviceMesh currently only supports orthogonal ProcessGroup support, not the diagonal ones needed in p2p comm of [S(0), S(1)] -> [S(1), S(0)]. If we know that there would be a lot of use cases for the latter, we can consider adding support that to DeviceMesh.

We shouldn't need to create new pytorch processgroups or devicemesh dimensions in order to use P2P ops from DTensor. DTensor can issue P2P ops using the 'TP' group. The processgroup will manage nccl communicators it uses under the hood. Today, issuing a send/recv op on a PG will create a new nccl communicator for the first time and then cache and reuse that communicator later.