lucidrains / reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
MIT License
2.13k stars 256 forks source link

Is the unsort logits necessary? #156

Open Croooooow opened 6 months ago

Croooooow commented 6 months ago

In the follow code

  bo = torch.einsum('buij,buje->buie', dropped_dots, bv)
  so = torch.reshape(bo, (batch_size, -1, dim))
  slogits = torch.reshape(dots_logsumexp, (batch_size, -1,))

  o = batched_index_select(so, undo_sort)
  logits = slogits.gather(1, undo_sort)

  o = torch.reshape(o, (batch_size, total_hashes, seqlen, dim))
  logits = torch.reshape(logits, (batch_size, total_hashes, seqlen, 1))

  if query_len != seqlen:
      query_slice = (slice(None), slice(None), slice(0, query_len))
      o, logits = o[query_slice], logits[query_slice]

  probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True))
  out = torch.sum(o * probs, dim=1)

The code is clear until the undo_sort over o andlogits. It looks like the undo_sort here is unnecessary as the o and logits are of the same order and the following weighted sum operation over total_hashes does not care about the order. Correct me if I am wrong.