Open jafioti opened 6 months ago
Didn't end up with that approach, but the new selector api is much nicer to write complex patterns with. Rope went from
let freqs = select_const!(1000000.0_f32.ln(), T)
.ptr(&mut theta)
.edge(
select_ty!(MetalConstant<T>)
.ptr(&mut inv_head_dim)
.edge(
select_ty!(MetalConstant<T>)
.ptr(&mut two)
.edge(
select_ty!(crate::other::MetalARange<T>)
.ptr(&mut head_dim_arange)
.edge(select_ty!(MetalMul<T>).ptr(&mut mul_2)),
)
.edge(select_ty!(MetalMul<T>).ptr(&mut head_dim_mul)),
)
.edge(select_ty!(MetalMul<T>).ptr(&mut theta_mul)),
)
.edge(select_ty!(MetalExp<T>).ptr(&mut exp))
.edge(select_ty!(MetalRecip<T>).ptr(&mut recip));
let seq = select_ty!(MetalConstant<T>).ptr(&mut seq_expr).edge(
select_ty!(crate::other::MetalARange<T>)
.ptr(&mut seq_arange)
.edge(select_ty!(MetalAdd<T>).ptr(&mut seq_add)),
);
let emb = freqs.edge(seq.edge(select_ty!(MetalMul<T>).ptr(&mut freq_seq_mul)));
let split = SelectOp::new()
.ptr(&mut input)
.edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig1));
let x0 = split
.clone()
.edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig2));
let x1 = split.edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig3));
let x0_sin = emb
.clone()
.edge(select_ty!(MetalSin<T>).ptr(&mut sin1))
.edge(x0.clone().edge(select_ty!(MetalMul<T>).ptr(&mut out_mul1)));
let x0_cos = emb
.clone()
.edge(select_ty!(MetalCos<T>).ptr(&mut cos1))
.edge(x0.edge(select_ty!(MetalMul<T>).ptr(&mut out_mul2)));
let x1_sin = emb
.clone()
.edge(select_ty!(MetalSin<T>).ptr(&mut sin2))
.edge(x1.clone().edge(select_ty!(MetalMul<T>).ptr(&mut out_mul3)));
let x1_cos = emb
.clone()
.edge(select_ty!(MetalCos<T>).ptr(&mut cos2))
.edge(x1.edge(select_ty!(MetalMul<T>).ptr(&mut out_mul4)));
let x0_out = x1_sin.edge(x0_cos.edge(select_ty!(MetalSub<T>).ptr(&mut out_sub)));
let x1_out = x0_sin.edge(x1_cos.edge(select_ty!(MetalAdd<T>).ptr(&mut out_add)));
let mut searcher = x1_out
.edge(x0_out.edge(select_ty!(MetalAdd<T>).ptr(&mut final_add)))
.search(graph);
To
let freqs = binary::<MetalMul<T>>(op::<MetalARange<T>>(), constant::<T>(2.0));
let freqs = binary::<MetalMul<T>>(freqs, op::<MetalConstant<T>>());
let freqs = binary::<MetalMul<T>>(freqs, constant::<T>((1000000_f32).abs().ln()));
let freqs = unary::<MetalRecip<T>>(unary::<MetalExp<T>>(freqs));
let prev_seq = op::<MetalConstant<T>>();
let emb = binary::<MetalMul<T>>(
binary::<MetalAdd<T>>(op::<MetalARange<T>>(), prev_seq.clone()),
freqs,
);
let inp = node();
let split = unary::<MetalContiguous<T>>(inp.clone());
let x0 = unary::<MetalContiguous<T>>(split.clone());
let x0_out = binary::<MetalSub<T>>(
binary::<MetalMul<T>>(x0, unary::<MetalSin<T>>(emb.clone())),
binary::<MetalMul<T>>(op::<MetalContiguous<T>>(), op::<MetalCos<T>>()),
);
let x1_out = binary::<MetalAdd<T>>(op::<MetalMul<T>>(), op::<MetalMul<T>>());
let add = binary::<MetalAdd<T>>(x0_out, x1_out);
let mut s = add.clone().search(graph);
Still some correctness bugs remaining. The selector graphs currently cannot reference the same node twice, because the backtracking function doesn't support it yet
Currently the graph selection api makes it difficult to write selectors for complex patterns like Rope: https://github.com/jafioti/luminal/blob/cb07523f02845e5a78b49f4c1fbf3f0705709ea9/crates/luminal_metal/src/unary.rs#L1278
Selectors should be built similarly to how primgraphs are already built, with a graphtensor-like api (no compile-time shapes though).