microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
2.98k stars 201 forks source link

about gamma/decay in RetNet #79

Closed rouniuyizu closed 8 months ago

rouniuyizu commented 8 months ago

hello,

Could someone enlighten me with the rational behind this line of code, i.e. why "1 - 2 ** (-5 -" etc.? Thank you,

https://github.com/microsoft/torchscale/blob/881d03079da7b0c52ba0a473c70faac47042efc8/torchscale/architecture/retnet.py#L27

radarFudan commented 8 months ago

I guess this is manually constructing some slow decay initialisations. It's like $h_{k+1} = decay h_k + x_k$. Then small decay guarantees the model store the information in the hidden states. The particular form is more from the smoothness and stability consideration. Similar practice has been adopted in paper such as S4/Safari/S5.

rouniuyizu commented 8 months ago

I guess this is manually constructing some slow decay initialisations. It's like hk+1=decayhk+xk. Then small decay guarantees the model store the information in the hidden states. The particular form is more from the smoothness and stability consideration. Similar practice has been adopted in paper such as S4/Safari/S5.

Thanks and agreed. I got the smooth-decay portion while still curious about why it starts with 1/32... anyway, could be the result of experiments that leads to this particular formula.