Open ufotalent opened 9 months ago
Same opinion here. Current implementation of _RowLinearAsyncCommunication backward pass will produce exactly the same grad of input among all TP shards which is definitely not the expected behavior. https://github.com/huggingface/nanotron/blob/c81551700fe77fc0e9ba75cbba4bd3d8b490557c/src/nanotron/parallel/tensor_parallel/functional.py#L439
Also, somehow the test for tensor parallel did not cover grad of input for this case: https://github.com/huggingface/nanotron/blob/c81551700fe77fc0e9ba75cbba4bd3d8b490557c/tests/test_tensor_parallel.py#L238-L264
@xrsrke @NouamaneTazi
Hello, I agree with @C-TC and @ufotalent.
To add a clarifying comment for the maintainers, as explained by @ufotalent , in theory, we expect the gradients to look like this:
dX1 = dY1 W1_t = concat( [dR1 W1_t], [dR2 W1_t]) dX2 = dY2 W2_t = concat( [dR1 W2_t], [dR2 W2_t])
In the current implementation, what we get is:
dX1 = dY1 W1_t = concat( [dR1 W1_t], [dR2 W2_t]) dX2 = dY2 W2_t = concat( [dR1 W1_t], [dR2 W2_t])
because computing the local/sharded gradient grad_tensor = grad_output.matmul(weight)
and then gathering dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True)
means the local/sharded gradient is only ever multiplied with the corresponding local/sharded weight matrix.
Hi nanotron authors. Thanks for the this work that accelerates LLM training of community. I'm recently developing a feature based on nanotron which needs some tweak on the tensor parallel / sequence parallel part. However I found something seems to be problematic:
In this code piece https://github.com/huggingface/nanotron/blob/8c1a49588d0745a6404644a86547c2dd6a63640e/src/nanotron/parallel/tensor_parallel/functional.py#L417. The backward of row linear is doing:
However from my understanding, this causes the grad_input to be identical cross TP shards. The correct way seems to be like this:
Some basic ideas behind this is: Row linear in TP is like:
Forward: Rank 0: X1(B, H/2) W1(H/2, O) = Y1(B, O) -> reduce scatter -> R1(B/2, O) Rank 1: X2(B, H/2) W2(H/2, O) = Y2(B, O) -> reduce scatter -> R2(B/2, O) The sharding is like:
So in backwards, theroatically dX1 = dY1 W1_t = concat( [dR1 W1_t], [dR2 W1_t]) dX2 = dY2 W2_t = concat( [dR1 W2_t], [dR2 W2_t])
Current implementation seems problematic. Could you help to verify or point out my mistake? Thanks