Closed abourramouss closed 12 hours ago
test-rope
dtype=FLOAT32
✓ Shapes match
✓ Satisfying (atol=1.9999694) <= 2
✓ Satisfying (rtol=2.4570446) <= 3
T
NIL
Ideally the error should be on the e-5 scale, the issue is that the tensors aren't identical, some values are different.
Your code looks very clean! LGTM With regard to the accuracy issue, let me make some experiments (yeah it should be fit in 1e-5 scale) this might be due to the compiler issue
the code looks very good. will merge this after fixing the error in 1e-5 scale
Hmm with JIT it should be scheduled to a single kernel? (don't worry about that I will take this!)
CATEN/TEST-SUITE> (call (Rope `(10)) (make-tensor `(10 10)))
{Tensor[float32] :shape (10 10) :id TID1450187
:buffer nil
:op #<RESHAPE {700BF75D83}>
:requires-grad NIL
:variables (TID1450176)
:tracker #<TRACKER :order={row(0 1)} :shape=(10 10) :contiguous-p=T>}
CATEN/TEST-SUITE> (proceed *)
[graph-schedule] Schedule Graph:
FastGraph[seen=NIL, outputs=(val_39)] {
{ Allocate } : [ val_34 <- (1 10 5 2) where lowered-p=nil ]
{ Allocate } : [ val_32 <- (1 10 5) where lowered-p=nil ]
{ KERNEL } : [ val_35 <- val_23, val_22, val_21, val_32, val_34 where lowered-p=nil :name=FUSED_CONCATENATENODE1463528]
{ Allocate } : [ val_0 <- (1 10 5) where lowered-p=nil ]
{ Allocate } : [ val_23 <- (10 10) where lowered-p=nil ]
{ Allocate } : [ val_17 <- (10 5) where lowered-p=nil ]
{ Allocate } : [ val_15 <- (10) where lowered-p=nil ]
{ Allocate } : [ val_9 <- (5) where lowered-p=nil ]
{ Allocate } : [ val_7 <- NIL where lowered-p=nil ]
{ Allocate } : [ val_5 <- NIL where lowered-p=nil ]
{ Allocate } : [ val_3 <- NIL where lowered-p=nil ]
{ Allocate } : [ val_1 <- NIL where lowered-p=nil ]
{ KERNEL } : [ val_22, val_21, val_36 <- val_1, val_3, val_5, val_7, val_9, val_15, val_17, val_23, val_0, val_35 where lowered-p=nil :name=FUSED_CONCATENATENODE_COSNODE1463526]
{ Allocate } : [ val_37 <- (1 10 5 2) where lowered-p=nil ]
{ KERNEL } : [ val_38 <- val_37, val_36 where lowered-p=nil :name=FUSED_MOVE1463522]
{ VMOP } : [ val_39 <- val_38 where lowered-p=nil :name=FUSED_BACKWARD1463520]
}
[14:21:53, 11/20/2024 (GMT+9)] : JIT Compilation Start (AVM=MAIN1450194)
* (1/3) FUSED_CONCATENATENODE1463528
=====> Lowering to blueprint
{
for (int _gid0=0;(_gid0<1);_gid0+=1) {
for (int _gid1=0;(_gid1<10);_gid1+=1) {
for (int _gid2=0;(_gid2<5);_gid2+=1) {
val_27 = -((val_23[((_gid0+(10*_gid1))+(1+(2*_gid2)))]*val_22[((_gid0+(5*_gid1))+_gid2)]));
val_35[(((100*_gid0)+(10*_gid1))+(2*_gid2))] = ((val_23[((_gid0+(10*_gid1))+(2*_gid2))]*val_21[((_gid0+(5*_gid1))+_gid2)])+val_27);
} // _gid2
} // _gid1
} // _gid0
}
Compilation Time : 0.030271(sec)
* (2/3) FUSED_CONCATENATENODE_COSNODE1463526
=====> Lowering to blueprint
{
for (int _gid0=0;(_gid0<1);_gid0+=1) {
for (int _gid1=0;(_gid1<10);_gid1+=1) {
for (int _gid2=0;(_gid2<5);_gid2+=1) {
val_19 = (_gid1*exp2((((_gid2*-9.2103405)*0.1)*1.442695)));
val_22[((5*_gid1)+_gid2)] = sin(val_19);
val_21[((5*_gid1)+_gid2)] = sin((val_19+1.5707964));
val_29 = (val_23[((_gid0+(10*_gid1))+(1+(2*_gid2)))]*val_21[((_gid0+(5*_gid1))+_gid2)]);
val_36[((((100*_gid0)+(10*_gid1))+(2*_gid2))+1)] = ((val_23[((_gid0+(10*_gid1))+(2*_gid2))]*val_22[((_gid0+(5*_gid1))+_gid2)])+val_29);
} // _gid2
} // _gid1
} // _gid0
}
Compilation Time : 0.04597(sec)
* (3/3) FUSED_MOVE1463522
=====> Lowering to blueprint
{
for (int _gid0=0;(_gid0<100);_gid0+=1) {
val_38[_gid0] = val_36[_gid0];
} // _gid0
}
Compilation Time : 0.001554(sec)
Been doing some tests in torch and mlx, the issue is that mlx and pytorch yield different results, and my implementation is basically inspired by mlx, while the test is using the torch implementation
This is the code i've been using:
import torch
from torchtune.modules import RotaryPositionalEmbeddings
from mlx.nn import RoPE
import mlx.core as mx
import numpy as np
# Parameters
seq_len, num_heads, head_dim = 30, 30, 30
dim = head_dim
max_seq_len = seq_len
theta = 10000.0
x = torch.rand(1, seq_len, num_heads, head_dim)
rope_torch = RotaryPositionalEmbeddings(dim=head_dim)
rope_mlx = RoPE(dims=head_dim)
numpy_array = x.detach().cpu().numpy()
x_mlx = mx.array(numpy_array)
mx_output = rope_mlx(x_mlx)
print("Output MLX:")
mlx_tensor = torch.tensor(numpy_array)
print(type(mlx_tensor))
print("Output PyTorch:")
torch_tensor = rope_torch(x)
print(type(torch_tensor))
absolute_diff = torch.sum(torch.abs(torch_tensor - mlx_tensor))
relative_diff = absolute_diff / torch.sum(torch.abs(mlx_tensor))
# Output results
print(f"Total Absolute Difference: {absolute_diff.item()}")
print(f"Total Relative Difference: {relative_diff.item()}")
and rtol atol:
Output MLX:
<class 'torch.Tensor'>
Output PyTorch:
<class 'torch.Tensor'>
Total Absolute Difference: 6711.056640625
Total Relative Difference: 0.49696406722068787
I will reimplement the call function using the torch implementation in order to lower the rtol and atol to the required values.
this pr is for the implementation of RoPE
195