Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
29 stars 6 forks source link

Deduplicate InterpretAst instances into a single polymorphic set of functions #90

Closed Mikolaj closed 1 year ago

Mikolaj commented 1 year ago

Limitations of the Haskell type system force us to copy this identical long instance code 5 times.

https://github.com/Mikolaj/horde-ad/blob/37bd69171c02e8a3acd951527763c36cd29a39c2/simplified/HordeAd/Core/ADValTensor.hs#L397

Ideally, one copy would suffice and then it can be a set of polymorphic functions, not a class. The best bet seems to be quantified constraints, which are however illegal for type families. These two issues list many workarounds for this limitation and perhaps one of these would suffice:

We already use of the workarounds in IsPrimalR and IsPrimalA, so perhaps these can be improved. There are also many others, smaller, duplicated pieces of code, mostly for Double vs Float, that may be eliminated similarly, if we find a really good workaround or a plugin or a TH trick.

tomsmeding commented 1 year ago

This doesn't actually work because my Evidence typeclass just generalises over Float and Double, and it won't permit an Ast0 instance so we get overlapping instances here because I can only reduce the Float and Double instances to a single one.

However, I think with a more extended set of type equalities, giving the precise equalities the code needs instead of just the plain definitions in terms of bare types that work only for Float and Double, this idea could work. It's a bludgeon hammer though, not sure if you want to do this. Just an idea.

This is a patch on 37bd69171c02e8a3acd951527763c36cd29a39c2. As before with these patches, intentionally no code styling or appropriate naming applied, that's for you to decide, just wanted to see the ghc errors.

In a <details> to not clutter the page. ```diff diff --git a/simplified/HordeAd/Core/ADValTensor.hs b/simplified/HordeAd/Core/ADValTensor.hs index e2b89516..211eee52 100644 --- a/simplified/HordeAd/Core/ADValTensor.hs +++ b/simplified/HordeAd/Core/ADValTensor.hs @@ -25,6 +25,9 @@ import qualified Data.Vector.Generic as V import GHC.TypeLits (KnownNat, type (+)) import Numeric.LinearAlgebra (Numeric, Vector) +import Data.Type.Equality +import Data.Proxy + import HordeAd.Core.Ast import HordeAd.Core.AstSimplify import HordeAd.Core.AstVectorize () @@ -394,14 +397,36 @@ class InterpretAst a where interpretAstDynamic :: AstEnv a -> AstDynamic (ScalarOf a) -> DynamicTensor a +data Dict c a where + Dict :: c a => Dict c a + +class (TensorIsArray a, HasPrimal (ADVal a), Numeric a, IsPrimalR a, RealFloat a, Floating (Vector a), Tensor (ADVal a), Tensor a) => Evidence a where + evi1 :: forall n. Proxy a + -> (TensorOf n (Primal (ADVal a)) :~: OR.Array n (Primal (ADVal a)) + ,TensorOf n (ADVal a) :~: ADVal (OR.Array n a) + ,BooleanOf (TensorOf n a) :~: Bool + ,Dict EqB (TensorOf n a) + ,Dict OrdB (TensorOf n a)) + evi2 :: Proxy a + -> (Primal (ADVal a) :~: a + ,Primal (ADVal a) :~: a + ,ScalarOf (ADVal a) :~: a + ,IntOf (ADVal a) :~: Int + ,Primal a :~: a + ,IntOf a :~: Int + ,BooleanOf a :~: Bool) + +instance Evidence Float where evi1 _ = (Refl, Refl, Refl, Dict, Dict) ; evi2 _ = (Refl, Refl, Refl, Refl, Refl, Refl, Refl) +instance Evidence Double where evi1 _ = (Refl, Refl, Refl, Dict, Dict) ; evi2 _ = (Refl, Refl, Refl, Refl, Refl, Refl, Refl) + -- These are several copies of exactly the same code, past the instantiated -- interpretAst signature. See if any workaround from -- https://gitlab.haskell.org/ghc/ghc/-/issues/14860 and -- https://gitlab.haskell.org/ghc/ghc/-/issues/16365 work here -- (and elsewhere, where we copy code for similar reasons). -instance InterpretAst (ADVal Double) where +instance Evidence s => InterpretAst (ADVal s) where interpretAst - :: forall n0 a. (KnownNat n0, a ~ ADVal Double) + :: forall n0 a. (KnownNat n0, a ~ ADVal s) => AstEnv a -> Ast n0 (ScalarOf a) -> TensorOf n0 a interpretAst = interpretAstRec where @@ -416,14 +441,14 @@ instance InterpretAst (ADVal Double) where :: forall n. KnownNat n => AstEnv a -> AstPrimalPart n (ScalarOf a) -> TensorOf n (Primal a) - interpretAstPrimal env (AstPrimalPart v) = + interpretAstPrimal env (AstPrimalPart v) | (Refl, Refl, _, _, _) <- evi1 @s @n Proxy, (Refl, Refl, Refl, Refl, Refl, Refl, _) <- evi2 @s Proxy = toArray $ tprimalPart $ interpretAstRec env v interpretAstRec :: forall n. KnownNat n => AstEnv a -> Ast n (ScalarOf a) -> TensorOf n a - interpretAstRec env = \case + interpretAstRec env | (Refl, Refl, _, _, _) <- evi1 @s @n Proxy, (Refl, Refl, Refl, Refl, Refl, Refl, _) <- evi2 @s Proxy = \case AstVar _sh (AstVarName var) -> case IM.lookup var env of Just (AstVarR d) -> tfromD d Just AstVarI{} -> @@ -477,7 +502,7 @@ instance InterpretAst (ADVal Double) where interpretAstInt :: AstEnv a -> AstInt (ScalarOf a) -> IntOf (Primal a) - interpretAstInt env = \case + interpretAstInt env | (Refl, Refl, Refl, Refl, Refl, Refl, Refl) <- evi2 @s Proxy = \case AstIntVar (AstVarName var) -> case IM.lookup var env of Just AstVarR{} -> error $ "interpretAstInt: type mismatch for Var" ++ show var @@ -496,145 +521,24 @@ instance InterpretAst (ADVal Double) where interpretAstBool :: AstEnv a -> AstBool (ScalarOf a) -> BooleanOf (Primal a) - interpretAstBool env = \case + interpretAstBool env | (Refl, Refl, Refl, Refl, Refl, Refl, Refl) <- evi2 @s Proxy = \case AstBoolOp opCodeBool args -> interpretAstBoolOp (interpretAstBool env) opCodeBool args AstBoolConst a -> if a then true else false - AstRel opCodeRel args -> + AstRel @n opCodeRel args | (_, _, Refl, Dict, Dict) <- evi1 @s @n Proxy -> let f v = interpretAstPrimal env (AstPrimalPart v) in interpretAstRelOp f opCodeRel args AstRelInt opCodeRel args -> let f = interpretAstInt env in interpretAstRelOp f opCodeRel args - interpretAstDynamic - :: forall a. a ~ ADVal Double - => AstEnv a -> AstDynamic (ScalarOf a) -> DynamicTensor a - interpretAstDynamic = interpretAstDynamicRec - where - interpretAstDynamicRec - :: AstEnv a - -> AstDynamic (ScalarOf a) -> DynamicTensor a - interpretAstDynamicRec env = \case - AstDynamicDummy -> error "interpretAstDynamic: AstDynamicDummy" - AstDynamicPlus v u -> - interpretAstDynamicRec env v `taddD` interpretAstDynamicRec env u - AstDynamicFrom w -> tfromR $ interpretAst env w - -instance InterpretAst (ADVal Float) where - interpretAst - :: forall n0 a. (KnownNat n0, a ~ ADVal Float) - => AstEnv a -> Ast n0 (ScalarOf a) -> TensorOf n0 a - interpretAst = interpretAstRec - where --- We could duplicate interpretAst to save some time (sadly, we can't --- interpret Ast uniformly in any Tensor and HasPrimal instance due to typing, --- so we can't just use an instance of interpretation to OR.Array for that), --- but it's not a huge saving, because all dual parts are gone before --- we do any differentiation and they are mostly symbolic, so don't even --- double the amount of tensor computation performed. The biggest problem is --- allocation of tensors, but they are mostly shared with the primal part. - interpretAstPrimal - :: forall n. KnownNat n - => AstEnv a - -> AstPrimalPart n (ScalarOf a) -> TensorOf n (Primal a) - interpretAstPrimal env (AstPrimalPart v) = - toArray $ tprimalPart $ interpretAstRec env v - interpretAstRec - :: forall n. KnownNat n - => AstEnv a - -> Ast n (ScalarOf a) -> TensorOf n a - interpretAstRec env = \case - AstVar _sh (AstVarName var) -> case IM.lookup var env of - Just (AstVarR d) -> tfromD d - Just AstVarI{} -> - error $ "interpretAstRec: type mismatch for Var" ++ show var - Nothing -> error $ "interpretAstRec: unknown variable Var" ++ show var - AstOp opCode args -> - interpretAstOp (interpretAstRec env) opCode args - AstConst a -> tconst a - AstConstant a -> tconst $ interpretAstPrimal env a - AstConstInt i -> tfromIndex0 $ interpretAstInt env i - AstIndexZ v is -> tindex (interpretAstRec env v) (fmap (interpretAstInt env) is) - -- if index is out of bounds, the operations returns with an undefined - -- value of the correct rank and shape; this is needed, because - -- vectorization can produce out of bound indexing from code where - -- the indexing is guarded by conditionals - AstSum v -> tsum (interpretAstRec env v) - -- TODO: recognize when sum0 may be used instead, which is much cheaper - -- or should I do that in Delta instead? no, because tsum0R is cheaper, too - -- TODO: recognize dot0 patterns and speed up their evaluation - AstScatter sh v (vars, ix) -> - tscatter sh (interpretAstRec env v) - (interpretLambdaIndexToIndex interpretAstInt env (vars, ix)) - AstFromList l -> tfromList (map (interpretAstRec env) l) - AstFromVector l -> tfromVector (V.map (interpretAstRec env) l) - AstKonst k v -> tkonst k (interpretAstRec env v) - AstAppend x y -> tappend (interpretAstRec env x) (interpretAstRec env y) - AstSlice i k v -> tslice i k (interpretAstRec env v) - AstReverse v -> treverse (interpretAstRec env v) - AstTranspose perm v -> ttranspose perm $ interpretAstRec env v - AstReshape sh v -> treshape sh (interpretAstRec env v) - AstBuild1 k (var, AstConstant r) -> - tconst - $ OR.ravel . ORB.fromVector [k] . V.generate k - $ toArray . interpretLambdaI interpretAstPrimal env (var, r) - AstBuild1 k (var, v) -> tbuild1 k (interpretLambdaI interpretAstRec env (var, v)) - -- to be used only in tests - AstGatherZ sh v (vars, ix) -> - tgather sh (interpretAstRec env v) - (interpretLambdaIndexToIndex interpretAstInt env (vars, ix)) - -- the operation accept out of bounds indexes, - -- for the same reason ordinary indexing does, see above - -- TODO: currently we store the function on tape, because it doesn't - -- cause recomputation of the gradient per-cell, unlike storing the build - -- function on tape; for GPUs and libraries that don't understand Haskell - -- closures, we cneck if the expressions involve tensor operations - -- too hard for GPUs and, if not, we can store the AST expression - -- on tape and translate it to whatever backend sooner or later; - -- and if yes, fall back to POPL pre-computation that, unfortunately, - -- leads to a tensor of deltas - AstFromDynamic t -> tfromD $ interpretAstDynamic env t - - interpretAstInt :: AstEnv a - -> AstInt (ScalarOf a) -> IntOf (Primal a) - interpretAstInt env = \case - AstIntVar (AstVarName var) -> case IM.lookup var env of - Just AstVarR{} -> - error $ "interpretAstInt: type mismatch for Var" ++ show var - Just (AstVarI i) -> i - Nothing -> error $ "interpretAstInt: unknown variable Var" ++ show var - AstIntOp opCodeInt args -> - interpretAstIntOp (interpretAstInt env) opCodeInt args - AstIntConst a -> a - AstIntFloor v -> let u = interpretAstPrimal env (AstPrimalPart v) - in tfloor u - AstIntCond b a1 a2 -> ifB (interpretAstBool env b) - (interpretAstInt env a1) - (interpretAstInt env a2) - AstMinIndex1 v -> tminIndex0 $ interpretAstRec env v - AstMaxIndex1 v -> tmaxIndex0 $ interpretAstRec env v - - interpretAstBool :: AstEnv a - -> AstBool (ScalarOf a) -> BooleanOf (Primal a) - interpretAstBool env = \case - AstBoolOp opCodeBool args -> - interpretAstBoolOp (interpretAstBool env) opCodeBool args - AstBoolConst a -> if a then true else false - AstRel opCodeRel args -> - let f v = interpretAstPrimal env (AstPrimalPart v) - in interpretAstRelOp f opCodeRel args - AstRelInt opCodeRel args -> - let f = interpretAstInt env - in interpretAstRelOp f opCodeRel args interpretAstDynamic - :: forall a. a ~ ADVal Float - => AstEnv a -> AstDynamic (ScalarOf a) -> DynamicTensor a + :: AstEnv (ADVal s) -> AstDynamic (ScalarOf (ADVal s)) -> DynamicTensor (ADVal s) interpretAstDynamic = interpretAstDynamicRec where interpretAstDynamicRec - :: AstEnv a - -> AstDynamic (ScalarOf a) -> DynamicTensor a + :: AstEnv (ADVal s) + -> AstDynamic (ScalarOf (ADVal s)) -> DynamicTensor (ADVal s) interpretAstDynamicRec env = \case AstDynamicDummy -> error "interpretAstDynamic: AstDynamicDummy" AstDynamicPlus v u -> ```
Mikolaj commented 1 year ago

It is a bludgeon hammer. :D

So instead of the illegal quantified constraints over type families, we promise to supply dicts and proofs of type equality for the same assumptions and the code needs to distribute them properly? That distribution thing is rather annoying, but easier to maintain than code duplication. It's also good that this machinery does not leak beyond this code (neither up nor down the code dependency tree).

Mikolaj commented 1 year ago

I simplified the relevant type classes and tried to generalize this further, but got discouraged by not being able to find n to apply in

   interpretAstBool :: AstEnv a
                    -> AstBool (ScalarOf a) -> BooleanOf (Primal a)
   interpretAstBool env | (Refl, Refl, Refl, _, _) <- ev @a @n Proxy = \case

in order to provide BooleanOf (TensorOf n (Primal a)) :~: BooleanOf (IntOf a) from

  ev :: forall n. Proxy a
     -> ( TensorOf n a :~: TensorOf n a
        , BooleanOf (TensorOf n a) :~: BooleanOf (TensorOf n (Primal a))
        , BooleanOf (TensorOf n (Primal a)) :~: BooleanOf (IntOf a)
        , Dict EqB (TensorOf n a)
        , Dict OrdB (TensorOf n a) )

needed for a recursive call

     AstRel opCodeRel args ->
       let f v = interpretAstPrimal env (AstPrimalPart v)
       in interpretAstRelOp f opCodeRel args

because the n is known from AstRel, not from the signature of interpretAstBool. And indeed, the following works

     AstRel @n opCodeRel args | (Refl, Refl, Refl, _, _) <- ev @a @n Proxy ->
       let f v = interpretAstPrimal env (AstPrimalPart v)
       in interpretAstRelOp f opCodeRel args

but that means I need to summon the evidence all over the code, for various n in hand. But perhaps that's not going to be so bad, there is only so many existential constructors in the grammar. I guess I'd need to do this dance only for each type signature and existential constructor?

tomsmeding commented 1 year ago

Yeah I feel like that'll be unavoidable; trying to get a quantified constraint in scope is wading into ImpredicativeTypes territory where I expect the troubles will be even worse than simply having a type class instance over a type family.

Perhaps, if most of the problematic calls are to the same functions, it might be possible to write wrapper functions for those that have more general type signatures and that construct the evidence for themselves in their function body. That may or may not reduce the number of times you have to write that evidence creation code. But I suspect the code will be much less readable for it; there is a lot to say for not making type signatures more complicated than they already are.

Mikolaj commented 1 year ago

I panicked way too soon. Look how simple it ended up being:

https://github.com/Mikolaj/horde-ad/blob/f07fec1b477794ec857d574c2b68d62ad6d9dbb2/simplified/HordeAd/Core/AstInterpret.hs#L76

I can't say I understand this success, but I certainly don't complain. :)

Mikolaj commented 1 year ago

Edit: even slightly simpler: https://github.com/Mikolaj/horde-ad/blob/24b22ddf3039725efda50a976c2dd466a8722f6a/simplified/HordeAd/Core/AstInterpret.hs

tomsmeding commented 1 year ago

What changed? Did you only end up needing it for AstRel? I like putting most of the equality constraints in the superclass list, that's a lot less Refl :)

Mikolaj commented 1 year ago

Yes, only needed in two places: the header of interpretAst (probably only for AstOp, which requires RealFloat (TensorOf n a)) and the constructor AstRel, which requires the rest. But that was a crazy roller-coaster. GHC asked me for a lot of evidence, e.g., IfB, a lot of type equations, the list kept growing, I was pulling my hair, and then, eventually, it all proved spurious, somehow deducible from the final set there. ¯_(ツ)_/¯

Mikolaj commented 1 year ago

Whoa, HordeAd.Core.AstInterpret.$fEvidenceADVal11 makes 5% of allocations of the big Ast gradient Mnist test. That's certainly due to lack of sharing of Ast terms during delta evaluation, which then result in a costly Ast interpretation pass, but still, I hoped the dictionaries get specialized/inlined somehow. I even added manual SPECIALIZE pragmas and also it's GHC 9.6.1, which is much better at specialization.

In any case, it works fine even under this artificial load (due to bad sharing), so let me close the ticket.

Edit; and 3.5% of runtime.

Mikolaj commented 1 year ago

@edsko just contributed a clear summary of the two known workarounds

https://gitlab.haskell.org/ghc/ghc/-/issues/14860#note_495352

so I perhaps this time I can understand the crux of our problem. Are the workarounds not applicable to our use case because the TensorOf type family has two arguments and that's why we need the explicit dictionaries, etc.?

https://github.com/Mikolaj/horde-ad/blob/c5160a0ded88716d45a19da41403729f196193a1/simplified/HordeAd/Core/AstInterpret.hs#L102-L126

This seems such a silly reason. Just have the pair construct in the language of types and problem is gone, right?

tomsmeding commented 1 year ago

I don't think the point is the absolute value of the arity, the point is whether the quantification ranges over a variable that the type family dispatches on or not. In Edsko's first example, the quantified variable is a, but the type family (F x a) analyses only x. Hence we can "let-bind" F x outside of the quantification, which is workaround 1. Workaround 2 is, I guess, a "let-bind" on a higher level? I'm not 100% sure how to explain that one.

The immediate reason workaround 1 doesn't work for Edsko's second example is that GHC doesn't support partial application of type synonyms/families, but even if it did accept the code for type checking, it would still not work because the let-binding doesn't do anything. This is similar to how transforming f x y + f x z to let g = \a -> f x z in g y + g z is not a useful application of "common" subexpression elimination, because we're not actually saving any computation. Similarly, abstracting out f ~ F x if F dispatches on a also would be useless, because then f a would still be a type family application.

The reason workaround 2 doesn't work for Edsko's second example is, I think, actually the core of the problem. GHC is asked to prove that for all a, we have P (f a) for some property P (namely, membership of the type class C) and some function f (namely F x). For an open type family F this would clearly not be provable, because what if someone adds a type instance later, in a different package, that violates the property? For closed type families it would seem possible, by exhaustively checking all the left-hand sides of F. But even then, an implementation of the dictionary for this quantified constraint would need to dispatch, at runtime, on the actual value of the type a that gets passed in. And since types are erased, that isn't going to work (without additional machinery like Typeable, or some other singleton representation).

EDIT: @Mikolaj, do you remember what the desired quantified constraint was?

Mikolaj commented 1 year ago

@Mikolaj, do you remember what the desired quantified constraint was?

I think we wanted to assert about a that forall n. RealFloat (TensorOf n a), etc,

tomsmeding commented 1 year ago

Yes, precisely. The reason why we might desire that to work is because for a particular concrete a, we are often going to have that TensorOf n a implements RealFloat for all n. This is not guaranteed to be so in all cases, but it will be so in all our cases. And even if it weren't so in all cases, one would still like to be able to satisfy the quantified constraint in the cases where it does work.

I would expect Edsko's workaround 1 to perhaps work if we flip the argument order of TensorOf, making it TensorOf r n, where it scrutinises the r only. I.e. type family TensorOf r :: Nat -> Type. Except this doesn't work because TensorOf also has a functional dependency, and it seems one cannot give both a functional dependency and a kind signature to a type family at the same time. I'm not sure if this is just a forgotten case in the syntax, or whether there are actual problems in supporting the combination.

tomsmeding commented 1 year ago

Actually maybe the second workaround just works? Or does this proof of concept not capture all the difficulties? https://play.haskell.org/saved/EBWCVstM

EDIT: Actually this is a slightly more complete demonstration: https://play.haskell.org/saved/Kl8A7A3e

Mikolaj commented 1 year ago

Yay, it probably works for RealFloat. It objects to forall n. CTensorOf EqB n (Primal a), saying

simplified/HordeAd/Core/AstInterpret.hs:134:1: error:
    • Illegal type synonym family application ‘Primal a’ in instance:
        CTensorOf EqB n (Primal a)
    • In the quantified constraint ‘forall (n :: GHC.TypeNats.Nat).
                                    CTensorOf EqB n (Primal a)’
      In the type synonym declaration for ‘Evidence’

and the following just doesn't fit the schema:

https://github.com/Mikolaj/horde-ad/blob/c5160a0ded88716d45a19da41403729f196193a1/simplified/HordeAd/Core/AstInterpret.hs#L124

tomsmeding commented 1 year ago

But for that we have workaround 1, because the quantifier doesn't quantify over a: https://play.haskell.org/saved/QhYMdcrn

Mikolaj commented 1 year ago

Wow, how strange. I need to name a bunch of such constraints

type Evi1 p = (Tensor p, forall n. CTensorOf Num n p, forall n. CTensorOf Eq n p)
type Evi2 a = Evi1 (Primal a)

and, strangely, it fails with

simplified/HordeAd/Core/AstInterpret.hs:145:27: error:
    • Expected a type, but ‘CTensorOf Num n p’ has kind ‘Constraint’
    • In the type ‘(forall n. CTensorOf Num n p,
                    forall n. CTensorOf Eq n p)’
      In the type declaration for ‘Evi1’
    |
145 | type Evi1 p = ( forall n. CTensorOf Num n p
    |                           ^^^^^^^^^^^^^^^^^

when I remove the Tensor p. It also fails as written with

simplified/HordeAd/Core/AstInterpret.hs:146:1: error:
    • Illegal type synonym family application ‘Primal a’ in instance:
        CTensorOf Num n (Primal a)
    • In the quantified constraint ‘forall (n :: GHC.TypeNats.Nat).
                                    CTensorOf Num n (Primal a)’
      In the type synonym declaration for ‘Evi2’
    |
146 | type Evi2 a = Evi1 (Primal a)
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Mikolaj commented 1 year ago

To justify my request, currently type signatures of my mutually recursive functions are

interpretAst
  :: forall n a. (KnownNat n, Evidence a)
  => AstEnv a -> AstMemo a
  -> Ast n (ScalarOf a) -> (AstMemo a, TensorOf n a)

using the hideous and costly Dict and Refl tricks. If I can't bunch the workaround constraints, I'd end up with

interpretAst
  :: forall n a. (KnownNat n, Evidence a, p ~ Primal a, CTensorOf EqB n p, CTensorOf OrdB n p)
  => AstEnv a -> AstMemo a
  -> Ast n (ScalarOf a) -> (AstMemo a, TensorOf n a)

which is a downgrade, even if the runtime overhead is minimized. Multiply by 10 or 20 signatures. And the number of constraints tends to grow.

Mikolaj commented 1 year ago

And applying the workaround only to RealFloat almost works and then infuriatingly fails at SPECIALIZE demanding that I prove the hidden n is KnownNat. I tried at https://play.haskell.org/saved/QhYMdcrn (and I didn't save :( ) to minimize the problem, but I failed. Here's the full version

https://github.com/Mikolaj/horde-ad/commit/f13e17e691fab86e1db197475107b106d81b3548

that fails with

simplified/HordeAd/Core/AstInterpret.hs:550:1: error:
    • No instance for (KnownNat n) for ‘interpretAstBool’
      Possible fix:
        add (KnownNat n) to the context of a quantified context
    • In the pragma:
        {-# SPECIALIZE interpretAstBool ::
                         AstEnv (ADVal Double)
                         -> AstMemo (ADVal Double)
                            -> AstBool Double -> (AstMemo (ADVal Double), Bool) #-}
    |
550 | {-# SPECIALIZE interpretAstBool
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^...

and the same for a few more SPECIALIZE pragmas.

Edit; all this with GHC 9.4.5.

Mikolaj commented 1 year ago

Thank you again for the interesting chat and for the solutions. The playgrounds we used: https://play.haskell.org/saved/qe5uIvvC and https://play.haskell.org/saved/RpZ1vgAj

BTW, is there an easier way to adjust datatypes coming from other libraries to the switched order than going via newtypes, as below?

type TensorOf (n :: Nat) r = Ranked r n
class Tensor r where
  type Ranked r = (t :: Nat -> Type) | t -> r

newtype ArraySwapped r n = ArraySwapped (OR.Array n r)

instance Tensor Double where
  -- fails: type Ranked Double n = OR.Array n r
  type Ranked Double = ArraySwapped Double
edsko commented 1 year ago

Not really, although you could use a general purpose Flip newtype.

Mikolaj commented 1 year ago

Thank you. The Flip indeed works fine.

Before.

instance Tensor Double where
  type TensorOf n Double = OR.Array n Double

After

instance Tensor Double where
  type Ranked Double = Flip OR.Array Double

Now yet another hiccup. Can we do better than introducing yet another newtype?

Before. (ADVal is a normal datatype with one type parameter.)

instance Tensor (ADVal Double) where
  type TensorOf n (ADVal Double) = ADVal (OR.Array n Double)

After.

instance Tensor (ADVal Double) where
  type Ranked (ADVal Double) = Compose ADVal (Flip OR.Array Double)

The bad thing about this case is that even if the type family OR.Array was flipped upstream, we'd still need to wrap all class methods in Compose (though not Flip). But I hope I'm missing something.

Mikolaj commented 1 year ago

TLDR: I managed to reproduce the "add (KnownNat n) to the context of a quantified context" problem by just adding the first KnownNat to our running example

https://play.haskell.org/saved/U4IVUDjq

The full story:

No stone has been left unturned, but our TensorOf associated type family now has reversed order of arguments. It's a pity we now need to use both reversed and normal order of arguments in tests and occasionally the Compose functor, too. But perhaps most of that can be hidden from an ordinary user (a user that does not need to extend the set of differentiable primitives).

https://github.com/Mikolaj/horde-ad/commit/7451afa20d3241fc2d9ae13e01e48d1f39747410

The "add (KnownNat n) to the context of a quantified context" problem in SPECIALIZE recurred and when I commented SPECIALIZE out, it occured much later in the program, where I actually call the functions.

https://github.com/Mikolaj/horde-ad/commit/b14cc369d6df659618a85704765fb62efdbd95a3

The cause seems to be that I lied to you that the constraints don't depend on n. In fact, they do depend on n being KnownNat, because that's needed for a lot of instances, e.g., of class RealFloat. Where we have forall n. CRanked RealFloat a n, we should instead have forall n. KnownNat n implies CRanked RealFloat a n. I wonder if this can be worked around by the Dict trick or any variant of the quantified constraints of families workarounds. Perhaps the reversed order of arguments helps (not reflected in the updated playground example)?

tomsmeding commented 1 year ago

The cause seems to be that I lied to you that the constraints don't depend on n. In fact, they do depend on n being KnownNat, because that's needed for a lot of instances, e.g., of class RealFloat. Where we have forall n. CRanked RealFloat a n, we should instead have forall n. KnownNat n implies CRanked RealFloat a n.

Is there a problem with just doing this? If the classes require it, then surely the forall n. can't ever work. Can't we just add said KnownNat constraint? It seems to work in the playground (I only changed foo's type): https://play.haskell.org/saved/gbzbJKuq

Mikolaj commented 1 year ago

Oh, wow, so not all is lost! I only tried that by adding the KnownNat n => into my bundle of constraints and it fails as follows.

https://play.haskell.org/saved/yDKadtxO

Though it also works with a dummy class that only serves to bundle constraints. But I'd need to make a dummy instance of this class for each type I'm planning to use, so it's not ideal.

https://play.haskell.org/saved/fP3lZB19

Mikolaj commented 1 year ago

@phadej just made it work: https://play.haskell.org/saved/EdozFXhx

The secret ingredient was the type BundleOfConstraints :: Type -> Constraint signature and ImpredicativeTypes to make it compile (https://gitlab.haskell.org/ghc/ghc/-/issues/16140).

Mikolaj commented 1 year ago

Here's the full code in the wild.

https://github.com/Mikolaj/horde-ad/commit/391f407837993a3b1087db0685120a12f12a0ad8

It uses 3 copies of workaround 2. All tests pass, though they are 5% slower that the Dict hack, which was a couple percent slower than the original five identical copies of the code, one for each type at which it's used. To be profiled, but I suspect specialization woes (it's on GHC 9.4.5 and specialization is known broken until 9.6, which sadly I can't use, because on GHC 9.6.1 our tests get broken for unclear reasons). Or is there a plausible justification for quantified constraints to degrade performance (this is with -fexpose-all-unfoldings -fspecialise-aggressively and explicit SPECIALIZE pragmas, so I hope they shouldn't).

I wonder if the reversed order of arguments could let us get rid of the workarounds or perhaps of explicit quantified constraints at all. E.g., could we define a constraint Num2 such that Num2 (Ranked Double) holds iff for all n Num (Ranked Double n) holds? Could we define ForallN such that ForallN Num is Num2? (Edit: I see something similar in https://hackage.haskell.org/package/base-4.18.0.0/docs/Data-Functor-Classes.html#t:Eq1 and https://hackage.haskell.org/package/base-4.18.0.0/docs/Data-Functor-Classes.html#t:Eq2, but I don't know how our extra KnownNat affects that and I couldn't find Num1 on hoogle.)

BTW, I've found a probable GHC 9.6.1 bug with this code. Tt spams Redundant constraint: KnownNat x a hundred times, probably whenever our quantified constraints are eventually specialized to a particular x (I suspect GHC 9.4.5 just gives up trying to specialize and so doesn't get to warn spuriously).

Mikolaj commented 1 year ago

Here's the maximally simplified variant that does not require ImpredicativeTypes, but is extremely fragile:

https://github.com/Mikolaj/horde-ad/commit/bc734c7dbe11daeb98b377872f6f25e63c73dd34

Funnily, changing BooleanOf r ~ b to b ~ BooleanOf r so that the class definition and instance is now

class (b ~ BooleanOf r) => BooleanOfMatches b r where
instance (b ~ BooleanOf r) => BooleanOfMatches b r where

breaks it [Edit: but it turns out the old version was similarly fragile]. Similarly, any relying on transitivity of ~ breaks the typing. (Edit: just filed https://gitlab.haskell.org/ghc/ghc/-/issues/23333 and the minimal repro is at https://play.haskell.org/saved/ztYeoPjf and an even smaller one at https://play.haskell.org/saved/EtiAllWh.)

Edit: and it seems to have fully undone the performance breakage that the previous attempts incurred. Probably both the new 5% slowdown and the old 2% from Dicts is now gone.

Mikolaj commented 1 year ago

The BooleanOf r ~ b vs b ~ BooleanOf and related problems are now fixed in

https://gitlab.haskell.org/ghc/ghc/-/merge_requests/10389

and worked around in

https://github.com/Mikolaj/horde-ad/commit/4e3eee462922dd997af88af3e9fb577b1ad37d28