NVIDIA / nccl

Optimized primitives for collective multi-GPU communication
Other
3.22k stars 810 forks source link

add collective kernel & using nccl internal memory buffer. #670

Open XiaoSong9905 opened 2 years ago

XiaoSong9905 commented 2 years ago

Hi NCCL Developers,

I'm trying to extend the NCCL package to support a fancy functionality. Specifically I have 4 GPU connected (half the DGX1-V100) and I'm trying to build a communicator with only GPU 0,1,2. With the GPU 0,1,2, NCCL would choose either tree / ring.

0 -- 3
| \ / |
| / \ |
1 -- 2

My Plan:

I want to add another path 1 -> 3, 3 -> 2, 3 -> 0 ( and vise versa for the complementary tree ) to communicate part of the data (for a communicator with only 0,1,2. So that 3 do not have input / output buffer and is only used to increase bandwidth).

This can be considered as adding a additional tree root at 0, node at 3, leaf at 0,2. The node 3 (different from how NCCL implement tree algorithm) do not have input and output data, and is only using NCCL internal buffer (defined by BUFFER_SIZE ) to take data from 1 and pass it to 2 and 3. (at least this is what I plan to achieve, but I'm not sure. Please see questions below)

0 -- 3
    / |
  /   |
1      2

I'm planning to add support for broadcast and allreduce. Based on what I understand from other github issue, the tree allreduce in nccl is implemented with reduce to root + broadcast, with different channel have different root. Given the above tree I mentioned, leaf 0 and 2 will send data to 3, 3 will store the incoming data using its internal buffer and send the data to 1, 1 will do the reduce computation and broadcast data back.

I'm planning to run two communicator, one communicator with GPU 0,1,2 and run NCCL algorithm. The other communicator with GPU 0,1,2,3 and run the side channel tree algorithm mentioend above. So the code would be like

comm3 = communicator with 3 GPU
comm4 = communicator with 4 GPU
group_start()
for ( gpu 0,1,2 )
   all_reduce( comm3, 70% data, stream0)
for (gpu 0,1,2,3)
   all_reduce_cutomize(comm4, 30% data, stream1) // will call our customized kernel
group_end()

What i know

  1. given the 4 gpu communicator, its tree structure can be set by external file (or computed based on topology, which is how nccl do it in most cases). I have change the logic here for it to only consider this single tree mentioned above.

    What I'm not sure

  2. Will two allreduce call to different communicator & stream block each other ( first allreduce need to finish until second allreduce can run ). I'm planning to send 70% or data using the communicator and algoriothm provided by NCCL and send 30% of data using the side channel tree idea mentioned above with a 4 GPU communicator and customized kernel (explain in above code).

  3. For the tree to run correctly, GPU 3 need to realize it do not have input buffer and output buffer and should use internel buffer to store data. Also GPU3 should only do reduction with input from 0 and 2 (not itself, since it do not have input buffer) and send the reduction result to 1. I'm planning to have a if branch inside the allreduce tree kernel so that 3. 3.

    __device__ allreduce_kernel_customized ()
    if ( node 3 )
        use nccl internal buffer to recv and send data
        reduction result saved to nccl internal buffer
    else
       do operation regularly

    I'm currently stucked at having GPU 3 using nccl internal buffer for send and recv data and doing reduction with result saved to internal buffer. I'm not sure how to implement this.

I think I should change the primative class, but all the API for primative class are using user provided input output buffer, which makes me more unsuered on where to change.

  1. Where is NCCL internal buffer (defined by BUFFER_SIZE ) used during the collective call.

  2. How to change graph.txt file to represent the above topology. I got this graph.txt file from another github issue. I think the patter=2 indicate using tree, but I'm not sure what do speedintra, speedinter, typeintra, typeinter should be set. And also how the gpu list inside <channel></channel> shoule be set?

<graphs version="1">
  <graph id="0" pattern="2" crossnic="0" nchannels="1" speedintra="9" speedinter="9" typeintra="SYS" typeinter="SYS" samechannels="1">
    <channel>
      <net dev="1"/>
      <gpu dev="3"/>
      <gpu dev="0"/>
      <gpu dev="2"/>
    </channel>
  </graph>
</graphs>

I'm using NCCL v2.7.8, which is slightly easier to add code on top.

Thanks you so much for help Xiao

sjeaugey commented 2 years ago

I'm not sure I understand why you want to do that. Do you think you would get better performance? Can you explain why?

What you describe here seems to me to be a LOT of work. Sure the tree algorithm already has most of that functionality, but you'd need to change the XML representation to be able to represent trees and not just chains, and more importantly you'd need to add support for intermediate nodes which are not part of the communicator which is absolutely not supported today, since ranks outside the communicator are not even in the topology graph.

Two concurrent NCCL allreduce can work in parallel if using different CUDA asynchronous streams. They're not guaranteed to run in parallel though, and could even deadlock each other if different GPUs run them in a different order and they block each other. Most of the time it works, but there is no way to guarantee it will work always.

XiaoSong9905 commented 2 years ago

Thank you so much for the response.

  1. we do this to get better bandwidth. We have the situation where 4 GPU resources is avalible but only the first 3 are actually running the model, we want to utilize the idle bandwidth between GPU0,1,2 and GPU 3 to improve the bandwidth for GPU0,1,2 collective operator.

  2. I don't think there will be a issue with rank outside communicator are not inside topo graph. We'll just create two communicator, one with GPU0,1,2 and one with GPU0,1,2,3. First 70% data will be send by communicator with GPU 0,1,2, and the rest 30% data will be send by the communicator with GPU0,1,2,3. The tree idea mentioned above can be done with a modifided tree kernel + communicator with GPU0,1,2,3.

2'. can you explain a little bit more on what "XML representation to be able to represent trees and not just chain" ? Does the NCCL_GRAPH_FILE not support input a tree ?

2.' Given I'll use the 4 GPU communicator, GPU 3 would now be part of the communicator. Is there a way to add a customized kernel that only use NCCL internal buffer?

  1. Can you explian a little bit more on why they're not guarenteey to run in paralle? Is the parallel here only a software abstraction, not actual hardware parallel ?
sjeaugey commented 2 years ago
  1. Can you explain what performance you get today with 3 and 4 GPUs and how much performance you expect to reach with your strategy?
  2. If a GPU is in the communicator it has to be part of all operations including providing values. Sure, anything is possible but that's going to require a lot of work. Providing zeroes instead seems way easier to me. NCCL_GRAPH_FILE does not support inputting a tree. Just a chain. Adding a "customized kernel where one GPU only uses internal buffer" is not going to be easy. You'd still need to call NCCL on that rank to launch the NCCL kernel on that GPU but without a buffer ...
  3. They will run in parallel in general, but CUDA makes no guarantee they won't block each other. In corner cases they could block each others and cause deadlocks.
XiaoSong9905 commented 2 years ago
  1. We're expected around 100% bandwidth increase for allreduce. Given DGX1-V100 and using GPU 0,1,2, the bandwidth will be limited by the single NVLink between 0-1 and 0-2, which give max bandwidth = 22. If we add the tree idea mentioned above ( 1->3, 3->2, 3->0 ), the maxbandwith for the tree is also limited by single NVLink between 1 and 3, which given the max tree bandwidth = 22. Adding this tree would increase max bandwidth from 22 to 44. ( please correct me if anyting is wrong ). I understand this idea can be hard to generalize to hunderdes of GPU cluster used commercially, but it is pretty common case in the research institute lab setting.
0 ---- 3
| \  / |
| /  \ |
1 ---- 2
  1. After go through the source code multiple times, providing a all zero user input buffer on GPU 3 might be a better idea.

2.' Based on the feedback you provided in https://github.com/NVIDIA/nccl/issues/672, it seems that NCCL currently do not support intra-node tree and user can not input a intra-node tree using the graph.txt file. Do you have any suggestions on where I should modify to enable a intra-node tree ( just a single tree, not packing multiple binary tree to achieve max bandwidth ).

sjeaugey commented 2 years ago

Creating a 4 GPUs communicator, you would get 44 GB/s already. You only need to have GPU 3 set its buffer to all zeroes and you have your 3-way allreduce at 44GB/s. That's less complicated than creating the second 4-GPU communicator and should have better performance.