Open huangjundashuaige opened 3 years ago
Hey @huangjundashuaige, you might need to hack into C++ code and rebuild to get the real stream.
Curious, what is the reason for accessing that stream? And will it help if we provide additional args to allow users configure the CUDA stream?
cc @agolynski
Thanks @mrshenli ! I was trying to overlap the computation and communication at stream level (if it is not work, I would have to go lower to kernel level) (similarly to bucketing gradient in DDP), it can not been done if I do not have a NCCL stream handle. Yes, it will definitely help if pytorch can have additional args toward CUDA stream.
Hey @huangjundashuaige, thanks for sharing the context. In the short term, to unblock you, does it work if you wrap computations into a different stream and use asynchronous collective communications? Sth like:
work = dist.all_reduce(tensor, async_op=True) # comm will wait for pending ops in the current stream
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream()) # let new stream wait for current stream before launching computations
with torch.cuda.stream(s):
# do some computation in parallel
torch.cuda.current_stream().wait_stream(s) # current stream wait for computation
work.wait() # current stream wait for communication.
Hi @mrshenli , thanks for the advise. Yes, I tried this way but the behavior of such arrangment can not be well controled as expected. Since NCCL stream and new computation stream can not have any dependency in that case, can not really predict whether the NCCL op and computation is executed first(it rely GPU's scheduler to decide and it is related to the shared memory and register occupancy). So I guess I will still need to hack into C++ and have the stream handle.
BTW, I thought the async api for all_reduce is async for message queuing. Does work.wait() means it will block until the communication op has been successfully enqueued and not means it will wait for the communication to be finished? Am I misunderstood about this?
Currently, we can easily get the elapsed_time
of computation stream by recording event in a given stream, but for communication it's not easy.
Is there workaround or hacky way to get the NCCL cuda stream in python?
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23