Closed awf closed 2 years ago
Arguments are tupled, could just as easily be curried. -awf
Actually, neither can work (or I don't know how), because tuple is not a functor (tuple of elements of the same type would be, but that's not what Haskell defines for the parens-with-commas notation). Would a list [x,y,z]
or a shaped tensor OS.fromList [x,y,z] :: OS.Array [3] a
suffice?
The best I can do for a tuple is the following. But in addition to not matching exactly what's above (the last lines of testFoo
), it requires the verbose glue code just below (the same I defined here for 3-tuples would be needed for each tuple size and any other datatype, including user-defined; for Traversable
datatypes I could do something general, but tuples and other non-functors are not Traversable
).
@awf, @tomjaguarpaw: any help appreciated, at your leisure. I think this is a Haskell limitation, not our library's, but I'm not sure I'm identifying it correctly (is the problem that tuple is not a functor? ot that the homogenous tuple is not a named type?).
Isn't this where you need something like argAdaptor
?
Isn't this where you need something like
argAdaptor
?
Right. That's precisely my question. Why your adaptor idea, which I (mis)used in this example (and attributed to you:) does not quite work for tuples. The point is, the adaptor should not be visible to the user. Ideally, we'd have exactly the syntax Andrew uses and the adaptor hidden in some instance in the library code.
Ah, I see. I didn't read carefully. I think there are two distinct issues:
testFoo
match AWF's simple example?I wouldn't worry about 2. I'm confident it can be solved by some existing or new generics library. I don't understand 1 though. It seems to have more to do with assertEqualUpToEps
than grad
. Is that right? In which case isn't the answer that assertEqualUpToEps
also needs an argAdaptor
-style thing?
I'm confident it can be solved by some existing or new generics library.
I was afraid that would be the answer. At least not TH, though. :)
Re 1, you are right assertEqualUpToEps
would need to be generalized. Bur first a bigger problem: dReverseFun
gives me a result of type Domains r
, which is four records. What I need to encode in the adaptor instance for 3-tuple, is to pick the first of the vectors (the vector of scalars) and pack its three first elements into a 3-tuple. Let's call that operation fromDomains
. So, copying unchanged what's already there, I would have
-- Inspired by adapters from @tomjaguarpaw's branch.
class Adaptable r fdr where
toDomains :: fdr -> Domains r
fromDualNumberInputs :: DualNumberInputs 'DModeGradient r -> fdr
fromDomains :: Domains r -> (r, r, r)
But I need the adaptor to be usable with all types, not just the 3-tuple, and I have no idea how to generalize the result type of fromDomains
to accomplish that. Copying again, the instance I'm using (perhaps it's a wrong instance? [edit; but that's what the function is instantiated to, in order to be usable in dReverseFun
]) is:
instance IsScalar 'DModeGradient r
=> Adaptable r ( DualNumber 'DModeGradient r
, DualNumber 'DModeGradient r
, DualNumber 'DModeGradient r ) where
So I would need a type operation that applied to (DualNumber 'DModeGradient r, DualNumber 'DModeGradient r, DualNumber 'DModeGradient r)
produces (r, r, r)
.
@sfindeisen: perhaps you would have an idea?
Hmm, I'm not sure I'm following. The argAdaptor
approach would be to have a data type, not a class. The class would only be introduced at the very end to "infer" the correct value of the data type. So we would have something like
data Adaptor r fdr rs where
toDomains :: fdr -> Domains r
fromDualNumberInputs :: DualNumberInputs 'DModeGradient r -> fdr
fromDomains :: Domains r -> rs
This data type has nice operations like the below, which are the ones which will be used by the conjectured generics library.
adaptPair :: Adaptor r fdr1 rs1 -> Adaptor r fdr2 rs2 -> Adaptor (fdr1, fdr2) (rs1, rs2)
invmap :: (rs -> rs') -> (rs' -> rs) -> (fdr -> fdr') -> (fdr' -> fdr) ->
Adaptor r fdr rs ->
Adaptor fdr' rs'
Then there would be a class that uses the conjectured generics library to infer the appropriate value of type Adaptor
for any given situation:
class Adaptable r fdr rs where
adapt :: Adaptor r fdr rs
-- Here's an example of a "generics library"
instance (Adaptable r fdr1 rs1, Adaptable r fdr2, rs2)
=> Adaptable r (fdr1, fdr2) (rs1, rs2) where
adapt = adaptPair adapt adapt
Does that make sense?
Does that make sense?
Almost --- what would be the type of grad
(I'm worried about the third argument of Adaptable
)?
I guess it can be
grad :: (HasDelta r, Adaptable r x anything)
=> (x -> DualNumber 'DModeGradient r) -> x -> Domains r
But I'm not really sure where fromDomains
is supposed to be used, so perhaps there's a better formulation, for example
data Adaptor2 r fdr fdr' where
toDomains :: fdr -> Domains r
fromDualNumberInputs :: DualNumberInputs 'DModeGradient r -> fdr'
class Adaptable2 r fdr fdr' where
adapt2 :: Adaptor2 r fdr fdr'
(and we have an entirely different data type for fromDomains
). Then the type of grad
would be
grad :: (HasDelta r, Adaptable2 r x x)
=> (x -> DualNumber 'DModeGradient r) -> x -> Domains r
Oh, I also forgot to mention earlier this piece which is essential for the "generics library":
adaptR :: Adaptable2 r r r
But I'm not really sure where
fromDomains
is supposed to be used
It would be used in grad
, as below (using a mixture of the old and the new types):
grad :: (HasDelta r, Adaptable r x rs)
=> (x -> DualNumber 'DModeGradient r) -> x -> rs --- in our case rs would instantiate to (r, r, r)
grad f x =
let g inputs = f $ fromDualNumberInputs inputs
in fromDomains $ fst $ dReverseFun 1 g (toDomains x)
My worry is that GHC would complain that rs
is not determined and, indeed, it has no simple relation to any other type in this signature. In our example, x
is the triple of DualNumber
s and rs
is (r, r, r)
. I don't think GHC would be able to guess the rs
type given the x
type in this case. And x
always needs to be some container of dual numbers, because that's the domain of the objective function. And rs
is always going to be a container of scalars, because that's what the gradient is.
OK, so in this formulation rs
can't be inferred from x
. But there are plenty of ways of dealing with that, for example
class Adaptable2 r fdr fdr' | fdr -> fdr' where
adapt2 :: Adaptor2 r fdr fdr'
(probably the functional dependency can be bidirectional, for even better inference)
(see Opaleye's runSelectI
for an example of this adaptor thing working in practice.)
Thanks a lot, @tomjaguarpaw. I guess that's enough hints to start experimenting. Any volunteers?
(probably the functional dependency can be bidirectional, for even better inference)
@tomjaguarpaw, indeed the other direction is necessary to make this simple example work. I wonder if we can instead use an associated type family and get less fragile type inference. No idea which one plays better with generics libraries.
@awf: done (almost exactly your code, albeit with some cheating behind the scenes, the biggest of which is that this only works for 3-tuples currently, though in theory this could be generalized in library code):
@awf: but the question stills stands: do you think a list of arguments [x,y,z]
would make as good an impression on newcomers, or are tuples essential? E.g., ilbrary ad
only admits lists and other Traversable types, but perhaps we should do better? There is some extra effort required (#66 and more), extra code and rather heavy machinery as dependencies (a generics library) to maintain if we want to support tuples. It's not yet clear to me if supporting tensors, which I suppose is absolutely essential, subsumes most of the work needed for tuples or not.
Another consideration is that for efficiency, especially in examples with rank 0 inputs, functions should not be polymorphic and neither should they take a vector (or list or whatever) of dual numbers as arguments. Instead they should take an unboxed vector of primal components of dual numbers and a boxed vector of precomputed dual components (that are very special in this case, being inputs), which are the same for each gradient descent run. I may again mistake "too complex" for "impossible", but I think rewriting user functions from their naive polymorphic forms to the efficient forms is too hard. BTW, the code of the functions I presented in the old READMEs was obscured by the machinery for the efficient style, so we've already had a taste of the opaqueness of it.
Is it a good outreach strategy to publicise a clear and familiar style that is, in fact, impractical for real examples? Or perhaps we don't care about examples with rank 0 inputs and the rare real life examples in which there are huge numbers of little tensors as opposed to a small number of big ones [edit: this only concerns inputs; inside the computation, we handle huge numbers of tensors equally efficiently regardless]? If so, the cute style may be performant enough.
Actually, the polymorhpic way of writing functions breaks down already for vectors, we don't even have to go to shaped tensors. That's because, while there is a +
operation that works both on Float
and DualNumber Float
, there are no vector-specific operations that would be similarly generic.
Arithmetic operations are generic enough, so +
would work both on vectors of numbers and vectors of dual numbers, but arithmetic is not nearly enough. We need sums of vector elements, appends, flattening arrays into vectors, etc. Not to mention the mapping, zipping and folding over vectors we are discussing recent, which has additional problems and is even less likely to have a generic implementation.
@awf: do you think it's worth experimenting with classes similar to Num
, thanks to which +
works on vectors, that would let us use the same operation names for vectors of numbers and vectors of dual numbers? I have no idea how far we would get with it and, in any case, the user would be limited to what is in the classes. This is more problematic than with numeric classes, because there is a vague canon of numeric operations, while each vector library does it differently (and there are add-on libraries in Haskell that add operations without changing the vector type, e.g., hmatrix; I'm sure GPU does this yet differently). Also, @tomjaguarpaw has valid arguments about why even the numeric classes for all ranks is an excess, though I'm not certain the arguments carry over to this case (this would not be a rank-polymorphic class, since it would only work for rank 1, except a couple of operations perhaps (sum elements, etc.) that could work for any rank).
Of course, the user is free to drop the polymorphic bling we advertise in README and start writing dual number code non-generically (with polymorphism only over the underlying scalar, not over numbers vs dual numbers). I'm just wondering that perhaps we should focus exclusively on the monomorphic mode that offers full expressiveness and performance and higher extensibility (though it's still limited by the set of implemented Delta-constructors).
BTW, the current scary README touches on the horrors of rank 1 operations, though we should certainly simplify it and just present a trivial rank 1 function and it's gradient. Currently, to avoid mentioning DualNumber
in such an objective function, we'd need to restrict it to element-wise numeric operations (effectively, exclusively bulk zipping and mapping over vectors). I think that would be cheating. Or we could forego vectors and present the same example with rank one shaped tensors. Exactly the same problem.
A code illustration for the point above. This could potentially go into the README. Unless shaped tensors is too much, so we might do the same for vectors, to at least introduce the DualNumber
noise, because it's unavoidable. Unless we go all in and add a class that has indexS
, fromS0
, etc., etc., and an instance for normal numbers (just some orthotope operations) and another for dual numbers (implemented using delta expressions).
Edit: this is an updated version of the code.
Coming in late, sorry. Programmers using JAX or Julia enjoy differentiation of essentially arbitrary datatypes:
foo :: S -> Num
grad foo :: S -> dS
foo :: S -> T
rev foo :: (S, dT) -> dS
fwd foo :: (S, dS) -> dT
etc. So it would be nice to say that Haskell offers the same. It almost certainly needs some template Haskell to offer this, at which point, the underlying signatures are kinda moot: the user will see the clean interface, so it doesn't matter that every underlying function is reduced to e.g. a list of Tensor.
So I think I'm saying: don't worry about complicated non-template-Haskell machinery to support tuples, you're almost certainly going to need to drop it when you bite the bullet and go to template Haskell (or whatever the chosen transformation framework is).
In terms of the concrete question above re tuples vs lists. Sure, this looks fine
grad foo [1.1, 2.2, 3.3]
or even
grad foo (T.fromList [1.1, 2.2, 3.3])
But it's the definition site that matters to many programmers, and
bar :: ADFloat a => (a,a,a) -> a
bar (x,y,z) =
let w = foo (x,y,z) * sin y
in atan2 z w + z * w
is nicer than
bar :: ADFloat a => T a -> a
bar t =
let x = t!0 -- Please correct this syntax if wrong
y = t!1
z = t!2
let w = foo (x,y,z) * sin y
in atan2 z w + z * w
And of course we should not special case for one example. I assume on reading bar
above that I could just as easily write
bar :: ADFloat a => (a, Tensor2 a, [TensorS a]) -> [TensorS a]
bar (s, W, xs) =
map (\x . s * dot W x) xs
and generate
$rev bar :: ADFloat a => ((a, Tensor2 a, [TensorS a]), [TensorS da]) -> (da, Tensor2 da, [TensorS da])
where da
is maybe spelled (D a)
.
Notes on above:
ADFloat
because I don't think it matters if it's RealFloat
, Num
, or something else, just so long as float literals are easily passed to it.da
, but it should certainly be possible to declare D s
for a type s
without too much palaver.So I think I'm saying: don't worry about complicated non-template-Haskell machinery to support tuples, you're almost certainly going to need to drop it when you bite the bullet and go to template Haskell (or whatever the chosen transformation framework is).
TH is quite likely, either hidden inside the generics library we are going to use, or employed to automate any boilerplate the library would require (e.g,. to go from 2-tuples all the way up to 10-tuples). However, I'd love the TH to be easy to understand (e.g., via mentally expanding the boilerplate it encodes). If we do full source transformation of objective functions, this is not only error-prone in our library code (because it's untyped), but also our library user won't know which code snippets at hand are translated and whether the translation contexts are compatible or not. Compiler no longer helps, most of the time.
If we fail in avoiding full TH, I'd love to understand exactly where and why. Additionally, I don't think we can source-transform to the efficient forms (with unboxed vectors of input floats and re-used boxed vectors of Input delta-expressions) anyway. If I'm wrong and we can do that, this may well be worth full meta-programming with TH.
Actually, the polymorhpic way of writing functions breaks down already for vectors, we don't even have to go to shaped tensors. That's because, while there is a
+
operation that works both onFloat
andDualNumber Float
, there are no vector-specific operations that would be similarly generic.
Huh, now I see that in the comment starting with the above, I'm effectively arguing that full meta-programming TH in unavoidable for ranks higher than 0, if we want to differentiate a wide range of user programs written without knowledge of our library. Rank 0 is special, because there are canonical enough numeric classes (that look familiar even to non-Haskellers) that we can hide our machinations behind. Not so for vector, matrix and tensor operations.
But if the user, unaware of our library, writes a function from Double
to Double
, we lose even for rank 0. The trick with numeric type classes works only for polymorphic programs. Similarly, if the user writes a polymorphic function on vectors with a Storable
constraint (which is crucial for efficiency, but can also be needed to interoperate with FFI without prohibitively expensive conversions). In fact, it's hard to write an interesting Haskell function on vectors without such constraints, except for boxed vectors. The same with matrices, tensors and all data structures usable in ML.
So, @awf, you are of course right, it's either TH or the user needs to write even the objective functions against our library API. Still, it's worth exploring how readable and familiar we can keep such API-obfuscated objective functions, while still making them easily differentiable with our library. (Let's put aside the performance trick with unboxed/shared vectors of inputs, which the advanced user that experiences a bottleneck may choose to enable at the cost of further pollution with API operations.)
Certainly, lists make things much simpler and explicitly mentioning DualNumber
is necessary for ranks > 0. You are right, the next hurdle is mixing ranks (which one can't do on a list, because types don't agree). I have to think some more about that one and come back.
Anyway, here's how foo
would look like with lists and explicit DualNumber
(only for illustration of the level of obfuscation; the original foo
being rank 0 and unconstrained polymorphic doesn't need DualNumber
). This is still polymorphic in d
(which can be gradient mode, derivative mode or value computation mode) and in r
(underlying scalar type), but a function with any or both of these instantiated would work exactly the same (probably no change to any other code needed).
Yet another POV is how the user can verify if a particular language feature is supported by horde-ad. With the objective functions written against library API, the complier errors out (probably with a clueless message, not knowing this is not a user error, but an extension needed for the library) and also the user can browse the API haddocks (or inspect live modules in ghci) beforehand.
With TH or plugin, we need to implement messages like, using a contrived example, "atan2 needs to be added to the horde-ad library". Such messages have the potential of being more readable. Extra API docs for the user are going to be needed as well; namely a special API for restricted orthotope (less shaped tensor operations and some of the operations having less general types, e.g., slicing or generalized transposing), another for restricted hmatrix (there's just too much and it's not really compositional due to FFI), and so on for each library we target. Edit: also the imports, module names, package names in cabal files, etc., will need to change, because we will only mimic the libraries, but not use them directly.
A desperate power-user will end up translating an objective function's code manually, linking to the real libraries using a special version of the cabal file, tweaking it with low-level horde-ad operations until it compiles, then capturing the tricks as a new horde-ad dual number operation, adding it to the library, extending the plugin and docs, translating the working example back to the pretty form.
Programmers using JAX or Julia enjoy differentiation of essentially arbitrary datatypes
I suppose that does not extend to arbitrary functions? E.g., taken from C via FFI? Nor, say, a different implementation of the datatype of unboxed vectors (in Haskell, that's likely to use GHC primitives, in Python just C code; not sure about JAX and Julia)? [Edit: or sparse matrices with particular performance trade-offs?] Are these languages extensible by programmers, or only by the language creators? Our library, just like Haskell, is fully extensible by the programmer. A TH or plugin outer layer would make extending it more expensive, though.
I think list of tensor is an acceptable general type
For untyped tensors, yes, in fact, that's what I do internally in the library (boxed vectors of untyped tensors, to be precise). However, shaped tensors of different shapes can't be put on a homogeneous list, because they have different types. Internally, I convert shaped tensors to untyped tensors to store them in a vector and convert them back when looking them up. But that's not an acceptable surface UI due to obfuscation and, most of all, due to lost type safety when the shaped tensors are converted to untyped tensors.
to which to translate arguments such as tuples of lists of tuples of tensor or float (like JAX PyTrees).
I think something like that is the way forward, but this needs to be prototyped. It doesn't require a generics library, but does require lots of typeclasses with associated type families. I think, the simplest way to do this is to have arbitrary nested tuples (a heterogeneous tree, basically), with Traversable (lists and boxed vectors) and MonoTraversable (unboxed vectors) of values of the same rank and shape in the leaves.
I that works, this admits the list-based fooD
, but also the original tuple-based foo
amended to explicit mention DualNumber
. A pity that the simplification from tuples to lists fails for shaped tensors, but this is still simpler than using a generics library to handle arbitrary types (which we might still attempt, but I think we should start small; alternatively, the next step can be TH/plugin).
I realize, my aversion to TH/plugins in this case is just a preference for shallowly embedded DSLs (https://wiki.haskell.org/Embedded_domain_specific_language). Actually, with TH/plugin it's a mix between shallow and deep embedding, because the produced abstract syntax objects (the actual code calling horde-ad API) are in Haskell syntax and usually immediately interpreted. But, morally, there is yet another code object between the user written code (the pretty code ignorant of horde-ad API) and the value the user obtains.
And to call foo
I use
-- from `R^3` to `R`.
fooD :: IsScalar d r => [DualNumber d r] -> DualNumber d r
fooD [x, y, z] =
let w = x * sin y
in atan2 z w + z * w
fooD _ = error "wrong number of arguments"
foo = value fooD
foo [1.1,2.2,3.3]
It feels like, after discussion today, that this is a possible target:
barS :: (ADModeAndNum d r, OS.Shape sh)
=> StaticNat n1 -> StaticNat n2 ->
( ADVal d r
, ADVal d (OS.Array '[n1, n2] r)
, [ADVal d (OS.Array (n2 ':: sh) r)]
)
->
[ADVal d (OS.Array (n1 ':: sh) r)]
barS MkSN MkSN (s, W, xs) =
map (\x . s * (dot W x)) xs
bar_3_75 = AD.value (barS (MkSN @3) (MkSN @75))
bar_vjp_3_75 = AD.vjp (barS (MkSN @3) (MkSN @75))
x =
let args= (1.1, random (3, 75), [random (75, 5), random (75, 5)])
dbar = [random (75, 5), random (75, 5)]
in
bar_vjp_3_75 args dbar
>>> x
(3.3, Tensor [....], ...)
Names like ADModeAndNum
and ADVal
are easily [GH]oogleable/Haddockable
Names like ADModeAndNum and ADVal are easily [GH]oogleable/Haddockable
Right. If the user wants to do the dot
above, the user searches in haddocks of all functions operating on ADVal
. The user finds three variants of dot
, one on hmatrix vectors, another on orthotope tensors, another, by a third party, on their shaped sparse matrices. The user picks one of the three, which also fixes the choice of the datatypes to be used.
That's quite different than with the original foo
code (which style could be maintained for higher ranks only with TH/plugin), where the dot
would be somehow automatically mapped to a default tensor implementation (or somebody would need to maintain a common typeclass unifying all tensor implementations).
@awf, I'm confused by vjp
, after all. That's neither our old gradient computation, because that would require a scalar codomain of the objective function, nor is it our old forward derivative computation, because that would be jvp
, wouldn't it? Is that the third operation that constructs the whole jacobian, the one I said we could possibly get for free given our existing machinery needed for ranks > 0? That would be a generalization of the gradient computation to non-scalar codomains? If so, I'd need to open a ticket to, in fact, obtain it for free. :)
Edit: In fact, even our forward derivative operations assume the objective function has rank 0 codomain, though in that case it should be easy to generalize. So some extra work is needed throughout the engine to permit dual number codomains of any rank and still some more (not sure how hard it would be; possibly related to #66) to permit arbitrary codomains. There is even no ticket about it, so I think we haven't considered it yet except for my "full jacobian for free" idea for extending the paper, to which there's been no reaction so far. Do you think opening the ticket and trying to implement it, in case it's easy, is worthwhile at this point?
Edit2: this is now moved to its own ticket #69.
@awf, and what is the type of dot
at the value level? Does it only take a rank 2 tensor as its first argument, as in (omitting type constraints)
dot :: OS.Array '[n1, n2] r -> OS.Array (n2 ': sh) r -> OS.Array (n1 ': sh) r
or is it something more general (I've transposed the first argument to keep sanity and adhere to "ops act on the leftmost/outermost dimension(s)"), like
dot :: OS.Array (n2 ': sh1) r -> OS.Array (n2 ': sh2) r -> OS.Array (sh1 '++ sh2) r
Frankly, I don't have an idea how to implement even the less general variant, even only at the value level (exclusively the primal part, no dual) with https://hackage.haskell.org/package/orthotope-0.1.2.0/docs/Data-Array-ShapedS.html. Probably do some stretching that copies the second argument n1
times and the first argument sh
times, transpose to bring them to exactly the same shape and put the n2
dimension rightmost (meaning innermost), at which point apply https://hackage.haskell.org/package/orthotope-0.1.2.0/docs/Data-Array-ShapedS.html#v:rerank2 with ordinary vector dot product? I'm just listening to https://www.youtube.com/watch?v=rtc_j8Hnzac&list=PLq1pyM--m7oDmKK6zG0TvYqfaBEH1ODjM&index=6, so perhaps ideas will come up. I guess a good start would be sh
equal to '[n3]
, so that a human has a chance to imagine what's going on.
To then lift it to dual numbers, we'd need dual number operations corresponding to stretch
[edit: or rather broadcast
? no idea what window
does], rerank
, etc., which is #28 and is hard. The variant with sh
equal to '[n2]
may be easier, with correspondingly less general stretch
, rerank
, etc. And we need to keep track of whether we drop to rank 0 implementing any of these (iterating over each element of the arrays, instead of whole n-1 dimensions slices), thus degrading performance. IIRC, that's a part of the difficulty #28, completely absent from value level coding, because at value level, eventually each element of the arrays is going to be used, so there is no problem with iterating over them (at least with CPU), while with dual numbers we need to make operations explicitly bulk (I imagine, it may be similar with GPU coding?), or we create a delta expression for each array element, ending in a disaster.
Edit: it turns out we already had the version with sh
equal to '[n3]
and it was just matrix multiplication implemented with hmatrix under the hood:
(<>$) :: (ADModeAndNum d r, KnownNat m, KnownNat n, KnownNat p)
=> ADVal d (OS.Array '[m, n] r)
-> ADVal d (OS.Array '[n, p] r)
-> ADVal d (OS.Array '[m, p] r)
(<>$) d e = from2S $ fromS2 d <>! fromS2 e
So should it just repeat the marix multiplication along all the extra dimensions of the second arguments? I'm not longer sure generalizing this to extra dimensions of the first argument makes sense nor, if it does, if the extra dimensions should concatenated, multiplied or somehow spliced together. Possibly no such generalizations are commonly used? I might have used the generalization of the second argument to three dimensions myself due to mini-batches, but I probably did it by extracting the dimension into a list and mapping over the list, because that's easier for a person new to tensors.
Anyway, here's as much as I could represent of the barS
example with the current engine. The cheating here is not as shallow as before, but there is a chance it's solvable and #65 and #66 are enough to straighten it all up (no generics libraries required). To permit [edit: arbitrary codomains of] objective functions, a different kind of work is needed and I'm not yet sure if it's going to be trivial or if we need to conceptually extend the paper first to get the proper design.
Edit: Updated example, with fwd
and value
generalized to objective functions with arbitrary tensor codomain (but not yet tuples of tensors, etc.). Generalizing rev
is harder and the best next step may be permitting a Domains r
codomain, which holds a collection of scalars, vectors, matrices and untyped tensors (that's the same type that all objectives functions have as their domain, internally).
Edit2: Another update, with rev
now accepting objective functions with codomains of arbitrary ranks (but not tuples, etc.). Nested tuples in codomains are harder and will be handled in #68 (and cheats for nested tuples in domains will be eliminated in #66).
barS :: (ADModeAndNum d r, OS.Shape sh)
=> StaticNat n1 -> StaticNat n2
-> ( ADVal d r
, ADVal d (OS.Array '[n1, n2] r)
, [ADVal d (OS.Array (n2 ': sh) r)] )
-> [ADVal d (OS.Array (n1 ': sh) r)]
barS MkSN MkSN (s, w, xs) =
map (\x -> konstS s * (dot w x)) xs
-- konstS is needed, after all, because @s@ is a differentiable quantity
-- with a given type, and not a constant that would be interpreted according
-- to the inferred type
-- TODO: this is a fake implementation and of the medium-general variant
dot :: (ADModeAndNum d r, OS.Shape sh, KnownNat n1)
=> ADVal d (OS.Array '[n1, n2] r)
-> ADVal d (OS.Array (n2 ': sh) r)
-> ADVal d (OS.Array (n1 ': sh) r)
dot _ _ = konstS 42
bar_3_75
:: ( ADModeAndNum 'ADModeValue r
, KnownNat k, OS.Shape sh)
=> ( r
, OS.Array '[3, 75] r
, [OS.Array (75 ': sh) r] )
-> OS.Array (k ': 3 ': sh) r
bar_3_75 = value (ravelFromListS . barS (MkSN @3) (MkSN @75))
-- @ravelFromListS@ is needed, because @valueFun@ expects the objective
-- function to have a dual number codomain and here we'd have a list
-- of dual numbers. The same problem is worked around with @head@ below.
testBarV :: Assertion
testBarV =
assertEqualUpToEpsVF @'[2, 3, 337] (1e-12 :: Double)
(bar_3_75
( 1.1
, OS.constant 17.3 -- TODO: create more interesting test data
, [ OS.constant 2.4
, OS.constant 3.6 ] ))
(OS.constant 46.2)
bar_vjp_3_75
:: forall sh r.
( ADModeAndNum 'ADModeDerivative r, Dual 'ADModeDerivative r ~ r
, OS.Shape sh )
=> ( r
, OS.Array '[3, 75] r
, [OS.Array (75 ': sh) r] )
-> ( r
, OS.Array '[3, 75] r
, [OS.Array (75 ': sh) r] )
-> OS.Array (3 ': sh) r
bar_vjp_3_75 = fwd (head . barS (MkSN @3) (MkSN @75))
-- TODO: implement real vjp
-- TODO: @head@os required, because our engine so far assumes
-- objective functions have dual number codomains (though they may be
-- of arbitrary rank). The same problem is worked around with
-- @ravelFromListS@ below.
testBarF :: Assertion
testBarF =
assertEqualUpToEpsVF (1e-7 :: Double)
(bar_vjp_3_75
( 1.1
, OS.constant 17.3 -- TODO: create more interesting test data
, [ OS.constant 2.4 :: OS.Array [75, 12, 2, 5, 2] Double
, OS.constant 3.6 ] ) -- input
( 2.1
, OS.constant 18.3
, [ OS.constant 3.4
, OS.constant 4.6 ] )) -- ds
(OS.constant 88.2)
bar_rev_3_75
:: forall sh r.
( HasDelta r
, OS.Shape sh)
=> ( r
, OS.Array '[3, 75] r
, [OS.Array (75 ': sh) r] )
-> ( r
, OS.Array '[3, 75] r
, [OS.Array (75 ': sh) r] )
bar_rev_3_75 = rev ((head :: [ADVal 'ADModeGradient (OS.Array (n1 ': sh) r)]
-> ADVal 'ADModeGradient (OS.Array (n1 ': sh) r))
. barS (MkSN @3) (MkSN @75))
-- TODO: @head@ is required, because our engine so far assumes
-- objective functions with scalar codomain, as in the paper
-- objective functions have dual number codomains (though they may be
-- of arbitrary rank)
testBarR :: Assertion
testBarR =
assertEqualUpToEpsR @'[2, 3, 341, 1, 5] (1e-7 :: Double)
(bar_rev_3_75
( 1.1
, OS.constant 17.3 -- TODO: create more interesting test data
, [ OS.constant 2.4
, OS.constant 3.6 ] )) -- input
( 1288980.0
, OS.constant 0
, [ OS.constant 0
, OS.constant 0 ] )
-- * Operations required to express the tests above (#66)
value :: ( ADModeAndNum 'ADModeValue r
, Adaptable 'ADModeValue r advals vals )
=> (advals -> ADVal 'ADModeValue a) -> vals -> a
value f vals =
let g inputs = f $ fromADInputs inputs
in valueFun g (toDomains vals)
rev :: ( HasDelta r, IsPrimalAndHasFeatures 'ADModeGradient a r
, Adaptable 'ADModeGradient r advals vals )
=> (advals -> ADVal 'ADModeGradient a) -> vals -> vals
rev f vals =
let g inputs = f $ fromADInputs inputs
in fromDomains $ fst $ revFun 1 g (toDomains vals)
fwd :: ( Numeric r, Dual 'ADModeDerivative r ~ r
, Dual 'ADModeDerivative a ~ a
, Adaptable 'ADModeDerivative r advals vals )
=> (advals -> ADVal 'ADModeDerivative a) -> vals -> vals -> a
fwd f x ds =
let g inputs = f $ fromADInputs inputs
in fst $ fwdFun (toDomains x) g (toDomains ds)
-- Inspired by adaptors from @tomjaguarpaw's branch.
type Adaptable d r advals vals =
(AdaptableDomains r vals, AdaptableInputs d r advals)
@awf, may I repeat the question what the intended type of dot
in https://github.com/Mikolaj/horde-ad/issues/64#issuecomment-1245374857 is?
I have a version ported from
at
but it can't be that simple, can it?
So I believe we should have an operation called dot_general
, following https://www.tensorflow.org/xla/operation_semantics#dotgeneral; but my original example should maybe use matmul
This is now mostly done, with results discussed in #75 (please discuss some more and open more tickets). Thank you! Closing.
Part of README #53