Rotary Positional Embeddings (RoPEs) are a common positional embedding type used in many transformer models today.
RoPEs work by applying a unique rotation transformation to the vectors that represent each token within our q and k tensors based on each token's respective position in the sequence $$m$$.
To compute attention, we must first compute $$\text{matmul(}Q \text{,} ~ K^T \text{)}$$. This effectively is taking the dot product between the vector embeddings of tokens in $$Q$$ and $$K^T$$. Given two tokens at positions $$i$$ and $$j$$, the closer $$i$$ and $$j$$ are to each other, then their vector embeddings will end up getting rotated roughly the same amount, and the dot product between these two token embedding vectors will be largely unchanged. However, the further away these tokens are from each other, the more the transformation applied to these two vector embeddings diverges, which causes the dot product to decay. As the dot product decays, so does the attention weighting applied between the two tokens, and likewise this effectively leads the model to learning that for a single token the tokens near it should be paid more attention to than the tokens much further away.
A more detailed explanation
Fundamentally RoPEs work by dividing the embedding space of our q and k vectors (the $$\text{head}$$ $$\text{dim}$$) into many chunks of two. Each 2-dimensional chunk can be thought of as a vector subcomponent of q and k projected on a 2-dimensional plane that exists within the higher dimensional space of the q and k embedding. RoPE "rotates" the planar chunks of our q and k vectors uniquely based on the index of the token in the sequence. Each "chunk" is rotated some unique amount $$\theta{m, d/2}$$ based on the index of the token in the sequence $$m$$, and the dimension $$d$$ of the subcomponents of q and k being rotated.
Implementation
RoPE is applied to Q and K at every attention layer. For developing a kernel there are two options:
Rotate Q and K using one kernel, then pass in the new rotated Q and K vectors into our flash_attn_kernel
Fuse RoPE into our flash_attn_kernel
Since Tri Dao already had a functional seperate RoPE kernel. I implemented approach 1 first.
Seperate RoPE and FlashAttention Kernels
We import from flash_attn.layers.rotary import apply_rotary_emb
Within class _attention(torch.autograd.Function) before calling splitk_flash_attn we rotate q and input_metadata.k_new by making a call to this method apply_rotary_emb which makes a call to a Triton kernel.
Motivation
Original Paper: RoFormer: Enhanced Transformer with Rotary Position Embedding
Rotary Positional Embeddings (RoPEs) are a common positional embedding type used in many transformer models today.
RoPEs work by applying a unique rotation transformation to the vectors that represent each token within our q and k tensors based on each token's respective position in the sequence $$m$$.
To compute attention, we must first compute $$\text{matmul(}Q \text{,} ~ K^T \text{)}$$. This effectively is taking the dot product between the vector embeddings of tokens in $$Q$$ and $$K^T$$. Given two tokens at positions $$i$$ and $$j$$, the closer $$i$$ and $$j$$ are to each other, then their vector embeddings will end up getting rotated roughly the same amount, and the dot product between these two token embedding vectors will be largely unchanged. However, the further away these tokens are from each other, the more the transformation applied to these two vector embeddings diverges, which causes the dot product to decay. As the dot product decays, so does the attention weighting applied between the two tokens, and likewise this effectively leads the model to learning that for a single token the tokens near it should be paid more attention to than the tokens much further away.
A more detailed explanation
Fundamentally RoPEs work by dividing the embedding space of our q and k vectors (the $$\text{head}$$ $$\text{dim}$$) into many chunks of two. Each 2-dimensional chunk can be thought of as a vector subcomponent of q and k projected on a 2-dimensional plane that exists within the higher dimensional space of the q and k embedding. RoPE "rotates" the planar chunks of our q and k vectors uniquely based on the index of the token in the sequence. Each "chunk" is rotated some unique amount $$\theta{m, d/2}$$ based on the index of the token in the sequence $$m$$, and the dimension $$d$$ of the subcomponents of q and k being rotated.
Implementation
RoPE is applied to Q and K at every attention layer. For developing a kernel there are two options:
Since Tri Dao already had a functional seperate RoPE kernel. I implemented approach 1 first.
Seperate RoPE and FlashAttention Kernels
We import
from flash_attn.layers.rotary import apply_rotary_emb
Within
class _attention(torch.autograd.Function)
before callingsplitk_flash_attn
we rotateq
andinput_metadata.k_new
by making a call to this methodapply_rotary_emb
which makes a call to a Triton kernel.Fused RoPE into FlashAttention
TODO
More Notes
Can be found at the following issue: https://github.com/ROCm/triton-internal/issues/33