linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
https://arxiv.org/pdf/2410.10989
BSD 2-Clause "Simplified" License
3.39k stars 193 forks source link

[feat] on-paper form of RoPE #61

Open yundai424 opened 2 months ago

yundai424 commented 2 months ago

🚀 The feature, motivation and pitch

right now our implementation of RoPE assumes the rotation matrix is created and used in the HuggingFace model code way, i.e. instead of the rotation matrix described in original RoPE paper https://arxiv.org/pdf/2104.09864, we assume it looks something like this instead:

\begin{pmatrix}
\cos m \theta_0 & 0 & 0 & \dots & 0 & -\sin m \theta_0 & 0 & 0 & \dots & 0 \\
0 & \cos m \theta_1 & 0 & \dots & 0 & 0 & -\sin m \theta_1 & 0 & \dots & 0 \\
0 & 0 & \cos m \theta_2 & \dots & 0 & 0 & 0 & -\sin m \theta_2 & \dots & 0 \\
\vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\
0 & 0 & 0 & \dots & \cos m \theta_{d/2-1} & 0 & 0 & 0 & \dots & -\sin m \theta_{d/2-1} \\
\sin m \theta_0 & 0 & 0 & \dots & 0 & \cos m \theta_0 & 0 & 0 & \dots & 0 \\
0 & \sin m \theta_1 & 0 & \dots & 0 & 0 & \cos m \theta_1 & 0 & \dots & 0 \\
0 & 0 & \sin m \theta_2 & \dots & 0 & 0 & 0 & \cos m \theta_2 & \dots & 0 \\
\vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\
0 & 0 & 0 & \dots & \sin m \theta_{d/2-1} & 0 & 0 & 0 & \dots & \cos m \theta_{d/2-1}
\end{pmatrix}
\times
\begin{pmatrix}
q_0 \\
q_1 \\
q_2 \\
\vdots \\
q_{d/2-1} \\
q_{d/2} \\
q_{d/2+1} \\
q_{d/2+2} \\
\vdots \\
q_{d-1}
\end{pmatrix}

We should also support use cases where people create their RoPE cos & sin buffers following the original formula.

Alternatives

We may need to consider the complex form too (i.e. what official meta llama code is doing https://github.com/meta-llama/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L64-L74)

Additional context

No response

Himanshunitrr commented 2 months ago

take @ByronHsu I would like to implement this. Can you assign it to me?