Beomi / InfiniTransformer

Unofficial PyTorch/🤗Transformers(Gemma/Llama3) implementation of Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention
https://arxiv.org/abs/2404.07143
MIT License
333 stars 29 forks source link

question about norm_term_broadcastable #8

Closed ictzyqq closed 4 months ago

ictzyqq commented 5 months ago

https://github.com/Beomi/InfiniTransformer/blob/d3659c3c2f50038ba8e64d29139c0aa3701964dc/modeling_gemma.py#L837 I think 'norm_term_broadcastable' should be multiplied by 'query_states'. image

Beomi commented 4 months ago

Thanks for you report! Fixed at https://github.com/Beomi/InfiniTransformer/commit/a82dbe9aec1b5426a4f6ef7c5093380f5850ef7a

ictzyqq commented 4 months ago

Thanks for you report! Fixed at a82dbe9

@Beomi Sorry, I actually didn't understand the changes in this commit. According to the paper, the denominator for A_mem should include σ(Q), so I think norm_term_broadcastable requires a matmul operation instead of an expand operation.

What's more, my understanding of the paper is that self.memory shoule be [batch_size, num_head, head_dim, value_dim], and self.norm_term shoule be [batch_size, num_head, head_dim]. Is this right?

Beomi commented 4 months ago

@ictzyqq Sorry, as you said, it was wrong commit, I updated the code for it -- check this: c2f0746

And for the memory/norm size, I guess -- the start memory is torch.matmul(key_states.transpose(-2, -1), value_states) , and key_states.transpose(-2, -1) has [batch_size, seq_len, head_dim, num_heads] and value_states has [batch_size, seq_len, num_heads, value_dim] so self.memory should have [batch_size, seq_len, head_dim, value_dim], right.

*Note that gemma uses GQA for K/V so the num_heads will be num_key_value_heads, which is 1 for Gemma-2B.

ictzyqq commented 4 months ago

@Beomi Thanks for your reply! Your latest commit is correct. There's another small mistake at https://github.com/Beomi/InfiniTransformer/blob/c2f07466074f00efbbee79c5a4345510463a5e31/modeling_gemma.py#L821 and https://github.com/Beomi/InfiniTransformer/blob/c2f07466074f00efbbee79c5a4345510463a5e31/modeling_gemma.py#L853-L854 The shape should actually be [batch_size, num_heads, seq_len, head_dim] because you have transposed QKV before, which is misleading. So the shapes of 'memory' and 'norm_term' are not related to seq_len.

Beomi commented 4 months ago

@ictzyqq thanks for your finegrained investigation! Applied at cffcaa2