Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
34 stars 6 forks source link

Decide: Desugard vs keep (generalized) dot operator through vectorization? #85

Closed Mikolaj closed 1 year ago

Mikolaj commented 1 year ago

The dot product operator (\u v -> sum (flatten (u * v))), when nested under n build operations, vectorizes to an operation that maps (which is expressible using transposition) dot product n ranks down inside a tensor. Is it worthwhile to define such generalized dot operation and keep it recognizable through vectorization? Or desugar dot into sum, flatten and multiplication and vectorize these, producing a lot of extra transpositions? After vectorization it may be much harder to recover the generalized dot operation shape. Potentially, it might involve non-local transformations and making up already eliminated operations (transpositions that fused/cancellled out) and/or even whole forgotten constant tensors (I'm not sure about that).

Is it worth keeping the dot recognizable? E.g., is the generalized tensor operation a common primitive, considerably faster than ad-hoc mapping (via transposition most probably) sum, flatten and/or multiplication over tensor ranks? Or, equally well, is the normal dot product so much cheaper and cyclic transposition (resulting in mapping in case of operations like sum and flatten) cheap enough that it's worthwhile to prevent other fusion to keep those intact?

Presumably a similar question would need to be answered for matrix multiplication, but I can't generalize it properly yet, even to many ranks, without yet taking (multiple) vectorization into account.

Mikolaj commented 1 year ago

For the current status, see dot0, matmul1 and matmul2 in Tensor class and the pretty-printed code of their vectorization and transpose in TestAdaptorSimplified and TestMnistFCNNR.

A C implementation of sum of the innermost (as opposed to outermost) dimension helps a lot for the case of dot0, but can't be easily generalized for matmul*. WIP.

tomsmeding commented 1 year ago

I see that your gradient for matmul2 looks like this:

\s0 dret x3 x4 ->
  let x9 = ttranspose [1,2,0] (tkonst 3 dret)
  in (tfromList []
     ,tsum (ttranspose [1,0] (tkonst 2 (ttranspose [1,0] x4) * x9))
     ,tsum (ttranspose [0,2,1] (ttranspose [1,0] (tkonst 4 x3) * x9)))

Using only inlining (of x9) and reasoning about transpose in combination with konst and elementwise operators (here *), I can manually reduce that to the following:

\s0 dret x3 x4 ->
  (tfromList []
  ,tsum (ttranspose [2,0,1] (tkonst 2 x4 * ttranspose [1,0,2] (tkonst 3 dret)))
  ,tsum (ttranspose [1,2,0] (tkonst 4 x3 * ttranspose [2,1,0] (tkonst 3 dret))))

This looks to me like two vectorised matrix multiplications. Is this the form you were looking for? Is this really much cheaper than the original?

derivation, can probably be simplified a lot ``` \s0 dret x3 x4 -> let x9 = ttranspose [1,2,0] (tkonst 3 dret) in (tfromList [] ,tsum (ttranspose [1,0] (tkonst 2 (ttranspose [1,0] x4) * x9)) ,tsum (ttranspose [0,2,1] (ttranspose [1,0] (tkonst 4 x3) * x9))) \s0 dret x3 x4 -> let x9 = ttranspose [1,2,0] (tkonst 3 dret) in (tfromList [] ,tsum (ttranspose [1,0] (tkonst 2 (ttranspose [1,0] x4) * x9)) ,tsum (ttranspose [0,2,1] (ttranspose [1,0] (tkonst 4 x3 * ttranspose [1,0] x9)))) \s0 dret x3 x4 -> let x9 = ttranspose [1,2,0] (tkonst 3 dret) in (tfromList [] ,tsum (ttranspose [1,0] (tkonst 2 (ttranspose [1,0] x4) * x9)) ,tsum (ttranspose [1,2,0] (tkonst 4 x3 * ttranspose [1,0] x9))) \s0 dret x3 x4 -> (tfromList [] ,tsum (ttranspose [1,0] (tkonst 2 (ttranspose [1,0] x4) * ttranspose [1,2,0] (tkonst 3 dret))) ,tsum (ttranspose [1,2,0] (tkonst 4 x3 * ttranspose [1,0] (ttranspose [1,2,0] (tkonst 3 dret))))) \s0 dret x3 x4 -> (tfromList [] ,tsum (ttranspose [1,0] (tkonst 2 (ttranspose [1,0] x4) * ttranspose [1,2,0] (tkonst 3 dret))) ,tsum (ttranspose [1,2,0] (tkonst 4 x3 * ttranspose [2,1,0] (tkonst 3 dret)))) \s0 dret x3 x4 -> (tfromList [] ,tsum (ttranspose [1,0] (ttranspose [0,2,1] (tkonst 2 x4) * ttranspose [1,2,0] (tkonst 3 dret))) ,tsum (ttranspose [1,2,0] (tkonst 4 x3 * ttranspose [2,1,0] (tkonst 3 dret)))) \s0 dret x3 x4 -> (tfromList [] ,tsum (ttranspose [1,0] (ttranspose [0,2,1] (tkonst 2 x4 * ttranspose [0,2,1] (ttranspose [1,2,0] (tkonst 3 dret))))) ,tsum (ttranspose [1,2,0] (tkonst 4 x3 * ttranspose [2,1,0] (tkonst 3 dret)))) \s0 dret x3 x4 -> (tfromList [] ,tsum (ttranspose [2,0,1] (tkonst 2 x4 * ttranspose [1,0,2] (tkonst 3 dret))) ,tsum (ttranspose [1,2,0] (tkonst 4 x3 * ttranspose [2,1,0] (tkonst 3 dret)))) ```
Mikolaj commented 1 year ago

Huh, what you accomplished demonstrates how week our simplification is currently, after sharing was added and it freezes x9 and after I stopped expanding transpose to gather (except when fusing with a gather and possibly in a couple other cases).

I totally missed that this is, more or less, a pair of two primal parts of the constructed dual number, as pretty-printed just below the gradient, and so, of pair of vectorized matmul2. This is great, because now we can try to interpret each of the these as tmatmul2R on OR.Array implemented using C blas matrix multiplication via hmatrix (or a corresponding GPU primitive). There are already interpretation hacks like that in interpretAst (until we have the two phases of simplification, the second of which would try to contract terms into known fast primitives). That's going to be a 100 times speedup even on CPU.

This actually matches the non-simplified horde-ad where the matmul2 primitive had gradient components that were matmuls with a transposed matrix: https://github.com/Mikolaj/horde-ad/blob/4e3eee462922dd997af88af3e9fb577b1ad37d28/src/HordeAd/Internal/Delta.hs#L800-L805

That was much saner than gradient components of matmul1 (multiplication of matrix by a vector), that involved outer products, etc. and that I also couldn't mentally connect to the new simplified horde-ad gradients: https://github.com/Mikolaj/horde-ad/blob/4e3eee462922dd997af88af3e9fb577b1ad37d28/src/HordeAd/Internal/Delta.hs#L738-L741

Technical tasks before the 100x optimization can happen:

An alternative: make the matmul2 a primitive and so add an Ast constructor for it. Making sure it does break vectorization is probably the biggest challenge.

Mikolaj commented 1 year ago

This is now done in

https://github.com/Mikolaj/horde-ad/commit/20ce84b024b8090f1d8970fc706eaaba1fd659f2

I was wrong about the 100x speedup. It's a 10x speedup and that's probably it.

This is very ad hoc, so when we support GPU via MLIR, we have to decide how to express the backend-specific simplification rewrites, as opposed to the general simplification that is always needed not to spend ages manipulating huge terms. Still, even though ad-hoc, this works well enough that adding new Ast constructs for dot, matvecmul and matmul, vectorizing them and ADing them is not necessary

@tomsmeding: thank you for the crucial insight. I would not recognize matmul2 in the strange terms resulting from AD. And it turns out there is only so many permutations of three elements, so we can always recover the matmul2 shape, regardless of how much the context mangles them (unless it genuinely simplifies them). It's quite possible the simplified-horde-ad approach really is not much less performant than the original horde-ad (even if it's currently still 10x to 100x slower).

Closing.