Closed Cranial-XIX closed 10 months ago
It's a little bit complicated. The key idea is to align the normalized computation to be identical to parallel representation. Due to fp16 data range limitation, in chunkwise forward, we have to normalize data range in different part. So, after these modification, the output will be same as totally naive version.
We mention this problem in Retention Score Normalization of our paper.
Thanks a lot!
Dear authors,
Nice work! I have a few questions regarding the normalization in the implementation of the RetNet and would like to consult your ideas about them:
Here,
why there is a normalization for attention? Is it mentioned in the paper anywhere? Why do you choose
abs()
?Here,
for the decay term, why do you normalize the mask here? Isn't the unnormalized mask correct?
Here,
what is the role of this
scale
variable and why is it divided both in theinner_mask
andquery
?Thank you very much in advance!