THUDM / ChatGLM3

ChatGLM3 series: Open Bilingual Chat LLMs | 开源双语对话语言模型
Apache License 2.0
13.19k stars 1.52k forks source link

RMSNorm的不同实现方式 #1240

Open trundleyrg opened 2 months ago

trundleyrg commented 2 months ago

System Info / 系統信息

torch版本:2.12

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

Reproduction / 复现过程

class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)

chatglm的RMSNorm实现中,weight用的是torch.empty的随机初始化。而llama的RMSNorm实现中,用的是torch.ones全一初始化。请问,chatglm用torch.empty是有随机初始化缩放系数的考虑嘛?

Expected behavior / 期待表现

能介绍一下两种实现方式的优劣吗?