Open Mikolaj opened 1 year ago
What would be the type of this new tupdate
?
I think, the simplest one that agrees with
which is
tupdate :: TensorOf (p + n) r -> IndexOf p r -> TensorOf n r -> TensorOf (p + n) r
which checks out with the type of transpose of update (tzero sh 0) ix v
, which is tindex v ix
tindex :: TensorOf (p + n) r -> IndexOf p r -> TensorOf n r
Wouldn't then tupdate base idx item
necessarily copy (almost) the entirety of base
? This is basically the one-hot encoding for the transposition of indexing, slightly modified to compute base + onehot i
instead of just onehot i
. I struggle to see how this will ever be remotely efficient if you're doing more than 1 indexing operation on an array; surely you want to batch them up into a single scatter?
The motivating example
let x11 = tscatter [1] (tfromList [tsum (x3 * x9)])
(\[i10] -> [0])
in x11 ! [0]
has nothing interesting to batch in a single scatter. Similarly, a transpose of indexing has just one one-hot, not a collection of them. I guess, a general rule for indexing of tupdate
would permit us to perform the indexing from the motivating example early and not materialize any of the large tensors. In other cases, we can interpret/compile sequential tupdate
s jointly. We can think of the associative accumulators.
Even if we end up batching many things up in a single scatter, we have to represent them somehow while they are sprinkled in many places of the generated code. I'm guessing trivial cases of scatter may not be the best way. Then we can transform the code to get these things together and then, eventually, batch them up.
Even if we end up batching many things up in a single scatter, we have to represent them somehow while they are sprinkled in many places of the generated code. I'm guessing trivial cases of scatter may not be the best way. Then we can transform the code to get these things together and then, eventually, batch them up.
Ah, I see, you want an easier-to-recognise representation for trivial scatters. Because I feel that your given trivial scatter won't really be much slower than the corresponding tupdate
, simply because all the overhead is in the copying of the base tensor. But if your point with tupdate
is not performance but recognisability and hence easier recombination later in an efficient single scatter, then yes that makes sense.
Though I wonder if it's necessary. Maybe we can find a way to combine (vectorise, essentially) more general forms of tscatter
in a way that is not too hard to implement and subsumes the cases where tupdate
would be useful.
But that depends on how they appear in the code to simplify, which in turn depends on how the indexing operations appear in the original program. If they appear easily batchable there already, then the problem doesn't even arise because the things are immediately vectorised to a gather
anyway. Do you happen to have a motivating example here?
Ah, I see, you want an easier-to-recognise representation for trivial scatters.
Yes, that's the main point.
Because I feel that your given trivial scatter won't really be much slower than the corresponding
tupdate
, simply because all the overhead is in the copying of the base tensor.
Sure, but if I have a rule
tupdate u v ix ! ix --> v
then this is faster than leaving the scatter
be, materializing it and then projecting. But, again, the rule can be just as well written for scatter
, not tupdate
, so it's mostly about presentation.
Though I wonder if it's necessary. Maybe we can find a way to combine (vectorise, essentially) more general forms of
tscatter
in a way that is not too hard to implement and subsumes the cases wheretupdate
would be useful.
That would be great.
But that depends on how they appear in the code to simplify, which in turn depends on how the indexing operations appear in the original program. If they appear easily batchable there already, then the problem doesn't even arise because the things are immediately vectorised to a
gather
anyway. Do you happen to have a motivating example here?
Not really. But once we construct tscatter
in whatever smart way, we'd want to fuse tscatter
and simplify it in other ways. What I have are, somewhat tangentially, the corresponding rules for tgather
, e.g.,
that simplify tgather
a lot and use indexing (astIndex
). I can't write such rules for tscatter
, because I don't have tupdate
(and using tgather
instead of tindex
and tscatter
instead of tupdate
would quickly lead to insanity).
This is killing my CI, so I will have to at least add the update
term so that it takes less memory than the special case of scatter
. Then I'd either start simplifying indexing of update
or fuse many update
into one. That's still very ad-hoc and much easier than general dualising the simplification and fusion of gather
, if it's possible at all.
Eventually I simplified the scatters that are the transpose of indexing and I also started simplifying some special forms of scatters. This helped with tests speed, but not nearly enough. All without introducing tupdate
yet, which would probably just be tupdate (c, ix) = AstScatter sh c (Z, ix)
(which seems to be precisely dual to indexing both when transposing and when comparing scatter and gather simplification rules) or tupdate t (c, ix) = t + AstScatter sh c (Z, ix)
(which may or may not fuse better in some cases). Other variants seem to have problems when getting vectorized.
All in all, scatter can certainly be fused with other scatters and can be simplified a bit more, but I'm no longer certain we can just reverse arrows in the gather simplification code. Reversing arrows seems tricky.
It should be such that
tupdate (tzero sh) ix v
is the transpose oftindex v ix
. Alsohttps://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/simplified/HordeAd/Core/AstSimplify.hs#L433
Probably
tscatter
can then be simplified usingtupdate
similarly astgather
simplifies usingtindex
right now. I'm not sure how much of the current complextgather
simplification code would dualize, but at least the trivial cases should do and they offer great benefits whenever they apply.I suppose, we'd also need an Ast term for the operation, vectorization rules and forward pass and transpose rules. A similar operation is already implemented at the low level, because it's needed too implement
scatter
:https://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/src/common/HordeAd/Internal/TensorOps.hs#L101-L117
This needs to be generalized to non-singleton indexes but, OTOH, it can be specialized to just one update, at least initially.
Overall, this ticket is a big chunk of work, but quite modular. A couple of parts, but probably intertwined with others, are crucial for performance of the simplified horde-ad.