hikettei / Caten

[wip] Deep Learning Compiler based on Polyhedral Compiler and Light-weight IRs based on Optimizing Pattern Matcher
https://hikettei.github.io/Caten/
Other
20 stars 4 forks source link

feat: Rotatory Positional Encoding #215

Closed abourramouss closed 12 hours ago

abourramouss commented 1 week ago

this pr is for the implementation of RoPE

195

abourramouss commented 3 days ago
test-rope
  dtype=FLOAT32
    ✓ Shapes match
    ✓ Satisfying (atol=1.9999694) <= 2
    ✓ Satisfying (rtol=2.4570446) <= 3
T
NIL

Ideally the error should be on the e-5 scale, the issue is that the tensors aren't identical, some values are different.

hikettei commented 2 days ago

Your code looks very clean! LGTM With regard to the accuracy issue, let me make some experiments (yeah it should be fit in 1e-5 scale) this might be due to the compiler issue

hikettei commented 2 days ago

the code looks very good. will merge this after fixing the error in 1e-5 scale

hikettei commented 2 days ago

Hmm with JIT it should be scheduled to a single kernel? (don't worry about that I will take this!)

CATEN/TEST-SUITE> (call (Rope `(10)) (make-tensor `(10 10)))
{Tensor[float32] :shape (10 10) :id TID1450187
  :buffer nil
  :op #<RESHAPE {700BF75D83}>
  :requires-grad NIL
  :variables (TID1450176)
  :tracker #<TRACKER :order={row(0 1)} :shape=(10 10) :contiguous-p=T>}
CATEN/TEST-SUITE> (proceed *)
[graph-schedule] Schedule Graph:

FastGraph[seen=NIL, outputs=(val_39)] {
    { Allocate } : [ val_34 <- (1 10 5 2) where lowered-p=nil ]
    { Allocate } : [ val_32 <- (1 10 5) where lowered-p=nil ]
    {  KERNEL  } : [ val_35 <- val_23, val_22, val_21, val_32, val_34 where lowered-p=nil :name=FUSED_CONCATENATENODE1463528]
    { Allocate } : [ val_0 <- (1 10 5) where lowered-p=nil ]
    { Allocate } : [ val_23 <- (10 10) where lowered-p=nil ]
    { Allocate } : [ val_17 <- (10 5) where lowered-p=nil ]
    { Allocate } : [ val_15 <- (10) where lowered-p=nil ]
    { Allocate } : [ val_9 <- (5) where lowered-p=nil ]
    { Allocate } : [ val_7 <- NIL where lowered-p=nil ]
    { Allocate } : [ val_5 <- NIL where lowered-p=nil ]
    { Allocate } : [ val_3 <- NIL where lowered-p=nil ]
    { Allocate } : [ val_1 <- NIL where lowered-p=nil ]
    {  KERNEL  } : [ val_22, val_21, val_36 <- val_1, val_3, val_5, val_7, val_9, val_15, val_17, val_23, val_0, val_35 where lowered-p=nil :name=FUSED_CONCATENATENODE_COSNODE1463526]
    { Allocate } : [ val_37 <- (1 10 5 2) where lowered-p=nil ]
    {  KERNEL  } : [ val_38 <- val_37, val_36 where lowered-p=nil :name=FUSED_MOVE1463522]
    {   VMOP   } : [ val_39 <- val_38 where lowered-p=nil :name=FUSED_BACKWARD1463520]
}

[14:21:53, 11/20/2024 (GMT+9)] : JIT Compilation Start (AVM=MAIN1450194)

* (1/3) FUSED_CONCATENATENODE1463528
=====> Lowering to blueprint
{
  for (int _gid0=0;(_gid0<1);_gid0+=1) {
    for (int _gid1=0;(_gid1<10);_gid1+=1) {
      for (int _gid2=0;(_gid2<5);_gid2+=1) {
        val_27 = -((val_23[((_gid0+(10*_gid1))+(1+(2*_gid2)))]*val_22[((_gid0+(5*_gid1))+_gid2)]));
        val_35[(((100*_gid0)+(10*_gid1))+(2*_gid2))] = ((val_23[((_gid0+(10*_gid1))+(2*_gid2))]*val_21[((_gid0+(5*_gid1))+_gid2)])+val_27);
      } // _gid2
    } // _gid1
  } // _gid0
}
Compilation Time : 0.030271(sec)
* (2/3) FUSED_CONCATENATENODE_COSNODE1463526
=====> Lowering to blueprint
{
  for (int _gid0=0;(_gid0<1);_gid0+=1) {
    for (int _gid1=0;(_gid1<10);_gid1+=1) {
      for (int _gid2=0;(_gid2<5);_gid2+=1) {
        val_19 = (_gid1*exp2((((_gid2*-9.2103405)*0.1)*1.442695)));
        val_22[((5*_gid1)+_gid2)] = sin(val_19);
        val_21[((5*_gid1)+_gid2)] = sin((val_19+1.5707964));
        val_29 = (val_23[((_gid0+(10*_gid1))+(1+(2*_gid2)))]*val_21[((_gid0+(5*_gid1))+_gid2)]);
        val_36[((((100*_gid0)+(10*_gid1))+(2*_gid2))+1)] = ((val_23[((_gid0+(10*_gid1))+(2*_gid2))]*val_22[((_gid0+(5*_gid1))+_gid2)])+val_29);
      } // _gid2
    } // _gid1
  } // _gid0
}
Compilation Time : 0.04597(sec)
* (3/3) FUSED_MOVE1463522 

=====> Lowering to blueprint
{
  for (int _gid0=0;(_gid0<100);_gid0+=1) {
    val_38[_gid0] = val_36[_gid0];
  } // _gid0
}
Compilation Time : 0.001554(sec)
hikettei commented 2 days ago

245 will it fix something for JIT=0?

abourramouss commented 2 days ago

Been doing some tests in torch and mlx, the issue is that mlx and pytorch yield different results, and my implementation is basically inspired by mlx, while the test is using the torch implementation

This is the code i've been using:

import torch
from torchtune.modules import RotaryPositionalEmbeddings
from mlx.nn import RoPE
import mlx.core as mx
import numpy as np

# Parameters
seq_len, num_heads, head_dim = 30, 30, 30
dim = head_dim
max_seq_len = seq_len
theta = 10000.0

x = torch.rand(1, seq_len, num_heads, head_dim)
rope_torch = RotaryPositionalEmbeddings(dim=head_dim)

rope_mlx = RoPE(dims=head_dim)

numpy_array = x.detach().cpu().numpy()

x_mlx = mx.array(numpy_array)

mx_output = rope_mlx(x_mlx)

print("Output MLX:")
mlx_tensor = torch.tensor(numpy_array)
print(type(mlx_tensor))

print("Output PyTorch:")
torch_tensor = rope_torch(x)
print(type(torch_tensor))

absolute_diff = torch.sum(torch.abs(torch_tensor - mlx_tensor))

relative_diff = absolute_diff / torch.sum(torch.abs(mlx_tensor))

# Output results
print(f"Total Absolute Difference: {absolute_diff.item()}")
print(f"Total Relative Difference: {relative_diff.item()}")

and rtol atol:

Output MLX:
<class 'torch.Tensor'>
Output PyTorch:
<class 'torch.Tensor'>
Total Absolute Difference: 6711.056640625
Total Relative Difference: 0.49696406722068787

I will reimplement the call function using the torch implementation in order to lower the rtol and atol to the required values.