Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
29 stars 6 forks source link

Higher-order functions in the API? #82

Open tomsmeding opened 1 year ago

tomsmeding commented 1 year ago

Loose Haskell lambdas that the user writes are invisible to the tracing done in this library, but the moment the user tries to store a function inside an array, their code will be rejected. A question we should ask is: do we even want to support this? If we don't need to support arrays of functions, then as long as we use a tracing-based API, the library will never see a function value to differentiate.

If we do want to support this, here are some things to think about:

@Mikolaj feel free to edit this if you want to add stuff

Mikolaj commented 1 year ago

@tomsmeding: thank you. I do want horde-ad to be able to accept, control, transform and differentiate as much Haskell code as possible. In particular I understand the main goal (I think), which is not to pre-evaluate all local function applications POPL-style and consequently end up with a huge trace to differentiate (this excludes the main objective function and the arguments to build, because these we already handle). So let me try to make a thought experiment of incorporating this "tracing stored functions" feature of yours.

So far, my test code only includes your f1 and f2 functions, so let me try to include f3 that contains tensors of functions:

https://github.com/Mikolaj/horde-ad/blob/6e90b4fb563a9bf0c0693a5b171a30fa8dd06016/test/simplified/TestSimplified.hs#L177-L182

It fails much earlier than in AD (where the additional "rank" would be needed), even earlier than in the Ast definition, where the AST nodes don't have the required types. It fails in the Tensor class and its method tbuild that takes a function that produces a tensor:

https://github.com/Mikolaj/horde-ad/blob/6e90b4fb563a9bf0c0693a5b171a30fa8dd06016/simplified/HordeAd/Core/TensorClass.hs#L232-L234

In f3, the function produces another function. Can we insert that function in a, say, rank 0 tensor using tscalar? Unfortunately not outright, because the argument of tscalar needs to be numeric, as stated in the header of Tensor. But we can make the function space an instance of all the numeric classes and so make a function an acceptable "scalar".

At this point, the test can pass through Ast with minor tweaks (such as adding function application node to Ast) and the next major hiccup would be the orthotope tensor instance of Tensor, because it's defined using hmatrix (with BLAS in C underneath) and it accepts as scalars only a handful of types, not including the function space. This can be overcome by writing a new Tensor instance specifically for functions, with boxed orthotope tensors (not storable as it's now) and vector package for implementing operations, not hmatrix. A better solution would be to switch to another fast library that accepts function space as scalars. I think massiv does that, ArrayFire does not (C under the hood, again) and I'm not sure about accelerate. We also need to consider GPU in the future, but I already have a sinking feeling this approach is a dead end, because C land doesn't seem to like boxing things, given that it implies a kind of a heap and so a kind of a GC. I expect GPUs don't like that either (even if they don't have to understand Haskell closures, because they'd just manipulate tensors of (pointers to) them and send them back to Haskell to be applied to values). What's the point of sending tensors of CPU pointers to a GPU anyway?

The next hiccup would be cloning and changing the ADVal (dual numbers) code into a separate "rank", but I won't get into that, because it's similar to the problems above (the current Tensor instance of ADVal is also implemented with hmatrix).

If function space as "scalars" during runtime is a dead end, perhaps we should somehow eliminate the functions in Ast by rewriting the Ast so that the functions never reach a concrete Tensor instance? Is that what you had in mind? This comment is already too long, so I won't make another experiment --- trying to decouple Tensor and Ast from the notion of runtime execution (currently, they are closely coupled with, e.g., the storable tensor type [edit: actually Ast is, Tensor may be abstract enough]). Another thing to do in a next comment would be to consider how the current vectorization of Ast would need to change in order to make use of the extra captured functions. E.g., we don't handle function application at all in vectorization right now (we don't even have such an Ast node).

Mikolaj commented 1 year ago

Actually, this fixed version of test f3 type-checks without any tweaks to the rest of the code:

_f3 :: ( ADReady r, Tensor (r -> r), Tensor ((r -> r) -> (r -> r))
       , IntOf r ~ IntOf (r -> r), IntOf r ~ IntOf ((r -> r) -> (r -> r)) )
    => TensorOf 0 r -> TensorOf 0 r
_f3 arg =
  let arr1 = tbuild [10] (\i -> tscalar $ \x ->
                            x + tunScalar (tfromIntOf0 (headIndex i)))
      arr2 = tbuild [10] (\i -> tscalar $ \f -> (tunScalar $ arr1 ! i) . f)
      arr3 = tbuild [10] (\i -> tscalar $ (tunScalar $ arr2 ! i)
                                            (tunScalar $ arr1 ! i)
                                              (tunScalar arg))
  in tsum arr3

it's only the test harness that would not type-check without numeric instances for function space, etc.

Mikolaj commented 1 year ago

And with very little effort the test got interpreted in Ast (or at least typed as Ast):

https://github.com/Mikolaj/horde-ad/blob/53730abe6a95effca248abe78076784f305a9a3e/test/simplified/TestSimplified.hs#L458

But, of course, the functions inside the test are not captured: Ast doesn't yet even have application nor abstraction nodes. Your formalism would now have to step in to make it possible to capture the functions.

And we'd either need to prove the functions vectorize away or find a way to evaluate them on the spot before we interpret the Ast (just as they get evaluated silently right now, taking Ast terms as arguments and producing other Ast terms). Or change the Ast interpretation, because currently it assumes to much about the scalars involved, etc.

All in all, I take back my previous pessimism: if we can ensure no functions are left after rewriting the Ast, we don't need any extra "rank" and so the differentiation code can be unchanged and no big effort is required (except implementing your formalism).

tomsmeding commented 1 year ago

Sorry, I haven't looked closely yet at all you've written above.

if we can ensure no functions are left after rewriting the Ast

Just wanted to quickly note this down. I believe the following could work: (as a transformation on the AST)

  1. Eta-expand everything
  2. Make a list of all lambda expressions in the program, optionally deduplicated (but if you deduplicate, types need to be part of the material to compare -- syntactically equal terms with different (internal) types should still be considered distinct)
  3. For each of these lambdas, invent an ad-hoc product type that can store values for all its free variables. Call this the lambda's closure type
  4. Create a big sum type where each constructor is one such closure type
  5. Replace all lambda expressions in the program with a construction of its closure type (storing, of course, the values that would end up in the lambda's runtime closure), injected in the corresponding constructor of the big sum type
  6. Create a single top-level function in the program, called e.g. apply, that takes a big sum type value as well as an argument and computes the result of calling the corresponding original lambda on that argument. Yes, this is a new function we're introducing in the program, so in the process of eliminating lambdas we're introducing a single one. But this one is closed! And maybe we can inline or whatever, dunno
  7. Replace all function application nodes in the program with calls to apply passing the "function value" (really the big sum type value now) and the argument
  8. Optional: optimise this big sum type by splitting it up in smaller sum types containing subsets of the full list of constructors, derived by doing data flow analysis on the program to find out which function values can flow where.

This is a way of describing defunctionalisation; you can get closure conversion from this by, instead of the last optional step, storing a (closed) function pointer in each closure type and calling that instead of apply.