microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
2.98k stars 201 forks source link

typo in normalization denominator in parallel retention? #78

Closed XintianHan closed 8 months ago

XintianHan commented 8 months ago

In the parallel retention code, the normalization denominator uses .sum(dim=-1, keepdim=True).abs()

def parallel_forward(self, qr, kr, v, mask):
      bsz, tgt_len, embed_dim = v.size()
      vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
      qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
      qk_mat = qk_mat * mask
      # invariant after normalization
      qk_mat = qk_mat / qk_mat.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)
      output = torch.matmul(qk_mat, vr)
      output = output.transpose(1, 2)
      return output

However, in the chunkwise retention, the normalization uses .abs().sum(). From my perspective, .abs().sum() is better than .sum().abs() for the normalization denominator since real values may cancel with each other during the summation. So is it a typo here?

donglixp commented 8 months ago

We used .abs().sum() in our internal code base. I fixed this issue at https://github.com/microsoft/torchscale/commit/fdd8838a756c7c435d7f8a1e4303e150dfac7442 . Thanks for pointing out this.