Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
33 stars 6 forks source link

Unravel the confusion of parameterization by scalar and by rank in the Delta GADT #10

Closed Mikolaj closed 2 years ago

Mikolaj commented 2 years ago

In short, the compiler is so confused that it can't eliminate all the unhandled cases in buildVectors in https://github.com/Mikolaj/horde-ad/blob/master/src/HordeAd/Core/Delta.hs that the GADT was specifically written to rule out. Hunams tend to be confused as well, e.g., not grasping immediately that Delta (Vector (Vector Float)) and Delta (Delta Double)) are unintended. Recent futile attempts at deriving a Show instance for the type confirmed the confusion and made me ask for help.

Ideally, the solution would not unduly complicate the DualNumber type nor any other type further up the abstraction ladder. Nor should too much explicit coercion be required, because the r type (e.g., Double) really is, at the same time, the underlying scalar type and the level 0 on the ladder of tensor ranks, for each of which we have specialized Delta operations (but for specific underlying scalars we don't have special operations).

@goldfirere: this is one of the (two, so far) hard (for me) typing problems I'm stumbling over that I mentioned to you.

Mikolaj commented 2 years ago

I've tried to ovehaul Delta definition using a type class, but I'm stuck:

https://github.com/Mikolaj/horde-ad/commit/a78236a8fd0d23dac687f0a354567cc7a6cf9a8b#diff-d9c2c6cd672992a85fefc4a8e29abf8de2e45d59779917a819b121f230b8d423R45

Mikolaj commented 2 years ago

I've reached a little further. but seems a blind alley (scale doesn't type-check):

class Delta d s r where
  scaleDelta :: s -> d r -> d r

instance Delta DeltaScalar r r where
  scaleDelta = ScaleScalar

instance Delta DeltaVector (Vector r) r where
  scaleDelta = ScaleVector

instance Delta DeltaMatrix (Matrix r) r where
  scaleDelta = ScaleMatrix

data DualNumber s = forall d r. D s (d s r)

scale :: (Num s, Delta d s r) => s -> DualNumber s -> DualNumber s
scale s (D u u') = D (s * u) (scaleDelta s u')

Edit: Here's even more obvious why what I'm trying is a blind alley (fails too):

sumElements1 :: Numeric r
             => DualNumber (Vector r) -> DualNumber r
sumElements1 (D u u') = D (HM.sumElements u) (SumElements1 u' (V.length u))

Edit2: Anyway, I give up. I'm out of ideas. I need further guidance. :)

goldfirere commented 2 years ago

Does this help?

class Delta f where
  scaleDelta :: ScaleType f r -> f r -> f r
  type ScaleType f r

instance Delta DeltaScalar where
  scaleDelta = ScaleScalar
  type ScaleType DeltaScalar r = r

instance Delta DeltaVector where
  scaleDelta = ScaleVector
  type ScaleType DeltaVector r = Vector r

instance Delta DeltaMatrix where
  scaleDelta = ScaleMatrix
  type ScaleType DeltaMatrix r = Matrix r
Mikolaj commented 2 years ago

It sure does help. Now I'm stuck at both the operations not type-checking (due to a silly definition of DualNumber, but it should be used more or less as in the operations; it is used like that right now):

https://github.com/Mikolaj/horde-ad/blob/0f8ee16838c859acccf1a20e660c83b47f1fbfdd/src/HordeAd/Core/Delta.hs#L44-L66

Mikolaj commented 2 years ago

But this one got much closer!

class Delta a where
  scaleDelta :: a -> DeltaExpression a -> DeltaExpression a
  type DeltaExpression a

instance Delta Double where
  scaleDelta = ScaleScalar
  type DeltaExpression Double = DeltaScalar Double

instance Delta Float where
  scaleDelta = ScaleScalar
  type DeltaExpression Float = DeltaScalar Float

instance Delta (Vector r) where
  scaleDelta = ScaleVector
  type DeltaExpression (Vector r) = DeltaVector r

instance Delta (Matrix r) where
  scaleDelta = ScaleMatrix
  type DeltaExpression (Matrix r) = DeltaMatrix r

data DualNumber a = forall d. D a (DeltaExpression a)

scale :: (Num a, Delta a) => a -> DualNumber a -> DualNumber a
scale a (D u u') = D (a * u) (scaleDelta a u')

sumElements1 :: (Numeric r, Delta r) => DualNumber (Vector r) -> DualNumber r
sumElements1 (D u u') = D (HM.sumElements u) (SumElements1 u' (V.length u))

Line 70 is the definition of the sumElements1 function:

src/HordeAd/Core/Delta.hs:70:47: error:
    • Couldn't match expected type ‘DeltaExpression r’
                  with actual type ‘DeltaScalar r’
    • In the second argument of ‘D’, namely
        ‘(SumElements1 u' (V.length u))’
      In the expression:
        D (HM.sumElements u) (SumElements1 u' (V.length u))
      In an equation for ‘sumElements1’:
          sumElements1 (D u u')
            = D (HM.sumElements u) (SumElements1 u' (V.length u))
    • Relevant bindings include
        u' :: DeltaExpression (Vector r)
          (bound at src/HordeAd/Core/Delta.hs:70:19)
        u :: Vector r (bound at src/HordeAd/Core/Delta.hs:70:17)
        sumElements1 :: DualNumber (Vector r) -> DualNumber r
          (bound at src/HordeAd/Core/Delta.hs:70:1)
Mikolaj commented 2 years ago

Yay, the following works (for the two functions at least):


type family DeltaExpression a where
  DeltaExpression (Vector r) = DeltaVector r
  DeltaExpression (Matrix r) = DeltaMatrix r
  DeltaExpression a = DeltaScalar a

class Delta a where
  scaleDelta :: a -> DeltaExpression a -> DeltaExpression a

instance DeltaExpression r ~ DeltaScalar r => Delta r where
  scaleDelta = ScaleScalar

instance Delta (Vector r) where
  scaleDelta = ScaleVector

instance Delta (Matrix r) where
  scaleDelta = ScaleMatrix

data DualNumber a = D a (DeltaExpression a)

scale :: (Num a, Delta a) => a -> DualNumber a -> DualNumber a
scale a (D u u') = D (a * u) (scaleDelta a u')

sumElements1 :: (Numeric r, DeltaExpression r ~ DeltaScalar r)
             => DualNumber (Vector r) -> DualNumber r
sumElements1 (D u u') = D (HM.sumElements u) (SumElements1 u' (V.length u))
Mikolaj commented 2 years ago

Huh, it looked good despite a few extra constraints here and there, but eventually, when all types are instantiated, I'm getting overlapping instances errors and I'm stuck again. So I may define instances for Double, Float and a couple of others, instead of the general fallback case, and call it done.

tomjaguarpaw commented 2 years ago

I'm confused about how a type class could possibly help here. There is a fixed, known number of instances (3) so I don't think one can do better with a type class than with two separate families of functions.

[EDIT: changed 2 to 3]

Mikolaj commented 2 years ago

@tomjaguarpaw: There are quite a few instances: Float, Double, CFloat, Vector Float, Matrix Double, etc. (not sure about Complex). Functions such as scale need to work with all of them. And I've just managed to get it to compile (not pushed yet).

tomjaguarpaw commented 2 years ago

Unless I've misunderstood something, there are only these there instances below, therefore I don't understand how a type class helps, as opposed to three functions: scaleDeltaScalar, scaleDeltaVector and scaleDataMatrix.


instance Delta DeltaScalar r r where
  scaleDelta = ScaleScalar

instance Delta DeltaVector (Vector r) r where
  scaleDelta = ScaleVector

instance Delta DeltaMatrix (Matrix r) r where
  scaleDelta = ScaleMatrix

(I said 2 instances earlier, I should have said 3)

Mikolaj commented 2 years ago

Yes, these are the three distinct instances. Thanks to the class you can write

scale :: (Num a, Delta a) => a -> DualNumber a -> DualNumber a
scale a (D u u') = D (a * u) (scaleDelta a u')

and not 3 different scale functions. That's even more important for Num, etc.

tomjaguarpaw commented 2 years ago

Hmm, I'm not following. I don't understand how

scale :: (Num a, Delta a) => a -> DualNumber a -> DualNumber a

is a valid thing to write. Delta takes three type arguments, not one.

tomjaguarpaw commented 2 years ago

Oh, it seems to be different in the linked version

https://github.com/Mikolaj/horde-ad/commit/a78236a8fd0d23dac687f0a354567cc7a6cf9a8b#diff-d9c2c6cd672992a85fefc4a8e29abf8de2e45d59779917a819b121f230b8d423R44-R54

Mikolaj commented 2 years ago

Sorry for broken examples. I will push to master in a quarter or two, so everything should be clear. Suggestions will be very much welcome!

tomjaguarpaw commented 2 years ago

OK, thanks. Could you post a direct link to the code in question once you've done so?

Mikolaj commented 2 years ago

Uhoh, I ended up with UndecidableInstances. Help! Pushed to master anyway.

@tomjaguarpaw: the new delta expression definitions: https://github.com/Mikolaj/horde-ad/blob/master/src/HordeAd/Core/Delta.hs

Mikolaj commented 2 years ago

And the Num and other definitions that don't need to be duplicated: https://github.com/Mikolaj/horde-ad/blob/master/src/HordeAd/Core/DualNumber.hs

Mikolaj commented 2 years ago

And the bit the requires UndecidableInstances:

https://github.com/Mikolaj/horde-ad/blob/master/src/HordeAd/Core/Engine.hs#L50

Error message is

src/HordeAd/Core/Engine.hs:50:10: error:
    • The constraint ‘IsTensor (Vector r)’
        is no smaller than
        the instance head ‘DeltaMonad r (DeltaMonadValue r)’
      (Use UndecidableInstances to permit this)
    • In the instance declaration for
        ‘DeltaMonad r (DeltaMonadValue r)’
tomjaguarpaw commented 2 years ago

I see, now the use of type class is IsTensor? I agree that makes sense as a type class.

Mikolaj commented 2 years ago

Yes. The only other class defined in the codebase is DeltaMonad, one of instances of which needs UndecidableInstances.

tomjaguarpaw commented 2 years ago

Ah, OK, I think I understand. You are defining the class of monads in which we can write programs. DeltaMonadValue just gives you the original value of the program. DeltaMonadGradient implements the algorithm from our paper.

Mikolaj commented 2 years ago

Precisely. DeltaMonadValue is an order of magnitude faster if you are only interested in the value (that is, in actually using the trained neural net, say). Not sure if there are going to be any other instances, so if DeltaMonadGradient was fast enough, we could dump the class, but even then it's a good way to ensure we are not using implementation details of DeltaMonadGradient.

Edit: except when we are using, which is where we mention DeltaMonadGradient directly --- there are a couple of such places, probably not all make sense.

tomjaguarpaw commented 2 years ago

OK, I don't think UndecidableInstances is too worrying for now. It tends to be a fairly benign extension.

Mikolaj commented 2 years ago

phew I take your word for it.

goldfirere commented 2 years ago

Other issues having been resolved, I'll just jump in to say that I agree that UndecidableInstances is benign. It means that GHC might not terminate, but if it does, your program is fine.

Mikolaj commented 2 years ago

Thank you. I tried to use the new IsTensor class to get a simpler @DeltaBinding@

data DeltaBinding = forall a. DeltaBinding (DeltaId a) (DeltaExpression a)

and only one operation

returnLet :: DualNumber a -> m (DualNumber a)

instead of the three we have now, one for each scalar rank, but I failed. The trouble is that it's impractical to try and make eval1, eval2 and

eval0 :: r -> DeltaScalar r -> ST s ()

become methods of the IsTensor class. So I tried adding

type Rank a :: Nat

to the class and dispatching on that, but my type-foo failed me and the compiler was not able to deduce that rank 1 of a binding implies that it contains a DeltaVector delta expression. Probably too deep down the type-level hole. Or is there a trick worth exploring?

goldfirere commented 2 years ago

I'm a bit lost here. What's the goal in your last question? Just to reduce the number of return functions? In general, your DeltaBinding type does not strike me as useful, at least without an IsTensor constraint in there. More likely, though, you might want DeltaBinding to have three constructors -- one for each supported tensor level.

Mikolaj commented 2 years ago

Yes, the goal is to reduce the number of distinct delta-let functions from 3, currently, to 1, as in the paper (returnLet is a slight tweak of delta-let). At the middle level of abstraction that the library currently implement, the delta-lets are inserted manually by the user, so it's a mild nuisance to use 3 of them and also most monadic functions that should be parametric over tensor rank are not, because the delta-let that is used forces the rank.

Indeed, currently DeltaBinding has three constructors --- see the link to file Delta.hs above, In itself it's not a problem.

Mikolaj commented 2 years ago

Doh, of course you are right, I don't need to change DeltaBinding and make eval a method of the class. It's enough to, kind of, make each of the threee DeltaBinding constructors a method of IsTensor. I will implement and show what I mean.

Mikolaj commented 2 years ago

This is what I had in mind and it works and now monadic functions on dual numbers are polymorphic over tensor rank, just as non-monadic always were!

https://github.com/Mikolaj/horde-ad/commit/2fdbcbed9bcc1e91f6b001735527f16079e5ec8f

Is there a way to simplify this machinery, e.g., to get rid of the associated type synonym ScalarOfTensor?

Mikolaj commented 2 years ago

Full success: no performance difference (all +RTS -s memory stats near-identical, all times similar) vs before the overhaul and Show instance can now be defined for delta expressions, producing such ugly printouts as the one on the left here: https://github.com/Mikolaj/mostly-harmless/discussions/16#discussioncomment-2200709

Mikolaj commented 2 years ago

I've got rid of the UndecidableInstances and -Wno-orphans in one go, at the cost of some silly boilerplate.

In other news, after all, there is a performance cost (six times slowdown) of the associated type synonyms. I had to revert a too zealous abstraction that used them: https://github.com/Mikolaj/horde-ad/commit/131584577bce752e270ef8199f1931e5741ed557

I wonder if it's known that associated type synonyms can prevent specialization (that's my guess) even in the presence of -fexpose-all-unfoldings -fspecialise-aggressively. The output of -Wall-missed-specialisations certainly is different with and without that workaround commit, but only in 9.2.1, which is terribly verbose, so I can't tell which warning is a red flag. I wonder what 9.2.2 or head would do.

goldfirere commented 2 years ago

What if you use SPECIALIZE? Will that help?

It surprises me that the type families get in the way here. :(

Mikolaj commented 2 years ago

When I do

{-# SPECIALIZE makeDualNumberVariables ::
     ( Vector Double
     , Data.Vector.Vector (Vector Double)
     , Data.Vector.Vector (Matrix Double) )
  -> ( Data.Vector.Vector (Delta0 Double)
     , Data.Vector.Vector (Delta1 Double)
     , Data.Vector.Vector (Delta2 Double) )
  -> DualNumberVariables Double #-}

makeDualNumberVariables
  :: ( Vector r
     , Data.Vector.Vector (Vector r)
     , Data.Vector.Vector (Matrix r) )
  -> ( Data.Vector.Vector (DeltaExpression r)
     , Data.Vector.Vector (DeltaExpression (Vector r))
     , Data.Vector.Vector (DeltaExpression (Matrix r)) )
  -> DualNumberVariables r

(newer names of types than in the original workaround commit, but it's equivalent) it tells me

src/HordeAd/Core/PairOfVectors.hs:29:1: warning:
    SPECIALISE pragma for non-overloaded function ‘makeDualNumberVariables’
   |
29 | {-# SPECIALIZE makeDualNumberVariables ::
   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^...

and indeed there's no speedup. I don't know what else to specialize, because all functions from this culprit to the final monomorphic function are polymorphic.

Mikolaj commented 2 years ago

Actually, I'm wrong. The slowdown is on the code path where function's value is computed, so, 1. there are only a few functions in that chain and 2. one of them uses an instance of IsTensor and then completely ignores it (we don't compute gradients at all).

Mikolaj commented 2 years ago

Amazing, adding the two SPECIALIZE (neither is enough), in this commit

https://github.com/Mikolaj/horde-ad/commit/229a9b834c80a1a98ece769ee49a634951f25405

gets the code from running 6 times slower and allocating twice as much as the version without this commit to, wait for it, allocating one hundred times less and running three times faster than the version without this commit.

Confirmed on 8.10.7, 9.0.2 and 9.2.1. Smells like a GHC bug. Also, no warning about missed specialization despite -Wall-missed-specialisations, except perhaps in 9.2.1, but it spams too much to verify. Is this mess expected/known or should I search the GHC bug tracker and report it?

goldfirere commented 2 years ago

:)

@simonpj has more informed expectations for when SPECIALIZE should make a difference.

Mikolaj commented 2 years ago

The SPECIALIZE problem is now handled in #14.