UCDvision / sima

Official implementation for "SimA: Simple Softmax-free Attention for Vision Transformers"
MIT License
37 stars 5 forks source link

Cross Attention support? #5

Open HardysJin opened 2 years ago

HardysJin commented 2 years ago

Hi,

Thank you for the sharing. I replaced the DETR's encoder MultiHeadAtten with SimA, it works pretty well. I am wondering is it possible to replace the cross attention in decoder? If so, any clue how to do it?

soroush-abbasi commented 2 years ago

Hi,

Thanks for your interest in our works. You can normalize Q and K similar to self-attention, then QKV dot product. Did you encounter any specific challenge?

HardysJin commented 2 years ago

Thanks for the quick reply. In order to replace the MultiheadAttention, I did some modications(https://github.com/UCDvision/sima/blob/main/sima.py#L239). Instead of getting q, k, v from linear layers of x, I am forwarding the q, k, v directly. In self-attn, since all q, k, v have the same shape, SimA works very well. But for corss attention where the k and v with shape [key_size, batch, embed_size] (from DETR encoder output) and q with shape [query_size, batch, embed_size] (from DETR decoder's self-attn). SimA no longer works because of the matrix mul (https://github.com/UCDvision/sima/blob/main/sima.py#L248) for illegal shapes

soroush-abbasi commented 2 years ago

So assuming that :

Q [B, M, D] : M is number of queries
K [B, N, D] : N is number of Keys V [B, N, D] : N is number of Values

(QK)V : MxN @ NxD -> MxD Q(KV) : MxD @ DxD -> MxD

So I guess there should not be a problem. Since you said that the q has the shape [M, B, D], maybe you need to permute it to be [B, M, D]?

HardysJin commented 2 years ago

Thanks for the suggestion. I will try this adjustment.