microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3k stars 201 forks source link

Retnet parameter dimension #57

Closed allanj closed 1 year ago

allanj commented 1 year ago

I wonder why we need twice dimensions for $\mathbf{W}_V$

image
Yuxin-CV commented 1 year ago

Please note that the MSR block includes an additional swish gate compared to the MHSA block in the vanilla Transformer. If we do not double the dimension of v, the MSR block will have 5d^2 parameters, while the MHSA block in the vanilla Transformer only has 4d^2 parameters. Given this scenario, it becomes challenging to determine the width and depth of a retnet for fair comparison with a baseline vanilla Transformer of the same size. Therefore, the authors decide to double the value of W_v and halve the value of d_ffn to maintain the overall parameters of each retnet block equal to 12d^2.

Alternatively, another option is to keep W_v the same as W_k and set d_ffn to 3.5d. However, it is preferable to have a wider swish gate rather than a wider mlp as ffn. For more details, please refer to https://arxiv.org/abs/2202.10447. I believe it is even better to use MSR block only and set d_v = 3.33d.

allanj commented 1 year ago

Cool, pretty much makes sense to me. Thanks for the thorough explanationa.