Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
34 stars 6 forks source link

Add a fusion trasformation partially inverse to vectorization #80

Open Mikolaj opened 1 year ago

Mikolaj commented 1 year ago

Once a program is fully vectorized and so no higher order operations remain except gather, we should run a fusion transformation. However, the fusion should never create new higher order operations (except gather), so it won't be able to fuse everything. It's interesting and not obvious how much can be fused given this restriction.

Edit: If we, additionally, performed AD on the Ast terms and not their Haskell interpretation, we would be able to fuse also the result of differentiation (and then interpret the Ast not necessarily in Haskell, but in MLIR for instance). However, this is requires investigations first and perhaps Ast would need to be extended (but vectorization not so, though fusion would).

Mikolaj commented 1 year ago

The non-build-introducing fusion is probably implemented completely now, except for fusion of scatters. However, it generates huge terms, mostly due to how reshape is expressed as gather with huge indexes and due to a lack of sharing (WIP) that then explodes the indexes when they substitute into each other.

The fusion that creates build operations may now be attempted, because we indeed produce Ast as the result of AD. However, this should be optional, because in some contexts fully vectorized code may be faster than builds that are then partially vectorized by OpenCL or other compilers and optimizers.

tomsmeding commented 1 year ago

However, it generates huge terms, mostly due to how reshape is expressed as gather with huge indexes and due to a lack of sharing (WIP) that then explodes the indexes when they substitute into each other.

Once sharing is fully solved, will this still generate huge terms?

Mikolaj commented 1 year ago

Once sharing is fully solved, will this still generate huge terms?

Not any more. However, the sharing needed here is of the fourth kind: sharing of integer terms. We already have sharing of delta terms and sharing of Ast terms that represent Delta terms. Two more are missing (sharing of Ast terms that represent operations performed during transpose and sharing of integer expressions) and of these AstIntLet is somewhat intrusive when trying to simplify the arithmetic modulo n.

If ranks 100 or 1000 were common, even the shared terms could be big, but ranks are usually 3--5, I guess.

tomsmeding commented 1 year ago

AstIntLet

Is there a reason why AstLet cannot be generic in the type of the thing that it's binding, and the type of its body expression? Then there would be no need for a separate integer let-binding.

Or is this in the separate AstInt data type? Can we make do with throwing all the integer operations into Ast? Or do we need a guarantee somewhere that some integer computation does not contain any scalar computation inside?

AstIntLet is somewhat intrusive when trying to simplify the arithmetic modulo n.

Well, yes, welcome to the world of optimisers, this is why optimisers don't always produce perfect code -- they have to stop reasoning at some point. :P If we have examples of integer terms that are hard to simplify but that we would like to, perhaps it's worth collecting some examples so that we can think more concretely about those cases, improving the situation as the need arises.

If ranks 100 or 1000 were common, even the shared terms could be big, but ranks are usually 3--5, I guess.

Yeah no I don't see anyone using a 100-dimensional array and assuming that to be efficient.

Mikolaj commented 1 year ago

Is there a reason why AstLet cannot be generic in the type of the thing that it's binding, and the type of its body expression? Then there would be no need for a separate integer let-binding.

Or is this in the separate AstInt data type? Can we make do with throwing all the integer operations into Ast?

Yes, it's in the separate AstInt data type. To unify with Ast, we'd need to do the conversion to Accelerate-style fixed set of scalars, as described in https://github.com/Mikolaj/horde-ad/issues/91#issuecomment-1494432248, make integers one of these and then perhaps it would work. The biggest remaining obstacle is that these three operations probably don't vectorize, though I haven't re-verified right now

https://github.com/Mikolaj/horde-ad/blob/fff6dd554be2b9bdf2ceef3d4accaab12f2743aa/simplified/HordeAd/Core/Ast.hs#L155-L157

A minor obstacle is that Ast operations mix up scalars and rank 0 tensors, which would add yet more syntactic noise to the code simplifying Peano modulo n arithmetic terms.

Or do we need a guarantee somewhere that some integer computation does not contain any scalar computation inside?

We don't so far and it does contain potentially big tensor expressions already.

If we have examples of integer terms that are hard to simplify but that we would like to, perhaps it's worth collecting some examples so that we can think more concretely about those cases, improving the situation as the need arises.

There are a few tests where these huge terms emerge, especially if one disables some ad-hoc rules that only work for reshape, but stop working as soon as it's expressed as a gather.

https://github.com/Mikolaj/horde-ad/blob/fff6dd554be2b9bdf2ceef3d4accaab12f2743aa/test/simplified/TestGatherSimplified.hs#L260-L267

I'm going to code a pretty-printer for terms and then it should become presentable.