NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.13k stars 2.28k forks source link

[QUESTION] glu activation with tensor parallel in GroupedMLP #985

Closed Teng-xu closed 3 weeks ago

Teng-xu commented 1 month ago

Description:

When training with GroupedMLP and Tensor Parallel (TP) enabled, and gated_linear_unit is activated, the activation function is applied to fc1_output. Assuming a TP degree of 2, this intermediate output only contains half of the information as it holds the tensor values on one TP rank. Applying the GLU activation function on this output leads to a loss of information because only half of the tensor values are involved in the activation function.

Specifically, in the GLU function (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.7.0/megatron/core/transformer/moe/experts.py#L48): self.config.activation_func(x[0]) * x[1]

Both self.config.activation_func(x[0]) and x[1] contain half of the output tensor due to TP being enabled, resulting in an output that does not match the results from training without TP.

Steps to Reproduce:

  1. Enable gated_linear_unit in the GroupedMLP configuration.
  2. Train the model with Tensor Parallel (TP) enabled.
  3. Compare the intermediate outputs of the GLU activation function with and without TP enabled. (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/experts.py#L176)

Expected Behavior:

The activation function should correctly handle the tensor values across all TP ranks to prevent any loss of information, ensuring consistency with results obtained without TP.

Actual Behavior:

The GLU activation function is applied to tensor values that only represent half of the full tensor due to TP, leading to inconsistent results.

ethanhe42 commented 1 month ago

Both self.config.activation_func(x[0]) and x[1] contain half of the output tensor due to TP being enabled

this is the expected behavior not an error. TP will add the fc2_output from different tp ranks to get the final result https://github.com/NVIDIA/Megatron-LM/blob/203b463689bd322eb915afb3e4d1076bcc4783ba/megatron/core/transformer/moe/token_dispatcher.py#L227

Teng-xu commented 1 month ago

I understand that TP adds the fc2_output from different TP ranks to get the final result. However, my concern is with the correctness of the intermediate output from the activation layer. If this intermediate output is incorrect, then the reduced final result will also be incorrect.

The following is the activation computation func (https://github.com/NVIDIA/Megatron-LM/blob/core_v0.7.0/megatron/core/transformer/moe/experts.py#L46):

x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]

which yields different results when TP is applied versus when it's not, even after reduction.

Consider the following example: Without TP (mat1):

mat1 = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]])
x1 = torch.chunk(mat1, 2, dim=-1)
r1 = x1[0] * x1[1]

Result: 
r1 = tensor([[  0,   9,  20,  33,  48,  65,  84, 105],
        [384, 425, 468, 513, 560, 609, 660, 713]])

With TP degree == 2 (mat2 and mat3 being inputs on TP rank 0 and 1):

mat2 = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7],
        [16, 17, 18, 19, 20, 21, 22, 23]])

mat3 = torch.tensor([[8, 9, 10, 11, 12, 13, 14, 15],
        [24, 25, 26, 27, 28, 29, 30, 31]])

x2 = torch.chunk(mat2, 2, dim=-1)
r2 = x2[0] * x2[1]

x3 = torch.chunk(mat3, 2, dim=-1)
r3 = x3[0] * x3[1]

Results:
r2 =  tensor([[  0,   5,  12,  21],
        [320, 357, 396, 437]]) 
r3 = tensor([[ 96, 117, 140, 165],
        [672, 725, 780, 837]])

The reduced results from r2 and r3 do not match r1 because, when TP degree > 1, each TP rank is multiplying using incorrect tensor values compared to the non-TP case.

Teng-xu commented 1 month ago

The issue with the GLU activation in Tensor Parallel is causing correctness problems that are blocking training. An update on this or any suggestions for moving forward would be greatly appreciated.

ethanhe42 commented 1 month ago

i see your point. in your example, the results are different because of different tensor layout. The order of TP, glu sharding is shuffled. In practice, this shouldn't affect training because the linear layers are learned. this might affect training where the parallelism strategy or model architecture is changed mid training.

Teng-xu commented 1 month ago

Thanks for your response. My primary concern is with fine-tuning. If we pretrain using TP 2 and then load the checkpoint to fine-tune with TP1 or any configs other than TP2, then we would see loss issues. Do you have any suggestions for addressing this, or are there any plans on your end to provide fixes for this?

ethanhe42 commented 1 month ago

a workaround is to manually convert tensor layout when you switch to finetuning.