Open chsigg opened 2 years ago
I believe what you are requesting is that NCCL remember the context which was current during each NCCL call within a ncclGroupStart/End() region, and then when the group executes (ends), batch the work such that each op happens in its respective context.
Note that NCCL has the following limitation, that the context current when creating a communicator must be the current context for all operations involving that communicator. Due to this constraint, each operation posted within the Group() must have its comm's context as current. Thus eliminating the need to "remember" anything more than the comm. So I agree, this could as easy as adding context remembrance into comm creation and context swapping code into ncclGroupEnd().
I believe what you are requesting is that NCCL remember the context which was current during each NCCL call within a ncclGroupStart/End() region, and then when the group executes (ends), batch the work such that each op happens in its respective context.
Yes, that's correct.
I will give this a try and send a PR (or come back with questions ;-)). Thanks!
I want to caution you that there is a low priority that this enhancement would be merged near term. A point that came up during internal discussion is that if NCCL were to start making driver API calls, we would want to go "all in" to avoid unforeseen terrors of mixing the two APIs. This could end up being a lot of work, not just swapping out all memory allocation calls, but also dealing with kernel loading and I don't even know what else. So at best, the "easy-mode" solution of swapping cudaGet/SetDevice()
for cuCtxGet/SetCurrent()
would be a proof of concept that maybe mixing the APIs isn't so bad.
Sure. I can't really say how involved the change is going to be before I actually try. But generally mixing APIs is a supported use case, because CUDA libraries should not force one or the other API on the user.
NCCL currently uses a device's primary context for API calls that are wrapped in
ncclGroupStart/End()
. It would be better if NCCL used the driver context that is current at the time of the API calls instead.Is this something you would consider worth fixing if I provided a PR? I haven't started, but I think the changes wouldn't be very involved. Basically,
cudaGet/SetDevice()
would need to be replaced withcuCtxGet/SetCurrent()
. The change would be transparent to the users of the runtime API.Cheers! Christian