Closed mmorinag127 closed 1 year ago
Note that: To simplify the problem, I changed the original code(MSR) below
class MultiScaleRetention(nn.Module):
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def forward(
self,
x,
rel_pos,
chunkwise_recurrent=False,
incremental_state=None
):
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
g = self.g_proj(x)
k *= self.scaling
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output, incremental_state = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
#output = self.group_norm(output)
output = output.reshape(bsz, tgt_len, self.head_dim * self.num_heads)
# output = self.gate_fn(g) * output
# output = self.out_proj(output)
return output, incremental_state
I notice that you comment on the following line:
#output = self.group_norm(output)
It is essential to keep identity among different forms. Besides, to check the consistency, we usually set eps=0
in group_norm. The inconsistency comes from small weight values and eps
, where the small initialization will be fixed after training.
Thanks a lot for your comment. I see. Now I understand what's going on.
Thanks a lot for your comment. I see. Now I understand what's going on.
Hi. Did you check the consistency again? I ran your code and fixed the group_norm thing. There is still a large diff between parallel forward and chunkwise forward.
Hi, I confirmed three outputs are exactly the same after the group_norm.
Thanks a lot for your comment. I see. Now I understand what's going on.
Hi. Did you check the consistency again? I ran your code and fixed the group_norm thing. There is still a large diff between parallel forward and chunkwise forward.
You should use self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=0, elementwise_affine=False)) # Check consistancy
instead of self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
. Then all the output will be True
Same question. When eps=0
, the consistency will be great, but the training stability will be a problem. Luckily, RMSNorm
is much more stable than LayerNorm
, then the eps
could be much smaller.
In the paper, it said GroupNorm replaces LayerNorm, but in the code, it uses RMSNorm. Looks like the goal is to normalize each head separately? Could you clarify on that part?
In the paper, it said GroupNorm replaces LayerNorm, but in the code, it uses RMSNorm. Looks like the goal is to normalize each head separately? Could you clarify on that part?
@Dao007forever We used LN in the experiments of our arxiv paper. In our latest experiments, we found that RMSNorm is more stable, especially to the LN.eps. So we switch the default one to RMSNorm.
Hello authors,
I'm really happy to see this great work! I have one question or request about the consistency of output from each forward mode. I have been comparing three outputs by using below simple code.
And I got below result.
As you might see good agreement between parallel and recurrent results. But the chunkwise output doesn't agree with both parallel and recurrent, after the 2nd chunk. Could you give me a hint to understand this?
(I have already pulled the latest main branch)
Thanks a lot, Masahiro Morinaga