Closed Mikolaj closed 2 years ago
I've tried to ovehaul Delta
definition using a type class, but I'm stuck:
I've reached a little further. but seems a blind alley (scale
doesn't type-check):
class Delta d s r where
scaleDelta :: s -> d r -> d r
instance Delta DeltaScalar r r where
scaleDelta = ScaleScalar
instance Delta DeltaVector (Vector r) r where
scaleDelta = ScaleVector
instance Delta DeltaMatrix (Matrix r) r where
scaleDelta = ScaleMatrix
data DualNumber s = forall d r. D s (d s r)
scale :: (Num s, Delta d s r) => s -> DualNumber s -> DualNumber s
scale s (D u u') = D (s * u) (scaleDelta s u')
Edit: Here's even more obvious why what I'm trying is a blind alley (fails too):
sumElements1 :: Numeric r
=> DualNumber (Vector r) -> DualNumber r
sumElements1 (D u u') = D (HM.sumElements u) (SumElements1 u' (V.length u))
Edit2: Anyway, I give up. I'm out of ideas. I need further guidance. :)
Does this help?
class Delta f where
scaleDelta :: ScaleType f r -> f r -> f r
type ScaleType f r
instance Delta DeltaScalar where
scaleDelta = ScaleScalar
type ScaleType DeltaScalar r = r
instance Delta DeltaVector where
scaleDelta = ScaleVector
type ScaleType DeltaVector r = Vector r
instance Delta DeltaMatrix where
scaleDelta = ScaleMatrix
type ScaleType DeltaMatrix r = Matrix r
It sure does help. Now I'm stuck at both the operations not type-checking (due to a silly definition of DualNumber
, but it should be used more or less as in the operations; it is used like that right now):
But this one got much closer!
class Delta a where
scaleDelta :: a -> DeltaExpression a -> DeltaExpression a
type DeltaExpression a
instance Delta Double where
scaleDelta = ScaleScalar
type DeltaExpression Double = DeltaScalar Double
instance Delta Float where
scaleDelta = ScaleScalar
type DeltaExpression Float = DeltaScalar Float
instance Delta (Vector r) where
scaleDelta = ScaleVector
type DeltaExpression (Vector r) = DeltaVector r
instance Delta (Matrix r) where
scaleDelta = ScaleMatrix
type DeltaExpression (Matrix r) = DeltaMatrix r
data DualNumber a = forall d. D a (DeltaExpression a)
scale :: (Num a, Delta a) => a -> DualNumber a -> DualNumber a
scale a (D u u') = D (a * u) (scaleDelta a u')
sumElements1 :: (Numeric r, Delta r) => DualNumber (Vector r) -> DualNumber r
sumElements1 (D u u') = D (HM.sumElements u) (SumElements1 u' (V.length u))
Line 70 is the definition of the sumElements1
function:
src/HordeAd/Core/Delta.hs:70:47: error:
• Couldn't match expected type ‘DeltaExpression r’
with actual type ‘DeltaScalar r’
• In the second argument of ‘D’, namely
‘(SumElements1 u' (V.length u))’
In the expression:
D (HM.sumElements u) (SumElements1 u' (V.length u))
In an equation for ‘sumElements1’:
sumElements1 (D u u')
= D (HM.sumElements u) (SumElements1 u' (V.length u))
• Relevant bindings include
u' :: DeltaExpression (Vector r)
(bound at src/HordeAd/Core/Delta.hs:70:19)
u :: Vector r (bound at src/HordeAd/Core/Delta.hs:70:17)
sumElements1 :: DualNumber (Vector r) -> DualNumber r
(bound at src/HordeAd/Core/Delta.hs:70:1)
Yay, the following works (for the two functions at least):
type family DeltaExpression a where
DeltaExpression (Vector r) = DeltaVector r
DeltaExpression (Matrix r) = DeltaMatrix r
DeltaExpression a = DeltaScalar a
class Delta a where
scaleDelta :: a -> DeltaExpression a -> DeltaExpression a
instance DeltaExpression r ~ DeltaScalar r => Delta r where
scaleDelta = ScaleScalar
instance Delta (Vector r) where
scaleDelta = ScaleVector
instance Delta (Matrix r) where
scaleDelta = ScaleMatrix
data DualNumber a = D a (DeltaExpression a)
scale :: (Num a, Delta a) => a -> DualNumber a -> DualNumber a
scale a (D u u') = D (a * u) (scaleDelta a u')
sumElements1 :: (Numeric r, DeltaExpression r ~ DeltaScalar r)
=> DualNumber (Vector r) -> DualNumber r
sumElements1 (D u u') = D (HM.sumElements u) (SumElements1 u' (V.length u))
Huh, it looked good despite a few extra constraints here and there, but eventually, when all types are instantiated, I'm getting overlapping instances errors and I'm stuck again. So I may define instances for Double, Float and a couple of others, instead of the general fallback case, and call it done.
I'm confused about how a type class could possibly help here. There is a fixed, known number of instances (3) so I don't think one can do better with a type class than with two separate families of functions.
[EDIT: changed 2 to 3]
@tomjaguarpaw: There are quite a few instances: Float, Double, CFloat, Vector Float, Matrix Double, etc. (not sure about Complex). Functions such as scale
need to work with all of them. And I've just managed to get it to compile (not pushed yet).
Unless I've misunderstood something, there are only these there instances below, therefore I don't understand how a type class helps, as opposed to three functions: scaleDeltaScalar
, scaleDeltaVector
and scaleDataMatrix
.
instance Delta DeltaScalar r r where
scaleDelta = ScaleScalar
instance Delta DeltaVector (Vector r) r where
scaleDelta = ScaleVector
instance Delta DeltaMatrix (Matrix r) r where
scaleDelta = ScaleMatrix
(I said 2 instances earlier, I should have said 3)
Yes, these are the three distinct instances. Thanks to the class you can write
scale :: (Num a, Delta a) => a -> DualNumber a -> DualNumber a
scale a (D u u') = D (a * u) (scaleDelta a u')
and not 3 different scale
functions. That's even more important for Num
, etc.
Hmm, I'm not following. I don't understand how
scale :: (Num a, Delta a) => a -> DualNumber a -> DualNumber a
is a valid thing to write. Delta
takes three type arguments, not one.
Oh, it seems to be different in the linked version
Sorry for broken examples. I will push to master in a quarter or two, so everything should be clear. Suggestions will be very much welcome!
OK, thanks. Could you post a direct link to the code in question once you've done so?
Uhoh, I ended up with UndecidableInstances
. Help! Pushed to master anyway.
@tomjaguarpaw: the new delta expression definitions: https://github.com/Mikolaj/horde-ad/blob/master/src/HordeAd/Core/Delta.hs
And the Num
and other definitions that don't need to be duplicated: https://github.com/Mikolaj/horde-ad/blob/master/src/HordeAd/Core/DualNumber.hs
And the bit the requires UndecidableInstances
:
https://github.com/Mikolaj/horde-ad/blob/master/src/HordeAd/Core/Engine.hs#L50
Error message is
src/HordeAd/Core/Engine.hs:50:10: error:
• The constraint ‘IsTensor (Vector r)’
is no smaller than
the instance head ‘DeltaMonad r (DeltaMonadValue r)’
(Use UndecidableInstances to permit this)
• In the instance declaration for
‘DeltaMonad r (DeltaMonadValue r)’
I see, now the use of type class is IsTensor
? I agree that makes sense as a type class.
Yes. The only other class defined in the codebase is DeltaMonad
, one of instances of which needs UndecidableInstances
.
Ah, OK, I think I understand. You are defining the class of monads in which we can write programs. DeltaMonadValue
just gives you the original value of the program. DeltaMonadGradient
implements the algorithm from our paper.
Precisely. DeltaMonadValue
is an order of magnitude faster if you are only interested in the value (that is, in actually using the trained neural net, say). Not sure if there are going to be any other instances, so if DeltaMonadGradient
was fast enough, we could dump the class, but even then it's a good way to ensure we are not using implementation details of DeltaMonadGradient
.
Edit: except when we are using, which is where we mention DeltaMonadGradient
directly --- there are a couple of such places, probably not all make sense.
OK, I don't think UndecidableInstances
is too worrying for now. It tends to be a fairly benign extension.
phew I take your word for it.
Other issues having been resolved, I'll just jump in to say that I agree that UndecidableInstances
is benign. It means that GHC might not terminate, but if it does, your program is fine.
Thank you. I tried to use the new IsTensor class to get a simpler @DeltaBinding@
data DeltaBinding = forall a. DeltaBinding (DeltaId a) (DeltaExpression a)
and only one operation
returnLet :: DualNumber a -> m (DualNumber a)
instead of the three we have now, one for each scalar rank, but I failed. The trouble is that it's impractical to try and make eval1, eval2 and
eval0 :: r -> DeltaScalar r -> ST s ()
become methods of the IsTensor class. So I tried adding
type Rank a :: Nat
to the class and dispatching on that, but my type-foo failed me and the compiler was not able to deduce that rank 1 of a binding implies that it contains a DeltaVector
delta expression. Probably too deep down the type-level hole. Or is there a trick worth exploring?
I'm a bit lost here. What's the goal in your last question? Just to reduce the number of return
functions? In general, your DeltaBinding
type does not strike me as useful, at least without an IsTensor
constraint in there. More likely, though, you might want DeltaBinding
to have three constructors -- one for each supported tensor level.
Yes, the goal is to reduce the number of distinct delta-let
functions from 3, currently, to 1, as in the paper (returnLet
is a slight tweak of delta-let
). At the middle level of abstraction that the library currently implement, the delta-lets are inserted manually by the user, so it's a mild nuisance to use 3 of them and also most monadic functions that should be parametric over tensor rank are not, because the delta-let
that is used forces the rank.
Indeed, currently DeltaBinding
has three constructors --- see the link to file Delta.hs
above, In itself it's not a problem.
Doh, of course you are right, I don't need to change DeltaBinding
and make eval
a method of the class. It's enough to, kind of, make each of the threee DeltaBinding
constructors a method of IsTensor
. I will implement and show what I mean.
This is what I had in mind and it works and now monadic functions on dual numbers are polymorphic over tensor rank, just as non-monadic always were!
https://github.com/Mikolaj/horde-ad/commit/2fdbcbed9bcc1e91f6b001735527f16079e5ec8f
Is there a way to simplify this machinery, e.g., to get rid of the associated type synonym ScalarOfTensor
?
Full success: no performance difference (all +RTS -s memory stats near-identical, all times similar) vs before the overhaul and Show
instance can now be defined for delta expressions, producing such ugly printouts as the one on the left here: https://github.com/Mikolaj/mostly-harmless/discussions/16#discussioncomment-2200709
I've got rid of the UndecidableInstances
and -Wno-orphans
in one go, at the cost of some silly boilerplate.
In other news, after all, there is a performance cost (six times slowdown) of the associated type synonyms. I had to revert a too zealous abstraction that used them: https://github.com/Mikolaj/horde-ad/commit/131584577bce752e270ef8199f1931e5741ed557
I wonder if it's known that associated type synonyms can prevent specialization (that's my guess) even in the presence of -fexpose-all-unfoldings -fspecialise-aggressively
. The output of -Wall-missed-specialisations
certainly is different with and without that workaround commit, but only in 9.2.1, which is terribly verbose, so I can't tell which warning is a red flag. I wonder what 9.2.2 or head would do.
What if you use SPECIALIZE
? Will that help?
It surprises me that the type families get in the way here. :(
When I do
{-# SPECIALIZE makeDualNumberVariables ::
( Vector Double
, Data.Vector.Vector (Vector Double)
, Data.Vector.Vector (Matrix Double) )
-> ( Data.Vector.Vector (Delta0 Double)
, Data.Vector.Vector (Delta1 Double)
, Data.Vector.Vector (Delta2 Double) )
-> DualNumberVariables Double #-}
makeDualNumberVariables
:: ( Vector r
, Data.Vector.Vector (Vector r)
, Data.Vector.Vector (Matrix r) )
-> ( Data.Vector.Vector (DeltaExpression r)
, Data.Vector.Vector (DeltaExpression (Vector r))
, Data.Vector.Vector (DeltaExpression (Matrix r)) )
-> DualNumberVariables r
(newer names of types than in the original workaround commit, but it's equivalent) it tells me
src/HordeAd/Core/PairOfVectors.hs:29:1: warning:
SPECIALISE pragma for non-overloaded function ‘makeDualNumberVariables’
|
29 | {-# SPECIALIZE makeDualNumberVariables ::
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^...
and indeed there's no speedup. I don't know what else to specialize, because all functions from this culprit to the final monomorphic function are polymorphic.
Actually, I'm wrong. The slowdown is on the code path where function's value is computed, so, 1. there are only a few functions in that chain and 2. one of them uses an instance of IsTensor and then completely ignores it (we don't compute gradients at all).
Amazing, adding the two SPECIALIZE
(neither is enough), in this commit
https://github.com/Mikolaj/horde-ad/commit/229a9b834c80a1a98ece769ee49a634951f25405
gets the code from running 6 times slower and allocating twice as much as the version without this commit to, wait for it, allocating one hundred times less and running three times faster than the version without this commit.
Confirmed on 8.10.7, 9.0.2 and 9.2.1. Smells like a GHC bug. Also, no warning about missed specialization despite -Wall-missed-specialisations
, except perhaps in 9.2.1, but it spams too much to verify. Is this mess expected/known or should I search the GHC bug tracker and report it?
:)
@simonpj has more informed expectations for when SPECIALIZE
should make a difference.
The SPECIALIZE problem is now handled in #14.
In short, the compiler is so confused that it can't eliminate all the unhandled cases in
buildVectors
in https://github.com/Mikolaj/horde-ad/blob/master/src/HordeAd/Core/Delta.hs that the GADT was specifically written to rule out. Hunams tend to be confused as well, e.g., not grasping immediately thatDelta (Vector (Vector Float))
andDelta (Delta Double))
are unintended. Recent futile attempts at deriving aShow
instance for the type confirmed the confusion and made me ask for help.Ideally, the solution would not unduly complicate the
DualNumber
type nor any other type further up the abstraction ladder. Nor should too much explicit coercion be required, because ther
type (e.g.,Double
) really is, at the same time, the underlying scalar type and the level 0 on the ladder of tensor ranks, for each of which we have specializedDelta
operations (but for specific underlying scalars we don't have special operations).@goldfirere: this is one of the (two, so far) hard (for me) typing problems I'm stumbling over that I mentioned to you.