Closed krasserm closed 1 year ago
Nevermind, I'll use the lower-level API e.g.
...
q_len = q.shape[2]
k_len = k.shape[2]
start_pos = k_len - q_len
freq_q = rot(torch.arange(start_pos, start_pos + q_len))
freq_k = rot(torch.arange(k_len))
q_rot = apply_rotary_emb(freq_q, q)
k_rot = apply_rotary_emb(freq_k, k)
...
It just would have been more convenient with a start_pos
argument for rotate_queries_or_keys()
.
This library seems to assume that queries and keys are left-aligned position-wise e.g.
where
p_i
are corresponding positions. This is enforced by starting the sequence of positions always from0
withtorch.arange(seq_len)
here. Applications like Perceiver AR, however, require a position-wise right-alignment e.g.This pull requests allows to specify a start position for queries and or keys to enable alignments other than left-alignments. For example
prints the following relative position embedding
(diagonal of 1s right-aligned) whereas the default behavior
would print
(diagonal of 1s left-aligned).