NVIDIA / nccl

Optimized primitives for collective multi-GPU communication
Other
3.22k stars 809 forks source link

[Bug] NCCL all_reduce failed with A800 when NCCL_ALGO uses Ring #1055

Open zigzagcai opened 12 months ago

zigzagcai commented 12 months ago

TL/DR:

Set env variable NCCL_ALGO=Tree if you meet accuracy problems with NCCL in A800 hardware.


Hello

We found a bug about all reduce on A800 GPU when NCCL_ALGO uses Ring, and we can provide minimum reproduce steps. We conducted comparative experiments on the A100 and A800 platforms respectively, and found that the model running on the A100 platform can converge, but the A800 platform cannot converge.

The minimum reproduce steps can be shown below:

codebase: https://github.com/karpathy/nanoGPT
Reproduce steps:
1. Prepare one node with 8x A800 and one node with 8x A100, and set the same seed=1024.
2. torchrun --nnodes=1 --nproc_per_node=8 train.py config/train_shakespeare_char.py

As expected, the loss of A800 should be the same as that of A100. However, when we set the backend to gloo, we can obtain the same loss, but when the backend is set to nccl, the loss output is inconsistent.

Furthermore, we found that if NCCL_ALGO=Tree is set, the loss remains consistent. However, if NCCL_ALGO=Ring or is not set, the loss cannot be kept consistent between A100/A800.

Additionally, when we use 8 nodes with IB connection, with one GPU card per node and set NCCL_ALGO=Ring, the loss can be kept consistent.

Therefore, we guess that there might be a bug in the current all_reduce implementation when NCCL_ALGO=Ring for A800 platform, and this bug might somehow related to the number of NVLink channels.

Note: A800 is a restricted version of A100 GPU. The only diference between A100/A800 is the number of NVLink channels: A100 has 24 channels; A800 has 16 channels.

PhdShi commented 12 months ago

Some reference https://github.com/NVIDIA/nccl/issues/446

sjeaugey commented 12 months ago

There is no bug. Ring, by nature, has a worse precision than say Tree, because of the order in which is performs the sums (see above bug). So if your use case is at the limit of convergence, using ring may cause higher imprecision and that could make a difference for you (when it doesn't for others). Would it work for you to set NCCL_ALGO=^RING in the cases which you know are sensitive to precision?

zigzagcai commented 12 months ago

There is no bug. Ring, by nature, has a worse precision than say Tree, because of the order in which is performs the sums (see above bug). So if your use case is at the limit of convergence, using ring may cause higher imprecision and that could make a difference for you (when it doesn't for others). Would it work for you to set NCCL_ALGO=^RING in the cases which you know are sensitive to precision?

Hi @sjeaugey

Thanks for your kindly reply!

Sure. When we set NCCL_ALGO=Tree, our model can converge normally on A800 platform.

zigzagcai commented 11 months ago

I still have a question. For the A100 platform (which has 24 nvlink channels), why limiting NCCL_MAX_NCHANNELS=16 (which is equivalent to A800 platform which only has 16 nvlink channels) will change the results of ring all_reduce? That is, the imprecision of ring all_reduce is related to the number of channels nvlink used.

From my observation, using A800 or just limiting NCCL_MAX_NCHANNELS=16 for A100 will make LLaMA model (such as 7B parameter size) not convergeable.

sjeaugey commented 11 months ago

To maximize bandwidth, each channel goes through a different path, and therefore performs operations in a different order. Changing the number of channels will change which offset uses which path, and could make things better or worse depending on how lucky (or unlucky) you are.

Tuvie commented 11 months ago

I'm curious about this issue because we encountered the same problem before. Since the order of Ring and Tree are different, I totally understand that the reduce results may differ. However, why do we think the precision of Ring is always worse than Tree? Is there any theory to explain this?

sjeaugey commented 11 months ago

Ring is adding one value to the sum of n values. So if values have the same order of magnitude, the later values will be very small compared to the sum so far, meaning some of the floating point value will be ignored.

With a binary tree, we add two sums of equal weight together, so the precision is better in general.

zigzagcai commented 11 months ago

I'm curious about this issue because we encountered the same problem before. Since the order of Ring and Tree are different, I totally understand that the reduce results may differ. However, why do we think the precision of Ring is always worse than Tree? Is there any theory to explain this?

Hi Tuvie,

I found some paper or slides talking about the error analysis of floating point summation. FYI image

To give a simple example, such as sum f1+f2+f3+f4, and we assume that f1,f2,f3,f4 follows the same distribution or is the same magnitude. We have, Ring all_reduce: the summation order will be f4+(f3+(f1+f2)) Tree all_reduce: Since the reduction operation is conducted on the binary tree, the summation order will be (f1+f2)+(f3+f4) It is clear to see that for Tree all_reduce, the two operands are always the same magnitude. While for Ring all_reduce, the two operands are not balanced.

Tuvie commented 11 months ago

That makes much sense. Thank both of you for your explanation. @sjeaugey @zigzagcai

Tuvie commented 11 months ago

But I still have another question about the tree algorithm in one DGX with 8 GPU. According to my NCCL log, I found the Tree's topology is just like this:

1/-1/-1->0->-1
2/-1/-1->1->0
3/-1/-1->2->1
4/-1/-1->3->2
5/-1/-1->4->3
6/-1/-1->5->4
7/-1/-1->6->5
-1/-1/-1->7->6

It seems the tree is equivalent with ring (in one server), where all the right child of each node is alway -1. So the reduce order will be ((((((f7+f6)+f5)+f4)+f3)+f2)+f1)+f0, which is similar with the ring case: ((((((f0+f1)+f2)+f3)+f4)+f5)+f6)+f7 Is that right?

zigzagcai commented 11 months ago

But I still have another question about the tree algorithm in one DGX with 8 GPU. According to my NCCL log, I found the Tree's topology is just like this:

1/-1/-1->0->-1
2/-1/-1->1->0
3/-1/-1->2->1
4/-1/-1->3->2
5/-1/-1->4->3
6/-1/-1->5->4
7/-1/-1->6->5
-1/-1/-1->7->6

It seems the tree is equivalent with ring (in one server), where all the right child of each node is alway -1. So the reduce order will be ((((((f7+f6)+f5)+f4)+f3)+f2)+f1)+f0, which is similar with the ring case: ((((((f0+f1)+f2)+f3)+f4)+f5)+f6)+f7 Is that right?

I have the same question since I have done the same experiment on one node with 8 GPU and got the same result.

According to my best knowledge of NCCL (please correct me if there is any misunderstanding), the tree structure is only for inter-node. For intra-node, it's actually a chain. So, the chain (or to put it another way, tilted tree) structure of your case is:

0
 \
  1
   \
    2
     \
      3
       \
        4
         \
          5
           \ 
            6
             \
              7

Tree all_reduce is implemented with the computation pattern reduce+broadcast, where the reduce order is 7->6->5->4->3->2->1->0, and the broadcast order is 0->1->2->3->4->5->6->7. While for Ring all_reduce, the computation pattern is reduce_scatter+all_gather, where the structure is a ring and not a chain. Although the summation order is similar for the use case of one node with 8 GPU, it is not equivalent due to different all_reduce computation pattern.

Reference issues: https://github.com/NVIDIA/nccl/issues/672 https://github.com/NVIDIA/nccl/issues/448 https://github.com/NVIDIA/nccl/issues/790 https://github.com/NVIDIA/nccl/issues/256#issuecomment-534192899

Tuvie commented 11 months ago

So it seems the question becomes: why does reducescatter in a ring has lower precision than reduce in a chain?

sjeaugey commented 11 months ago

Your understanding is correct. And reduce should not have a better precision than reducescatter. If you see it work better, it could just be that the reversed order works better out of random chance.

Tuvie commented 11 months ago

I think reducescatter in a ring is also equivalent with a series of reduce in a chain for different chunks. For example, there is 4 ranks which are doing reducescatter for 4 chunks of data (c1, c2, c3,c4). Then it is equivalent with we do reduce for these 4 chunks respectively. The only difference is that is the for different chunks the chain order is different. So it seems reducescatter should not have better precision than reduce. If this is true, tree algorithm in one node should not have better precision than ring algorithm. Is this analysis correct?

zigzagcai commented 11 months ago

I think reducescatter in a ring is also equivalent with a series of reduce in a chain for different chunks. For example, there is 4 ranks which are doing reducescatter for 4 chunks of data (c1, c2, c3,c4). Then it is equivalent with we do reduce for these 4 chunks respectively. The only difference is that is the for different chunks the chain order is different. So it seems reducescatter should not have better precision than reduce. If this is true, tree algorithm in one node should not have better precision than ring algorithm. Is this analysis correct?

I don’t think that experiment on one node with 8 GPU card can come to the conclusion that tree algorithm has better precision than ring algorithm. In fact, in the scenario of one node with 8 GPU, the summation order is similar between the two algorithms.

The minimal reproduction steps that we provided is just to find how can we align the precision between A100/A800, and finally give some insight to solve the imprecision problem in multi-node LLaMA training scenarios on A800 platform.

Much difference of imprecision will only appear in multi-node training scenarios (where tree all_reduce has better precision than ring all_reduce by nature), since the operands in tree all_reduce is more balanced than ring all_reduce.

zigzagcai commented 11 months ago

BTW, I don't know if there is any hard coding related to the number of nvlink channels in nvlink firmware.

I guess (perhaps not right) that the reason why some models cannot converge normally on A800 platform with NCCL_ALGO=Ring might be hard coding in nvlink firmware, which might lead to silent data corruption or silent computation imprecision.

zigzagcai commented 2 months ago

Update:

I have written and analyzed this issue on zhihu. Link: https://zhuanlan.zhihu.com/p/701623664

Hope it helps!

yupatrick22 commented 1 month ago

To maximize bandwidth, each channel goes through a different path, and therefore performs operations in a different order. Changing the number of channels will change which offset uses which path, and could make things better or worse depending on how lucky (or unlucky) you are.

To be more clear, let us consider a allreduce of 4 GPUs, the allreduce result is simply to get (((a0+a1)+a2)+a3).

The offset of a element depends on the chunk index, for example, the elements with chunk_id 3 and rank_id 0 should be at the position a0, while the elements with chunk_id 2 and rank_id 0 are at the position a1.

Now the question becomes how does channel counts change the chun_id of a elements ?

According to the code, https://github.com/NVIDIA/nccl/blob/2ea4ee94bfb04c886c79ccae60ac9961000fdee2/src/include/device.h#L325 , the data are partitioned continuously along the comm channels. But, the elements count of chunk of low channel, mid channel and high channel are DIFFERENT. For example, if we only use 2 channels, there are two chunk count, the one for low channel, the other for high channel. If we use 4 channels, there will be 3 chunk counts, one is for low channel (channel 0), one is for mid channel (channel 1 and 2), and the last for high channel (channel 3). That explains why channel count changes the chunk id of a specfic element.

Is my understanding correct? @sjeaugey

Just a curious question, why chunk count are different with respect to channel index? Normally, all the channel should have the same bandwidth, and they should take the same amount of work.