Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
33 stars 6 forks source link

Vectorize AstBuild (AstConstInt), which corresponds to tbuild (tfromIndex0) #91

Closed Mikolaj closed 1 year ago

Mikolaj commented 1 year ago

This is not needed in order to avoid tensors of deltas, because AstConstInt terms have zero dual parts. However, it's needed if we vectorize in order to make GPU code faster. Vectorization would probably amount to eliminating AstConstInt and vectorizing the result. But while currently we simplify inside AstConstInt, we don't attempt to fuse with the outside of AstConstInt, so that even tfromIndex0 (fromIntegral N) is not simplified to fromIntegral N. I'm not yet sure if eliminating AstConstInt (except any number of base cases) is possible nor what performance implications it may have (nor whether vectorizing or simplifying terms that have zero dual parts is beneficial on CPU).

Relevant snippets:

https://github.com/Mikolaj/horde-ad/blob/774bde22fd0e06f0ad77338c0a59b03dabb963f7/simplified/HordeAd/Core/AstVectorize.hs#L244-L249

https://github.com/Mikolaj/horde-ad/blob/774bde22fd0e06f0ad77338c0a59b03dabb963f7/simplified/HordeAd/Core/AstSimplify.hs#L831

Mikolaj commented 1 year ago

I've illustrated some of the problems by scratching at the surface on Overleaf.

Mikolaj commented 1 year ago

And here's my blurb from the last email:

In short, my only idea of how to remove build from build(fromInt e) involves a lot of code that embeds the integer algebra into the float algebra and in this way eliminates fromInt completely. I'm not sure if it's going to work, given that integer expressions may contain float tensor expressions, etc. Perhaps I'm missing a different solution or a good argument why it's obviously not possible at all?

Mikolaj commented 1 year ago

Phew, I've redone the Overleaf document according to our conclusions and, in particular, I've removed any non-trivial reduction of fromInt from the document. Let me store it here for where we want to revisit (probably using Tom's idea of constructing integer tensors and vectorizing integer expressions recursively, instead of eliminating them by an iffy rewrite into floats).

For special form (4) use the following rules, which are not applicable to any terms not of this form.

\begin{verbatim}
build1 k (var, fromInt var) -->       -- TS: Perhaps have a iota ~primitive~ language element in the real implementation? Not necessary here
  fromList [0 .. k - 1]               -- MK: right, orthotope has one, but we need it for terms, etc.; let's see
build1 k (var, fromInt var2) -->      -- TS: "terms"?
  konst k (fromInt var2)              -- MK: I mean, we may need to have 'iota' in the Ast grammar; otherwise, it's cumbersome and costly to recognize such terms when interpreting them in order to use the orthotope iota tensor-producing function
\end{verbatim}     

With $\cc{fromInt}$ as the redex: %\quad \emph{TS: Why can we not have tensors of integers?} \quad \emph{MK: Yes, we very much could. A Tensor instance for Int would be identical to the one for Double, which is identical to the one for Float. Why?} \\
Note: if fromInt takes a non-constant argument, we'll need to generalise the whole core language to also work on integer tensors. Then we can generalise the vectorisation rules as well, and all will be fine: no need to express integer arithmetic in float arithmetic somehow.
\begin{verbatim}
fromInt (floor t) -->
  floor t
fromInt (rem a b) -->
  round (fromInt b * (fromInt a / fromInt b - floor (fromInt a / fromInt b)))  -- MK: and this is probably subtly wrong
...
\end{verbatim}
Mikolaj commented 1 year ago

I've tried the "tensors of integers" approach and I'm currently stuck at this

data Ast :: Nat -> Type -> Type where
  AstConstInt :: Ast n (AstInt r) -> Ast n r

which leads to potential values of type Ast n (AstInt (AstInt (AstInt r))), the existence of which breaks, e.g., my Show instances. The AstConstInt constructor is needed for the transformation build1 k (\i -> AstConstInt0 i) --> AstConstInt (build1 k (\i -> i)).

tomsmeding commented 1 year ago

What is the type of your AstConstInt0? Is that AstInt r -> Ast 0 r, i.e. an alternative name would be intToFloat? If so I don't see how AstConstInt cannot have the same shape, i.e. Ast n (AstInt r) -> Ast n (Ast r).

EDIT: currently AstConstInt as written moves between AST nesting levels, which is surpremely weird. Surely you don't have a part of the computation that is written in ASTs of ASTs, embedded in a computation which is ASTs of Doubles?

Mikolaj commented 1 year ago

Yes, that's precisely the type: AstConstInt0 :: AstInt r -> Ast 0 r. Regarding the next type signature you provided, I can write AstConstInt :: TensorOf n (AstInt r) -> TensorOf n (Ast0 r), but that's equal to the original problematic type I gave. As such, Ast r doesn't make sense, because it lacks the Nat argument.

Yes, nesting levels is precisely what I want to avoid, but these are pretend-nesting levels only. The meaning of AstInt (AstInt (AstInt r))) is "the int type that goes together with the int type that goes together with the int type that goes together with r", which is always equal to just AstInt r. But the type-checker doesn't know that and I probably can't express that as a single constraint.

tomsmeding commented 1 year ago

Then what you're writing here doesn't correspond to master where AstInt is a GADT, not a type family. I'm having difficulty coming up with anything coherent without the full definitions of all names involved. :p

Mikolaj commented 1 year ago

Oh, you are right. AstInt (AstInt (AstInt r))) is equal to AstInt r only up to interpretation. Doh. So the nesting is not that shallow. All the more reason to avoid it.

But I still don't understand "AstConstInt as written moves between AST nesting levels". AstConstInt :: Ast n (AstInt r) -> Ast n r moves from tensor terms with base (non-tensor) int expression in leaves to tensor terms with base float expressions in leaves (faked as Tensor 0 r tensor expressions, because that's how Ast is cheating). The cheat is that instead of a hypothetical Ast n (Ast0 r), I do Ast n r, where the Ast0 r expressions are represented as Ast 0 r and folded at the bottom level of Ast n r.

Is it this cheat that causes the problem I have? I'd be surprised.

Mikolaj commented 1 year ago

BTW, AstInt (AstInt r) is iffy [edit: even up to interpretation] also because I have (simplifying)

AstIntFloor :: r -> AstInt r

which would instantiate to

AstIntFloor :: AstInt r -> AstInt (AstInt r)

but taking a floor of an integer is a very artificial concept.

tomsmeding commented 1 year ago

Is it this cheat that causes the problem I have? I'd be surprised.

Well, I dunno, but if you'd have AstConstInt :: Ast n (AstInt r) -> Ast n (Ast 0 r) or AstConstInt :: AstInt r -> Ast n r then you wouldn't be having the problem that I think I understand you have. :P

Another option: why not fold the integer terms into Ast just like you did with the float terms?

-- Perhaps GHC will accept it if you say 'type ActuallyIntOf r = IntOf r', but if not,
-- we need a wrapper. Because this is a type-level wrapper only, with new GHC this
-- could then be 'type data ActuallyIntOf a', but that's newfangled -XTypeData.
data ActuallyIntOf r

data Ast where
  -- ... the existing Ast constructors
  AstIntConst :: Int -> Ast 0 (ActuallyIntOf r)
  AstIntFloor :: Ast 0 r -> Ast 0 (ActuallyIntOf r)
  AstToFloating :: Ast 0 (ActuallyIntOf r) -> Ast 0 r
  -- etc.

By the way that AstD looks fishy and an intrusion of AD into code that should not be aware of AD. But I don't actually know what's going on there. :)

Mikolaj commented 1 year ago

Is it this cheat that causes the problem I have? I'd be surprised.

Well, I dunno, but if you'd have AstConstInt :: Ast n (AstInt r) -> Ast n (Ast 0 r) or AstConstInt :: AstInt r -> Ast n r then you wouldn't be having the problem that I think I understand you have. :P

Yes, I can see. I think the whole Ast n (AstInt r) idea is wrong, because Ast n r means "tensor terms with r as the underlying concrete scalar that will end up in the physical tensor cells" and not "tensor terms with r things, or their interpretations, in the tensor cells (or term leaves)". So what I really want is Ast n (IntOf r), where IntOf is the associated type family from Tensor class. But so far the Ast module was almost independent from the TensorClass module, which made me reluctant. I will end up with IntOf (IntOf r), but this time, I can add constraints IntOf (IntOf r) ~ IntOf r wherever needed.

Another option: why not fold the integer terms into Ast just like you did with the float terms?

That makes sense when expressing tensors of integers, but when expressing integer indexes into tensors, using sized lists of 0-rank tensors of integers is going to feel weird. However, the Tensor API can probably convert between rank 0 tensors of integers and actual integers on the fly and so hide the internal Ast representation.


-- Perhaps GHC will accept it if you say 'type ActuallyIntOf r = IntOf r', but if not,

Yes, it's all fine. I just had to enable AllowAmbiguousTypes for the sake of

AstIntConst0 :: Int -> Ast 0 (IntOf r)

where the r is only seen wrapped in a type family.

data Ast where
  -- ... the existing Ast constructors
  AstIntConst :: Int -> Ast 0 (ActuallyIntOf r)
  AstIntFloor :: Ast 0 r -> Ast 0 (ActuallyIntOf r)
  AstToFloating :: Ast 0 (ActuallyIntOf r) -> Ast 0 r
  -- etc.

Oh yes, let me try something like that. [Edit: so far, so good, though deriving Show can't cope with IntOf not being injective.]

By the way that AstD looks fishy and an intrusion of AD into code that should not be aware of AD. But I don't actually know what's going on there. :)

Hah, wait until you see the lower half of class Tensor, the one below "DO NOT LOOK HERE". ;D

tomsmeding commented 1 year ago

Hah, wait until you see the lower half of class Tensor, the one below "DO NOT LOOK HERE". ;D

Good luck explaining that to the others :P

Also good lord that ADReady synonym. Surely that can be compressed using something like this?

type Many f c r = (f (c (TensorOf 0 r)), f (c (TensorOf 1 r)), ..., f (c (TensorOf 12 r)))
Mikolaj commented 1 year ago

Surely that can be compressed using something like this?

Thank you very much. That will let me write even longer constraints.

However, I didn't manage to compress the type equalities with a type synonym family, but that's probably just hard: https://github.com/Mikolaj/horde-ad/commit/e80d5feca3576fe82ce92dc398165f4dbec401dc#diff-c8a91a9784657ab492383ca410522de722902fff8e38a9e0106a70bd57fd5124R274

Regarding the main problem, I got myself confused and fed up trying to adjust the substitution functions. Once we have two underlying scalars at once, r and IntOf r, we have two sets of integer variables and two sets of possible types of terms that may be substituted for the variables. The r terms compute, e.g., argmin of a vector of floats, while the IntOf r terms compute argmin of a vector of integers. So we can't unsafeCoerce, sadly.

Some of the time I thought this can be fixed by using a constructor AstIntConvert :: AstInt r -> AstInt (IntOf r), or similarly for Ast n r, but then the usual GADT Show deriving woes strike, because from Show (IntOf r) one can't derive Show r (the type family is not injective). And the same problem re-appears for a lot of other constraints later on, so even adding Show r to the constructor doesn't help. And even if AstIntConvert worked, it's obnoxious (similarly as AstIntLet), because simplification of Peano modulo n suddenly has to dive into AstIntConvert terms to decide if a redex is available or if associativity needs to be used to search for it elsewhere (and the associativity shuffles need to be aware of AstIntConvert as well).

This is related to an even more general problem of how to write some part of the objective function using Float and another using Double. I can think of no way to pass around enough constraints to make this possible. Unless we hard-wire a couple of float (and int?) types and mix only when we know the concrete types we need, not at the stage "we'd need two floats, long and short, and one integer".

Mikolaj commented 1 year ago

I missed one obvious simplification. The r in Ast r can only be Float, Double, ADVal Float and ADVal Double. And for each of those, IntOf r is CInt. Only IntOf (Ast0 r) is AstInt r, but Ast (Ast0 r) is not supported.

So we can get rid of the type family in the codomain, which however doesn't solve the problem that we can't deduce r from the codomain. We couldn't deduce r from IntOf r and we can't deduce r from Cint just the same. And I still have two cases of AstMinIndex1 constructor application: one applied to a vector of r (say, Double) and another to a vector of CInt and they are going to get two different specialized codes during the same compilation so I can't just unsafeCoerce one to another.

This implies that when simplifying

AstIndex (AstToFloat v) ix --> AstToFloat (AstIndex v (convertIndexes ix))

I still have to use convertIndexes :: [AstInt r] -> [AstInt CInt], which wraps individual indexes in the AstInt constructor that needs to carry all the relevant dictionaries.

Mikolaj commented 1 year ago

I think I have a solution with AstIota. You were onto something, @tomsmeding, with your off-hand Overleaf comment.

tomsmeding commented 1 year ago

Some of the time I thought this can be fixed by using a constructor AstIntConvert :: AstInt r -> AstInt (IntOf r)

I'm confused, why would you need this and why would this be sensible in the first place? What is that r there -- is it, for example, Double? What does AstInt Double even mean? And if it's CInt, then why are we generating IntOf CInt in the first place?

So we can get rid of the type family in the codomain, which however doesn't solve the problem that we can't deduce r from the codomain.

No, indeed. From the information that something is a vector of integers we cannot know whether it contains some subterms with Doubles, some with Floats, or both.

Is there a problem with having an Ast Double that contains some integral subcomputations, i.e. some Ast Int subterms, that in turn contain some scalar subcomputations, i.e. some Ast Double subterms? I don't see how that prevents you from doing anything.

EDIT: Also, oh no you used my Many. /me scared

Mikolaj commented 1 year ago

I think I have a solution with AstIota.

Here it is:

https://github.com/Mikolaj/horde-ad/commit/cd5d64ef14710ff53052e9b882a14fe6ddd7988c

It's sick and I can't quite explain why it works, but it does, it doesn't introduce any AstBuild terms, it's short and probably optimal. Sadly, it triggers https://gitlab.haskell.org/ghc/ghc/-/issues/23109, so can't be used with GHC 9.6.1.

Mikolaj commented 1 year ago

Some of the time I thought this can be fixed by using a constructor AstIntConvert :: AstInt r -> AstInt (IntOf r)

I'm confused, why would you need this and why would this be sensible in the first place? What is that r there -- is it, for example, Double? What does AstInt Double even mean? And if it's CInt, then why are we generating IntOf CInt in the first place?

Let me cite myself, because I managed to simpilfy this recently: "The r in Ast r can only be Float, Double, ADVal Float and ADVal Double. And for each of those, IntOf r is CInt. Only IntOf (Ast0 r) is AstInt r, but Ast (Ast0 r) is not supported." So IntOf is old news, it's not considered any more.

AstInt Double is this datatype, in which the big float tensors (AstPrimalPart a newtype over Ast that represents primal parts of Ast tensors) have doubles in them:

https://github.com/Mikolaj/horde-ad/blob/fff6dd554be2b9bdf2ceef3d4accaab12f2743aa/simplified/HordeAd/Core/Ast.hs#L149-L157

For AstInt CInt, the tensors have CInt in them. AstInt (ADVal Double) would not have pairs of tensors, because of the PrimalPart bit, so these would be tensors with doubles just the same.

So we can get rid of the type family in the codomain, which however doesn't solve the problem that we can't deduce r from the codomain.

No, indeed. From the information that something is a vector of integers we cannot know whether it contains some subterms with Doubles, some with Floats, or both.

Right. :(

Is there a problem with having an Ast Double that contains some integral subcomputations, i.e. some Ast Int subterms, that in turn contain some scalar subcomputations, i.e. some Ast Double subterms? I don't see how that prevents you from doing anything.

There is no problem with Ast and AstInt nested 100 times, with different type of the elements of the tensors in each of them. Except of technical problems and possibly just one: functions throughout the codebase case-analyze these and encounter the 100 types of elements of tensors. They require various properties of them, expressed as constraints. They have nowhere to take the dictionaries from nor can they determine what the type at hand is so that they could take a publicly available dictionary from top level. And we can't pack the 100 dictionaries in Ast, e.g., a set of a few dictionaries accompanying each subterm with a particular type of tensor elements, because in Ast we don't know what dictionaries all future code using Ast will need. Trying to think ahead would lead to recursive modules or a Types.hs module with all datatypes used in the program. shudder

EDIT: Also, oh no you used my Many. /me scared

:D I still to have to limit myself, though, because parts of the constraint did not compress well.

tomsmeding commented 1 year ago

or a Types.hs module with all datatypes used in the program

This is what Accelerate does. /shrug

In general, these are two kind of dual encodings of the same thing: either encode the types in the AST directly, defining operations separately as functions of those types (that's the Accelerate encoding); or encode the operations in the AST directly, defining the types that implement those operations separately as instances of those classes (that's the encoding you talk about, used e.g. here).

Neither is particularly great but it is very reminiscent of the expression problem, and there's no perfect answer. Personally, from experience working on Accelerate and personal projects, I much prefer the data Type t where approach -- it makes code transformations much easier to add precisely because you aren't commonly adding new types to your supported set, but you are commonly adding optimisations / transformations.

A project I inherited used the type class based approach at first, and it was hell to work on.

So unless you think the current state is preferable over this, I would advise biting the bullet, defining Types.hs, and putting Type r in all the relevant AST constructors. :) Then it can all be one GADT (AstIntConvert is gone) and there are no more problems with type inference (though you'll need a manual show instance, although I'm probably able to cook up some TH code that does this for you).

tomsmeding commented 1 year ago

I just saw this, sorry for missing before:

tfromIndex0 i = AstConstant $ AstPrimalPart
                  $ AstIndexZ AstIota (singletonIndex i)

Surely that's not optimal? Or is the idea that you special-case "index of iota" in evaluation and differentiation, which makes this go alright? But then why is this better than just defining AstIndexZAstIotaOfSingletonIndex, which is just FromIndex0?

Mikolaj commented 1 year ago

it is very reminiscent of the expression problem, and there's no perfect answer

Yes, indeed. Good to know my problem is not unique. I don't care about tensors with integers, but I do about using many types of floats in the same program #97, so your pointers are very valuable. After we are done with the single-float-type version, I will try the Accelerate approach that you recommend (or earlier if I hit any snags it may help with). There's going to be plenty of open polymorphism and constraints due to many instances of Tensor, anyway. Also, we already duplicate a lot of our code for Double and for Float, effectively reiterating in many modules the set of approved floating types, so we can just as well do it for Ast. I much prefer Types.hs with all possible instantiations of type variables to Types.hs with all constraints that are ever going to be needed and that terms have to carry along.

Or is the idea that you special-case "index of iota" in evaluation and differentiation

I do special-case it in Ast evaluation, basically kicking the can down the road as long as possible. I don't need to do anything with differentiation, because iota is an AD-constant (dual part is dZero), so I didn't have to touch anything.

why is this better than just defining AstIndexZAstIotaOfSingletonIndex

Not for any deep reasons. I just piggy-back on the existing code for AstIndexZ and extend it in a way that is easier to understand in the context of the existing AstIndexZ handling. E.g,. I don't need to do anything in vectorization, because the existing code for AstIndexZ handles this case fully (indexing of something that doesn't contain integer variables).

tomsmeding commented 1 year ago

Cool!

Mikolaj commented 1 year ago

Let me close this and move the discussion of multiple underlying scalars and unification with AstInt to #80.