abertsch72 / unlimiformer

Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"
MIT License
1.05k stars 77 forks source link

Question: During training, the calculation of topk value’s att_weight is different from the classic transformer’s multi-head attention. #45

Closed jjkk123456 closed 7 months ago

jjkk123456 commented 11 months ago

The top-k key attention calculation formula in Unlimiformer:

this_layer_prompt_keys: (batch, head, source_len, dim)

query: (batch, tgt, head, dim)

attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1), query.unsqueeze(-1)) \ .reshape(batch_size, tgt_len, query.shape[-2], 1,this_layer_prompt_keys.shape[-2])

The normal transformer multi-head attention mechanism:

key(batch*head, source_len, dim)

query(batch*head , target_len, dim)

att_weight = torch.matmul(query, key)

Unlimiformer uses the first attention calculation formula to get the top-k key, so i want to know why uses it and if i can use normal transformer multi-head attention mechanism to get top-k value.

urialon commented 11 months ago

Hi @JKTang1 , Thank you for your interest in our work!

Our computation is mathematically equivalent to the standard transformer's torch.matmul(query, key) calculation. The variable that you see as query in our implementation, is not the standard query, but it has gone through processing here: https://github.com/abertsch72/unlimiformer/blob/main/src/unlimiformer.py#L721 . Please see section 2.3 in our paper.

Best, Uri