Closed jjkk123456 closed 7 months ago
Hi @JKTang1 , Thank you for your interest in our work!
Our computation is mathematically equivalent to the standard transformer's torch.matmul(query, key)
calculation.
The variable that you see as query
in our implementation, is not the standard query, but it has gone through processing here: https://github.com/abertsch72/unlimiformer/blob/main/src/unlimiformer.py#L721 .
Please see section 2.3 in our paper.
Best, Uri
The top-k key attention calculation formula in Unlimiformer:
this_layer_prompt_keys: (batch, head, source_len, dim)
query: (batch, tgt, head, dim)
attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1), query.unsqueeze(-1)) \ .reshape(batch_size, tgt_len, query.shape[-2], 1,this_layer_prompt_keys.shape[-2])
The normal transformer multi-head attention mechanism:
key(batch*head, source_len, dim)
query(batch*head , target_len, dim)
att_weight = torch.matmul(query, key)
Unlimiformer uses the first attention calculation formula to get the top-k key, so i want to know why uses it and if i can use normal transformer multi-head attention mechanism to get top-k value.