Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
31 stars 6 forks source link

Don't use rule (D) for AstAppend and let simplification pass through #98

Open Mikolaj opened 1 year ago

Mikolaj commented 1 year ago

It turns out our handling of AstAppend was broken

https://github.com/Mikolaj/horde-ad/commit/d3ddcf76572f49406a5d729570c6424b88f391c0#diff-89b28a3b50641a4f5bea4f760c491a369d1245c655cb471abec544c7438bc90dR312-R316

and because in our system append can't be expressed as gather nor build, in desperation we now use rule (D) and we don't simplify indexing of append nor fuse gathers from outside and inside of append.

If some tenets of our system need relaxing to accommodate more efficient handling of append, it may be worth it. If rule (D) has to be used nevertheless, we should at least try to simplify more, similarly as AstFromList is simplified maximally despite rule (D) being required for its vectorization.

tomsmeding commented 1 year ago

Re your comment in the diff, why is using astCond for append not allowed because of possible errors in the untaken branch, but we are apparently allowed to turn regular conditionals into astCond? The user can write append in terms of build.

Mikolaj commented 1 year ago

That's a good question and a delicate point that needs to be breached in the paper. We don't provide true conditionals, but eager conditionals. So our ifB class instance is not lawful, though it satisfies some weaker laws. The user using our eager conditionals has to take the weaker semantics into account and avoid any crashes (e.g., division by zero) in the untaken branch. That includes crashes that do not occur in the original program (even when all impossible branches are materialized), but only appear after vectorization or after extra simplification or after AD.

To help, we've changed the semantics of indexing, gather and scatter (and we are ready to change the semantics of slice once we start simplifying gather to slice) so that indexing out of bounds does not crash. We haven't augmented division not to crash when dividing by zero (now that I think about it, it's an option). It would be helpful to think if there are any other sources of crashes in addition to these two.

We, as a user of our own eager conditionals, have to be careful in the same way, but we are in a worse position, because we don't know the whole program. To avoid introducing new crashes into some user programs, our rules can't introduce valuations of integer variables that are not possible in the original program, even guarded by our eager conditional, even manifesting only after our transformations of the program. In particular, the bad rule for AstAppend did introduce possibly negative valuations for some indexes (under an impossible conditional branch), which are obviously impossible in the original program, because all built-in tensor operations assign only non-negative integers to the indexes.

We need to characterize the precautions the user needs to take and the guarantees we make. I hope our transformations no longer produce valuations of integer variables that are impossible in the original program. That should be enough to guarantee they don't cause crashes. I'm not sure, though, and perhaps our guarantees are weaker (e.g., we never produce new negative valuations, but we do sometimes multiply integers by a positive constant).

The user can write append in terms of build.

You mean, with a conditional that assigns negative integers to an indexing variable? That should be fine when the user verifies that any index expressions in the surrounding program that our transformations can possibly substitute inside that conditional don't contain crashing operations, such as division by, under these conditions, zero. In one of my new tests this does not hold for the conditional that the AstAppend introduces, so the rule is not applicable to such programs and since I don't check programs before applying that rule, I had to remove it. In all other tests it worked fine.

tomsmeding commented 1 year ago

Ah, I see! That makes sense.

I wonder if it would be useful to introduce a multigather primitive, or whether that would make our lives very sad.

tmultigather
  :: (KnownNat m, KnownNat n, KnownNat p)
  => ShapeInt (m + n) -> [TensorOf (p + n) r]
  -> (IndexOf m r -> (Int, IndexOf p r))
  -> TensorOf (m + n) r
tmultigather sh l f = tbuild sh (\i -> let (j, ix) = f i in (l !! j) ! ix)
-- implementation in terms of tbuild just to provide semantics, this should
-- be a primitive

or even one that doesn't require all the p values to be uniform over the input arrays, which will result in horrible HList-like things.

This would very neatly express append:

append a b =
  let sh@(n ::: _) = tshape a
  in multigather sh [a, b]
                 (\(i ::: ix) -> ifB (i <* n) (0, i ::: ix) (1, (i - n) ::: ix))

but I fear that having multigather in the core language may make vectorisation/simplification much harder again, due to the generality of the primitive.

tomsmeding commented 1 year ago

You mean, with a conditional that assigns negative integers to an indexing variable?

And yes, indeed I meant that, but you're correct that this works because now the burden of proof is on the user, not on us. This will need to be very prominent in the documentation, though -- users will not expect ifB to be eager.

Mikolaj commented 1 year ago

Yes, that tmultigather looks promising, giving us the first candidate for solving this ticket. It side-steps the restriction that all tensors in fromList have equal shape (if not for the restriction, it would be expressible using fromList). I suppose, when compiling to GPU, this list could be efficiently represented as a rugged array or normal array of the shape determined from the largest element (with other elements padded with zeroes).

Your append is safe, because the negative indexes it introduces are inside ifB of AstInt r and this one is lazy, unlike ifB of Ast n r. And it also stays lazy, because our transformations don't mess with it (at most, they substitute and simplify inside it, but don't change laziness nor take expressions out of it, unless the condition can be determined ahead of the time).

I can't guess easily if this new primitive vectorizes/simplifies well. Probably only the "rugged" aspect can cause any problems.

tomsmeding commented 1 year ago

I suppose, when compiling to GPU, this list could be efficiently represented as a rugged array or normal array of the shape determined from the largest element (with other elements padded with zeroes).

Given that shapes are known statically, I can even imagine that if the arrays used as input to multigather are not used anywhere else, their data can be stored (in destination-passing-style style) in the buffer where they will be used, i.e. in the concatenated buffers of the inputs to multigather -- i.e. in appended form. :P

Assuming that the list that is an input to multigather is short, I would actually expect its elements to be separate buffers in GPU memory. Same holds for fromList. I would not expect the lists given to multigather/fromList to compile to an array -- then the input should have been an array. Will fromList ever have super long input lists?

Mikolaj commented 1 year ago

I think you are right. I probably mixed-up meta-levels.

Will fromList ever have super long input lists?

I guess, not. For large constants, the user has tconst.