Closed ictzyqq closed 4 months ago
Thanks for you report! Fixed at https://github.com/Beomi/InfiniTransformer/commit/a82dbe9aec1b5426a4f6ef7c5093380f5850ef7a
Thanks for you report! Fixed at a82dbe9
@Beomi Sorry, I actually didn't understand the changes in this commit. According to the paper, the denominator for A_mem should include σ(Q), so I think norm_term_broadcastable requires a matmul operation instead of an expand operation.
What's more, my understanding of the paper is that self.memory shoule be [batch_size, num_head, head_dim, value_dim], and self.norm_term shoule be [batch_size, num_head, head_dim]. Is this right?
@ictzyqq Sorry, as you said, it was wrong commit, I updated the code for it -- check this: c2f0746
And for the memory/norm size, I guess -- the start memory is torch.matmul(key_states.transpose(-2, -1), value_states)
,
and key_states.transpose(-2, -1) has [batch_size, seq_len, head_dim, num_heads] and value_states has [batch_size, seq_len, num_heads, value_dim] so self.memory should have [batch_size, seq_len, head_dim, value_dim], right.
*Note that gemma uses GQA for K/V so the num_heads will be num_key_value_heads, which is 1 for Gemma-2B.
@Beomi Thanks for your reply! Your latest commit is correct. There's another small mistake at https://github.com/Beomi/InfiniTransformer/blob/c2f07466074f00efbbee79c5a4345510463a5e31/modeling_gemma.py#L821 and https://github.com/Beomi/InfiniTransformer/blob/c2f07466074f00efbbee79c5a4345510463a5e31/modeling_gemma.py#L853-L854 The shape should actually be [batch_size, num_heads, seq_len, head_dim] because you have transposed QKV before, which is misleading. So the shapes of 'memory' and 'norm_term' are not related to seq_len.
https://github.com/Beomi/InfiniTransformer/blob/d3659c3c2f50038ba8e64d29139c0aa3701964dc/modeling_gemma.py#L837 I think 'norm_term_broadcastable' should be multiplied by 'query_states'.