Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
34 stars 6 forks source link

Eliminate the explosion of spurious IsScalarS constraints #26

Closed Mikolaj closed 2 years ago

Mikolaj commented 2 years ago

Fixes #25. The main commit is Inline PrimalS, hack the rest to hang on.

Removal of the spurious constraints is accomplished by defining class IsDualS, as sketched by @simonpj, analogous to the existing IsDual class, but for types parameterized by shape. Since the ~ type equality does not work for higher kind types, we can't create type family PrimalS in class IsDualS analogous to Primal in the IsDual class. Therefore, we effectively inline PrimalS (but not Primal from the IsDual class!).

The trick is completed by defining an instance of IsDual constructed based on IsDualS, which is again a mechanism suggested by @simonpj, though in that early prototype it was based on a quantified constraint, which doesn't seem to work, while the instance (barely) works. In general, the construction would not be possible, because we lack PrimalS (we just inlined it); however, we need the construction to work correctly only for the case of shaped tensors and their dual numbers (the rest has manually written IsDual instances), for which we can hardwire the correct Primal type family.

The instance requires IncoherentInstances to work in GHC 9.2.2 and random code changes break it. In particular, I had to revert a simplification of HordeAd.Core.DualClass that broke type-checking due to the fragility. It seems to work perfectly in GHC HEAD, but type-level plugins don't, so I can't test the full codebase in HEAD and confirm.

Against all odds, not even one of the calamities I predicted materialized. The API of HordeAd.Core.DualClass was not split/cloned for the case of shaped tensors. If fact, it's even more regular than before (or would be, if the simplification was not reverted), even though the implementation has some extra complexity. Surprisingly, no code duplication was needed either in user code nor in implementation of the API, despite a new class and the necessity to define a newtype to change the order of parameters to partially apply a type at the second parameter (the underlying scalar type, leaving the shape not applied).

We apparently managed to create the best of both worlds: the constraint in a type signature is given only once, using the higher-kinded IsDualS under the hood, while all the types are written using fully applied (to multiple shapes!) mechanisms from the regular classes IsDual and HasRanks that are used also for vectors, untyped tensors, etc. Even the inlining of PrimalS is not visible in the API, because IsDualS is not exported and the not inlined Primal from IsDual agrees in all concrete cases and can't observe/express the higher-kinded cases.

goldfirere commented 2 years ago

I haven't been following closely, but I see this text above: "Since the ~ type equality does not work for higher kind types,". But ~ does work on higher-kinded types! Maybe the problem would be that you need to use ~ on a partially applied type family? That indeed does not work.

I should be able to attend our meeting on Wednesday and can explain more there.

Mikolaj commented 2 years ago

"Since the ~ type equality does not work for higher kind types,". But ~ does work on higher-kinded types! Maybe the problem would be that you need to use ~ on a partially applied type family? That indeed does not work.

Thank you for the clarification. I stand corrected. In fact, neither side of the equation we needed was a partially applied type family. The equality in question was

PrimalS (TensorS r) ~ RevArray (Primal r)

where PrimalS would be a fully applied type synonym family that takes a type of kind [Nat] -> Type to a type of the same kind, applied to an application of type synonym family TensorS that takes Type to [Nat] -> Type. The right hand side is a paritally applied newtype RevArray that wraps the orthotope's datatype Array swapping the original (shape, type) arguments. I don't remember if precisely this failed in GHC 9.2.2 and with what error.

Mikolaj commented 2 years ago

I'm going to merge this after tomorrow's call, after addressing any comments that emerge.

goldfirere commented 2 years ago

In the meeting today, Mikolaj asked about uncommenting the last part of IsScalar and removing IsScalarS in favor of IsScalar. He ran into some trouble. I have a solution. (Disclaimer: I don't fully understand what these are -- just following types and fixing errors.)

The problem is that the last line includes a new IsDualS (TensorS r) constraint. In the definition of parametersNew in OutdatedOptimizer.sgdBatchFastForward, we need to solve IsScalar Double, and thus IsDualS (TensorS Double). This expands to

[W] IsDualS (RevArray Double)

using a [W] notation to denote a Wanted constraint. Naturally, we use

instance (forall sh. Num (OS.Array sh r)) => IsDualS (RevArray r)

So we then have

[W] forall sh. Num (OS.Array sh Double)

This is solved using the instance in the orthotope library

instance (Shape sh, Num a) => Num (Array sh a)

So now we have

[W] Num Double
[W] forall sh. Shape sh

The first of these is easily dispatched. But the second is stuck -- this is what the error message is about. The solution is easy, though: change the IsDualS (RevArray r) instance to be

instance (forall sh. OS.Shape sh => Num (OS.Array sh r)) => IsDualS (RevArray r)

And I'm sure this is actually the right thing, given the Num instance for Array. After all, a Num (Array sh r) constraint would hold only when Shape sh holds -- not for any sh.

Mikolaj commented 2 years ago

@goldfirere, thanks a lot. I'm so glad to be wrong here. I thought this is GHC 9.2.2 acting up after being forced by IncoherentInstances to accept these types. However, you may be right GHC <= 9.2.2 is ready to process such types correctly and the only thing it lacks is correct overlapping instances detection (fixed in HEAD, which is why HEAD does not require IncoherentInstances). So, once again, we avoided switching to HEAD.

Mikolaj commented 2 years ago
PrimalS (TensorS r) ~ RevArray (Primal r)

@goldfirere: you were right all along. The equality above works (I must have used the syntax of multi-parameter type family and wrongly assumed I'd get the same errors with a similar single-parameter type family with a higher kind result). I'm inlining away PrimalS regardless (see last commit) to avoid the RevArray newtype wrapper polluting the API, but it's great to know we have the option of introducing PrimalS if needed.

Merging the PR. Thank you everybody!