Closed Mikolaj closed 1 year ago
So you want to simplify index (scatter sh v f) ix
to something that doesn't have index
at the head?
This by definition involves inverting f
(more precisely: computing preimages), which is not computable in general. :)
Some f
s you can invert, like the one for sum
: the preimage of ZI
is simply the whole of sh
. This is why sum
works. More generally, if you can invert f
, then you should be able to reduce index-of-scatter to something that has either a smaller indexing operation or the sum of a gather, I think. Not sure what happens if ix
is shorter than the codomain of f
.
More fundamentally, I think the problem you're running into here is push arrays vs pull arrays, or producers vs consumers. Everything expressed in terms of build
is a consumer, and reductions like fold
and scan
are also consumers. Importantly, indexing is also a consumer. Consumers fuse into consumers, hence indexing into a pull array yields another pull array.
However, scatter
is a producer, yielding a push array. Producers do not fuse into consumers, because that involves inverting data flow direction of the producer. Hence reducing index-of-scatter is hard.
Huh, that probably means I need to remove scatter
from the user-accessible part of the Tensor
class. It also means, my ability to simplify the code emitted by AD is limited, It gets stuck at scatter
, which appears where there was gather
in the (vectorized form of the) objective function. E.g., normally I can fuse any gather
contained inside any other gather
. Now I can't if there's a scatter
on the way.
Assuming we ban scatter
, I wonder if there are any other special cases (other than sum
) that are both vectorizable and useful.
Alternatively to banning scatter
, we could ban build
inside scatter
or just not vectorize some build
s inside some scatter
s.
Or perhaps, for producers, there is a way to vectorize inside index(producer)
without eliminating the indexing. With rewriting systems, if you are stuck reducing, you make the problematic form a normal form and see if this blocks all other reduction rules or if they can be extended to operate on the new normal form specially. We already have special normal forms for fromList
, both wrapped in index
and wrapped in gather
. I wonder if fromList
is a producer.
Really, going from push-style to pull-style (i.e. scattering and then indexing into the result) requires you, operationally, to materialise the intermediate array. You can't get around that, unless you can feasibly invert the index mapping function f and turn the scatter into a gather in the first place.
So since, as we discussed over email, we're going to need let
anyway, what about the following?
build n (\i -> index (scatter sh v f) ix)
~>
let a = scatter (n ::: sh) (build n (\l -> v)) (\(j1 ::: js) -> j1 ::: f js)
in gather n a (\i -> ix)
I hope I got that index arithmetic right. Code size increases somewhat, but with a bounded growth factor.
I managed to parse your proposal and I fixed some minor scope-checking problems (and simplified the list notation and renamed build
to build1
and gather
to gather1
)
build1 n (\i -> index (scatter sh v f) ix)
~>
let a = scatter (n : sh) (build1 n (\i -> v)) (\(i : js) -> i : f js)
in gather1 n a (\i -> ix)
Let me know if the fix is spurious, because that would mean I misunderstand our formalism.
Then I mechanically applied backwards the scatter
vectorization rule (see the overleaf)
build1 k (var, scatter sh t (vars, ix)) -->
scatter (k : sh) (build1 k (var, t)) (var : vars, var : ix)
and I arrived at
build1 n (\i -> index (scatter sh v f) ix) -->
let a = build1 n (\i -> scatter sh v f)
in gather1 n a (\i -> ix)
which can be distilled to rule (D0)
build1 n (\i -> index t ix) -->
gather1 n (build1 n (\i -> t)) (\i -> ix)
This looks valid and reminds me of the gross inefficiency of vectorizing fromList
, so let me try to distil a similar rule from the two vectorization rules for fromList
, see the Overleaf. Apologies for mixing in the pair notation instead of sticking to the visible lambda notation. Let's start with the rule containing indexing
build1 k (var, index (fromList [t1 .. tn]) [i]) -->
gather (k : tshape t1)
(fromList [build1 k (var, t1) .. build1 k (var, tn)])
([var], [i, var])
and apply backwards the other rule
build1 k (var, fromList [t1 .. tn] -->
tr (fromList [build1 k (var, t1) .. build1 k (var, tn)])
where tr == transpose {0<-1, 1<-0} and is its own inverse, obtaining
build1 k (var, index (fromList [t1 .. tn]) [i]) -->
gather (k : tshape t1)
(tr (build1 k (var, fromList [t1 .. tn]))
([var], [i, var])
and distil again
build1 k (var, index t [i]) -->
gather1 k (tr (build1 k (var, t)) (var, [i, var])
and present in the visible lambda notation, with some renaming
build1 n (\i -> index t [i2]) -->
gather1 n (tr (build1 n (\i -> t)) (\i -> [i2, i])
So close to rule (D0)
build1 n (\i -> index t ix) -->
gather1 n (build1 n (\i -> t)) (\i -> ix)
and so far at the same time. But there is a simple rule for fusion of a long gather
with a short transpose
, not needed for the vectorization rewriting system, so only present in the code
AstGather ...
AstTranspose perm v | valueOf @p' >= length perm ->
AstGather sh4 v (vars4, permutePrefixIndex perm ix4)
and after applying it, we obtain
build1 n (\i -> index t [i2]) -->
gather1 n (build1 n (\i -> t)) (\i -> [i, i2])
which is almost the same and must indicate a bug, and one that a shape-checker (and even a rank-checker) would catch. Let me venture that rule (D0) should be amended to rule (D)
build1 n (\i -> index t ix) -->
gather1 n (build1 n (\i -> t)) (\i -> i : ix)
and this fix should be propagated upwards to the other rules. If that's correct, your solution generalizes and we've got ourselves a new vectorization rule (D) (instead of a suitable simplification rule, as I originally hoped, which probably doesn't exist). That's a rule of last resort, so implementation should apply it only when stuck. However, perhaps it eliminates the need for any simplification rules in our presentation, if we are content with presenting an inefficient vectorization procedure and only mention that a different procedure with simplification rules fares much better and falls back to (D) only in some cases of scatter
and fromList
.
Interesting! Yes, all you say looks correct, including your fix to (D0) to produce (D). Indeed, I did not get the indices completely right :)
It would be interesting if this eliminates the need for simplification rules. And if it does, I'd be interested to see whether there is a decent set of simplification rules that can be applied after vectorisation (with eager rule (D)) that in the end has the same effect as performing simplification during vectorisation (with reluctant rule (D)). If such a set exists, presenting the algorithm with that set almost certainly results in a nicer presentation.
Let's try eager (D) and simplifying afterwards (apologies for the temporary pairs instead of lambdas). To vectorize and simplify fromList
we need (D), the simple vectorization rule
build1 k (var, fromList [t1 .. tn] -->
tr (fromList [build1 k (var, t1) .. build1 k (var, tn)])
and additionally the gather of transpose rule from above
AstGather ...
AstTranspose perm v | valueOf @p' >= length perm ->
AstGather sh4 v (vars4, permutePrefixIndex perm ix4)
This lets us mimic the existing "hard part" vectorization rule
build1 k (var, index (fromList [t1 .. tn]) [i]) -->
gather (k : tshape t1)
(fromList [build1 k (var, t1) .. build1 k (var, tn)])
([var], [i, var])
which seems to offer the least expensive result with build1
pushed down. I think, the general case with ix
in place of [i]
would work the same and can be subsequently simplified by an existing gather
simplification rule that pushes some indexes of gather inside fromList
, where they become new gather
terms. Except for the first index, which needs to stay in the outermost gather
.
For scatter
, let's start with
build1 n (\i -> index (scatter sh v f) ix)
With (D) we are getting
gather1 n (build1 n (\i -> (scatter sh v f))) (\i -> i : ix)
with the simple vectorization rule
build1 k (var, scatter sh t (vars, ix)) -->
scatter (k : sh) (build1 k (var, t)) (var : vars, var : ix)
we obtain
gather1 n (scatter (n : sh) (build1 n (\i -> v)) (\(i : is) -> i : f is)) (\i -> i : ix)
which is the same as the fixed version of the result of your original scatter vectorization rule
let a = scatter (n : sh) (build1 n (\i -> v)) (\(i : js) -> i : f js)
in gather1 n a (\i -> i : ix)
and no extra rules are required to simplify anything further.
Let's now try the first of the "hard part" vectorization rules for the remaining normal forms of simplification
build1 k (var, index x ix) -->
gather (k : tshape (index x ix)) x ([var], ix)
The result of (D) applied to the left hand side is (mixing up variable names)
gather1 n (build1 n (\i -> x)) (\i -> i : ix)
and simple vectorization turns this into
gather1 n (konst n x) (\i -> i : ix)
To proceed we require the simplification rule for gather of konst
AstKonst _k v -> astGather sh4 v (vars4, rest4)
which turns this into
gather1 n x (\i -> ix)
which is indeed equal to
gather (k : tshape (index x ix)) x ([var], ix)
from above.
Let's try one more case, where we need to simplify index(gather) in order to apply the simple vectorization rule for gather
build1 n (i -> gather sh v (vars, ix)) -->
gather (n : sh) (build1 k (i -> t) ((i : vars) -> i : ix)
Let's assume we start with
build1 n (\i -> index (gather sh v (\(var : vars) -> ix)) [m])
which via indexing simplification rules easily reduces to the left hand side (potentially with some substitutions applied, but never mind that), which then results in the right hand side. Rule (D), OTOH, gives
gather1 n (build1 n (\i -> gather sh v (\(var : vars) -> ix))) (\i -> i : [m])
and the simple vectorization rule for gather reduces the inner build1
, giving
gather1 n (gather (n : sh) (build1 n (\i -> t) (\(i : var : vars) -> i : ix))) (\i -> i : [m])
and it surely all ends well, but it's obvious by this point that to resolve this we need fusion of gathers.
Concluding, it seems we have two alternative sets of rules that complement the simple vectorization rules
I'm not 100% sure these sets are disjoint (e.g., in the implementation, some gather simplification rules do "call" indexing simplification rules, but I don't know if these are just shortcuts or if this affects the final results and, if the latter, whether the extra simplification is required to expose redexes for further vectorization or only optional). In general, gather simplification rules are several times more complex than indexing simplification rules.
There's also the third variant you mentioned, "reluctant rule (D)", which consists of simple vectorization, indexing simplification and rule D restricted to the normal forms of simplification (either by instantiating the rule to these forms or by side conditions to the rule). It's most probably the shortest rule set of the three. I will try to implement it (with the usual tweak that I simplify reluctantly, as well) and then, if tests are positive, we could base the first attempt at the Overleaf presentation on that.
Thoughts?
And indeed code got shorter, scatter
is now vectorized and the limited testing shows no problems with the "reluctant rule (D)" method.
https://github.com/Mikolaj/horde-ad/commit/5e5c763d1a309ef2e35a38d533069eb84705d34d
Unfortunately, I can't get even a portion of indexing inside scatter
, unlike with fromList
, so the new normal form is the fully general index (scatter...) ix
, except for ix == ZI
. I also don't fuse scatter
with anything yet, but there's no rush with fusion, it's not needed for vectorization.
Another success.
I've written it all down in the Overleaf the best I could. I haven't touched fromInt
until we decide in https://github.com/Mikolaj/horde-ad/issues/91 whether it should be dropped or not.
The easier half is already done. The harder half is simplifying indexing of scatter to expose a constructor that vectorization can work on (it can't work on an indexing term).
If it's helpful,
sum v = scatter sh v (\(_i :. ZI) -> ZI)
and we can simplify indexing of sum just fine. However, the problem is in the places where the indexes are empty forsum
.