Closed Mikolaj closed 1 year ago
This doesn't actually work because my Evidence typeclass just generalises over Float and Double, and it won't permit an Ast0 instance so we get overlapping instances here because I can only reduce the Float and Double instances to a single one.
However, I think with a more extended set of type equalities, giving the precise equalities the code needs instead of just the plain definitions in terms of bare types that work only for Float and Double, this idea could work. It's a bludgeon hammer though, not sure if you want to do this. Just an idea.
This is a patch on 37bd69171c02e8a3acd951527763c36cd29a39c2. As before with these patches, intentionally no code styling or appropriate naming applied, that's for you to decide, just wanted to see the ghc errors.
It is a bludgeon hammer. :D
So instead of the illegal quantified constraints over type families, we promise to supply dicts and proofs of type equality for the same assumptions and the code needs to distribute them properly? That distribution thing is rather annoying, but easier to maintain than code duplication. It's also good that this machinery does not leak beyond this code (neither up nor down the code dependency tree).
I simplified the relevant type classes and tried to generalize this further, but got discouraged by not being able to find n
to apply in
interpretAstBool :: AstEnv a
-> AstBool (ScalarOf a) -> BooleanOf (Primal a)
interpretAstBool env | (Refl, Refl, Refl, _, _) <- ev @a @n Proxy = \case
in order to provide BooleanOf (TensorOf n (Primal a)) :~: BooleanOf (IntOf a)
from
ev :: forall n. Proxy a
-> ( TensorOf n a :~: TensorOf n a
, BooleanOf (TensorOf n a) :~: BooleanOf (TensorOf n (Primal a))
, BooleanOf (TensorOf n (Primal a)) :~: BooleanOf (IntOf a)
, Dict EqB (TensorOf n a)
, Dict OrdB (TensorOf n a) )
needed for a recursive call
AstRel opCodeRel args ->
let f v = interpretAstPrimal env (AstPrimalPart v)
in interpretAstRelOp f opCodeRel args
because the n
is known from AstRel
, not from the signature of interpretAstBool
. And indeed, the following works
AstRel @n opCodeRel args | (Refl, Refl, Refl, _, _) <- ev @a @n Proxy ->
let f v = interpretAstPrimal env (AstPrimalPart v)
in interpretAstRelOp f opCodeRel args
but that means I need to summon the evidence all over the code, for various n
in hand. But perhaps that's not going to be so bad, there is only so many existential constructors in the grammar. I guess I'd need to do this dance only for each type signature and existential constructor?
Yeah I feel like that'll be unavoidable; trying to get a quantified constraint in scope is wading into ImpredicativeTypes territory where I expect the troubles will be even worse than simply having a type class instance over a type family.
Perhaps, if most of the problematic calls are to the same functions, it might be possible to write wrapper functions for those that have more general type signatures and that construct the evidence for themselves in their function body. That may or may not reduce the number of times you have to write that evidence creation code. But I suspect the code will be much less readable for it; there is a lot to say for not making type signatures more complicated than they already are.
I panicked way too soon. Look how simple it ended up being:
I can't say I understand this success, but I certainly don't complain. :)
What changed? Did you only end up needing it for AstRel?
I like putting most of the equality constraints in the superclass list, that's a lot less Refl
:)
Yes, only needed in two places: the header of interpretAst
(probably only for AstOp
, which requires RealFloat (TensorOf n a)
) and the constructor AstRel
, which requires the rest. But that was a crazy roller-coaster. GHC asked me for a lot of evidence, e.g., IfB
, a lot of type equations, the list kept growing, I was pulling my hair, and then, eventually, it all proved spurious, somehow deducible from the final set there. ¯_(ツ)_/¯
Whoa, HordeAd.Core.AstInterpret.$fEvidenceADVal11
makes 5% of allocations of the big Ast gradient Mnist test. That's certainly due to lack of sharing of Ast terms during delta evaluation, which then result in a costly Ast interpretation pass, but still, I hoped the dictionaries get specialized/inlined somehow. I even added manual SPECIALIZE pragmas and also it's GHC 9.6.1, which is much better at specialization.
In any case, it works fine even under this artificial load (due to bad sharing), so let me close the ticket.
Edit; and 3.5% of runtime.
@edsko just contributed a clear summary of the two known workarounds
https://gitlab.haskell.org/ghc/ghc/-/issues/14860#note_495352
so I perhaps this time I can understand the crux of our problem. Are the workarounds not applicable to our use case because the TensorOf
type family has two arguments and that's why we need the explicit dictionaries, etc.?
This seems such a silly reason. Just have the pair construct in the language of types and problem is gone, right?
I don't think the point is the absolute value of the arity, the point is whether the quantification ranges over a variable that the type family dispatches on or not. In Edsko's first example, the quantified variable is a
, but the type family (F x a
) analyses only x
. Hence we can "let-bind" F x
outside of the quantification, which is workaround 1. Workaround 2 is, I guess, a "let-bind" on a higher level? I'm not 100% sure how to explain that one.
The immediate reason workaround 1 doesn't work for Edsko's second example is that GHC doesn't support partial application of type synonyms/families, but even if it did accept the code for type checking, it would still not work because the let-binding doesn't do anything. This is similar to how transforming f x y + f x z
to let g = \a -> f x z in g y + g z
is not a useful application of "common" subexpression elimination, because we're not actually saving any computation. Similarly, abstracting out f ~ F x
if F
dispatches on a
also would be useless, because then f a
would still be a type family application.
The reason workaround 2 doesn't work for Edsko's second example is, I think, actually the core of the problem. GHC is asked to prove that for all a
, we have P (f a)
for some property P
(namely, membership of the type class C
) and some function f
(namely F x
). For an open type family F
this would clearly not be provable, because what if someone adds a type instance later, in a different package, that violates the property? For closed type families it would seem possible, by exhaustively checking all the left-hand sides of F
. But even then, an implementation of the dictionary for this quantified constraint would need to dispatch, at runtime, on the actual value of the type a
that gets passed in. And since types are erased, that isn't going to work (without additional machinery like Typeable
, or some other singleton representation).
EDIT: @Mikolaj, do you remember what the desired quantified constraint was?
@Mikolaj, do you remember what the desired quantified constraint was?
I think we wanted to assert about a
that forall n. RealFloat (TensorOf n a)
, etc,
Yes, precisely. The reason why we might desire that to work is because for a particular concrete a
, we are often going to have that TensorOf n a
implements RealFloat
for all n
. This is not guaranteed to be so in all cases, but it will be so in all our cases. And even if it weren't so in all cases, one would still like to be able to satisfy the quantified constraint in the cases where it does work.
I would expect Edsko's workaround 1 to perhaps work if we flip the argument order of TensorOf, making it TensorOf r n
, where it scrutinises the r
only. I.e. type family TensorOf r :: Nat -> Type
. Except this doesn't work because TensorOf
also has a functional dependency, and it seems one cannot give both a functional dependency and a kind signature to a type family at the same time. I'm not sure if this is just a forgotten case in the syntax, or whether there are actual problems in supporting the combination.
Actually maybe the second workaround just works? Or does this proof of concept not capture all the difficulties? https://play.haskell.org/saved/EBWCVstM
EDIT: Actually this is a slightly more complete demonstration: https://play.haskell.org/saved/Kl8A7A3e
Yay, it probably works for RealFloat
. It objects to forall n. CTensorOf EqB n (Primal a)
, saying
simplified/HordeAd/Core/AstInterpret.hs:134:1: error:
• Illegal type synonym family application ‘Primal a’ in instance:
CTensorOf EqB n (Primal a)
• In the quantified constraint ‘forall (n :: GHC.TypeNats.Nat).
CTensorOf EqB n (Primal a)’
In the type synonym declaration for ‘Evidence’
and the following just doesn't fit the schema:
But for that we have workaround 1, because the quantifier doesn't quantify over a
: https://play.haskell.org/saved/QhYMdcrn
Wow, how strange. I need to name a bunch of such constraints
type Evi1 p = (Tensor p, forall n. CTensorOf Num n p, forall n. CTensorOf Eq n p)
type Evi2 a = Evi1 (Primal a)
and, strangely, it fails with
simplified/HordeAd/Core/AstInterpret.hs:145:27: error:
• Expected a type, but ‘CTensorOf Num n p’ has kind ‘Constraint’
• In the type ‘(forall n. CTensorOf Num n p,
forall n. CTensorOf Eq n p)’
In the type declaration for ‘Evi1’
|
145 | type Evi1 p = ( forall n. CTensorOf Num n p
| ^^^^^^^^^^^^^^^^^
when I remove the Tensor p
. It also fails as written with
simplified/HordeAd/Core/AstInterpret.hs:146:1: error:
• Illegal type synonym family application ‘Primal a’ in instance:
CTensorOf Num n (Primal a)
• In the quantified constraint ‘forall (n :: GHC.TypeNats.Nat).
CTensorOf Num n (Primal a)’
In the type synonym declaration for ‘Evi2’
|
146 | type Evi2 a = Evi1 (Primal a)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
To justify my request, currently type signatures of my mutually recursive functions are
interpretAst
:: forall n a. (KnownNat n, Evidence a)
=> AstEnv a -> AstMemo a
-> Ast n (ScalarOf a) -> (AstMemo a, TensorOf n a)
using the hideous and costly Dict
and Refl
tricks. If I can't bunch the workaround constraints, I'd end up with
interpretAst
:: forall n a. (KnownNat n, Evidence a, p ~ Primal a, CTensorOf EqB n p, CTensorOf OrdB n p)
=> AstEnv a -> AstMemo a
-> Ast n (ScalarOf a) -> (AstMemo a, TensorOf n a)
which is a downgrade, even if the runtime overhead is minimized. Multiply by 10 or 20 signatures. And the number of constraints tends to grow.
And applying the workaround only to RealFloat
almost works and then infuriatingly fails at SPECIALIZE demanding that I prove the hidden n
is KnownNat
. I tried at https://play.haskell.org/saved/QhYMdcrn (and I didn't save :( ) to minimize the problem, but I failed. Here's the full version
https://github.com/Mikolaj/horde-ad/commit/f13e17e691fab86e1db197475107b106d81b3548
that fails with
simplified/HordeAd/Core/AstInterpret.hs:550:1: error:
• No instance for (KnownNat n) for ‘interpretAstBool’
Possible fix:
add (KnownNat n) to the context of a quantified context
• In the pragma:
{-# SPECIALIZE interpretAstBool ::
AstEnv (ADVal Double)
-> AstMemo (ADVal Double)
-> AstBool Double -> (AstMemo (ADVal Double), Bool) #-}
|
550 | {-# SPECIALIZE interpretAstBool
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^...
and the same for a few more SPECIALIZE pragmas.
Edit; all this with GHC 9.4.5.
Thank you again for the interesting chat and for the solutions. The playgrounds we used: https://play.haskell.org/saved/qe5uIvvC and https://play.haskell.org/saved/RpZ1vgAj
BTW, is there an easier way to adjust datatypes coming from other libraries to the switched order than going via newtypes, as below?
type TensorOf (n :: Nat) r = Ranked r n
class Tensor r where
type Ranked r = (t :: Nat -> Type) | t -> r
newtype ArraySwapped r n = ArraySwapped (OR.Array n r)
instance Tensor Double where
-- fails: type Ranked Double n = OR.Array n r
type Ranked Double = ArraySwapped Double
Not really, although you could use a general purpose Flip
newtype.
Thank you. The Flip
indeed works fine.
Before.
instance Tensor Double where
type TensorOf n Double = OR.Array n Double
After
instance Tensor Double where
type Ranked Double = Flip OR.Array Double
Now yet another hiccup. Can we do better than introducing yet another newtype?
Before. (ADVal
is a normal datatype with one type parameter.)
instance Tensor (ADVal Double) where
type TensorOf n (ADVal Double) = ADVal (OR.Array n Double)
After.
instance Tensor (ADVal Double) where
type Ranked (ADVal Double) = Compose ADVal (Flip OR.Array Double)
The bad thing about this case is that even if the type family OR.Array
was flipped upstream, we'd still need to wrap all class methods in Compose
(though not Flip
). But I hope I'm missing something.
TLDR: I managed to reproduce the "add (KnownNat n) to the context of a quantified context" problem by just adding the first KnownNat
to our running example
https://play.haskell.org/saved/U4IVUDjq
The full story:
No stone has been left unturned, but our TensorOf associated type family now has reversed order of arguments. It's a pity we now need to use both reversed and normal order of arguments in tests and occasionally the Compose functor, too. But perhaps most of that can be hidden from an ordinary user (a user that does not need to extend the set of differentiable primitives).
https://github.com/Mikolaj/horde-ad/commit/7451afa20d3241fc2d9ae13e01e48d1f39747410
The "add (KnownNat n) to the context of a quantified context" problem in SPECIALIZE recurred and when I commented SPECIALIZE out, it occured much later in the program, where I actually call the functions.
https://github.com/Mikolaj/horde-ad/commit/b14cc369d6df659618a85704765fb62efdbd95a3
The cause seems to be that I lied to you that the constraints don't depend on n
. In fact, they do depend on n
being KnownNat
, because that's needed for a lot of instances, e.g., of class RealFloat
. Where we have forall n. CRanked RealFloat a n
, we should instead have forall n. KnownNat n implies CRanked RealFloat a n
. I wonder if this can be worked around by the Dict
trick or any variant of the quantified constraints of families workarounds. Perhaps the reversed order of arguments helps (not reflected in the updated playground example)?
The cause seems to be that I lied to you that the constraints don't depend on n. In fact, they do depend on n being KnownNat, because that's needed for a lot of instances, e.g., of class RealFloat. Where we have forall n. CRanked RealFloat a n, we should instead have forall n. KnownNat n implies CRanked RealFloat a n.
Is there a problem with just doing this? If the classes require it, then surely the forall n.
can't ever work. Can't we just add said KnownNat
constraint? It seems to work in the playground (I only changed foo
's type): https://play.haskell.org/saved/gbzbJKuq
Oh, wow, so not all is lost! I only tried that by adding the KnownNat n =>
into my bundle of constraints and it fails as follows.
https://play.haskell.org/saved/yDKadtxO
Though it also works with a dummy class that only serves to bundle constraints. But I'd need to make a dummy instance of this class for each type I'm planning to use, so it's not ideal.
@phadej just made it work: https://play.haskell.org/saved/EdozFXhx
The secret ingredient was the type BundleOfConstraints :: Type -> Constraint
signature and ImpredicativeTypes
to make it compile (https://gitlab.haskell.org/ghc/ghc/-/issues/16140).
Here's the full code in the wild.
https://github.com/Mikolaj/horde-ad/commit/391f407837993a3b1087db0685120a12f12a0ad8
It uses 3 copies of workaround 2. All tests pass, though they are 5% slower that the Dict
hack, which was a couple percent slower than the original five identical copies of the code, one for each type at which it's used. To be profiled, but I suspect specialization woes (it's on GHC 9.4.5 and specialization is known broken until 9.6, which sadly I can't use, because on GHC 9.6.1 our tests get broken for unclear reasons). Or is there a plausible justification for quantified constraints to degrade performance (this is with -fexpose-all-unfoldings -fspecialise-aggressively
and explicit SPECIALIZE pragmas, so I hope they shouldn't).
I wonder if the reversed order of arguments could let us get rid of the workarounds or perhaps of explicit quantified constraints at all. E.g., could we define a constraint Num2
such that Num2 (Ranked Double)
holds iff for all n Num (Ranked Double n)
holds? Could we define ForallN
such that ForallN Num
is Num2
? (Edit: I see something similar in https://hackage.haskell.org/package/base-4.18.0.0/docs/Data-Functor-Classes.html#t:Eq1 and https://hackage.haskell.org/package/base-4.18.0.0/docs/Data-Functor-Classes.html#t:Eq2, but I don't know how our extra KnownNat
affects that and I couldn't find Num1
on hoogle.)
BTW, I've found a probable GHC 9.6.1 bug with this code. Tt spams Redundant constraint: KnownNat x
a hundred times, probably whenever our quantified constraints are eventually specialized to a particular x
(I suspect GHC 9.4.5 just gives up trying to specialize and so doesn't get to warn spuriously).
Here's the maximally simplified variant that does not require ImpredicativeTypes
, but is extremely fragile:
https://github.com/Mikolaj/horde-ad/commit/bc734c7dbe11daeb98b377872f6f25e63c73dd34
Funnily, changing BooleanOf r ~ b
to b ~ BooleanOf r
so that the class definition and instance is now
class (b ~ BooleanOf r) => BooleanOfMatches b r where
instance (b ~ BooleanOf r) => BooleanOfMatches b r where
breaks it [Edit: but it turns out the old version was similarly fragile]. Similarly, any relying on transitivity of ~
breaks the typing. (Edit: just filed https://gitlab.haskell.org/ghc/ghc/-/issues/23333 and the minimal repro is at https://play.haskell.org/saved/ztYeoPjf and an even smaller one at https://play.haskell.org/saved/EtiAllWh.)
Edit: and it seems to have fully undone the performance breakage that the previous attempts incurred. Probably both the new 5% slowdown and the old 2% from Dicts is now gone.
The BooleanOf r ~ b
vs b ~ BooleanOf
and related problems are now fixed in
https://gitlab.haskell.org/ghc/ghc/-/merge_requests/10389
and worked around in
https://github.com/Mikolaj/horde-ad/commit/4e3eee462922dd997af88af3e9fb577b1ad37d28
Limitations of the Haskell type system force us to copy this identical long instance code 5 times.
https://github.com/Mikolaj/horde-ad/blob/37bd69171c02e8a3acd951527763c36cd29a39c2/simplified/HordeAd/Core/ADValTensor.hs#L397
Ideally, one copy would suffice and then it can be a set of polymorphic functions, not a class. The best bet seems to be quantified constraints, which are however illegal for type families. These two issues list many workarounds for this limitation and perhaps one of these would suffice:
We already use of the workarounds in
IsPrimalR
andIsPrimalA
, so perhaps these can be improved. There are also many others, smaller, duplicated pieces of code, mostly for Double vs Float, that may be eliminated similarly, if we find a really good workaround or a plugin or a TH trick.