NVIDIA / nccl

Optimized primitives for collective multi-GPU communication
Other
3.23k stars 815 forks source link

Why does redecescatter Ring algo's result is better than NVLS's result in one H100 server? #1047

Open haobaba5353 opened 1 year ago

haobaba5353 commented 1 year ago

Hello, I have a question about redeucescatter and NVLS on one node. I expected the reducescatter with NVLS algo is better than Ring. But the result is not. Please refer to the below test results.

Reducescatter (H100: ring) Rank 7 Group 0 Pid 64380 on node47 device 7 [0xdb] NVIDIA H100 80GB HBM3

size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong.  
(B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       

134217728 4194304 float sum -1 377.2 355.81 311.34 0 371.9 360.91 315.80 0 268435456 8388608 float sum -1 709.7 378.22 330.94 0 708.2 379.05 331.67 0 536870912 16777216 float sum -1 1377.8 389.65 340.94 0 1374.7 390.54 341.73 0 1073741824 33554432 float sum -1 2689.3 399.27 349.36 0 2689.2 399.27 349.36 0 2147483648 67108864 float sum -1 5295.9 405.50 354.81 0 5301.3 405.09 354.45 0 4294967296 134217728 float sum -1 10505 408.86 357.76 0 10452 410.92 359.56 0 8589934592 268435456 float sum -1 20754 413.89 362.16 0 20709 414.80 362.95 0 Out of bounds values : 0 OK Avg bus bandwidth : 344.487

Reducescatter (H100: NVLS) size count type redop root time algbw busbw #wrong time algbw busbw #wrong (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
134217728 4194304 float sum -1 422.0 318.02 278.27 0 422.6 317.61 277.91 0 268435456 8388608 float sum -1 792.8 338.60 296.27 0 793.9 338.13 295.87 0 536870912 16777216 float sum -1 1535.3 349.70 305.98 0 1538.0 349.07 305.44 0 1073741824 33554432 float sum -1 3023.0 355.19 310.80 0 3020.6 355.47 311.04 0 2147483648 67108864 float sum -1 5975.6 359.38 314.46 0 5979.9 359.12 314.23 0 4294967296 134217728 float sum -1 11871 361.81 316.59 0 11872 361.79 316.56 0 8589934592 268435456 float sum -1 23641 363.34 317.92 0 23641 363.35 317.93 0 Out of bounds values : 0 OK Avg bus bandwidth : 305.662

KaimingOuyang commented 1 year ago

Because algorithm-wise, NVLS reduce_scatter transfers one more chunk to itself which causes lower perf compared with Ring.

haobaba5353 commented 1 year ago

Thanks for your reply. From my understanding, NVLS can accelerate the collective communication and the message should be sent to the NVSwitch only once. So why NVLS reduce_scatter transfers one more chunk to itself?

KaimingOuyang commented 1 year ago

because load-reduce involves all ranks on the node, a rank needs to send its own data to switch so the total data each rank sends is nRanks * count; however, Ring only sends (nRanks - 1) * count in total which makes it faster.

yanminjia commented 1 year ago

From the kernel function in reduce_scatter.h, the collective communication ReduceScatter is performed by scatter (prims.scatter) and recv (prims.recv). I guess the data on each GPU/Rank is scattered and reduced in different intra-node NVSwitches and then sent back to the GPU/Ranks. If my logic is correct, to complete ReduceScatter, with respect to the GPU/Rank, NVLS needs only once send and once receive. By the way, could you please let me know how nvls->down and nvls->up is used in the kernnel function? It looks the values of the elements in nvls->down and nvls->up are greater than comm->nranks. Thanks.

because load-reduce involves all ranks on the node, a rank needs to send its own data to switch so the total data each rank sends is nRanks * count; however, Ring only sends (nRanks - 1) * count in total which makes it faster.