naver-ai / rope-vit

[ECCV 2024] Official PyTorch implementation of RoPE-ViT "Rotary Position Embedding for Vision Transformer"
https://arxiv.org/abs/2403.13298
Other
200 stars 3 forks source link

Question about learnable components in mixed RoPE #11

Closed liamhebert closed 6 days ago

liamhebert commented 1 week ago

Hi! Interesting work!

In your paper, you mention that you use two different frequencies for each axis in the rotation matrix

Screenshot 2024-10-29 at 12 20 27 PM

and then later set them to be learnable. In the code, this is done here https://github.com/naver-ai/rope-vit/blob/2e801b6a5cdb4a7a8d95ed555aef5645b58044df/self-attn/rope_self_attn.py#L110-L114

where freq is seeded with the appropriate cos and sin values, as in Euler's formula. https://github.com/naver-ai/rope-vit/blob/2e801b6a5cdb4a7a8d95ed555aef5645b58044df/self-attn/rope_self_attn.py#L11-L24

Where I'm confused is that a core part of RoPE is computed with relative positions, which is uniquely offered by the fact that we use cos and sin for the rotation matrix, which has periods. In contrast, your code makes the rotation matrix completely learnable, only using the correct cos and sin values as starting seeds, which would destroy this property.

Is this intended? I think a more correct implementation would be to learn the mag tensor for each axis, rather then the whole matrix. https://github.com/naver-ai/rope-vit/blob/2e801b6a5cdb4a7a8d95ed555aef5645b58044df/self-attn/rope_self_attn.py#L14

Let me know if I'm thinking of this the wrong way! Liam

bhheo commented 6 days ago

Hi Liam

Thank you for your interest in our paper

The frequencies in this code are just for initialization Because I used a random rotation for init value, it might lead you to confusion

https://github.com/naver-ai/rope-vit/blob/2e801b6a5cdb4a7a8d95ed555aef5645b58044df/self-attn/rope_self_attn.py#L11-L24

The code for a forward part is here I wrote it using Euler's formula, i.e. torch.polar

https://github.com/naver-ai/rope-vit/blob/2e801b6a5cdb4a7a8d95ed555aef5645b58044df/self-attn/rope_self_attn.py#L32-L39

Note that I used learnable freqs and fixed t as

https://github.com/naver-ai/rope-vit/blob/2e801b6a5cdb4a7a8d95ed555aef5645b58044df/self-attn/rope_self_attn.py#L114-L118 https://github.com/naver-ai/rope-vit/blob/2e801b6a5cdb4a7a8d95ed555aef5645b58044df/self-attn/rope_self_attn.py#L132-L136

I hope this answers your questions about our implementation

Best Heo

liamhebert commented 6 days ago

That makes sense! I had completely missed the usage of torch.polar and the semantics involved. Thanks!