Closed wangmengzhi closed 1 year ago
The division here is to control numerical stability and invariant to LayerNorm. I believe your modification also works.
@wangmengzhi The normalization doesn't affect the final results, because GN is scale-invariant. But it helps to rescale the numerical values. Both should work fine in practice.
I wrote a piece of code to check consistency:
from multiscale_retention import MultiScaleRetention
from retnet import RetNetRelPos
import argparse
import torch
parser = argparse.ArgumentParser()
parser.add_argument('--layernorm_eps', type=float, default=1e-5)
parser.add_argument('--decoder_embed_dim', type=int, default=512)
parser.add_argument('--decoder_retention_heads', type=int, default=2)
parser.add_argument('--recurrent_chunk_size', type=int, default=0)
args = parser.parse_args()
torch.manual_seed(0)
B,T,D=32,100,512
pos=RetNetRelPos(args)
pos_para=pos(T)
ret=MultiScaleRetention(args, args.decoder_embed_dim, args.decoder_retention_heads)
ret.eval()
x=torch.rand(B,T,D)
y1=ret(x,pos_para)
y2=[]
state={}
for t in range(T):
pos_rec=pos(t+1, activate_recurrent=True)
y=ret(x[:,t:t+1], pos_rec, incremental_state=state)
y2.append(y.squeeze(1))
y2=torch.stack(y2).permute(1,0,2)
print((y1-y2).abs().sum().item())
The output is 22011.2265625. If layernorm_eps set to 1e-10, output is 254.89244079589844
But if I use
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
The output is 0.021843837574124336. The difference is much smaller. I don't think GN is enough.
You are right! .sqrt()
here can avoid numerical underflow thus improving consistency and performance. We are looking forward to your pull request.
@wangmengzhi Awesome! Thanks for your insightful suggestion. The improvement looks terrific. We will also do more local tests to compare them.
@wangmengzhi Awesome! Thanks for your insightful suggestion. The improvement looks terrific. We will also do more local tests to compare them.
I'm glad to be of help, and thank you all for proposing such a great Retnet model.
kv = prev_kv * (1 - 1 / scale).view(self.num_heads, 1, 1) + kv / scale.view(self.num_heads, 1, 1)
line 108 in the multiscale_retention.py should bekv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
because line 65 of retnet.py has the sqrt functionmask = mask / mask.sum(dim=-1, keepdim=True).sqrt()