microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3.01k stars 202 forks source link

recurrent_forward in MultiScaleRetention #44

Closed Anker-ZX-AI closed 1 year ago

Anker-ZX-AI commented 1 year ago

For MultiScaleRetention, the forward function use input X: [B, tgt_len, embed_dim] to construct q: [B, tgt_len, embed_dim], k: [B, tgt_len, embed_dim], v: [B, tgt_len, factor*embed_dim] and g: [B, tgt_len, factor*embed_dim]. For v, if choose to use "recurrent_forward", in line 102 v = v.view(bsz, self.num_heads, self.head_dim, 1), since self.head_dim = self.embed_dim * self.factor // num_heads, it splits factor*embed_dim in to num_heads and head_dim, but where's tgt_len, the output shape [B, num_heads, head_dim, 1] is invalid for the input size [B, tgt_len, factor*embed_dim]

Anker-ZX-AI commented 1 year ago

Fine, I noticed that the "recurrent_forward" go through the "tgt_len", then v is reshape to [B, num_heads, head_dim, 1] since tgt_len = 1 in this mode