pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.85k stars 22.61k forks source link

"distributed" NCCL tests fail when having more than 3 GPUs #46248

Open Flamefire opened 4 years ago

Flamefire commented 4 years ago

🐛 Bug

When running the distributed NCCL tests on a system with more than 3 GPUs available the tests fail with a connection refused error.

Tracing that down via NCCL_DEBUG=INFO reveals a warning "bootstrap.cc:190 NCCL WARN Bootstrap Root : mismatch in rank count from procs 6 : 3", which hints that 6 GPUs are available and hence expected but only 3 processes are started which is indeed what happens. See https://github.com/pytorch/pytorch/blob/v1.7.0-rc1/test/run_test.py#L181

As a test I increased the WORLD_SIZE to 6 and the tests succeeded.

To Reproduce

Steps to reproduce the behavior:

  1. Use a system with at least 4 GPUs
  2. Run NCCL_DEBUG="INFO" TEMP_DIR=/tmp/tmp74prol6l BACKEND=nccl INIT_METHOD=env:// WORLD_SIZE=3 TEST_REPORT_SOURCE_OVERRIDE=dist-nccl python distributed/test_distributed_spawn.py --verbose TestDistBackendWithFork.test_DistributedDataParallel (extracted from run_test.py)

Relevant part of the output:

taurusml26:67298:67298 [0] NCCL INFO Launch mode Parallel
/tmp/install_pt/lib/python3.7/site-packages/torch/nn/parallel/distributed.py:448: UserWarning: Single-Process Multi-GPU is not the recommended mode for DDP. In this mode, each DDP instance operates on multiple devices and creates multiple module replicas within one process. The overhead of scatter/gather and GIL contention in every forward pass can slow down training. Please consider using one DDP instance per device or per module replica by explicitly setting device_ids or CUDA_VISIBLE_DEVICES. 
  "Single-Process Multi-GPU is not the recommended mode for "
taurusml26:67298:67352 [0] NCCL INFO Channel 00/04 :    0   1
taurusml26:67298:67353 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
taurusml26:67298:67352 [0] NCCL INFO Channel 01/04 :    0   1
taurusml26:67298:67353 [1] NCCL INFO Trees [0] -1/-1/-1->1->0|0->1->-1/-1/-1 [1] 0/-1/-1->1->-1|-1->1->0/-1/-1 [2] -1/-1/-1->1->0|0->1->-1/-1/-1 [3] 0/-1/-1->1->-1|-1->1->0/-1/-1
taurusml26:67298:67352 [0] NCCL INFO Channel 02/04 :    0   1
taurusml26:67298:67353 [1] NCCL INFO Setting affinity for GPU 1 to ffffff,ffffffff,ffffffff
taurusml26:67298:67352 [0] NCCL INFO Channel 03/04 :    0   1
taurusml26:67298:67352 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
taurusml26:67298:67352 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1|-1->0->1/-1/-1 [1] -1/-1/-1->0->1|1->0->-1/-1/-1 [2] 1/-1/-1->0->-1|-1->0->1/-1/-1 [3] -1/-1/-1->0->1|1->0->-1/-1/-1
taurusml26:67298:67352 [0] NCCL INFO Setting affinity for GPU 0 to ffffff,ffffffff,ffffffff
taurusml26:67298:67352 [0] NCCL INFO Channel 00 : 0[404000] -> 1[405000] via P2P/direct pointer
taurusml26:67298:67353 [1] NCCL INFO Channel 00 : 1[405000] -> 0[404000] via P2P/direct pointer
taurusml26:67298:67353 [1] NCCL INFO Channel 01 : 1[405000] -> 0[404000] via P2P/direct pointer
taurusml26:67298:67352 [0] NCCL INFO Channel 01 : 0[404000] -> 1[405000] via P2P/direct pointer
taurusml26:67298:67353 [1] NCCL INFO Channel 02 : 1[405000] -> 0[404000] via P2P/direct pointer
taurusml26:67298:67352 [0] NCCL INFO Channel 02 : 0[404000] -> 1[405000] via P2P/direct pointer
taurusml26:67298:67353 [1] NCCL INFO Channel 03 : 1[405000] -> 0[404000] via P2P/direct pointer
taurusml26:67298:67352 [0] NCCL INFO Channel 03 : 0[404000] -> 1[405000] via P2P/direct pointer
taurusml26:67298:67353 [1] NCCL INFO 4 coll channels, 4 p2p channels, 4 p2p channels per peer
taurusml26:67298:67352 [0] NCCL INFO 4 coll channels, 4 p2p channels, 4 p2p channels per peer
taurusml26:67298:67353 [1] NCCL INFO comm 0x200400000e00 rank 1 nranks 2 cudaDev 1 busId 405000 - Init COMPLETE
taurusml26:67298:67352 [0] NCCL INFO comm 0x200354000e00 rank 0 nranks 2 cudaDev 0 busId 404000 - Init COMPLETE
taurusml26:67298:67298 [0] NCCL INFO Launch mode Group/CGMD

taurusml26:67298:67367 [0] bootstrap.cc:190 NCCL WARN Bootstrap Root : mismatch in rank count from procs 6 : 3
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying
taurusml26:67300:67371 [4] NCCL INFO Call to connect returned Connection refused, retrying

taurusml26:67300:67371 [4] include/socket.h:403 NCCL WARN Connect to 10.1.148.186<38911> failed : Connection refused

Environment

cc @mruberry @VitalyFedyunin @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski

rohan-varma commented 4 years ago

Thanks for reporting! I can confirm I can repro this on a 4-GPU machine, and it indeed looks like we should have WORLD_SIZE = no. of GPUs used in the test. I'm not sure why we hardcode that value to '3' in run_test.py.

jiayisuse commented 4 years ago

We recently checked in a batch of NCCL change. Not sure if this failure is related. @osalpekar @mingzhe09088

Flamefire commented 4 years ago

Imo the usage of only 3 ranks makes sense and pytorch should handle this. Given that multi gpu per process seems to be deprecated in ddp a valid fix would include limiting the test to 1 gpu per process. But that should still work even when more are available

If you can point me at the commits I can try it they help

mingzhe09088 commented 4 years ago

I don't think we are hardcoding the number of GPUs to 2 ro 3. For the failed test, we do check all available GPUs and assign them evenly among processes here https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/distributed/distributed_test.py#L2655.

rohan-varma commented 4 years ago

I don't think we are hardcoding the number of GPUs to 2 ro 3. For the failed test, we do check all available GPUs and assign them evenly among processes here https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/distributed/distributed_test.py#L2655.

We do assign GPUs evenly to processes, but isn't the issue that only 3 processes are started instead of the expected 6? It seems that issue comes from here: https://github.com/pytorch/pytorch/blob/v1.7.0-rc1/test/run_test.py#L181 which sets the WORLD_SIZE (no. of processes that we need to spawn).

Flamefire commented 4 years ago

We do assign GPUs evenly to processes, but isn't the issue that only 3 processes are started instead of the expected 6?

See my comment above (https://github.com/pytorch/pytorch/issues/46248#issuecomment-707946886): Using 3 processes on a system with 6 GPUs should be fully valid. It can be discussed if that should use only 3 GPUs (1/process) or all 6.

It would be good to know where/why NCCL tries to use 6 ranks if only 3 are participating

Flamefire commented 4 years ago

Ok some more (printf) debugging: I recorded the invocation of ProcessGroupNCCL::getNCCLComm and see that for my case the first process uses devices 0 and 1, the second device 2 and the third device 4. So for some reason only the first process "remembers" to use 2 devices and expects the same from the others that somehow forgot it.

This still happens on master, so not fixed and pretty sure a bug in the framework code. Continuing to investigate...

Flamefire commented 4 years ago

Ok, I likely found the cause after a lot of tedious debugging. The TLDR is that the DistributedDataParallel class can't handle the mix of local and distributed GPUs.

So what happens is:

So at some point a wrong communicator seems to be used. Going a bit back I see that prior to creation of the DDP an ALLREDUCE happens among all 3 processes (actually ProcessGroupNCCL::barrier) each using the GPU with the id equal to its rank. As communicators are keyed on the GPU id rank 0 as a communicator for cuda:0, rank 1 for cuda:1 etc. But later on in https://github.com/pytorch/pytorch/blob/1f791c06f0d61f25aa2273ccf15cc65c3073d51c/torch/testing/_internal/distributed/distributed_test.py#L371-L376 the ranks are assigned GPUs with offsets, so rank 0 gets cuda:0 and cuda:1, rank 1 gets cuda:2 and cuda:3 ... When broadcasting a single tensor only 1 GPU, namely the first, is used. So rank 0 finds an already existing communicator but the others don't which explains the blocking. See https://github.com/pytorch/pytorch/blob/1f791c06f0d61f25aa2273ccf15cc65c3073d51c/torch/lib/c10d/ProcessGroupNCCL.cpp#L1070-L1072 which shows that for a single tensor only a single device is taken which results in the key used to the the NCCL communicator.

My recommendation would be to completely disallow using DDP with more than 1 GPU (raise an exception) and change the test(s) to use only 1 GPU per rank.

Edit: The first communicator is created by a call to barrier from the end of init_process_group and as no devices are used yet, the impl "guesses" to use 1 device per rank. So the problem then is that process 0 later reuses that but others cannot. IMO that is a serious logic error caused by keying the communicators by the local GPU used while in fact they also depend on the remote GPUs used. This can hence still happen when even when only 1 GPU is used.
Example: 6 GPUs available. NCCL Process Group with 3 ranks is created -> Usage of GPUs 0-2 for the barrier. User then selects GPU 0 for rank0, GPU 5 and 6 for ranks 1 and 2 --> Same problem as before: Rank 0 will try reusing the communicator instead of creating a new one containing GPUs 0, 5 and 6

Flamefire commented 4 years ago

@rohan-varma @mrshenli How to continue here? I'm specifically concerned there are deeper issues. For example I traced the broadcasts to _sync_params_and_buffers where the first process does a broadcast_coalesced and seemingly finished without any other process being involved which kinda makes sense as it doesn't actually need to do anything. However it uses the wrong communicator (One with GPUs 0-2 with one on each process while the others want to create one with GPU 2 and 4)

mrshenli commented 4 years ago

Hey @Flamefire

Sorry for dropping the ball on this.

6 GPUs available. NCCL Process Group with 3 ranks is created -> Usage of GPUs 0-2 for the barrier. User then selects GPU 0 for rank0, GPU 5 and 6 for ranks 1 and 2 --> Same problem as before: Rank 0 will try reusing the communicator instead of creating a new one containing GPUs 0, 5 and 6

This is great discovery. I wonder for this use case, should the application create two different ProcessGroup objects? One for [0, 1, 2], and one for [0, 5, 6]?

My recommendation would be to completely disallow using DDP with more than 1 GPU (raise an exception) and change the test(s) to use only 1 GPU per rank.

We totally want to go there, but there are legacy applications that's preventing us from retiring that DDP mode. Hopefully, we can retire it in the next few releases.

How to continue here? I'm specifically concerned there are deeper issues. For example I traced the broadcasts to _sync_params_and_buffers where the first process does a broadcast_coalesced and seemingly finished without any other process being involved which kinda makes sense as it doesn't actually need to do anything. However it uses the wrong communicator (One with GPUs 0-2 with one on each process while the others want to create one with GPU 2 and 4)

IIUC, users can create three process groups, and call allreduce on GPU [0, 1, 2] and then [0, 5, 6] respectively, which will hit the error you mentioned above. It's seems hard to tell process 0 to not use the cached communicator and create new one instead for the second call, without hurting the perf. But since the second call will timeout (please correct me if I am wrong), will it be sufficient that we provide better error message (maybe by checking the store)?

cc @osalpekar does this belong to the NCCL reliability project?

osalpekar commented 4 years ago

Thanks for the investigation! We've been working on a number of fixes for NCCL reliability in general so we should take this case into account as well. Let me dig into this and try to fix the incorrect communicator usage

Flamefire commented 4 years ago

I wonder for this use case, should the application create two different ProcessGroup objects? One for [0, 1, 2], and one for [0, 5, 6]?

Not sure how that would help. The underlying problem is an implicit (from users view) barrier on creating the PG which uses a heuristic on which GPUs to use for that. And that heuristic fails as later other GPUs are used by the second+ ranks. So one solution would be to specify the GPUs to use when creating the PG or not caching the comm for the implicit barrier which is easier.

will it be sufficient that we provide better error message (maybe by checking the store)?

As from a users perspective the usage is perfectly valid, having to create different PGs is not great. However storing more info, e.g. the expected number of ranks in the store would allow detection and reporting of the later issue which is better then the entirely cryptic error. It could even suggest to use a different GPU distribution (i.e. round-robin) but I guess not caching (or removing afterwards) of the created communicator based on that heuristic should workaround that problem.