Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
34 stars 6 forks source link

The 2 remaining vectorization cases #83

Closed Mikolaj closed 1 year ago

Mikolaj commented 1 year ago

We don't have enough tests with build/map/zipWith, so this may be premature optimism, but it seems only 2 vectorization cases remain unimplemented and, although they seem hard and would require lots of work, both seem doable. More precisely, it seems that all terms that shaped tensors would type-check, our code is able to 100% successfully vectorize, if only the 2 remaining cases are solved. I'd need a design review to either show gaps in the plan to tackle the two or, ideally, propose a simpler way.

Both cases are in function build1VectorizeIndexVar which vectorizes term AstBuildPair k (var, AstIndexN v is) and variable var is known to occur in v. The function pushes the indexing down term v until in recursive calls it no longer contains var, at which point the term is rewritten to AstGatherPair (var, is) v k and we are done. Sometimes the term simplifies to a form where var does not occur neither in v nor in is, which makes it possible to simplify it even further and not store even the simple non-diferentiated integer function (the one encoded in (var, is) on the tape later on. [Edit: this is slightly contrived. In the special case of vectorizing AstGatherPair that resulted from a previous vectorization pass, an even more complex term arises, AstGatherPairN, with an integer function of many variables, which thankfully vectorizes to itself,]

  1. Transpose. I think I should be able to extend build1VectorizeIndexVar to also do transposition in addition to indexing. With some luck, one transposition should be enough per call. However, the function is already complex, so any simpler is idea would be invaluable.

https://github.com/Mikolaj/horde-ad/blob/612e0c6a06f7981c26ca96693ff905a72f4d9228/simplified/HordeAd/Core/TensorClass.hs#L636-L659

  1. Reshape [edit: this is now done according to Tom's suggestion]. This is hard and I abhor the thought of touching it. One abortive idea is to let the AstSlice term have integer variable expression in its offset parameter, the same expressions as indexing contains. This keeps shapes static, because only offset is made variable, but the size of the slice remains fixed. However, vectorization of such slice, in turn, gives rise to an operation that is similar to gather, but probably incomparable even for the most general versions of both that we may need.

https://github.com/Mikolaj/horde-ad/blob/612e0c6a06f7981c26ca96693ff905a72f4d9228/simplified/HordeAd/Core/TensorClass.hs#L665-L686

I hope there is a simpler way, but I can't find it. There is a simpler way to vectorize AstFlatten, but AstReshape not only arises when vectorizing AstFlatten, but is a very useful operation for the user (and has a dirt-cheap gradient calculation.

Mikolaj commented 1 year ago

Tom recommends generalising projections to [Maybe (IntOf r)] for Transpose (this is harder than what I imagined) and using Gather instead of variable Slice for Reshape (this is unexpectedly easy, if inefficient, but we can optimise later on; for now, completeness is what we traded dynamic shapes for, so let’s reap it). Let’s try!

Mikolaj commented 1 year ago

AstGather in the role of AstSlice works great, so the first case, AstReshape is solved.

For the AstTranspose case I tried a solution touching less code than Tom's, namely keeping a normal form AstBuild1 k (var, AstIndexZ (AstTranspose perm v) ix) and it works fine for AstTranspose and elsewhere, but it gets very complex when pushing inside AstFromList (the permutation needs to be factored into a single swap and a permutation that does not touch the first element, etc.; that's some hairy combinatorics). I guess Tom’s idea is simpler overall, despite the need to extend Ast indexing grammar, Delta expressions and everything else except the Tensor class (to keep the user's API unchanged, though perhaps the user would benefit from the stronger indexing, at the cost of syntactic noise for all indexing?). Could there be a solution that is as simple as Tom’s, but touches as little code as mine?

Edit: End I worry what the gradient (the transpose in the other sense, of a linear transformation) of that generalized indexing would be. Currently the gradient is a zero tensor with the subtensor at the indexing path updated to the cotangent accumulator value. What what to do when the path contains the Nothing components? Update in all cells at that rank? That wouldn't be too complex (even though it would be slow). But that's one more complex area getting more complex. In "simplified" horde-ad.

Edit2: Would representing permutations as a list of swaps suddenly simplify everything (probably not)? This is also related to our next task after vectorization: fusion+simplification of vectorized code. We need to be able to group transpositions somehow in order to fuse them. It’s not always possible to push them down the terms nor up (e.g., in AstFromList both often fail). Any tricks? Any plausible normal forms (e.g., transpose of AstFromList could be one or we could permit only a single swap, if that helps)?

Mikolaj commented 1 year ago

I failed, at the last lap, in the attempt to vectorize indexing via the method "push down indexes applied to a transpose". As expected, the code got absolutely undebuggable: https://github.com/Mikolaj/horde-ad/commit/ae745e781d47b082a25ae09875d4a6743b651639

In the last case, GatherN, it turned out I need exactly the generalized indexing that Tom mentioned. Only one potential solution now remains, but it still scares me, especially that it will probably require touching all the code, down to Delta expressions. I guess desugaring the generalized indexing to indexing composed with mapping would generate a lot of delta expressions, which is why we need to reach that deep. I haven't verified that, though. Is there a cheap way to express the generalize indexing using the remaining operations?

I'm still looking for an out of the box solution.

Mikolaj commented 1 year ago

In the end, I have probably scavenged the "push down indexes applied to a transpose", though I'm still not able to type-check it and it will most probably require half a dozen unsafeCoerce.

The method is way too complex, but something similar will be needed to fuse multiple stacked GatherN into one, so we'd better get the tools ready. If the typing plugins were more complete, that would help a lot. The presburger arithmetic plugin can't cope with so many additions and the normalise plugin has some gaps. But even eithout that, the combinatorics of permutations and of rank juggling is more than I bargained for. That's definitely not code that writes itself guided by types.

Mikolaj commented 1 year ago

That was a terrible experience, this is surely buggy and none of our tests exercises any of the multiple huge terrible parts of the code. Let's try to live-demo and count on the demo effect.

https://github.com/Mikolaj/horde-ad/commit/9d4d5ac4d99fa40221a4bfd7ea041ae1456563a1

Mikolaj commented 1 year ago

End by the end of it I forgot what it was about --- the commit mentions GatherN, but the missing piece was really a fully general Transpose (applied to GatherN and then projections applied on top --- that's the perfect storm).

tomsmeding commented 1 year ago

I wonder if what you need here is generating random programs :p

Mikolaj commented 1 year ago

Yes, I think so. Dimensions only 1, 2 or 3, ranks up to 5, I guess. Maybe 7.

Mikolaj commented 1 year ago

The randomized testing discussion moved to https://github.com/Mikolaj/horde-ad/issues/86.

Mikolaj commented 1 year ago

It turns out "transpose can be expressed as gather, doh" is the "out of the box" idea that would prevent the creation of this horrible mess of a code (pushing both indexing and transposition down any terms, especially gather) that closed this ticket. Sadly, it came a couple of days too late.

Here this code is finally excised and GHC 9.2. copes with type-checking again (and vectorization is split into simplification and vectorization proper):

https://github.com/Mikolaj/horde-ad/commit/61a414d3ce311b98ec803e7dcaef9aed908debd8

tomsmeding commented 1 year ago

@Mikolaj So much -, I love it! So what you do now is just weaken transposition to gather, and vectorise using that?

I expect that you'll want to reintroduce transpose later by recognising gathers that are secretly just transposes, as an optimisation step.

Mikolaj commented 1 year ago

@Mikolaj So much -, I love it! So what you do now is just weaken transposition to gather, and vectorise using that?

Yes, though a bit more delicately: whenever I try to vectorize build .. (var, index (transpose perm v) ix) (and var appears in v, so I can't just create a gather outright), I translate transpose to gather, in order to push the indexing down, etc. More generally, I do one step of index simplification. This outermost strategy level (which rules to apply when) is still WIP.

I expect that you'll want to reintroduce transpose later by recognising gathers that are secretly just transposes, as an optimisation step.

Yes, not translating transpose to gather and instead, e.g., fusing two transposes, helps a lot, but I'm not yet principled enough about that. Regarding translating gather back to various simpler stuff, I will, but so far I never do. That would help both when term rewriting and later on when interpreting it in ADVal (faster tensor ops and smaller terms to AD).

Mikolaj commented 1 year ago

And, to be fair, some of that removed terrible code was paid for by the new simplification code added recently. However, the simplification code would be needed anyway, because of our vectorize -> simplify -> AD pipeline. Vectorization contains just enough of the simplification calls to prevent getting stuck and also to simplify away any code that is being added or changed (this is tricky, because simplifying redexes uncovers new redexes and it's hard to say if they where there already in the original code; we'd need to watermark the whole Ast).

However, how much simplification is done in the dedicated simplification step is up to the user. It can be very well be none, because what vectorization spoiled, it tried to simplify on the spot and simplification is idempotent (ensured by tests). But if the user wants to also simplify the original user's code, that's fine (and no money back if it turns out slower:).

tomsmeding commented 1 year ago

However, how much simplification is done in the dedicated simplification step is up to the user. It can be very well be none, because what vectorization spoiled, it tried to simplify on the spot and simplification is idempotent (ensured by tests). But if the user wants to also simplify the original user's code, that's fine (and no money back if it turns out slower:).

Why would you ever want to not simplify? What rules are there in the simplify step that potentially make performance worse?

Also, do I understand correctly that you need to simplify during the vectorisation step in order to eliminate certain code patterns that vectorisation can't handle? If so, that feels like a fragile system: one doesn't typically assume that the simplifier needs to avoid introducing some particular code patterns because that's relied on by some other part of the code. Could there be a dedicated preprocessing step before vectorisation, so preproc -> vectorise -> simplify -> AD, where preproc does the right simplifications so that vectorise can do its job, and nothing more?

EDIT: presumably actually preproc -> vectorise -> simplify -> AD -> simplify.

Mikolaj commented 1 year ago

Edit: since writing this, I reworked the step-by-step simplification and, in particular, it is halting after one step (or more) not only in case of the constructors that have many subterms. Other than that, it's all more or less as described.

Why would you ever want to not simplify? What rules are there in the simplify step that potentially make performance worse?

I haven't benchmarked and it's all WIP (and sharing for Ast is not yet implemented), but the more I extend the simplification code, the slower the tests get. ;)

The obvious suspects are the rules that take a term and make many copies of it and then recursively simplify in each subterm (somehow composed with the copy), which are those guarded by stepOnly in the code of astIndexZOrStepOnly. Example:

https://github.com/Mikolaj/horde-ad/blob/61a414d3ce311b98ec803e7dcaef9aed908debd8/simplified/HordeAd/Core/AstSimplify.hs#L227-L229

though the really unavoidable example is above, for AstOp (arithmetic operations). They have, fortunately, at most two subtrees, but 2 is enough to do exponential blowup (if that's what happens).

However, right now probably the biggest slowdown is from fusing gathers, which is just unfinished and unguarded and so unnecessarily turns reshape into gather (these are terrible to simplify, because they have lots of quot and rem terms and I haven't yet implemented the a quot b + a rem b -> a rule; if it even helps) and potentially then complicates such terms even more by substitution, while fusing. I'm just removing AstGather1 as we speak in order to focus on finishing simplification of AstGatherN.

Also, do I understand correctly that you need to simplify during the vectorisation step in order to eliminate certain code patterns that vectorisation can't handle?

Yes. The forbidden code pattern is just one: indexing (except a few simple cases of indexing). Depending on how you look at it, gather may be another one, but we handle gather by calling the indexing simplification for the index contained within gather. so it's more of the same.

If so, that feels like a fragile system: one doesn't typically assume that the simplifier needs to avoid introducing some particular code patterns because that's relied on by some other part of the code.

I must have miscommunicated. We don't need to avoid introducing anything. We could simplify fully and vectorization would be very happy. However

  1. the user knows best, so code should not be touched unless necessary or requested
  2. vectorization changes code, introducing patterns that vectorization can't handle
  3. so vectorization needs to call simplification repeatedly, regardless of whether we simplified all at the start
  4. vectorization changes code into too complex and easily reducible stuff, so simplification needs to be called to clean this up, as well.

Could there be a dedicated preprocessing step before vectorisation, so preproc -> vectorise -> simplify -> AD, where preproc does the right simplifications so that vectorise can do its job, and nothing more?

Hmm. You are actually right. Vectorization introduces indexing (the bad pattern) or gather where there was none, so any vectorization for the outer build operation occurrence then encounters them and needs to call simplification again. But, it turns out, vectorization never introduces the bad patterns into the arguments to its own recursive calls. So, unless I'm missing something, we could do the following:

depth-first-build (preproc; vectorise) -> simplify -> AD

where depth-first-build f applies f to each build subterm bottom-up, replacing the subterm by the result, before moving upwards. I use it already, but for the vectorize freely mixed with simplify. Is that a traverse or a tree fold? And the preproc would just be full simplify, or we could try to be smart, start at the root and stop in each recursive subtree call whenever we've eliminated enough occurrences of the build variable so that the remaining indexing (and gather) are the easy cases. I'm not sure if this stop condition can be characterized without mimicking a lot of the vectorization logic.

This looks simpler to what I (intend to) do, but the drawback is either the stop condition or simplifying much deeper in the term than necessary (e.g., simplify whole huge subterms that never mention the build variable). Respect for user code aside, at this point I'm not sure if simplification can incur term size blowup and whether sharing makes it harmless (by representing a huge tree as small DAG and enabling traversal over that one). So, simplifying too much may be questionable or an outright disaster. E.g., if simplification is really bad, with the smart approach we could at least guarantee that the badness is bounded by the number of build occurrences or, with lots of luck, by the maximum rank appearing in the program, because builds that fuse with each other (or that produce gathers that fuse) don't contribute to the disaster.

EDIT: presumably actually preproc -> vectorise -> simplify -> AD -> simplify.

Simplify after AD is tricky, because we don't AD the Ast form, but concrete dual numbers (after interpreting Ast in ADVal). Another story and, yes, it would be cool to AD the Ast.

Mikolaj commented 1 year ago

I'm making progress simplifying the terms: a term has just grown beyond my 64G RAM on its way to being fully simplified. ;D

To be fair, it probably blew up while running show on the term after a failed test (testing idempotence of simplification).

Mikolaj commented 1 year ago

It turns out the computer-killing terms are those with map expressed via build. They blow up whenever gradients are computed for them without first vectorizing them (for test purposes, to compare results with all the other several configurations). However, if they are simplified before AD, the computation is fast.

Which suggests our simplification procedure is relatively fine now and the most urgent need is sharing (though I don't have tests that exhibit the need when vectorized or when simplified, so perhaps bigger tests are more urgent).

Vectorization is also in good shape, because it's much simplified via decoupling it from simplification. Simplification, in particular the fusion of two gather terms, is the most complex part of the code now

https://github.com/Mikolaj/horde-ad/blob/f3ab346243faa4f6316e1fede42fe384958d00a2/simplified/HordeAd/Core/AstSimplify.hs#L611-L631

though it can't compare with the previous monstrosity that at once vectorized gather and fused it with a projection composed with a transposition (that was before @tomsmeding came up with the brutal but effective idea of expressing transposition as gather).