huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.23k stars 122 forks source link

[Question] Correctness of backward pass of RowLinear #46

Open ufotalent opened 9 months ago

ufotalent commented 9 months ago

Hi nanotron authors. Thanks for the this work that accelerates LLM training of community. I'm recently developing a feature based on nanotron which needs some tweak on the tensor parallel / sequence parallel part. However I found something seems to be problematic:

In this code piece https://github.com/huggingface/nanotron/blob/8c1a49588d0745a6404644a86547c2dd6a63640e/src/nanotron/parallel/tensor_parallel/functional.py#L417. The backward of row linear is doing:

  1. Start an async all-gather to collect unsharded grad_outputs
  2. At the same time, calculate the partial grad_input using grad_output * local weights
  3. After 1 finishes allgather the partial grad_input and use it as a result.

However from my understanding, this causes the grad_input to be identical cross TP shards. The correct way seems to be like this:

  1. Start an async all-gather to collect unsharded grad_outputs
  2. At the same time, calculate the partial grad_input using grad_output * local weights
  3. After 1 finishes, calculate the remaining grad_input using grad_output fetched from other ranks by: grad_output_remaining * local_weights
  4. concatenate remaining_grad_input and grad_input to form a total_grad_input

Some basic ideas behind this is: Row linear in TP is like:

Forward: Rank 0: X1(B, H/2) W1(H/2, O) = Y1(B, O) -> reduce scatter -> R1(B/2, O) Rank 1: X2(B, H/2) W2(H/2, O) = Y2(B, O) -> reduce scatter -> R2(B/2, O) The sharding is like:

X= [X1, X2]  W= [W1]    Y= [Y1+Y2]       R= [R1]                  
                [W2]                        [R2]

So in backwards, theroatically dX1 = dY1 W1_t = concat( [dR1 W1_t], [dR2 W1_t]) dX2 = dY2 W2_t = concat( [dR1 W2_t], [dR2 W2_t])

Current implementation seems problematic. Could you help to verify or point out my mistake? Thanks

C-TC commented 6 months ago

Same opinion here. Current implementation of _RowLinearAsyncCommunication backward pass will produce exactly the same grad of input among all TP shards which is definitely not the expected behavior. https://github.com/huggingface/nanotron/blob/c81551700fe77fc0e9ba75cbba4bd3d8b490557c/src/nanotron/parallel/tensor_parallel/functional.py#L439

Also, somehow the test for tensor parallel did not cover grad of input for this case: https://github.com/huggingface/nanotron/blob/c81551700fe77fc0e9ba75cbba4bd3d8b490557c/tests/test_tensor_parallel.py#L238-L264

@xrsrke @NouamaneTazi

HaroldBenoit commented 6 months ago

Hello, I agree with @C-TC and @ufotalent.

To add a clarifying comment for the maintainers, as explained by @ufotalent , in theory, we expect the gradients to look like this:

dX1 = dY1 W1_t = concat( [dR1 W1_t], [dR2 W1_t]) dX2 = dY2 W2_t = concat( [dR1 W2_t], [dR2 W2_t])

In the current implementation, what we get is:

dX1 = dY1 W1_t = concat( [dR1 W1_t], [dR2 W2_t]) dX2 = dY2 W2_t = concat( [dR1 W1_t], [dR2 W2_t])

because computing the local/sharded gradient grad_tensor = grad_output.matmul(weight) and then gathering dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True) means the local/sharded gradient is only ever multiplied with the corresponding local/sharded weight matrix.