microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3.01k stars 202 forks source link

Questions about the implementation of deepnorm #16

Closed jiaohuix closed 1 year ago

jiaohuix commented 1 year ago

I have a doubt about deepnorm. In the paper, deepnorm_init function use xaviernormal(x, gain=beta) for "ffn" "v_proj" "out_proj". image However, in the source code of torhscale use xaviernormal(x, gain=1)/ beta:

`

        for name, p in self.named_parameters():
            if (
                "fc1" in name
                or "fc2" in name
                or "out_proj" in name
                or "v_proj" in name
            ):
                p.data.mul_(init_scale)

` Although i know that X ~ N(0,std^2), aX ~ N(0,(a*std)^2), I plot the distribution of both methods using a histogram,the results show some differences between the two methods:

image

`

import torch
import matplotlib.pyplot as plt
from torch.nn.init import xavier_normal_
torch.manual_seed(1)

init_scale = 0.343
linear1 = torch.nn.Linear(4096, 512)  # 1  xavier_norm_(x, gain=beta)
linear2 = torch.nn.Linear(4096, 512) # 2 xavier_norm_(x, gain=1) / beta
xavier_normal_(linear1.weight,gain=init_scale)
xavier_normal_(linear2.weight,gain=1)

linear1_weight = linear1.weight.detach().numpy().reshape((-1, ))
linear2_weight = linear2.weight.detach().numpy().reshape((-1, )) / init_scale
plt.figure(figsize=(10, 6))
temp = plt.hist([linear1_weight, linear2_weight], bins=100, rwidth=0.8, histtype="step")
plt.xlabel("value")
plt.ylabel("count")
plt.legend({"1 xavier_norm_(x, gain=beta)", "2 xavier_norm_(x, gain=1)/beta"})

plt.show()

`

Is my implementation wrong? Which method should I use? I hope someone can enlighten me, thank you!!!

shumingma commented 1 year ago

Hi @MiuGod0126

$\beta$ is a multiplier, so it should be:

linear2_weight = linear2.weight.detach().numpy().reshape((-1, )) * init_scale

instead of

linear2_weight = linear2.weight.detach().numpy().reshape((-1, )) / init_scale
jiaohuix commented 1 year ago

@shumingma Ooooh! Sorry, I was careless to see mul as division, thank you for your correction!!! I understand deeper on deepnorm_init, and the corrected distribution is as follows: image