microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3k stars 201 forks source link

scale.sqrt() in the recurrent_forward function of the multiscale_retention module #47

Closed wangmengzhi closed 1 year ago

wangmengzhi commented 1 year ago

kv = prev_kv * (1 - 1 / scale).view(self.num_heads, 1, 1) + kv / scale.view(self.num_heads, 1, 1) line 108 in the multiscale_retention.py should be kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1) because line 65 of retnet.py has the sqrt function mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()

sunyt32 commented 1 year ago

The division here is to control numerical stability and invariant to LayerNorm. I believe your modification also works.

donglixp commented 1 year ago

@wangmengzhi The normalization doesn't affect the final results, because GN is scale-invariant. But it helps to rescale the numerical values. Both should work fine in practice.

wangmengzhi commented 1 year ago

I wrote a piece of code to check consistency:

from multiscale_retention import MultiScaleRetention
from retnet import RetNetRelPos
import argparse
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--layernorm_eps', type=float, default=1e-5)
parser.add_argument('--decoder_embed_dim', type=int, default=512)
parser.add_argument('--decoder_retention_heads', type=int, default=2)
parser.add_argument('--recurrent_chunk_size', type=int, default=0)
args = parser.parse_args()

torch.manual_seed(0)
B,T,D=32,100,512
pos=RetNetRelPos(args)
pos_para=pos(T)
ret=MultiScaleRetention(args, args.decoder_embed_dim, args.decoder_retention_heads)
ret.eval()
x=torch.rand(B,T,D)
y1=ret(x,pos_para)

y2=[]
state={}
for t in range(T):
    pos_rec=pos(t+1, activate_recurrent=True)
    y=ret(x[:,t:t+1], pos_rec, incremental_state=state)
    y2.append(y.squeeze(1))
y2=torch.stack(y2).permute(1,0,2)
print((y1-y2).abs().sum().item())

The output is 22011.2265625. If layernorm_eps set to 1e-10, output is 254.89244079589844 But if I use kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1) The output is 0.021843837574124336. The difference is much smaller. I don't think GN is enough.

sunyt32 commented 1 year ago

You are right! .sqrt() here can avoid numerical underflow thus improving consistency and performance. We are looking forward to your pull request.

donglixp commented 1 year ago

@wangmengzhi Awesome! Thanks for your insightful suggestion. The improvement looks terrific. We will also do more local tests to compare them.

wangmengzhi commented 1 year ago

@wangmengzhi Awesome! Thanks for your insightful suggestion. The improvement looks terrific. We will also do more local tests to compare them.

I'm glad to be of help, and thank you all for proposing such a great Retnet model.