Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
34 stars 6 forks source link

Try equality saturation for our simplifcation/vectorization rewriting system #103

Open Mikolaj opened 1 year ago

Mikolaj commented 1 year ago

This is the hammer and it may just succeed: https://github.com/alt-romes/hegg

Mikolaj commented 1 year ago

The type-set rewrite rules in our Overleaf document may be a good start, but they don't cover the whole equational theory.

alt-romes commented 1 year ago

Feel free to ask me any questions, I think the tutorial post might be a bit confusing at points. I would like to write again about hegg in some near future in perhaps a longer post, but for now do let me know if anything doesn't make sense.

Mikolaj commented 1 year ago

@alt-romes: thank you. Once a volunteer tackles that, we'd make sure to communicate.

Actually, I think we should start with the simplification pass that happens after the transpose pass (AD backprop reverse pass). A this point vectorization is no longer a concern, but fusion and other kinds of simplification are and all transformation are permitted as long as in the end we get tensor code that is faster to execute than what we started with.

tomsmeding commented 1 year ago

I believe we cannot use hegg directly for the same reason that we cannot use e.g. data-reify directly: our AST is type-indexed, so there is no single base functor.

alt-romes commented 1 year ago

Perhaps there's a generalization that could be made over the base functor, s.t. hegg accepts a type-indexed AST ? Sounds hard, I haven't seen the type-indexed AST in question (where is it?)

Mikolaj commented 1 year ago

Here is is: https://github.com/Mikolaj/horde-ad/blob/4e3eee462922dd997af88af3e9fb577b1ad37d28/simplified/HordeAd/Core/Ast.hs#L73

tomsmeding commented 1 year ago

@alt-romes here's a test case, if you can support this in hegg then we party:

data Expr a where
  Let :: Expr a -> (Expr a -> Expr b) -> Expr b
  LitF :: Double -> Expr Double
  LitI :: Int -> Expr Int
  ToF :: Expr Int -> Expr Double
  Floor :: Expr Double -> Expr Int
  Add :: Expr a -> Expr a -> Expr a

All the rest is I think reducible to something of this form:

alt-romes commented 1 year ago

I will try branching hegg to use indexed fixed points, with which I think we could represent the AST you've shown -- this seems the way to go: https://oleg.fi/gists/posts/2020-08-28-indexed-fixpoint.html.

Eventually I can even try using the more general representation by default and represent the non-indexed base functor through it, but I'll have to see if it is a good thing :)

alt-romes commented 1 year ago

What type of rewrite rules would we apply to these expressions? One or two examples would be good, to test in the implementation.

Mikolaj commented 1 year ago

Here are the rules

Horde_flow9.pdf

and here is their implementations (the first half in the first file, the second in the second)

https://github.com/Mikolaj/horde-ad/blob/master/simplified/HordeAd/Core/AstVectorize.hs

https://github.com/Mikolaj/horde-ad/blob/master/simplified/HordeAd/Core/AstSimplify.hs

Mikolaj commented 1 year ago

Look only at the ranked variant (types indexed with n, not with sh), because the shaped variant is only half-done (the second file has only stubs). Let me know if that's too cryptic.