lucidrains / rotary-embedding-torch

Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch
MIT License
573 stars 44 forks source link

Custom position offset when rotating queries or keys #2

Closed krasserm closed 1 year ago

krasserm commented 1 year ago

This library seems to assume that queries and keys are left-aligned position-wise e.g.

q = [p_0, p_1, p_2]
k = [p_0, p_1, p_2, p_3, p_4]

where p_i are corresponding positions. This is enforced by starting the sequence of positions always from 0 with torch.arange(seq_len) here. Applications like Perceiver AR, however, require a position-wise right-alignment e.g.

q =           [p_2, p_3, p_4]
k = [p_0, p_1, p_2, p_3, p_4]

This pull requests allows to specify a start position for queries and or keys to enable alignments other than left-alignments. For example

import torch
from rotary_embedding_torch.rotary_embedding_torch import RotaryEmbedding

rot = RotaryEmbedding(dim=32)

q = torch.ones(1, 8, 4, 32)
k = torch.ones(1, 8, 6, 32)

q = q / torch.norm(q, dim=-1, keepdim=True)
k = k / torch.norm(k, dim=-1, keepdim=True)

q_rot = rot.rotate_queries_or_keys(q, start_pos=k.shape[2] - q.shape[2])
k_rot = rot.rotate_queries_or_keys(k)

attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
print(attn[0, 0])

prints the following relative position embedding

tensor([[0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
        [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581],
        [0.7288, 0.7670, 0.8581, 0.9571, 1.0000, 0.9571],
        [0.7361, 0.7288, 0.7670, 0.8581, 0.9571, 1.0000]])

(diagonal of 1s right-aligned) whereas the default behavior

...

q_rot = rot.rotate_queries_or_keys(q)
k_rot = rot.rotate_queries_or_keys(k)

attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
print(attn[0, 0])

would print

tensor([[1.0000, 0.9571, 0.8581, 0.7670, 0.7288, 0.7361],
        [0.9571, 1.0000, 0.9571, 0.8581, 0.7670, 0.7288],
        [0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
        [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581]])

(diagonal of 1s left-aligned).

krasserm commented 1 year ago

Nevermind, I'll use the lower-level API e.g.

...

q_len = q.shape[2]
k_len = k.shape[2]
start_pos = k_len - q_len

freq_q = rot(torch.arange(start_pos, start_pos + q_len))
freq_k = rot(torch.arange(k_len))

q_rot = apply_rotary_emb(freq_q, q)
k_rot = apply_rotary_emb(freq_k, k)

...

It just would have been more convenient with a start_pos argument for rotate_queries_or_keys().