For MultiScaleRetention, the forward function use input X: [B, tgt_len, embed_dim] to construct q: [B, tgt_len, embed_dim], k: [B, tgt_len, embed_dim], v: [B, tgt_len, factor*embed_dim] and g: [B, tgt_len, factor*embed_dim]. For v, if choose to use "recurrent_forward", in line 102 v = v.view(bsz, self.num_heads, self.head_dim, 1), since self.head_dim = self.embed_dim * self.factor // num_heads, it splits factor*embed_dim in to num_heads and head_dim, but where's tgt_len, the output shape [B, num_heads, head_dim, 1] is invalid for the input size [B, tgt_len, factor*embed_dim]
For MultiScaleRetention, the forward function use input X:
[B, tgt_len, embed_dim]
to construct q:[B, tgt_len, embed_dim]
, k:[B, tgt_len, embed_dim]
, v:[B, tgt_len, factor*embed_dim]
and g:[B, tgt_len, factor*embed_dim]
. For v, if choose to use "recurrent_forward", in line 102v = v.view(bsz, self.num_heads, self.head_dim, 1)
, sinceself.head_dim = self.embed_dim * self.factor // num_heads
, it splitsfactor*embed_dim
in tonum_heads
andhead_dim
, but where'stgt_len
, the output shape[B, num_heads, head_dim, 1]
is invalid for the input size[B, tgt_len, factor*embed_dim]