Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
33 stars 6 forks source link

How to design AD to permit multiple scalar types in the objective function expression? #106

Closed Mikolaj closed 1 year ago

Mikolaj commented 1 year ago

This seems disappointing, because it seems, to implement the multiple scalar types, we can't use normal polymorphism in the scalar type. That's because the map of cotangent contributions corresponding to Delta expression nodes needs a single type of its elements (keys are just identifiers of the nodes). We want different nodes to have different scalar types, so we either need an existential or a variant type covering all floats (and ints) we will ever handle.

The delta evaluation state is currently

data EvalState ranked shaped r = EvalState
  { iMap        :: EM.EnumMap (InputId (DynamicOf ranked r))
                              (DynamicOf ranked r)
      -- ^ eventually, cotangents of objective function inputs
      -- (eventually copied to the vector representing the gradient
      -- of the objective function);
      -- the identifiers need to be contiguous and start at 0
  , dMap        :: EM.EnumMap NodeId (DynamicOf ranked r)
      -- ^ eventually, cotangents of non-input subterms indexed
      -- by their node identifiers
  , nMap        :: EM.EnumMap NodeId (DeltaBinding ranked shaped r)
      -- ^ nodes left to be evaluated
  , astBindings :: [(AstVarId, DynamicOf ranked r)]
  }

and we instead need

data EvalState ranked shaped = EvalState
  { iMap        :: EM.EnumMap (InputId (ExistsDynamicOf ranked))
                              (ExistsDynamicOf ranked)
      -- ^ eventually, cotangents of objective function inputs
      -- (eventually copied to the vector representing the gradient
      -- of the objective function);
      -- the identifiers need to be contiguous and start at 0
  , dMap        :: EM.EnumMap NodeId (ExistsDynamicOf ranked)
      -- ^ eventually, cotangents of non-input subterms indexed
      -- by their node identifiers
  , nMap        :: EM.EnumMap NodeId (ExistsDeltaBinding ranked shaped)
      -- ^ nodes left to be evaluated
  , astBindings :: [(AstVarId, ExistsDynamicOf ranked)]
  }

If we have

data ExistsDynamicOf ranked = forall r. ExistsDynamicOf (DynamicOf ranked r)

then the only way to convert between scalar types would be via the polymorphic realToFrac, because we'd not know which particular scalar type we are handling (unless we do tricks with reflection and unsafe coerce). Otherwse, we could do

data ExistsDynamicOf ranked = ExistsDynamicOfDouble (DynamicOf ranked Double) | ExistsDynamicOfInt (DynamicOf ranked Int), etc. 

in which case we can use specific conversions that work only between particular scalar types, etc.

In either case, many other operations can no longer happily assume they work on an unknown fixed r, but have to unpack and pack the existential or pack and unpack the variant type to get the (pseudo-)tensor inside.

Am I missing alternatives or further problems? Which of the two hacks above looks less painful?

tomsmeding commented 1 year ago

An alternative: moar dependent types.

import "dependent-map" Data.Dependent.Map (DMap)
-- type DMap :: (Type -> Type) -> (Type -> Type) -> Type
-- e.g. (!) :: GCompare k => DMap k f -> k v -> f v
-- e.g. insert :: GCompare k => k v -> f v -> DMap k f -> DMap k f
-- with GCompare from here: https://hackage.haskell.org/package/some-1.0.1/docs/Data-GADT-Compare.html#t:GCompare
import "dependent-sum" Data.Dependent.Sum (DSum)
-- type DSum :: (Type -> Type) -> (Type -> Type) -> Type
-- data DSum f g = forall a. !(f a) :=> (g a)

data Ty a where
  TInt :: Ty Int
  TDouble :: Ty Double
  TFloat :: Ty Float
  -- etc.
  TDynamic :: ranked -> Ty r -> Ty (DynamicOf ranked r)

-- include a Ty in an InputId. This Ty is the SCALAR type, not the array type.
data InputId r = InputId (Ty r) Int
  deriving (Show)

-- do the same for NodeId; this now also begets a type parameter that it didn't have before
data NodeId r = NodeId (Ty r) Int
  deriving (Show)

-- do the same for AstVarId
data AstVarId r = AstVarId (Ty r) Int
  deriving (Show)

-- Note the loss of the `r` parameter everywhere!
data EvalState ranked shaped = EvalState
  { iMap        :: DMap InputId (DynamicOf ranked)
      -- ^ eventually, cotangents of objective function inputs
      -- (eventually copied to the vector representing the gradient
      -- of the objective function);
      -- the identifiers need to be contiguous and start at 0
  , dMap        :: DMap NodeId (DynamicOf ranked)
      -- ^ eventually, cotangents of non-input subterms indexed
      -- by their node identifiers
  , nMap        :: DMap NodeId (DeltaBinding ranked shaped)
      -- ^ nodes left to be evaluated
  , astBindings :: [DSum AstVarId (DynamicOf ranked)]
  }

It would be possible for InputId etc. to be indexed by DynamicOf ranked r instead of just by r, then EvalState would look like this:

data Foo a where
  Foo :: DeltaBinding ranked shaped r -> Foo (DynamicOf ranked r)

-- Note the loss of all type parameters here
data EvalState = EvalState
  { iMap        :: DMap InputId Identity
  , dMap        :: DMap NodeId Identity
  , nMap        :: DMap NodeId Foo
  , astBindings :: [DSum AstVarId Identity]
  }

Sorry for being too lazy to figure out a name for Foo.

I'm not sure if this is the better alternative of the three. It is the one that seems most natural to me. But I'm postponing judgement for now.

EDIT: The point of all of the above is that Ty can be an instance of GCompare, and hence NodeId etc. can also be, hence you can use them as an index of a DMap.

Mikolaj commented 1 year ago

I won't know until I try implementing this and start understanding, but let me attempt some coherent remarks.

This is cool, because it permits me to use the old Delta evaluations code that thinks the r is fixed.

This is not cool in that I need to list all supported scalar types upfront, unlike the existential type method (with a Real constraints context).

This is awkward, because I'd need to stop implementing the domains type (used all over the code as a flattened domain of the objective function and then flattened gradient value, with Ast variants in between) as a boxed vector

https://github.com/Mikolaj/horde-ad/blob/3283f7d867c3cee04e99704045d67502cb5a5d42/simplified/HordeAd/Core/Adaptor.hs#L29

so the relevant old code everywhere outside Delta.hs would need to be changed, after all, and also the dependent maps would proliferate quite widely (in particular the adaptor, flattening and unflattening such vectors, would need to be rewritten).

BTW, what is TDynamic :: ranked -> Ty r -> Ty (DynamicOf ranked r) doing? Is it nesting tensors as elements of other tensors so that we can handle the flavour of orthotope that permits such nesting (at the cost of boxing every element)?

tomsmeding commented 1 year ago

BTW, what is TDynamic :: ranked -> Ty r -> Ty (DynamicOf ranked r) doing? Is it nesting tensors as elements of other tensors so that we can handle the flavour of orthotope that permits such nesting (at the cost of boxing every element)?

Oops, at the very least that doesn't kindcheck. With my code as written you should remove the whole TDynamic constructor from Ty. With the alternative version where InputId etc. are indexed by the full array type, you'd need this:

  TDynamic :: Ty ranked -> Ty r -> Ty (DynamicOf ranked r)
  TORArray :: Ty OR.Array
  -- etc.

Now Ty has become polykinded, which I believe is fine with the appropriate extensions.

This is not cool in that I need to list all supported scalar types upfront, unlike the existential type method (with a Real constraints context).

An alternative here is to use Typeable instead of an explicit Ty, as in

data InputId r = Typeable r => InputId Int

using eqTypeRep.

Some further notes:

Mikolaj commented 1 year ago

Oh, yes, I can now see how the existential type would force me to use unsafeCoerce, because there is no way to know whether the type unpacked from the existential agrees with the type expected by the context. I'm not sure I can even check this at runtime, unless the type from the context is Typeable (via constraints of the eval* functions) and the packed type carries the Typeable dictionary (among many others). Keeping the types in the indexes of the map side-steps the problem.

It was not so bad with keeping ranked and shaped tensors on the map, packed as dynamic tensors, because the dynamic tensor contained the physical shape, which enabled runtime check and recovery of the type. In the worst case I could store a shaped tensor and retrieve a ranked one, which would be morally wrong (the user did not request a conversion), but with no runtime crash risk.

So I need to store the scalar type somehow: 1. as a type of the index, 2. as a tag of an enumeration of all legal scalar types or 3. as a Typeable dictionary. You solution seems to combine 1 with 2 or 3, where 1 prevents the possibility of failure of runtime checks and 2/3 is a way to implement comparison of indexes (so that same indexes with different types don't get disastrously merged). A less typed approach is to have only 2 or 3, with runtime checks that either fail or recover a type.

Given that currently I use the less typed method for storing tensors of different ranks/shapes on the same map, I may start the same low-tech approach. Probably with Typeable, because I'd prefer not to hardwire the set of legal scalars so that the computation backends can determine them (hmatrix, GPU, etc.).

BTW, I think your approach stores the tags (or Typeable stuff) in the index, while the low-tech approach keeps indexes as Int and stores the tags in the values. The latter approach severely limits even the ability to runtime check things, e.g., lookups or stores. I guess the only thing that can be done is to, at each update, compare the tag of the old and the new value. That's an extra runtime cost, sadly. I guess I don't do this currently with shapes of tensors and I should, at least as expensive assertions to be turned off for production.

Edit: phew, actually, I mostly do the checks, because most of my updates is adding to the old value and the addition runs the sameShape stuff you suggested in the past.

tomsmeding commented 1 year ago

A less typed approach is to have only 2 or 3, with runtime checks that either fail or recover a type.

And a more dangerous approach is to only have 1, in which case you need unsafeCoerce Refl to implement GEq and GCompare for InputId etc., as I described in my previous post here.

If you use what you call the low-tech way for scalar types as well, then I think that means dropping the type index from IDs entirely. Which I guess might be what you were moving towards anyway.

Mikolaj commented 1 year ago

Oh, yes, I forgot there's a phantom type there currently. I could keep only ranked there to differentiate indexes (and maps) of terms vs concrete tensors vs dual numbers, but r would be gone, surely.

Mikolaj commented 1 year ago

Closing. This seems to work fine right now. It took some workarounds for quantified constraints (some are not fully resolved yet, but I'm optimistic), but I'm sure it would take much more backtracking if not the thorough design discussion we had. Thank you, Tom.

Mikolaj commented 1 year ago

@tomsmeding , FYI, the existential types approach to having different subtrees of the AST operate on different underlying scalars works fine, but specializing code with such existentials is iffy:

https://github.com/Mikolaj/horde-ad/blob/5393574a36292d133aa92ee2b11a223785a72ad9/src/HordeAd/Core/AstInterpret.hs#L115-L120

I gave a presentation about that hack to the kind WT folk and they gave me a better solution without the offending dictionary in the existential constructors (actually a Typeable constructor would suffice, but GoodScalar constains Typeable)

https://github.com/Mikolaj/horde-ad/blob/5393574a36292d133aa92ee2b11a223785a72ad9/src/HordeAd/Core/Ast.hs#L224-L227

and instead with a singleton type denoting which known underlying scalar is used. That means we will end up having an enumeration type listing all possible underlying scalars. That's unfortunate, because one of the few benefits of existential types was that we could leave the list of scalars open and not keep the codebase in sync with it as it grows.

tomsmeding commented 1 year ago

Ouch, that nested testEquality stuff in your first snippet is really painful indeed. At that point it makes sense that one can just as well have an enumeration type to dispatch on.

I haven't thought this through and perhaps it's nonsense, but maybe the set of allowed scalars can be open anyhow by adding an additional enumeration item TOther :: Typeable t => Ty t to the enumeration Ty. I don't know if that actually works in practice, and if it's worth it in the first place.

Yay more singletons. I like them, they're nice and predictable. I hope performance doesn't suffer.

Mikolaj commented 1 year ago

Actually, this nice trick from @sheaf and @phadej

{-# LANGUAGE GHC2021, QuantifiedConstraints, UndecidableInstances, ViewPatterns, StrictData, GADTs, TypeFamilies, TypeFamilyDependencies, FunctionalDependencies, RecordWildCards, MultiWayIf, LambdaCase, DefaultSignatures, PatternSynonyms, NoStarIsType, DataKinds #-}

import           Data.Type.Equality (type (~~), (:~~:) (HRefl))
import           Data.Kind (Type)
import           Type.Reflection (TypeRep, Typeable, eqTypeRep, typeRep)

pattern TypeRepOf :: forall {k1 :: Type} {k2 :: Type} {a :: k1} (b :: k2).
                     Typeable b => (b ~~ a) => TypeRep a
pattern TypeRepOf <- ( eqTypeRep @k2 @k1 @b @a ( typeRep @b ) -> Just HRefl )

foo :: forall a result. Typeable a => ( forall b. Num b => b -> result ) -> a -> result
foo f x = case typeRep @a of
  TypeRepOf @Double -> f @Double x
  TypeRepOf @Int    -> f @Int    x
  _                 -> error "Not specialised"

permits pattern-matching instead of nested conditionals. Alas, it doesn't work in GHC 9.4 probably due to https://gitlab.haskell.org/ghc/ghc/-/merge_requests/9844 so I decided not to experiment with it but go straight to the singletons instead.

But I'm again confused about what precisely I gain by removing the Typeable dictionary from AstLet. Perhaps nothing, after all. I kept the whole big GoodScalar (which contains Typeable) there because I initially permitted arbitrary scalars in the last case instead of erroring out and the GoodScalar described how to work with the unexpected scalar. That was bad, because the GoodScalar dictionary in a constructor probably breaks specialization of the interpretAst calls at concrete types (this is yet unclear, but https://gitlab.haskell.org/ghc/ghc/-/issues/23874#note_521871 makes me think it's a possibility) and certainly it makes terms larger. [Edit: and it's even less low-tech than existential types]

But if we close the list of scalars, I can leave just Typeable in AstLet and I wonder how a singleton tag is better than a Typeable dictionary. Certainly tools like https://github.com/nomeata/inspection-testing won't complain about the Typeable constraint not being specialized, which is a boon, but other than that? I avoid the runtime error at unknown scalar, because I have static guarantee it's known. What else? The drawback of a singleton is that it's less readable than an existential type (if we can get TypeRepOf to work).

I haven't thought this through and perhaps it's nonsense, but maybe the set of allowed scalars can be open anyhow by adding an additional enumeration item TOther :: Typeable t => Ty t to the enumeration Ty. I don't know if that actually works in practice, and if it's worth it in the first place.

The problem is, you need to conjure the GoodScalar dictionary for that type from thin air. For the scalars of (runtime-) know concrete types, we take the dictionary statically know for the type, but if the only thing we know is Typeable and we don't compare with a concrete type, we don't know how to construct GoodScalar. [Edit: but TOther :: GoodScalar t => Ty would probaly work? I hope it doesn't break specialization of unrelated things] [Edit2: actually Show and Storable from GoodScalar is also needed to derive Show for the AST types and also removing GoodScalar means that I need to use the Typeable trick in absolutely all the little functions that work on AST, which I'm fine not specializing, not just just the interpreter functions that take 80% of the runtime]

tomsmeding commented 1 year ago

Actually, this nice trick [...] permits pattern-matching instead of nested conditionals.

Which, even if it works, is just a syntactic improvement, right? After pattern synonym and guards desugaring that's just going to be a chain of conditionals again.

I'm slightly out of the loop here; let me try to describe the situation as far as I understand it (if I misunderstand please correct me).

There are two separate issues here, both stemming from the fact that the AST now allows programs that use more than one type:

  1. Firstly, any function processing the AST now becomes polymorphically recursive, and in particular the interpretAst* family, which performs significant computations on values of the described types, and thus must (without the nested typerep conditionals tricks etc.) perform those computations using methods from dictionaries that will become available only at runtime, from the AST. This is slow because too much data movement and too little action.
  2. Furthermore, we thus need to store those dictionaries in the AST, which looks like bloat. I'd guess that those dictionaries are only ever allocated once (because they are for scalar types, thus at the point of introduction you necessarily know precisely what type it is -- this is contrary to for example a dictionary for C (Maybe Int), which could be distinct from another C (Maybe Int) dictionary if it was created using the C a => C (Maybe a) function on some unknown type a that happened to be Int at runtime). Thus the AST bloat should just be pointers to those allocated dictionaries in the places where a subcomputation has a type that is not already fixed just by its position in the AST (e.g. a Let RHS; let's call this a "type branching point").

A priori, these two problems seem independent. For (2.), one can indeed replace the dictionaries in the AST by an indexed enumeration of scalar types, but then you're replacing a few pointers to type classes by one pointer to an enumeration constructor -- a memory saving that could also have been achieved by making a single type class that bundles all necessary classes in its superclass list, hence also replacing the multiple pointers in the AST constructor by a single pointer.

However, the choice of representation of the type info in the AST does impact the possible solutions to (1.). Fixing (1.) involves ensuring that at every type branching point, we actually branch on the type we stumble upon and then call one of various specialised variants of the interpretation function. This means that (a.) we need to have something to branch on, and (b.) we need to do a multi-way branch. why does markdown not have alpha lists

a. Thus part of the info contained in the AST node needs to be an encoding of the precise scalar type of the subcomputation. Basically, this is either Typeable or an indexed enumeration. An indexed enumeration is more compact (Typeable is just a TypeRep, and TypeRep contains some more fields), but on the other hand there are probably already TypeReps around for all scalar types, so really Typeable is more compact.

b. On the other hand, a multi-way branch on an indexed enumeration is just... a multi-way branch, i.e. a case. A multi-way branch on TypeReps, however, is going to be a chain of conditionals, written using case-using syntactic sugar or not. When suitably optimised and when the chain is short enough, a chain of conditionals is actually more efficient than a jump table (hence gcc compiling small switch statements to chains of conditionals, lol, at least when last I looked), but that depends on all of the TypeRep matching stuff simplifying away to little more than a single comparison instruction, which I doubt.

So based on this, the indexed enumeration approach sounds like the superior option, fixing both (1.) and (2.). Of course there is the drawback of closing the space of scalars, but then, the current chain of TypeRep conditionals is no better.

Am I missing stuff?

Mikolaj commented 1 year ago

I'm slightly out of the loop here; let me try to describe the situation as far as I understand it (if I misunderstand please correct me).

You seem to understand this better than me, since I have trouble seeing the forest for the trees. :)

There are two separate issues here, both stemming from the fact that the AST now allows programs that use more than one type:

Well put. I had no idea I'm using the dreaded polymorphic recursion but, yes, I am.

A priori, these two problems seem independent. For (2.), one can indeed replace the dictionaries in the AST by an indexed enumeration of scalar types, but then you're replacing a few pointers to type classes by one pointer to an enumeration constructor -- a memory saving that could also have been achieved by making a single type class that bundles all necessary classes in its superclass list, hence also replacing the multiple pointers in the AST constructor by a single pointer.

Actually GoodScalar is such a bundle but GHC, at least sometimes, helpfully expands it to a tuple of constraints. :D

But I think your point that it's all just a couple of pointers is crucial, if true. Looking at core, I've seen all kinds of strange dictionary expressions that mean the same thing, but differ and even refer to some local things. But either your are right these are all just shared pointers to a few dictionaries for the same thing, only constructed differently, which is fine. Or this is a problem with GHC that struggles to realise the dictionaries are fully determined and can be moved to top-level and so shared (https://gitlab.haskell.org/ghc/ghc/-/issues/23874 may describe such cases, but it's over my head) and if so, it's going to be fixed in GHC RSN I'm sure.

a. Thus part of the info contained in the AST node needs to be an encoding of the precise scalar type of the subcomputation. Basically, this is either Typeable or an indexed enumeration. An indexed enumeration is more compact (Typeable is just a TypeRep, and TypeRep contains some more fields), but on the other hand there are probably already TypeReps around for all scalar types, so really Typeable is more compact.

Oh, good, so it's all the same.

b. On the other hand, a multi-way branch on an indexed enumeration is just... a multi-way branch, i.e. a case. A multi-way branch on TypeReps, however, is going to be a chain of conditionals, written using case-using syntactic sugar or not. When suitably optimised and when the chain is short enough, a chain of conditionals is actually more efficient than a jump table (hence gcc compiling small switch statements to chains of conditionals, lol, at least when last I looked), but that depends on all of the TypeRep matching stuff simplifying away to little more than a single comparison instruction, which I doubt.

Ben said yesterday it's just comparison of integers (Integers or Ints? I don't know) and so it's cheap, but OTOH, my profiles show 10% of runtime taken by the nested conditional alone. But, yes, matching on an enum is probably not going to matter much, even if it slashes that 10% to 5%. A better optimization would be to handle some common cases specially or make the number of constructors with existential types smaller (there is over a dozen now).

So based on this, the indexed enumeration approach sounds like the superior option, fixing both (1.) and (2.). Of course there is the drawback of closing the space of scalars, but then, the current chain of TypeRep conditionals is no better.

It could be trivially extended to unexpected scalars just by replacing the error with the polymorphic call to the same interpreter function. But, in practice, the 20 scalars ought to be enough for everybody and people would not appreciate getting their program accepted and running at snail pace (and people don't read warnings).

Am I missing stuff?

You missed my Edit2 above, after I tried to actually remove GoodScalar from the existential constructors and replace, for now, with Typeable only:

[Edit2: actually Show and Storable from GoodScalar is also needed to derive Show for the AST types and also removing GoodScalar means that I need to use the Typeable trick in absolutely all the little functions that work on AST, which I'm fine not specializing, not just just the interpreter functions that take 80% of the runtime]

which is an argument for using Typeable, after all, and even the whole GoodScalar containing it. But you convinced me the two options are similar performance-wise, so I'm at ease now. I think WT guys preferred the singleton indexed enumeration approach just on type-safety basis, because indeed, Typeable is open and we morally want a closed enumeration, so there's a mismatch of semantics and types. But in all except the performance sensitive functions I want to pretend the collection of scalars is open, because this simplifies code a lot. E.g., even stupid pretty-printing would need to dispatch on the singleton over a dozen times.

The story is even worse with terms simplification, which totally doesn't care about scalars, except in the rare but important cases where it simplifies operations on constants (which is why we can't ignore the scalar dictionaries and we have to pass them along everywhere). This is a huge blob of code that will grow forever and complicating it further would be a high price for, for now, minor performance gains (though once I optimize the interpreter, the simplifier my turn out to be the whale).

I hope I'm not missing any stuff the kind WT people suggested.