jafioti / luminal

Deep learning at the speed of light.
https://luminalai.com
Apache License 2.0
1.45k stars 90 forks source link

Better graph selection api #33

Open jafioti opened 6 months ago

jafioti commented 6 months ago

Currently the graph selection api makes it difficult to write selectors for complex patterns like Rope: https://github.com/jafioti/luminal/blob/cb07523f02845e5a78b49f4c1fbf3f0705709ea9/crates/luminal_metal/src/unary.rs#L1278

Selectors should be built similarly to how primgraphs are already built, with a graphtensor-like api (no compile-time shapes though).

jafioti commented 6 months ago

Didn't end up with that approach, but the new selector api is much nicer to write complex patterns with. Rope went from

let freqs = select_const!(1000000.0_f32.ln(), T)
            .ptr(&mut theta)
            .edge(
                select_ty!(MetalConstant<T>)
                    .ptr(&mut inv_head_dim)
                    .edge(
                        select_ty!(MetalConstant<T>)
                            .ptr(&mut two)
                            .edge(
                                select_ty!(crate::other::MetalARange<T>)
                                    .ptr(&mut head_dim_arange)
                                    .edge(select_ty!(MetalMul<T>).ptr(&mut mul_2)),
                            )
                            .edge(select_ty!(MetalMul<T>).ptr(&mut head_dim_mul)),
                    )
                    .edge(select_ty!(MetalMul<T>).ptr(&mut theta_mul)),
            )
            .edge(select_ty!(MetalExp<T>).ptr(&mut exp))
            .edge(select_ty!(MetalRecip<T>).ptr(&mut recip));
let seq = select_ty!(MetalConstant<T>).ptr(&mut seq_expr).edge(
    select_ty!(crate::other::MetalARange<T>)
        .ptr(&mut seq_arange)
        .edge(select_ty!(MetalAdd<T>).ptr(&mut seq_add)),
);
let emb = freqs.edge(seq.edge(select_ty!(MetalMul<T>).ptr(&mut freq_seq_mul)));
let split = SelectOp::new()
    .ptr(&mut input)
    .edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig1));
let x0 = split
    .clone()
    .edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig2));
let x1 = split.edge(select_ty!(MetalContiguous<T>).ptr(&mut split_contig3));
let x0_sin = emb
    .clone()
    .edge(select_ty!(MetalSin<T>).ptr(&mut sin1))
    .edge(x0.clone().edge(select_ty!(MetalMul<T>).ptr(&mut out_mul1)));
let x0_cos = emb
    .clone()
    .edge(select_ty!(MetalCos<T>).ptr(&mut cos1))
    .edge(x0.edge(select_ty!(MetalMul<T>).ptr(&mut out_mul2)));
let x1_sin = emb
    .clone()
    .edge(select_ty!(MetalSin<T>).ptr(&mut sin2))
    .edge(x1.clone().edge(select_ty!(MetalMul<T>).ptr(&mut out_mul3)));
let x1_cos = emb
    .clone()
    .edge(select_ty!(MetalCos<T>).ptr(&mut cos2))
    .edge(x1.edge(select_ty!(MetalMul<T>).ptr(&mut out_mul4)));
let x0_out = x1_sin.edge(x0_cos.edge(select_ty!(MetalSub<T>).ptr(&mut out_sub)));
let x1_out = x0_sin.edge(x1_cos.edge(select_ty!(MetalAdd<T>).ptr(&mut out_add)));
let mut searcher = x1_out
    .edge(x0_out.edge(select_ty!(MetalAdd<T>).ptr(&mut final_add)))
    .search(graph);

To

let freqs = binary::<MetalMul<T>>(op::<MetalARange<T>>(), constant::<T>(2.0));
let freqs = binary::<MetalMul<T>>(freqs, op::<MetalConstant<T>>());
let freqs = binary::<MetalMul<T>>(freqs, constant::<T>((1000000_f32).abs().ln()));
let freqs = unary::<MetalRecip<T>>(unary::<MetalExp<T>>(freqs));
let prev_seq = op::<MetalConstant<T>>();
let emb = binary::<MetalMul<T>>(
    binary::<MetalAdd<T>>(op::<MetalARange<T>>(), prev_seq.clone()),
    freqs,
);
let inp = node();
let split = unary::<MetalContiguous<T>>(inp.clone());
let x0 = unary::<MetalContiguous<T>>(split.clone());
let x0_out = binary::<MetalSub<T>>(
    binary::<MetalMul<T>>(x0, unary::<MetalSin<T>>(emb.clone())),
    binary::<MetalMul<T>>(op::<MetalContiguous<T>>(), op::<MetalCos<T>>()),
);
let x1_out = binary::<MetalAdd<T>>(op::<MetalMul<T>>(), op::<MetalMul<T>>());
let add = binary::<MetalAdd<T>>(x0_out, x1_out);
let mut s = add.clone().search(graph);

Still some correctness bugs remaining. The selector graphs currently cannot reference the same node twice, because the backtracking function doesn't support it yet