NVIDIA / nccl

Optimized primitives for collective multi-GPU communication
Other
3.14k stars 791 forks source link

all_reduce_perf hangs on DGX-H100 #1184

Open itzsimpl opened 7 months ago

itzsimpl commented 7 months ago

We have two systems with latest FW, DGXOS and upgrades. We use Slurm 23.11.1, pyxis/enroot 0.16.1/3.4.1. Both systems have nvidia drivers 535.154.05. The two systems are connected over a single NIC with both ports connected over 100GbE (bonded). On DGX-H100 the slot3 NIC, on DGX-A100 the slot4 NIC.

Tested with nccl 2.18.3-1+cuda12.2, 2.19.3-1+cuda12.2, 2.20.3-1+cuda12.3. The commands were run with

NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=bond0 UCX_NET_DEVICES=bond0

With nccl 2.18.3-1+cuda12.2

With nccl 2.19.3-1+cuda12.2

With nccl 2.20.3-1+cuda12.3

a100:1590425:1591728 [0] bootstrap.cc:77 NCCL WARN Message truncated : received 2048 bytes instead of 256 a100:1590425:1591728 [0] NCCL INFO bootstrap.cc:561 -> 3 a100:1590425:1591728 [0] NCCL INFO transport.cc:140 -> 3 a100:1590425:1591728 [0] NCCL INFO init.cc:1220 -> 3 a100:1590425:1591728 [0] NCCL INFO init.cc:1500 -> 3 a100:1590425:1591728 [0] NCCL INFO group.cc:64 -> 3 [Async thread] a100:1590425:1591733 [5] NCCL INFO bootstrap.cc:557 -> 3 a100:1590425:1591731 [3] NCCL INFO bootstrap.cc:557 -> 3 a100:1590425:1591735 [7] NCCL INFO bootstrap.cc:557 -> 3 a100:1590425:1591735 [7] NCCL INFO transport.cc:246 -> 3 a100:1590425:1591729 [1] NCCL INFO bootstrap.cc:557 -> 3 a100:1590425:1591735 [7] NCCL INFO init.cc:1220 -> 3 a100:1590425:1591729 [1] NCCL INFO transport.cc:139 -> 3 a100:1590425:1591734 [6] NCCL INFO bootstrap.cc:557 -> 3 a100:1590425:1591735 [7] NCCL INFO init.cc:1500 -> 3 a100:1590425:1591730 [2] NCCL INFO bootstrap.cc:557 -> 3 a100:1590425:1591729 [1] NCCL INFO init.cc:1220 -> 3 a100:1590425:1591730 [2] NCCL INFO transport.cc:139 -> 3 a100:1590425:1591731 [3] NCCL INFO transport.cc:244 -> 3 a100:1590425:1591734 [6] NCCL INFO transport.cc:246 -> 3 a100:1590425:1591735 [7] NCCL INFO group.cc:64 -> 3 [Async thread] a100:1590425:1591733 [5] NCCL INFO transport.cc:139 -> 3 a100:1590425:1591731 [3] NCCL INFO init.cc:1220 -> 3 a100:1590425:1591732 [4] NCCL INFO bootstrap.cc:557 -> 3 a100:1590425:1591734 [6] NCCL INFO init.cc:1220 -> 3 a100:1590425:1591731 [3] NCCL INFO init.cc:1500 -> 3 a100:1590425:1591734 [6] NCCL INFO init.cc:1500 -> 3 a100:1590425:1591731 [3] NCCL INFO group.cc:64 -> 3 [Async thread] a100:1590425:1591734 [6] NCCL INFO group.cc:64 -> 3 [Async thread] a100:1590425:1591733 [5] NCCL INFO init.cc:1230 -> 3 a100:1590425:1591729 [1] NCCL INFO init.cc:1500 -> 3 a100:1590425:1591729 [1] NCCL INFO group.cc:64 -> 3 [Async thread] a100:1590425:1591733 [5] NCCL INFO init.cc:1500 -> 3 a100:1590425:1591730 [2] NCCL INFO init.cc:1220 -> 3 a100:1590425:1591733 [5] NCCL INFO group.cc:64 -> 3 [Async thread] a100:1590425:1591732 [4] NCCL INFO transport.cc:140 -> 3 a100:1590425:1591730 [2] NCCL INFO init.cc:1500 -> 3 a100:1590425:1591730 [2] NCCL INFO group.cc:64 -> 3 [Async thread] a100:1590425:1591732 [4] NCCL INFO init.cc:1230 -> 3 a100:1590425:1591732 [4] NCCL INFO init.cc:1500 -> 3 a100:1590425:1591732 [4] NCCL INFO group.cc:64 -> 3 [Async thread] a100:1590425:1590425 [7] NCCL INFO group.cc:418 -> 3 a100:1590425:1590425 [7] NCCL INFO group.cc:95 -> 3

a100: Test NCCL failure common.cu:961 'internal error - please report this issue to the NCCL developers / ' .. a100 pid 1590425: Test failure common.cu:844 h100:3546566:3547692 [0] NCCL INFO Channel 03/0 : 8[0] -> 15[7] via P2P/direct pointer ...


The addition of `NCCL_NVLS_ENABLE=0` resolves the crash, but the behaviour is then the same as with older versions of nccl. Only by using `NCCL_NVLS_ENABLE=0 NCCL_PROTO=SIMPLE NCCL_ALGO=Tree` works. In fact with these even older versions of nccl work over 2 nodes with 8gpu each. 

We know this is not a traditional setup, but we do not see any such issues when running the DGX-A100 + other nodes based on A100 PICe or 4x A100 SXM4. The issue is present only when the DGX-H100 is involved. 

Any explanation will be appreciated.
sjeaugey commented 7 months ago

Thanks for the report. In theory, NCCL should detect that some nodes have features that others don't have (e.g. NVLS) and all ranks should do the same thing. Here it looks like we're not doing that properly and we have trouble synchronizing the algorithm config between nodes, causing various issues.

Would you be able to share a full log (as an attached txt file), ideally with a different file for each rank, so that we can see what each rank is computing in terms of rings/trees/etc. and what they try to connect. Please run with NCCL_DEBUG_SUBSYS=INIT,ENV,GRAPH NCCL_DEBUG=INFO NCCL_DEBUG_FILE=out.%h.%p.

itzsimpl commented 7 months ago

@sjeaugey thank you for the reply. I've attached the full logs, let me know if you need anything else.

out.h100.558869.txt out.a100.2506996.txt

sjeaugey commented 7 months ago

Ok, understood. It seems we have a couple of places in ncclTopoPostSet where we modify the number of channels, based on factors that may not be the same between two nodes, causing trouble. [One such occurence is here: https://github.com/NVIDIA/nccl/blob/master/src/graph/connect.cc#L456 but I'm not sure that impacts us because bwIntra for Ring is 24 (from the log) so we should be fine. EDIT: this one is not a problem because comm->minCompCap should be consistent between nodes.] Another occurence is this one: https://github.com/NVIDIA/nccl/blob/master/src/graph/connect.cc#L472-L475 which is effectively duplicating the rings on the H100 node to use 16 channels instead of 2. We could solve that in multiple ways, one being to set comm->nvlsChannels = 0 if comm->nvlsSupport is 0. Can you try to apply this patch and see what happens:

diff --git a/src/init.cc b/src/init.cc
index b61068a3b..042fd2a91 100644
--- a/src/init.cc
+++ b/src/init.cc
@@ -1162,7 +1162,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
       graphs[a]->typeInter = std::max(allGather3Data[i].graphInfo[a].typeInter, graphs[a]->typeInter);
     }
     if (graphs[NCCL_ALGO_COLLNET_CHAIN]->nChannels == 0) comm->collNetSupport = 0;
-    if (graphs[NCCL_ALGO_NVLS]->nChannels == 0) comm->nvlsSupport = 0;
+    if (graphs[NCCL_ALGO_NVLS]->nChannels == 0) comm->nvlsSupport = comm->nvlsChannels = 0;
   }

   comm->nChannels = treeGraph.nChannels = ringGraph.nChannels = std::min(treeGraph.nChannels, ringGraph.nChannels);

Thanks!

itzsimpl commented 7 months ago

The patch does help with the crash, but now the test gets stuck before size 8 (see files that end with _patch); I left it running for approx 1h, all I noticed was all the GPUs were stuck at 100% utilisation).

out.h100.1234336_patch.txt out.a100.2957751_patch.txt

If I add just NCCL_PROTO=Simple it gets stuck as well (see files that end with _proto=simple).

out.h100.1257816_proto=simple.txt out.a100.2977687_proto=simple.txt

If I add just NCCL_ALGO=Tree the test crashes at 1M (see files that end with _algo=tree)

      524288        131072     float     sum      -1   1102.0    0.48    0.89      0   1097.4    0.48    0.90      0
h100: Test NCCL failure common.cu:303 'invalid usage (run with NCCL_DEBUG=WARN for details) / '
 .. h100 pid 1245713: Test failure common.cu:401
 .. h100 pid 1245713: Test failure common.cu:414
 .. h100 pid 1245713: Test failure common.cu:603
 .. h100 pid 1245713: Test failure all_reduce.cu:90
 .. h100 pid 1245713: Test failure common.cu:615
 .. h100 pid 1245713: Test failure common.cu:1019
 .. h100 pid 1245713: Test failure common.cu:844
     1048576        262144     float     sum      -1a100: Test NCCL failure common.cu:303 'remote process exited or there was a network error / '
 .. a100 pid 2967540: Test failure common.cu:401
 .. a100 pid 2967540: Test failure common.cu:414
 .. a100 pid 2967540: Test failure common.cu:603
 .. a100 pid 2967540: Test failure all_reduce.cu:90
 .. a100 pid 2967540: Test failure common.cu:615
 .. a100 pid 2967540: Test failure common.cu:1019
 .. a100 pid 2967540: Test failure common.cu:844
srun: error: h100: task 1: Exited with exit code 3

out.h100.1245713_algo=tree.txt out.a100.2967540_algo=tree.txt

If I add both NCCL_ALGO=Tree NCCL_PROTO=Simple the test succeeds (see files that end with _algo=tree+proto=simple)

out.h100.1253759_algo=tree+proto=simple.txt out.a100.2973888_algo=tree+proto=simple.txt

sjeaugey commented 7 months ago

Thanks for the feedback. It would look like the tuning between the two nodes is different, so we don't use the same algorithms for a given size, causing hangs and errors. Forcing a single Proto+Algo solves the problem as we're sure all ranks will be consistent in their decisions.

Can you run with NCCL_DEBUG=INFO NCCL_DEBUG_SUBSYS=TUNING? That would print the tuning for each rank and we can look for differences.

itzsimpl commented 7 months ago

I'm attaching the logs when run with just NCCL_DEBUG_SUBSYS=TUNING (files that end with _tuning) and with NCCL_DEBUG_SUBSYS=TUNING NCCL_ALGO=Tree (files that end with _tuning+algo=tree).

out.a100.3594394_tuning.txt out.h100.2102563_tuning.txt out.a100.3603123_tuning+algo=tree.txt out.h100.2113054_tuning+algo=tree.txt

sjeaugey commented 7 months ago

Ah, my apologies. Only rank 0 prints the tuning, due to this line: https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc#L341

Can you remove replace comm->rank == 0 by 1 and dump again the tuning for all ranks? Sorry about that.

itzsimpl commented 7 months ago

I've changed the line to comm->rank >= 0 to dump on all ranks. I'm attaching both logs.

out.h100.2172385_tuning.txt out.a100.3649053_tuning.txt out.a100.3654218_tuning+algo=tree.txt out.h100.2178475_tuning+algo=tree.txt

itzsimpl commented 6 months ago

@sjeaugey did you have a chance to look into this?

sjeaugey commented 6 months ago

Sorry I missed it. So indeed the tree bandwidth is different and the ring latency is different.

For the ring latency, it could be because the CPU is different on the different nodes, causing NCCL to compute a different net overhead. As a workaround, setting NCCL_NET_OVERHEAD=2000 should harmonize the latencies for both nodes.

For the Tree bandwidth, I still haven't found where this is coming from. I have found another place where we apply a different tuning for AMD CPU but it doesn't explain the Tree/Simple BW difference. We'd need to track how the Tree Bw is computed in src/graph/tuning.cc and see where things differ.

itzsimpl commented 6 months ago

I have the systems churning atm, will run with NCCl_NET_OVERHEAD=2000 as soon as they free up. Let me know what more can I do to help (eg. track src/graph/tuning.cc).

itzsimpl commented 6 months ago

@sjeaugey I've added NCCL_NET_OVERHEAD=2000 but there is no difference. As soon as I run with 2 or more GPUs the test gets stuck before size 8.

I've modified line https://github.com/NVIDIA/nccl/blob/b6475625fbcaa2c3c0e50eed2fa1255d7514d4a2/src/enqueue.cc#L1498 to print the info on all ranks and collected the logs when ran with 2 GPUs. The A100 selects Algorithm 1, protocol 2, while the H100 selects Algorithm 0, protocol 2.

out.h100.1812047_tuning+2000_2.txt out.a100.619343_tuning+2000_2.txt

sjeaugey commented 6 months ago

Yes that's expected. As I mentioned the NET_OVERHEAD should fix the difference in ring latency (and the logs you attached confirm it does!), but I still don't know why we get different Tree bandwidth:

h100 [0] NCCL INFO     AllReduce |    18.0/   5.4 |    33.5/   0.0 |   112.0/  20.7 |    21.4/   4.0 |    34.0/   0.0 |    76.4/  16.0 |     0.8/   0.0 |     0.8/   0.0 |    39.2/   0.0 |
a100 [0] NCCL INFO     AllReduce |    18.0/   3.1 |    33.5/   0.0 |   112.0/  11.8 |    21.4/   4.0 |    34.0/   0.0 |    76.4/  16.0 |     0.8/   0.0 |     0.8/   0.0 |    39.2/   0.0 |

(3.1 vs 5.4 with Tree/LL or 11.8 vs 20.7 with Tree/Simple)

sjeaugey commented 6 months ago

Could you apply this patch and see what we get?

diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc
index 7ca5922..ff5812e 100644
--- a/src/graph/tuning.cc
+++ b/src/graph/tuning.cc
@@ -216,6 +216,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
           if (a == NCCL_ALGO_RING) ratio *= (1.0 * nRanks) / nsteps;
           else if (a == NCCL_ALGO_NVLS || a == NCCL_ALGO_NVLS_TREE) ratio *= 5.0/6.0;
           else ratio *= .5;
+printf("Coll %d Algo %d proto %d Channels %d Bw intra %g inter %g busBw %g ratio %g index1 %d index2 %d compCapIndex %d perChMaxTreeBw %g\n", coll, a, p, graphs[a]->nChannels, graphs[a]->bwIntra, graphs[a]->bwInter, busBw, ratio, index1, index2, compCapIndex, perChMaxTreeBw);
           busBw *= ratio;
         }
         comm->bandwidths[coll][a][p] = busBw;
itzsimpl commented 6 months ago

@sjeaugey find the logs attached out.h100.3901699_tuning+2000_2_p4.txt out.a100.849595_tuning+2000_2_p4.txt

sjeaugey commented 6 months ago

Ok. Weird. Looks like minCompCap is not actually the min of the CUDA compute capabilities of all ranks, resulting in compCapIndex being computed differently between the different nodes.

However, minCompCap/maxCompCap are supposed to be a min/max of all compCap of all ranks and should be computed early enough that by the time we reach ncclTopoTuneModel, they are set to correct values.

Could you add printfs to the min/maxCompCap computation code above to see if anything it not working as expected?

Edit: nevermind, there is an obvious bug, as we're checking comm->peerInfo[rank].compCap instead of comm->peerInfo[i].compCap... :facepalm:

itzsimpl commented 6 months ago

@sjeaugey I can confirm this resolves the hungup, but only with NCCl_NET_OVERHEAD=2000 set. When not set, the test starts, but hangs up at the 4096/8192 byte mark. I'm attaching the tuning logs (see *_tuning_2000_p5.txt for the NCCL_NET_OVERHEAD=2000 case and *_tuning_p5.txt without).

out.h100.235978_tuning_2000_2_p5.txt out.a100.24035_tuning+2000_2_p5.txt out.h100.240808_tuning_2_p5.txt out.a100.28181_tuning_2_p5.txt

sjeaugey commented 6 months ago

Yes, that's expected. Without NET_OVERHEAD set, the different CPU will also cause a tuning mismatch.

The minCompCap/maxCompCap bug is easy to fix. The difference in CPU is harder (we need to share that between all ranks and harmonize).

itzsimpl commented 6 months ago

I see. From what I can understand from the code, there are only two locations where the CPU vendor is taken into account; to compute llMaxBw and netOverhead. This happens in tunning.c at lines 145 and 111-118. In both cases the update is based on the local CPU vendor.

I have a very limited understanding of the logic, but in a multi-node setting, wouldn't it be more reasonable to consider the "weakest link", i.e. check if anywhere within the ranks an AMD CPU is present? Currently it is the only one that is treated as "special"; it is also the only one that will set netOverhead to 2, all others to 1, hence leading to misconfiguration in a mixed cluster. Is the information of the peer's CPU info available anywhere, or is only local information that is at hand?

sjeaugey commented 6 months ago

Yes, indeed, we should share the CPU info between all ranks and take the slower option of all cases. But if some future code to be reversed (AMD is faster than Intel, or other), then taking the "min" isn't that easy.

One option would be to compute the "CPU global vendor", and have a new value "MIXED", which will be handled everywhere as the worst case.

itzsimpl commented 6 months ago

What defines the "CPU global vendor", i.e. when is the system considered to be "MIXED", is one CPU of a "specific" vendor enough?

Yet another possibility could be to go the same route as with compCap, instead of passing the "CPU vendor" or alongside passing that info, you could compute the netOverhead locally and pass that value to other nodes. Going the per variable route has the advantage that the nodes then can settle on the min/max/avg depending on per variable basis.