bo = torch.einsum('buij,buje->buie', dropped_dots, bv)
so = torch.reshape(bo, (batch_size, -1, dim))
slogits = torch.reshape(dots_logsumexp, (batch_size, -1,))
o = batched_index_select(so, undo_sort)
logits = slogits.gather(1, undo_sort)
o = torch.reshape(o, (batch_size, total_hashes, seqlen, dim))
logits = torch.reshape(logits, (batch_size, total_hashes, seqlen, 1))
if query_len != seqlen:
query_slice = (slice(None), slice(None), slice(0, query_len))
o, logits = o[query_slice], logits[query_slice]
probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True))
out = torch.sum(o * probs, dim=1)
The code is clear until the undo_sort over o andlogits. It looks like the undo_sort here is unnecessary as the o and logits are of the same order and the following weighted sum operation over total_hashes does not care about the order. Correct me if I am wrong.
In the follow code
The code is clear until the undo_sort over
o
andlogits
. It looks like the undo_sort here is unnecessary as theo
andlogits
are of the same order and the following weighted sum operation overtotal_hashes
does not care about the order. Correct me if I am wrong.