pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.72k stars 22.58k forks source link

Why Pytroch.distributed does not expose NCCL cuda stream? #59612

Open huangjundashuaige opened 3 years ago

huangjundashuaige commented 3 years ago

Is there workaround or hacky way to get the NCCL cuda stream in python?

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @agolynski @SciPioneer @H-Huang @mrzzd @cbalioglu @gcramer23

mrshenli commented 3 years ago

Hey @huangjundashuaige, you might need to hack into C++ code and rebuild to get the real stream.

https://github.com/pytorch/pytorch/blob/1faba1e4cce2c6dfb3d89a19e7d709fd9af5fe67/torch/lib/c10d/ProcessGroupNCCL.cpp#L906-L909

Curious, what is the reason for accessing that stream? And will it help if we provide additional args to allow users configure the CUDA stream?

cc @agolynski

huangjundashuaige commented 3 years ago

Thanks @mrshenli ! I was trying to overlap the computation and communication at stream level (if it is not work, I would have to go lower to kernel level) (similarly to bucketing gradient in DDP), it can not been done if I do not have a NCCL stream handle. Yes, it will definitely help if pytorch can have additional args toward CUDA stream.

mrshenli commented 3 years ago

Hey @huangjundashuaige, thanks for sharing the context. In the short term, to unblock you, does it work if you wrap computations into a different stream and use asynchronous collective communications? Sth like:

work = dist.all_reduce(tensor, async_op=True)  # comm will wait for pending ops in the current stream
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())  # let new stream wait for current stream before launching computations
with torch.cuda.stream(s):
    # do some computation in parallel

torch.cuda.current_stream().wait_stream(s)  # current stream wait for computation
work.wait()  # current stream wait for communication. 
huangjundashuaige commented 3 years ago

Hi @mrshenli , thanks for the advise. Yes, I tried this way but the behavior of such arrangment can not be well controled as expected. Since NCCL stream and new computation stream can not have any dependency in that case, can not really predict whether the NCCL op and computation is executed first(it rely GPU's scheduler to decide and it is related to the shared memory and register occupancy). So I guess I will still need to hack into C++ and have the stream handle.

BTW, I thought the async api for all_reduce is async for message queuing. Does work.wait() means it will block until the communication op has been successfully enqueued and not means it will wait for the communication to be finished? Am I misunderstood about this?

cdzhan commented 6 months ago

Currently, we can easily get the elapsed_time of computation stream by recording event in a given stream, but for communication it's not easy.