Closed igozali closed 2 months ago
That's indeed a concern if you have collectives back-to-back without checking abort flag right away. Do you see the perf improvement with your patch? If it is indeed helping the case, we can port the patch to the next release.
My main concern with the patch would be that it introduces a GPU initiated PCI-E read on every invocation of that function. We added the poll after 1M spins to avoid performance issues from generating a lot of PCI-E traffic and also the latency costs of PCI-E reads. This issue was also raised back in 2021: https://github.com/NVIDIA/nccl/issues/598
@KaimingOuyang yes, it indeed does help improve the duration it takes to flush the queued NCCL kernels after the abortFlag
is set, and thanks @AddyLaddy for sharing that issue, it's good to know that tests from the linked issue confirm this as well.
We also shared similar concerns about reading the page-locked memory on every NCCL kernel launch. I'll close this PR since it doesn't seem to be the best approach and will ask in the issue what some of the good ideas are 😄
While trying to abort the PyTorch process groups that are involved in a hybrid-sharded data parallel job, we noticed that it takes a long time (>2 mins) to call
ncclCommAbort()
.Upon further debugging, we realized that when
ncclCommAbort()
was called, there are already a bunch of queued NCCL AllGather kernels, but strangely these NCCL kernels that run after abort are taking a long time to exit (much longer than nominal AllGathers). Concretely, nominal AllGathers would take around 1ms to complete, but afterncclCommAbort
is called, the post-abort AllGathers would take 200ms to complete, and there are many of these AllGather kernels stuck in the queue. Because there's a stream synchronize call involved inncclCommAbort()
, these queued kernels then also causencclCommAbort()
to get stuck until the stream can be cleared.We traced it down to this https://github-ap.tesla.com/AI/nccl/blob/master/src/device/prims_ll.h#L55-L62 call which seems to only be examining the
abortFlag
after 1,000,000 spins.This PR proposes to examine the
abortFlag
in the 0thspin
so that there's an opportunity for the NCCL kernels to exit early ifabortFlag
is already set to 1, therefore also speeding upncclCommAbort()
.I was wondering if this approach is reasonable and whether it has major downsides? cc @sjeaugey @KaimingOuyang