Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
33 stars 6 forks source link

Vectorize AstScatter #89

Closed Mikolaj closed 1 year ago

Mikolaj commented 1 year ago

The easier half is already done. The harder half is simplifying indexing of scatter to expose a constructor that vectorization can work on (it can't work on an indexing term).

If it's helpful, sum v = scatter sh v (\(_i :. ZI) -> ZI) and we can simplify indexing of sum just fine. However, the problem is in the places where the indexes are empty for sum.

tomsmeding commented 1 year ago

So you want to simplify index (scatter sh v f) ix to something that doesn't have index at the head?

This by definition involves inverting f (more precisely: computing preimages), which is not computable in general. :)

Some fs you can invert, like the one for sum: the preimage of ZI is simply the whole of sh. This is why sum works. More generally, if you can invert f, then you should be able to reduce index-of-scatter to something that has either a smaller indexing operation or the sum of a gather, I think. Not sure what happens if ix is shorter than the codomain of f.

tomsmeding commented 1 year ago

More fundamentally, I think the problem you're running into here is push arrays vs pull arrays, or producers vs consumers. Everything expressed in terms of build is a consumer, and reductions like fold and scan are also consumers. Importantly, indexing is also a consumer. Consumers fuse into consumers, hence indexing into a pull array yields another pull array.

However, scatter is a producer, yielding a push array. Producers do not fuse into consumers, because that involves inverting data flow direction of the producer. Hence reducing index-of-scatter is hard.

Mikolaj commented 1 year ago

Huh, that probably means I need to remove scatter from the user-accessible part of the Tensor class. It also means, my ability to simplify the code emitted by AD is limited, It gets stuck at scatter, which appears where there was gather in the (vectorized form of the) objective function. E.g., normally I can fuse any gather contained inside any other gather. Now I can't if there's a scatter on the way.

Assuming we ban scatter, I wonder if there are any other special cases (other than sum) that are both vectorizable and useful.

Alternatively to banning scatter, we could ban build inside scatter or just not vectorize some builds inside some scatters.

Or perhaps, for producers, there is a way to vectorize inside index(producer) without eliminating the indexing. With rewriting systems, if you are stuck reducing, you make the problematic form a normal form and see if this blocks all other reduction rules or if they can be extended to operate on the new normal form specially. We already have special normal forms for fromList, both wrapped in index and wrapped in gather. I wonder if fromList is a producer.

tomsmeding commented 1 year ago

Really, going from push-style to pull-style (i.e. scattering and then indexing into the result) requires you, operationally, to materialise the intermediate array. You can't get around that, unless you can feasibly invert the index mapping function f and turn the scatter into a gather in the first place.

So since, as we discussed over email, we're going to need let anyway, what about the following?

build n (\i -> index (scatter sh v f) ix)
~>
let a = scatter (n ::: sh) (build n (\l -> v)) (\(j1 ::: js) -> j1 ::: f js)
in gather n a (\i -> ix)

I hope I got that index arithmetic right. Code size increases somewhat, but with a bounded growth factor.

Mikolaj commented 1 year ago

I managed to parse your proposal and I fixed some minor scope-checking problems (and simplified the list notation and renamed build to build1 and gather to gather1)

build1 n (\i -> index (scatter sh v f) ix)
~>
let a = scatter (n : sh) (build1 n (\i -> v)) (\(i : js) -> i : f js)
in gather1 n a (\i -> ix)

Let me know if the fix is spurious, because that would mean I misunderstand our formalism.

Then I mechanically applied backwards the scatter vectorization rule (see the overleaf)

build1 k (var, scatter sh t (vars, ix)) -->
  scatter (k : sh) (build1 k (var, t)) (var : vars, var : ix)

and I arrived at

build1 n (\i -> index (scatter sh v f) ix) -->
  let a = build1 n (\i -> scatter sh v f)
  in gather1 n a (\i -> ix)

which can be distilled to rule (D0)

build1 n (\i -> index t ix) -->
  gather1 n (build1 n (\i -> t)) (\i -> ix)

This looks valid and reminds me of the gross inefficiency of vectorizing fromList, so let me try to distil a similar rule from the two vectorization rules for fromList, see the Overleaf. Apologies for mixing in the pair notation instead of sticking to the visible lambda notation. Let's start with the rule containing indexing

build1 k (var, index (fromList [t1 .. tn]) [i]) -->
  gather (k : tshape t1)
         (fromList [build1 k (var, t1) .. build1 k (var, tn)])
         ([var], [i, var])

and apply backwards the other rule

build1 k (var, fromList [t1 .. tn] -->
  tr (fromList [build1 k (var, t1) .. build1 k (var, tn)])

where tr == transpose {0<-1, 1<-0} and is its own inverse, obtaining

build1 k (var, index (fromList [t1 .. tn]) [i]) -->
  gather (k : tshape t1)
         (tr (build1 k (var, fromList [t1 .. tn]))
         ([var], [i, var])

and distil again

build1 k (var, index t [i]) -->
  gather1 k (tr (build1 k (var, t)) (var, [i, var])

and present in the visible lambda notation, with some renaming

build1 n (\i -> index t [i2]) -->
  gather1 n (tr (build1 n (\i -> t)) (\i -> [i2, i])

So close to rule (D0)

build1 n (\i -> index t ix) -->
  gather1 n (build1 n (\i -> t)) (\i -> ix)

and so far at the same time. But there is a simple rule for fusion of a long gather with a short transpose, not needed for the vectorization rewriting system, so only present in the code

AstGather ...
  AstTranspose perm v | valueOf @p' >= length perm ->
    AstGather sh4 v (vars4, permutePrefixIndex perm ix4)

and after applying it, we obtain

build1 n (\i -> index t [i2]) -->
  gather1 n (build1 n (\i -> t)) (\i -> [i, i2])

which is almost the same and must indicate a bug, and one that a shape-checker (and even a rank-checker) would catch. Let me venture that rule (D0) should be amended to rule (D)

build1 n (\i -> index t ix) -->
  gather1 n (build1 n (\i -> t)) (\i -> i : ix)

and this fix should be propagated upwards to the other rules. If that's correct, your solution generalizes and we've got ourselves a new vectorization rule (D) (instead of a suitable simplification rule, as I originally hoped, which probably doesn't exist). That's a rule of last resort, so implementation should apply it only when stuck. However, perhaps it eliminates the need for any simplification rules in our presentation, if we are content with presenting an inefficient vectorization procedure and only mention that a different procedure with simplification rules fares much better and falls back to (D) only in some cases of scatter and fromList.

tomsmeding commented 1 year ago

Interesting! Yes, all you say looks correct, including your fix to (D0) to produce (D). Indeed, I did not get the indices completely right :)

It would be interesting if this eliminates the need for simplification rules. And if it does, I'd be interested to see whether there is a decent set of simplification rules that can be applied after vectorisation (with eager rule (D)) that in the end has the same effect as performing simplification during vectorisation (with reluctant rule (D)). If such a set exists, presenting the algorithm with that set almost certainly results in a nicer presentation.

Mikolaj commented 1 year ago

Let's try eager (D) and simplifying afterwards (apologies for the temporary pairs instead of lambdas). To vectorize and simplify fromList we need (D), the simple vectorization rule

build1 k (var, fromList [t1 .. tn] -->
  tr (fromList [build1 k (var, t1) .. build1 k (var, tn)])

and additionally the gather of transpose rule from above

AstGather ...
  AstTranspose perm v | valueOf @p' >= length perm ->
    AstGather sh4 v (vars4, permutePrefixIndex perm ix4)

This lets us mimic the existing "hard part" vectorization rule

build1 k (var, index (fromList [t1 .. tn]) [i]) -->
  gather (k : tshape t1)
         (fromList [build1 k (var, t1) .. build1 k (var, tn)])
         ([var], [i, var])

which seems to offer the least expensive result with build1 pushed down. I think, the general case with ix in place of [i] would work the same and can be subsequently simplified by an existing gather simplification rule that pushes some indexes of gather inside fromList, where they become new gather terms. Except for the first index, which needs to stay in the outermost gather.

For scatter, let's start with

build1 n (\i -> index (scatter sh v f) ix)

With (D) we are getting

gather1 n (build1 n (\i -> (scatter sh v f))) (\i -> i : ix)

with the simple vectorization rule

build1 k (var, scatter sh t (vars, ix)) -->
  scatter (k : sh) (build1 k (var, t)) (var : vars, var : ix)

we obtain

gather1 n (scatter (n : sh) (build1 n (\i -> v)) (\(i : is) -> i : f is)) (\i -> i : ix)

which is the same as the fixed version of the result of your original scatter vectorization rule

    let a = scatter (n : sh) (build1 n (\i -> v)) (\(i : js) -> i : f js)
    in gather1 n a (\i -> i : ix)

and no extra rules are required to simplify anything further.

Let's now try the first of the "hard part" vectorization rules for the remaining normal forms of simplification

build1 k (var, index x ix) -->
  gather (k : tshape (index x ix)) x ([var], ix)

The result of (D) applied to the left hand side is (mixing up variable names)

gather1 n (build1 n (\i -> x)) (\i -> i : ix)

and simple vectorization turns this into

gather1 n (konst n x) (\i -> i : ix)

To proceed we require the simplification rule for gather of konst

AstKonst _k v -> astGather sh4 v (vars4, rest4)

which turns this into

gather1 n x (\i -> ix)

which is indeed equal to

gather (k : tshape (index x ix)) x ([var], ix)

from above.

Let's try one more case, where we need to simplify index(gather) in order to apply the simple vectorization rule for gather

build1 n (i -> gather sh v (vars, ix)) -->
  gather (n : sh) (build1 k (i -> t) ((i : vars) -> i : ix)

Let's assume we start with

build1 n (\i -> index (gather sh v (\(var : vars) -> ix)) [m])

which via indexing simplification rules easily reduces to the left hand side (potentially with some substitutions applied, but never mind that), which then results in the right hand side. Rule (D), OTOH, gives

gather1 n (build1 n (\i -> gather sh v (\(var : vars) -> ix))) (\i -> i : [m])

and the simple vectorization rule for gather reduces the inner build1, giving

gather1 n (gather (n : sh) (build1 n (\i -> t) (\(i : var : vars) -> i : ix))) (\i -> i : [m])

and it surely all ends well, but it's obvious by this point that to resolve this we need fusion of gathers.

Concluding, it seems we have two alternative sets of rules that complement the simple vectorization rules

I'm not 100% sure these sets are disjoint (e.g., in the implementation, some gather simplification rules do "call" indexing simplification rules, but I don't know if these are just shortcuts or if this affects the final results and, if the latter, whether the extra simplification is required to expose redexes for further vectorization or only optional). In general, gather simplification rules are several times more complex than indexing simplification rules.

There's also the third variant you mentioned, "reluctant rule (D)", which consists of simple vectorization, indexing simplification and rule D restricted to the normal forms of simplification (either by instantiating the rule to these forms or by side conditions to the rule). It's most probably the shortest rule set of the three. I will try to implement it (with the usual tweak that I simplify reluctantly, as well) and then, if tests are positive, we could base the first attempt at the Overleaf presentation on that.

Thoughts?

Mikolaj commented 1 year ago

And indeed code got shorter, scatter is now vectorized and the limited testing shows no problems with the "reluctant rule (D)" method.

https://github.com/Mikolaj/horde-ad/commit/5e5c763d1a309ef2e35a38d533069eb84705d34d

Unfortunately, I can't get even a portion of indexing inside scatter, unlike with fromList, so the new normal form is the fully general index (scatter...) ix, except for ix == ZI. I also don't fuse scatter with anything yet, but there's no rush with fusion, it's not needed for vectorization.

Another success.

Mikolaj commented 1 year ago

I've written it all down in the Overleaf the best I could. I haven't touched fromInt until we decide in https://github.com/Mikolaj/horde-ad/issues/91 whether it should be dropped or not.