NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.13k stars 2.28k forks source link

[BUG] GroupedMLP calculation problem. #839

Open Baibaifan opened 4 months ago

Baibaifan commented 4 months ago

Describe the bug

image

As shown in the figure above, when calculating w1 in this part, using view will cause element confusion. image As shown in the figure above, it is wrong to use view. Split or chunk should be used for conversion. Because w2 does not have the dimensionality of ffn_hidden_size, there is no problem using view.

test codes

import torch
from torch.nn.parameter import Parameter
import math
torch.manual_seed(123)

def scaled_init_method_normal(sigma=0.02, num_layers=10):
    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
    std = sigma / math.sqrt(2.0 * num_layers)

    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_
weight2 = Parameter(
    torch.empty(
        4,
        10,
        dtype=torch.bfloat16
                )
            )
init_weight2 = scaled_init_method_normal()
init_weight2(weight2)

print(weight2)
weight2_v = weight2.view(5, 4, -1) # weight2 (4,10)->(5, 4, 2)
weight2_v = weight2_v[0,:,:] # weight2_v 1st experts (4, 2)
print(weight2_v)

To Reproduce

Expected behavior Determine if there is a problem with GroupedMLP.

Stack trace/logs

Environment (please complete the following information):

Proposed fix

Additional context

ethanhe42 commented 4 months ago

this is a minor issue. it won't affect the correctness. If you want to load weights, just need to make sure the layout is in the view format

Baibaifan commented 3 months ago

@ethanhe42 Thank you for your answer. I have made some modifications to the loading scene. I understand that using view is to have better continuous memory usage efficiency. If split or chunk is used, it will return [tensor1, tensor2]. Are there any CPU scheduling issues?

github-actions[bot] commented 1 month ago

Marking as stale. No activity in 60 days.