huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.14k stars 107 forks source link

Fix _RowLinearAsyncCommunication #172

Closed C-TC closed 2 months ago

C-TC commented 4 months ago

As discussed in #46, the implementation of row parallel linear backward pass seems wrong, when TP mode is reduce-scatter and async communication is enabled. To be more specific, the gradient of input tensor is produced by

  1. compute local slice of input gradient
  2. allgather slice of input gradient

This will produce the same input gradient for all TP shards, which is not the correct behavior of Row Parallel Linear layer. Also, the check of input gradient for row parallel is also missing in test cases. After adding the test, the result for test_tensor_parallel.py::test_row_linear[True-TensorParallelLinearMode.REDUCE_SCATTER-4-1-1] is as follows: Snipaste_2024-05-16_18-37-08 Only 1/4 of the input gradient values are correct, because that's the locally computed part.

This PR does the following:

  1. Fixed the backward pass of _RowLinearAsyncCommunication, in a similar way to forward pass of _ColumnLinearAsyncCommunication that overlaps communication with part of computation.
  2. Add the missing test that checks the correctness of input gradient in _RowLinearAsyncCommunication

This bug is related to convergence and can be triggered when TP mode is reduce-scatter, and async communication is enabled (which is a common setup for users).

Please fix this, thanks!

3outeille commented 2 months ago

Thank you for your PR, there was indeed a bug in the way we compute the backward pass of RowLinear async (for every rank, each local shard of grad_output was multiplied by their local shard weight, which is not correct)

Just waiting for the CI to be ran and will merge this

However, I would have expected a bigger differences before/after the PR in term of loss (cf https://api.wandb.ai/links/bouteille/2pe2otwy). I trained a 1B llama on fineweb-edu for 500 steps as sanity check