jafioti / luminal

Deep learning at the speed of light.
https://luminalai.com
Apache License 2.0
1.41k stars 86 forks source link

Better Symbolic Algebra Library #47

Open jafioti opened 3 months ago

jafioti commented 3 months ago

Currently luminal uses a small symbolic algebra library I wrote to do expressions in src/shape/symbolic.rs.

The pro is that it's very simple and easy to reason about. The con is that it's very bad at simplifying complex expressions, so we get index expressions that look like this:

#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device float* input0 [[buffer(0)]], device float* input1 [[buffer(1)]], device float* input2 [[buffer(2)]], device float *out [[buffer(3)]], device uint& n_elements [[buffer(4)]], uint idx [[thread_position_in_grid]], device int& s [[buffer(5)]], device int& t [[buffer(6)]], device int& p [[buffer(7)]]) {
    if (idx < n_elements) {
        float intermediate0 = (((int)((((int)idx/(128*t))%s) < s) != 0) ? (float)input1[(((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))%2)+(((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/2)%64))+((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/128)%s)*64))+((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/(128*s))%32)*(64*s)))] + (float)input2[((((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))%2)-1)+(((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/2)%64))+((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/128)%s)*64))+((((((((int)idx%128)+((((int)idx/(128*t))%s)*128))+((((int)idx/((128*t)*s))%4)*(128*s)))+((((int)idx/(((128*t)*s)*4))%8)*((128*s)*4)))/(128*s))%32)*(64*s)))] : 0.0);
        out[idx] = (float)(intermediate0 * (float)input0[((((int)idx%128)+((((int)idx/128)%(p+max((int)s, (int)0)))*128))+((((int)idx/(((128*(p+max((int)s, (int)0)))*s)*4))%8)*(128*(p+max((int)s, (int)0)))))]);
    }
}

We need to balance power and simplicity. I don't want to write a huge symbolic algebra library in the core, but it needs to be more capable than the current system. Ideally we can 80/20 this and get 80% of the simplifications with 20% of the code. Doesn't need to be perfect.

So the options are:

I think the last one is the correct approach. Symbolic libraries can get very complex, but all that complexity can be effectively bottlenecked by the Expression type. And they can be unit tested quite well. Having a seperate crate will allow for more complexity headroom in the simplification logic, while keeping the core of luminal clean, and allowing us to write extensive unit tests to ensure the library works.

Why not use another crate? I haven't found a crate that has all the needed ops (max, less than, mod, etc) and has very few other ops (which is needed to keep functions like expr_to_metal_string simple.

jafioti commented 3 months ago

Relavent links:

jafioti commented 3 months ago

I've fed the index expression above into sympy (likely one of the most mature symbolic algebra libraries out there) and it wasn't really able to minimize it: image Which tells me we need to better construct these equations, rather than naively build them and hope the symalg library can reduce them

jafioti commented 2 months ago

Dimension combination has been added, which slightly improves the generated equations. Adding ranges to the terms will allow for things like i % 15 where i has a min of 0 and max of 15 to be reduced to just i.

NewBornRustacean commented 2 months ago

Hello @jafioti ! I'm looking around this issue. Is there any progress?

The way I see, you're planning to build a new, distilled symbolic computation crate(am I get it well?). If so, would the desired output be a reduced(simpler, more calculated) form of a result of input equation?

jafioti commented 2 months ago

Yeah that's correct. The goal is to reduce the expressions to a minimum mathematically equivalent form so it's most efficient to compute many times over.

I've been working on this the past few days, hope to push soon

jafioti commented 2 months ago

I'm using cas-rs for doing simplification, the remaining issue is that it doesn't support Mod, which we need. Looking at implementing that now

jafioti commented 2 months ago

This is being worked on in the cas branch

genos commented 2 months ago

I've had good luck with egg before, especially if your needs are a little more specialized/bespoke than, say, cas-rs.

jafioti commented 2 months ago

@genos Does egg do symbolic algebra reductions? I didn't see much example code / documentation

genos commented 2 months ago

@jafioti You write the simplification rules yourself, like in the docs.rs example. The rest of those docs were helpful to me in a previous project, though they required some digging. The website is a good resource as well.

YichengDWu commented 2 months ago

What is this complicated indexing coming from?

jafioti commented 2 months ago

The shape tracker can do zero cost movements like transpose or slicing, and then it generates an indexing expression to convert logical indexes to physical indexes.

YichengDWu commented 2 months ago

You can mostly get rid of them with nested views. See what I did here.https://github.com/tinygrad/tinygrad/pull/3988

jafioti commented 2 months ago

@genos egg is amazing! I've been looking at it yesterday and today, and switched from cas-rs over to it since it's much more flexible and can easily reach the same level of reductions cas-rs did but in a more robust way. Thanks for the suggestion!

There's still a few bugs with it (conv2d still fails for some reason) and compile times are longer since it's not that fast, but it's definitely the right approach in this case.

jafioti commented 2 months ago

You can mostly get rid of them with nested views. See what I did here.https://github.com/tinygrad/tinygrad/pull/3988

@YichengDWu This looks very interesting, did you base it off a paper or something? I'd like to read more

YichengDWu commented 2 months ago

Cutlass/CuTe

genos commented 2 months ago

compile times are longer since it's not that fast

@jafioti two things come to mind concerning optimizing egg usage from previous experience:

  1. If you don't mind the extra dep (though it's transitively required by egg), building the SExp directly rather than using .parse() can offer a speed up.
  2. Carefully monitoring the rules you create and their usage, and trimming down as much as possible may help. For instance, full commutativity and associativity can blow up the search space a lot, though I admit it's hard to live without them.
jafioti commented 2 months ago

@genos Is there any examples for constructing an SExp and producing a RecExpr? Is there a way to directly build a RexExpr?

genos commented 2 months ago

@jafioti not that I recall; I went through the code (linked from the docs.rs for egg) by hand and found what .parse does.

jafioti commented 1 month ago

Egg is great, I think we'll stick with it for the expressions. Seems much better than other cas systems. Only thing remaining before I close this issue is more efficient conversion from luminal::Expression to and from egg::Expression.

genos commented 1 month ago

Glad my suggestion was helpful! If you want more speed, I still recommend building SExprs by hand. If I get some free time before you manage it, I’ll see if I can be of use.

jafioti commented 1 month ago

@genos I've added the code to build SExprs, .parse is no longer used. The next step is to build egg RecExpr directly, to avoid going through SExpr altogether. The other thing that needs to be done is work out a more efficient way to go from RecExpr -> luminal::Expression, because right now that function creates a ton of vectors which are directly consumed.

genos commented 1 month ago

@jafioti it turns out my memory was a bit hazy; we used egg at first to get as much simplification as we could, but eventually chucked it and instead simplified everything by hand because it was faster. We may not have known what simplifications were needed, however, without first turning to egg.

genos commented 1 month ago

@jafioti it occurs to me that with the way expressions are represented in luminal (with vectors), you've got an efficient way to compare expressions by size, looking at their .len(). As such, simplifying by hand (though perhaps tedious to write and test) will be quite speedy; a recursive function simplify(e: Expression, fuel: usize) -> Expression similar to what we wrote in the above PR, with fuel ticking down towards zero, and greedily applying whichever first entry a list of possible simplifications shrinks the size of your expression should do the trick.

jafioti commented 1 month ago

@genos Yeah that's somewhat similar to what we had, but I really liked the idea of an e-graph based solution with simple rewrite rules that can compose together to do complex rewrites. It isn't obvious to me if you greedily choose the next rewrite and wind up at nearly as good of simplifications as e-graphs yield

Or do you mean like recursively search down a tree of each rewrite and the shortest equation bubbles up? Also another issue we have is that it's harder to do rewrites in reverse polish notation (though definitely possible)

genos commented 1 month ago

Agreed, e-graphs seem much more likely to find good simplifications than simplistic greedy search. And I thought I was the only one having trouble working with the RPN setup! 😆

jafioti commented 1 month ago

I ended up using RPN just cause it was super simple to store, and I want ShapeTracker to be stored on the stack so Expression had to be stored on the stack (no recursive types). I'm thinking of switching it to be more like the postfix egg uses, where you can have terms in rpn that reference other terms so that you can use common subexpressions, which should greatly speed up the translation from egg to luminal

asukaminato0721 commented 3 weeks ago

in egg, it's possible to define the cost function. By default it's AstSize, but can be any other function.

For example, if we want eliminate some op, set their cost to a huge number.

pub struct MathCostFn;
impl egg::CostFunction<Math> for MathCostFn {
    type Cost = usize;
    fn cost<C>(&mut self, enode: &Math, mut costs: C) -> Self::Cost
    where
        C: FnMut(Id) -> Self::Cost,
    {
        let op_cost = match enode {
            Math::Diff(..) => 1000,
            Math::Integral(..) => 1000,
            _ => 1,
        };
        enode.fold(op_cost, |sum, i| sum + costs(i))
    }
}

taken from egg's test.