ROCm / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
142 stars 46 forks source link

Integrated Rotary Positional Embeddings (RoPEs) into flash_attn_kvcache #83

Closed alexkranias-amd closed 5 days ago

alexkranias-amd commented 2 months ago

Motivation

Original Paper: RoFormer: Enhanced Transformer with Rotary Position Embedding

Rotary Positional Embeddings (RoPEs) are a common positional embedding type used in many transformer models today.

RoPEs work by applying a unique rotation transformation to the vectors that represent each token within our q and k tensors based on each token's respective position in the sequence $$m$$.

To compute attention, we must first compute $$\text{matmul(}Q \text{,} ~ K^T \text{)}$$. This effectively is taking the dot product between the vector embeddings of tokens in $$Q$$ and $$K^T$$. Given two tokens at positions $$i$$ and $$j$$, the closer $$i$$ and $$j$$ are to each other, then their vector embeddings will end up getting rotated roughly the same amount, and the dot product between these two token embedding vectors will be largely unchanged. However, the further away these tokens are from each other, the more the transformation applied to these two vector embeddings diverges, which causes the dot product to decay. As the dot product decays, so does the attention weighting applied between the two tokens, and likewise this effectively leads the model to learning that for a single token the tokens near it should be paid more attention to than the tokens much further away.

Dot Product Decay

A more detailed explanation

Fundamentally RoPEs work by dividing the embedding space of our q and k vectors (the $$\text{head}$$ $$\text{dim}$$) into many chunks of two. Each 2-dimensional chunk can be thought of as a vector subcomponent of q and k projected on a 2-dimensional plane that exists within the higher dimensional space of the q and k embedding. RoPE "rotates" the planar chunks of our q and k vectors uniquely based on the index of the token in the sequence. Each "chunk" is rotated some unique amount $$\theta{m, d/2}$$ based on the index of the token in the sequence $$m$$, and the dimension $$d$$ of the subcomponents of q and k being rotated.

RoPE Implementation Details

Implementation

RoPE is applied to Q and K at every attention layer. For developing a kernel there are two options:

  1. Rotate Q and K using one kernel, then pass in the new rotated Q and K vectors into our flash_attn_kernel
  2. Fuse RoPE into our flash_attn_kernel

Since Tri Dao already had a functional seperate RoPE kernel. I implemented approach 1 first.

Seperate RoPE and FlashAttention Kernels

We import from flash_attn.layers.rotary import apply_rotary_emb

Within class _attention(torch.autograd.Function) before calling splitk_flash_attn we rotate q and input_metadata.k_new by making a call to this method apply_rotary_emb which makes a call to a Triton kernel.

Fused RoPE into FlashAttention

TODO

More Notes

Can be found at the following issue: https://github.com/ROCm/triton-internal/issues/33