Open tomsmeding opened 1 year ago
@tomsmeding: thank you. I do want horde-ad to be able to accept, control, transform and differentiate as much Haskell code as possible. In particular I understand the main goal (I think), which is not to pre-evaluate all local function applications POPL-style and consequently end up with a huge trace to differentiate (this excludes the main objective function and the arguments to build
, because these we already handle). So let me try to make a thought experiment of incorporating this "tracing stored functions" feature of yours.
So far, my test code only includes your f1
and f2
functions, so let me try to include f3
that contains tensors of functions:
It fails much earlier than in AD (where the additional "rank" would be needed), even earlier than in the Ast definition, where the AST nodes don't have the required types. It fails in the Tensor class and its method tbuild
that takes a function that produces a tensor:
In f3
, the function produces another function. Can we insert that function in a, say, rank 0 tensor using tscalar
? Unfortunately not outright, because the argument of tscalar
needs to be numeric, as stated in the header of Tensor
. But we can make the function space an instance of all the numeric classes and so make a function an acceptable "scalar".
At this point, the test can pass through Ast with minor tweaks (such as adding function application node to Ast) and the next major hiccup would be the orthotope
tensor instance of Tensor
, because it's defined using hmatrix
(with BLAS in C underneath) and it accepts as scalars only a handful of types, not including the function space. This can be overcome by writing a new Tensor
instance specifically for functions, with boxed orthotope
tensors (not storable as it's now) and vector
package for implementing operations, not hmatrix
. A better solution would be to switch to another fast library that accepts function space as scalars. I think massiv
does that, ArrayFire
does not (C under the hood, again) and I'm not sure about accelerate
. We also need to consider GPU in the future, but I already have a sinking feeling this approach is a dead end, because C land doesn't seem to like boxing things, given that it implies a kind of a heap and so a kind of a GC. I expect GPUs don't like that either (even if they don't have to understand Haskell closures, because they'd just manipulate tensors of (pointers to) them and send them back to Haskell to be applied to values). What's the point of sending tensors of CPU pointers to a GPU anyway?
The next hiccup would be cloning and changing the ADVal
(dual numbers) code into a separate "rank", but I won't get into that, because it's similar to the problems above (the current Tensor
instance of ADVal
is also implemented with hmatrix
).
If function space as "scalars" during runtime is a dead end, perhaps we should somehow eliminate the functions in Ast by rewriting the Ast so that the functions never reach a concrete Tensor
instance? Is that what you had in mind? This comment is already too long, so I won't make another experiment --- trying to decouple Tensor
and Ast from the notion of runtime execution (currently, they are closely coupled with, e.g., the storable tensor type [edit: actually Ast is, Tensor
may be abstract enough]). Another thing to do in a next comment would be to consider how the current vectorization of Ast would need to change in order to make use of the extra captured functions. E.g., we don't handle function application at all in vectorization right now (we don't even have such an Ast node).
Actually, this fixed version of test f3
type-checks without any tweaks to the rest of the code:
_f3 :: ( ADReady r, Tensor (r -> r), Tensor ((r -> r) -> (r -> r))
, IntOf r ~ IntOf (r -> r), IntOf r ~ IntOf ((r -> r) -> (r -> r)) )
=> TensorOf 0 r -> TensorOf 0 r
_f3 arg =
let arr1 = tbuild [10] (\i -> tscalar $ \x ->
x + tunScalar (tfromIntOf0 (headIndex i)))
arr2 = tbuild [10] (\i -> tscalar $ \f -> (tunScalar $ arr1 ! i) . f)
arr3 = tbuild [10] (\i -> tscalar $ (tunScalar $ arr2 ! i)
(tunScalar $ arr1 ! i)
(tunScalar arg))
in tsum arr3
it's only the test harness that would not type-check without numeric instances for function space, etc.
And with very little effort the test got interpreted in Ast (or at least typed as Ast):
But, of course, the functions inside the test are not captured: Ast doesn't yet even have application nor abstraction nodes. Your formalism would now have to step in to make it possible to capture the functions.
And we'd either need to prove the functions vectorize away or find a way to evaluate them on the spot before we interpret the Ast (just as they get evaluated silently right now, taking Ast terms as arguments and producing other Ast terms). Or change the Ast interpretation, because currently it assumes to much about the scalars involved, etc.
All in all, I take back my previous pessimism: if we can ensure no functions are left after rewriting the Ast, we don't need any extra "rank" and so the differentiation code can be unchanged and no big effort is required (except implementing your formalism).
Sorry, I haven't looked closely yet at all you've written above.
if we can ensure no functions are left after rewriting the Ast
Just wanted to quickly note this down. I believe the following could work: (as a transformation on the AST)
apply
, that takes a big sum type value as well as an argument and computes the result of calling the corresponding original lambda on that argument. Yes, this is a new function we're introducing in the program, so in the process of eliminating lambdas we're introducing a single one. But this one is closed! And maybe we can inline or whatever, dunnoapply
passing the "function value" (really the big sum type value now) and the argumentThis is a way of describing defunctionalisation; you can get closure conversion from this by, instead of the last optional step, storing a (closed) function pointer in each closure type and calling that instead of apply
.
Loose Haskell lambdas that the user writes are invisible to the tracing done in this library, but the moment the user tries to store a function inside an array, their code will be rejected. A question we should ask is: do we even want to support this? If we don't need to support arrays of functions, then as long as we use a tracing-based API, the library will never see a function value to differentiate.
If we do want to support this, here are some things to think about:
@Mikolaj feel free to edit this if you want to add stuff