microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3.01k stars 202 forks source link

Question about the normalization in attention #81

Closed Cranial-XIX closed 10 months ago

Cranial-XIX commented 10 months ago

Dear authors,

Nice work! I have a few questions regarding the normalization in the implementation of the RetNet and would like to consult your ideas about them:

Here,

qk_mat = qk_mat / qk_mat.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)

why there is a normalization for attention? Is it mentioned in the paper anywhere? Why do you choose abs()?

Here,

value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)

for the decay term, why do you normalize the mask here? Isn't the unnormalized mask correct?

Here,

scale = mask.sum(dim=-1, keepdim=True).sqrt()

what is the role of this scale variable and why is it divided both in the inner_mask and query?

Thank you very much in advance!

sunyt32 commented 10 months ago

It's a little bit complicated. The key idea is to align the normalized computation to be identical to parallel representation. Due to fp16 data range limitation, in chunkwise forward, we have to normalize data range in different part. So, after these modification, the output will be same as totally naive version.

We mention this problem in Retention Score Normalization of our paper.

Cranial-XIX commented 10 months ago

Thanks a lot!