Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
33 stars 6 forks source link

Add `tupdate` to `Tensor` class and start simplifying `tscatter` #101

Open Mikolaj opened 1 year ago

Mikolaj commented 1 year ago

It should be such that tupdate (tzero sh) ix v is the transpose of tindex v ix. Also

https://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/simplified/HordeAd/Core/AstSimplify.hs#L433

Probably tscatter can then be simplified using tupdate similarly as tgather simplifies using tindex right now. I'm not sure how much of the current complex tgather simplification code would dualize, but at least the trivial cases should do and they offer great benefits whenever they apply.

I suppose, we'd also need an Ast term for the operation, vectorization rules and forward pass and transpose rules. A similar operation is already implemented at the low level, because it's needed too implement scatter:

https://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/src/common/HordeAd/Internal/TensorOps.hs#L101-L117

This needs to be generalized to non-singleton indexes but, OTOH, it can be specialized to just one update, at least initially.

Overall, this ticket is a big chunk of work, but quite modular. A couple of parts, but probably intertwined with others, are crucial for performance of the simplified horde-ad.

tomsmeding commented 1 year ago

What would be the type of this new tupdate?

Mikolaj commented 1 year ago

I think, the simplest one that agrees with

https://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/simplified/HordeAd/Core/AstSimplify.hs#L433

which is

tupdate ::  TensorOf (p + n) r -> IndexOf p r -> TensorOf n r -> TensorOf (p + n) r

which checks out with the type of transpose of update (tzero sh 0) ix v, which is tindex v ix

tindex :: TensorOf (p + n) r -> IndexOf p r -> TensorOf n r
tomsmeding commented 1 year ago

Wouldn't then tupdate base idx item necessarily copy (almost) the entirety of base? This is basically the one-hot encoding for the transposition of indexing, slightly modified to compute base + onehot i instead of just onehot i. I struggle to see how this will ever be remotely efficient if you're doing more than 1 indexing operation on an array; surely you want to batch them up into a single scatter?

Mikolaj commented 1 year ago

The motivating example

let x11 = tscatter [1] (tfromList [tsum (x3 * x9)])
                       (\[i10] -> [0])
  in x11 ! [0]

has nothing interesting to batch in a single scatter. Similarly, a transpose of indexing has just one one-hot, not a collection of them. I guess, a general rule for indexing of tupdate would permit us to perform the indexing from the motivating example early and not materialize any of the large tensors. In other cases, we can interpret/compile sequential tupdates jointly. We can think of the associative accumulators.

Even if we end up batching many things up in a single scatter, we have to represent them somehow while they are sprinkled in many places of the generated code. I'm guessing trivial cases of scatter may not be the best way. Then we can transform the code to get these things together and then, eventually, batch them up.

tomsmeding commented 1 year ago

Even if we end up batching many things up in a single scatter, we have to represent them somehow while they are sprinkled in many places of the generated code. I'm guessing trivial cases of scatter may not be the best way. Then we can transform the code to get these things together and then, eventually, batch them up.

Ah, I see, you want an easier-to-recognise representation for trivial scatters. Because I feel that your given trivial scatter won't really be much slower than the corresponding tupdate, simply because all the overhead is in the copying of the base tensor. But if your point with tupdate is not performance but recognisability and hence easier recombination later in an efficient single scatter, then yes that makes sense.

Though I wonder if it's necessary. Maybe we can find a way to combine (vectorise, essentially) more general forms of tscatter in a way that is not too hard to implement and subsumes the cases where tupdate would be useful.

But that depends on how they appear in the code to simplify, which in turn depends on how the indexing operations appear in the original program. If they appear easily batchable there already, then the problem doesn't even arise because the things are immediately vectorised to a gather anyway. Do you happen to have a motivating example here?

Mikolaj commented 1 year ago

Ah, I see, you want an easier-to-recognise representation for trivial scatters.

Yes, that's the main point.

Because I feel that your given trivial scatter won't really be much slower than the corresponding tupdate, simply because all the overhead is in the copying of the base tensor.

Sure, but if I have a rule

tupdate u v ix ! ix --> v

then this is faster than leaving the scatter be, materializing it and then projecting. But, again, the rule can be just as well written for scatter, not tupdate, so it's mostly about presentation.

Though I wonder if it's necessary. Maybe we can find a way to combine (vectorise, essentially) more general forms of tscatter in a way that is not too hard to implement and subsumes the cases where tupdate would be useful.

That would be great.

But that depends on how they appear in the code to simplify, which in turn depends on how the indexing operations appear in the original program. If they appear easily batchable there already, then the problem doesn't even arise because the things are immediately vectorised to a gather anyway. Do you happen to have a motivating example here?

Not really. But once we construct tscatter in whatever smart way, we'd want to fuse tscatter and simplify it in other ways. What I have are, somewhat tangentially, the corresponding rules for tgather, e.g.,

https://github.com/Mikolaj/horde-ad/blob/35ea9186aa2e4ac90b9e0a0f855da67dc709ab0c/simplified/HordeAd/Core/AstSimplify.hs#L626-L645

that simplify tgather a lot and use indexing (astIndex). I can't write such rules for tscatter, because I don't have tupdate (and using tgather instead of tindex and tscatter instead of tupdate would quickly lead to insanity).

Mikolaj commented 1 year ago

This is killing my CI, so I will have to at least add the update term so that it takes less memory than the special case of scatter. Then I'd either start simplifying indexing of update or fuse many update into one. That's still very ad-hoc and much easier than general dualising the simplification and fusion of gather, if it's possible at all.

Mikolaj commented 1 year ago

Eventually I simplified the scatters that are the transpose of indexing and I also started simplifying some special forms of scatters. This helped with tests speed, but not nearly enough. All without introducing tupdate yet, which would probably just be tupdate (c, ix) = AstScatter sh c (Z, ix) (which seems to be precisely dual to indexing both when transposing and when comparing scatter and gather simplification rules) or tupdate t (c, ix) = t + AstScatter sh c (Z, ix) (which may or may not fuse better in some cases). Other variants seem to have problems when getting vectorized.

All in all, scatter can certainly be fused with other scatters and can be simplified a bit more, but I'm no longer certain we can just reverse arrows in the gather simplification code. Reversing arrows seems tricky.