NVIDIA / nccl-tests

NCCL Tests
BSD 3-Clause "New" or "Revised" License
775 stars 226 forks source link

busbw exceeds network bandwidth (2 nodes, 16 gpus, 100Gbps intel NIC, no NVSwitch) - what algorithm is used? #196

Closed ofilip closed 6 months ago

ofilip commented 6 months ago

We are running all-reduce tests with our GPU nodes.

HW setup:

When running all_reduce test with 2 nodes, reported busbw exceeds 100Gbps. After little thinking we concluded that all_reduce is done in some hierarchical fashion (so that inter-node bandwidth is only algbw * (2*(n_nodes-1)/n_nodes) = algbw instead of algbw * (2*(n-1)/n), basically it seems that chunk is first reduced inside nodes to spare network bandwidth). After little more research we learnt about collnet/SHARP algorithm which seems related but implemented only with NVSwitch which is present in our setup.

Any hints what exactly is going on with all_reduce in our setup? How can I find details of algorithm used for all_reduce?

sjeaugey commented 6 months ago

On two nodes, NCCL uses the tree (or nvlstree) algorithm which can indeed go beyond the network bottleneck (provided the intra-node bandwidth can sustain that higher speed). If you want to benchmark your network performance through NCCL on 2 nodes, you may want to force NCCL_ALGO=RING.

ofilip commented 6 months ago

Thanks!

skalyan commented 3 months ago

On two nodes, NCCL uses the tree (or nvlstree) algorithm which can indeed go beyond the network bottleneck (provided the intra-node bandwidth can sustain that higher speed). If you want to benchmark your network performance through NCCL on 2 nodes, you may want to force NCCL_ALGO=RING.

Is this a change in recent NCCL versions? For 2 A100s with Infiniband HDR generation, we didn't have to turn NCCL_ALGO=RING to benchmark nw performance.

sjeaugey commented 3 months ago

It's been there since NCCL 2.4. But depending on the platform and number of GPUs per node, NCCL may or may not select the Tree algorithm on 2 nodes depending on its performance. In particular if you have 8 NICs, Tree will struggle to get better performance than ring without using 32 SMs (which we disallow). But if you have less NICs you'll probably see the "Tree effect".

On H100 we have NVLink SHARP so we can use the NVLSTree algorithm between 2 nodes which will give you higher BW and doesn't require that many SMs since the reductions are offloaded.

skalyan commented 3 months ago

That might explain why we see "higher than" NW bandwidth with 2 H100s with 8 IB nics each coupled with Infiniband NDR generation.

I tried NCCL_ALGO=ring, which brings the bus-bw down to 3xxGBps from 4xx GBps

-Kalyan

On Mon, Apr 8, 2024 at 9:49 AM Sylvain Jeaugey @.***> wrote:

It's been there since NCCL 2.4. But depending on the platform and number of GPUs per node, NCCL may or may not select the Tree algorithm on 2 nodes depending on its performance. In particular if you have 8 NICs, Tree will struggle to get better performance than ring without using 32 SMs (which we disallow). But if you have less NICs you'll probably see the "Tree effect".

On H100 we have NVLink SHARP so we can use the NVLSTree algorithm between 2 nodes which will give you higher BW and doesn't require that many SMs since the reductions are offloaded.

— Reply to this email directly, view it on GitHub https://github.com/NVIDIA/nccl-tests/issues/196#issuecomment-2043220876, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACNCS4677MWAVQXPBJP7Q3Y4LDAJAVCNFSM6AAAAABCZXAE7GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDANBTGIZDAOBXGY . You are receiving this because you commented.Message ID: @.***>