Open AmedeoSapio opened 4 years ago
Hi, thanks for the interest in the CollNet plugin!
Regarding question 1: which NCCL version and how many GPUs per node are you using? There is a known tuning issue in versions >= 2.7.5 which selects the Ring algorithm over the CollNet algorithm for the 8-GPU-per-node case. As a temporary workaround, you can set NCCL_ALGO=CollNet
to force NCCL to use the CollNet algorithm, together with NCCL_COLLNET_ENABLE=1
you already have. We are trying to fix this issue in coming releases so that you don't need to set NCCL_ALGO
anymore. Sorry about the inconvenience!
Regarding question 2: the "regMr" function is called on NCCL's internal intermediate buffers, which include a buffer space for the Simple protocol (4M) and a buffer space for the Low Latency (LL) protocol (256K). For now, the "regMr" function is not called on the user buffer.
Regarding question 3: indeed that's the requirement for now, and there was some design consideration for requiring so. Please let us know if this is blocking your development.
Thanks for the quick and detailed answers!
I am using NCCL version 2.7.6+cuda10.2 on 2 servers, each with a single P100 GPU. Setting NCCL_ALGO=CollNet
worked, thanks!
However, I am surprised to see that, when running the allreduce test to aggregate 100MB vectors, my allreduce function is being called with count=131072
which is 512KB vectors.
Without the plugin, I can see from the debug log that allreduce is called with count=26214400
which is correctly 100MB. Do you have any insights on why this is happening?
The full allreduce operation is 100MB, but NCCL will still reduce data inside the node going through all the GPUs, then send the reduced data to the network (collnet allreduce), then broadcast it back inside the node. That's what's using chunks of 512K, within FIFOs of 4MB (for the Simple protocol).
The collnet plugin only takes care of the inter-node communication.
Thanks @sjeaugey So the plugin allreduce is always invoked with chunks of 512K elements, no matter the size of the aggregated tensor, right? Is there a way to change the size of the chunks?
You may be able to use a larger chunk size by setting NCCL_BUFFSIZE
to a larger value (for the Simple protocol, current default is 4MB). https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-buffsize
Thanks @kwen2501 How does the chunk size change with the buffer size? I checked and it does not seem to change linearly.
It should be linear, at least when using the Simple protocol.
I tested with different buffer sizes, using the Simple protocol, and these are the values of the count
parameter in the call of allreduce of the plugin:
NCCL_BUFFSIZE | count |
---|---|
8MB | 262144 (1MB) |
16MB | 524288 (2MB) |
32MB | 1048576 (4MB) |
64MB | 1048576 (4MB) |
128MB | 1048576 (4MB) |
800MB | 1638400 (6MB) |
This is strange... the chunk size is usually computed as SLICESTEPS*comm->buffSizes[protocol] / NCCL_STEPS
.
NCCL_STEPS
is defined to 8, SLICESTEPS
is 2 for ring and 1 for tree, and comm->buffSizes[NCCL_PROTO_SIMPLE]
should be what we get from the environment. Not sure why that's not what you see.
Hi, I'm also working on a collnet plugin for an FPGA NIC and hit this issue.
Though I increased the buffer size with NCCL_BUFFSIZE
, the chunkSize didn't increase linearly since chunkSize is trimmed down inside computeColl()
. Adding an env parameter (NCCL_COLLNET_CHUNKSIZE
) like below to set the chunk size forcibly worked as a workaround.
Could you explain why we need to reduce the chunk size here? With the increased chunk size, the performance of the all-reduce did increase for our platform. Is it possible to add a parameter like below or add a mechanism to negotiate the appropriate chunk size with collnet plugin?
diff --git a/src/enqueue.cc b/src/enqueue.cc
index 43d0ba1..2ec8494 100644
--- a/src/enqueue.cc
+++ b/src/enqueue.cc
@@ -1274,6 +1274,8 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) {
return ncclSuccess;
}
+NCCL_PARAM(CollnetChunkSize, "COLLNET_CHUNKSIZE", 0);
+
static ncclResult_t computeColl(struct ncclInfo* info /* input */, int* workFuncIndex, struct ncclWorkElem* work, struct ncclProxyOp* proxyOp /* output */) {
int collNetTypeSupport = 0;
// Check whether algo and proto have been preset (as in aggregation case)
@@ -1334,6 +1336,10 @@ comp_next:
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*64 && chunkSize > 131072) chunkSize /= 2;
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*8 && chunkSize > 65536) chunkSize /= 2;
while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2;
+ int64_t env = ncclParamCollnetChunkSize();
+ if (env != 0) {
+ chunkSize = env;
+ }
work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype);
@sjeaugey Could you take a look at my previous post?
I'm also curious if it is reasonable to extend the tuning plugin API to tell NCCL the desired chunk size from the collnet plugin point of view.
I thought we already had a chunk size param, but I could be wrong. W.r.t. the tuning plugin, it's not how it works today so it would need to be studied. It doesn't seem hard to do from a high level perspective, but you never know until you actually try to write the code.
Thank you very much for your response.
We have the chunk size param but it is only used for P2P operations.
Also, I'd like to know why NCCL is reducing the chunkSize
in computeColl()
in the following part.
https://github.com/NVIDIA/nccl/blob/master/src/enqueue.cc#L1352-L1358
Even if we extend to use the chunk size param for collective operations, this part would trim down the chunkSize eventually.
It doesn't seem hard to do from a high level perspective, but you never know until you actually try to write the code.
I see. Yes, I thought it wouldn't be difficult to do it. But it looks like I need to study the code more.
I am working on a plugin to use a different algorithm for allreduce. While I have been able to understand most of the code required, I still have a few questions:
1) I defined my plugin and run the allreduce test with NCCL_COLLNET_ENABLE=1. My plugin is loaded and some of the functions are called, but the allreduce function is not being used. I found out that the tuning code is prefering other algorithms. The "reduceSupport" function is called at each iteration (and setting supported=1), but my allreduce function is not being called. How is the algorithm/protocol time being calculated? Is there a way to force the use of my function?
2) I am running the allreduce test to aggregate 100MB vectors, but I see that the "regMr" function is called 4 times with sizes: 4M, 256K, 4M, 256K. What is this memory being registered? Is it not supposed to be the vector memory?
3) Even if I am only interested in writing a collectives plugin, the current library interface requires that I write both a send/recv plugin and a collectives plugin. I changed a bit the NCCL library load code, so that if the NCCL_PLUGIN_SYMBOL (from the .so) contains all NULLs, then NCCL instantiate ncclNet_t as usual and then copies it to the NCCL_PLUGIN_SYMBOL in the library. A cleaner approach would require to allow to define NCCL_PLUGIN_SYMBOL and NCCL_COLLNET_PLUGIN_SYMBOL independently, and provide a function to pass the ncclNet_t to the library if it does not define NCCL_PLUGIN_SYMBOL. This solution would require to change the plugin APIs though. Is there any other way for writing a plugin for collectives only?
Thanks a lot for the help!