Closed srush closed 3 years ago
Yeah, base types (Float), tables, tuples and records are the only things that will work at the moment I think. Adding support for general ADTs is on our roadmap, but it is a bit subtle, because the type of e.g. jvp
really is
jvp :: (a -> b) -> (a, Tangent a) -> (b, Tangent b)
with Tangent
being a function that takes a
to its tangent type. Now, for floats and tables we have Tangent a == a
. For Int
, Bool
we have Tangent a == Unit
. For tuples and records, you apply the Tangent
map to each member type. For example Tangent (Int & Float & n=>Int & m=>Float) = (Unit & Float & n=>Unit & m=>Float)
Supporting ADTs is on the roadmap, but is a bit more involved. Below are two reasons that we know of. If you have any thoughts or suggestions do let us know!
Unlike tuples and records, ADTs have named constructors, with names given by the users. This is a problem, because the Tangent
we need for an ADT might no longer fit the same definition. Consider this definition:
data MyADT = MyConstructor Float Int
Now, what should its tangent type be? We could "downgrade it" into a tuple of (Float & Unit)
, but that seems quite arbitrary and would get very ugly with nested ADTs. Alternatively, we could make it so that each data
definition additionally introduces another ADT (e.g. MyADTTangent
), but that seems quite ugly.
In short, we'd like the tangent types of ADTs to have some correspondence to the names given by the user to the original ADT. On the other hand, we don't want to put the burden of writing another, extremely similar ADT by hand.
Differentiating through sum types is even more involved, because then the Tangent
mapping actually becomes dependent. That is, the type of jvp
would have to be
jvp :: (a -> b) -> (x:a ** Tangent x) -> (y:b ** Tangent y)
where **
is the dependent pair constructor. This is because the Tangent
one provides for a sum-type argument has to use exactly the same constructor/case as that argument.
Of course an alternative here is to make jvp
partial (i.e. allow any case, but throw an error when the wrong one is given), which we might do initially for implementation simplicity.
As always this answer is so good, and taught me so much more than the question I asked. Your time is highly appreciated.
I'd like to extend @apaszke's excellent answer about associated Tangent
types!
To also address the question: what types work with differentiation?
There's an analogy between types in programming languages and mathematical spaces.
Consider the type of jvp
from above:
jvp :: (a -> b) -> (x:a ** Tangent x) -> (y:b ** Tangent y)
I'll reuse the (x:a ** B x)
syntax for dependent sum types from above. For those unfamiliar: a dependent sum type can be thought of as a "generalized pair type", where the second type (B x
) can depend on a value in the first type (x:a
).
Tangent
type. jvp
takes a (x:a) -> (y:b)
function and returns a Tangent x -> Tangent y
differential function (aka pushforward, rationale).x
) on a differentiable manifold, there is an associated tangent vector space (Tangent x
). Tangent vectors in the space represent small rates of change from the point.x
and a tangent vector v: Tangent x
, and moves the point along the tangent vector to get a new point x'
.(x: a ** v: Tangent x) -> a
. This is actually the type of gradient descent as a function!def gradient_descent(parameters, gradient):
"""Performs gradient descent, updating parameters given their gradients.
Gradient descent effectively has the type signature of exponential map.
This implementation updates parameters in-place, as done in modern deep learning frameworks.
You could imagine an equivalent functional implementation that returns new parameters.
Args:
parameters: bundle of arrays. Like `model.parameters()` from PyTorch.
gradient: bundle of tangent arrays. Like `[p.grad for p in model.parameters()]` from PyTorch.
"""
parameters -= learning_rate * gradient
Let's visualize this analogy.
We can think of _types in programming languages_ as _mathematical spaces_. * A function from type `T` to type `U` can be drawn as a function between two manifolds. * Points on the manifolds have a tangent vector space, and a differential function can be drawn between tangent vector spaces. * We can see the types of the differential function and the exponential map operation in the slides below. ![Differentiable manifolds and differential function](https://user-images.githubusercontent.com/5590046/105540666-8e8eca00-5cc4-11eb-88e0-2d0635afeefd.png) ![Differentiable manifolds and exponential map](https://user-images.githubusercontent.com/5590046/105540674-90f12400-5cc4-11eb-8ba2-0fdf22a083e3.png)
Differentiable
interface.Simple script for the animation: * Manifolds have points, each point has an associated tangent vector space, we can move a point along a tangent vector to get a new point (via exponential map). * This helps us answer the question: what types work with differentiation? * We can design a `Differentiable` interface (a Swift _protocol_) with two requirements: a `TangentVector` associated type and an exponential map operation called `move(along:)`. ![differentiable-manifolds-to-differentiable-protocol](https://user-images.githubusercontent.com/5590046/105580878-0f99a000-5d5d-11eb-96c9-8102cba3be75.gif) If you liked this animation: they come from [these slides](http://bit.ly/swift-autodiff-intro), which provide more context and detail — including a [differentiable physics demo](https://colab.sandbox.google.com/github/tensorflow/swift/blob/main/notebooks/talk_demos/Differentiable_Physics.ipynb)!
Image showing two differentiable manifolds: a sphere and a spheroid, from
https://en.wikipedia.org/wiki/Pushforward_(differential).
If a map, φ, carries every point on manifold M to manifold N, then the
pushforward of φ carries vectors in the tangent space at every point in M to
a tangent space at every point in N.
So far, we've talked about differentiable functions and differential functions, which are _pushforward linear approximation functions_.
Differential functions implement forward-mode differentiation. What about reverse-mode differentiation? Reverse-mode seems to require _backpropagator_ [pullback functions](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)): starting with partial derivatives at outputs and ending by computing partial derivatives at inputs.
There is a correspondence between forward-mode and reverse-mode differentiation. We can view the two modes as different _associations_ of the multiplications in the [chain rule of differentiation](https://en.wikipedia.org/wiki/Chain_rule):
Visualization: chain rule
From the ["Automatic differentiation" section of the Swift Differentiable Programming Manifesto](https://github.com/apple/swift/blob/main/docs/DifferentiableProgramming.md#automatic-differentiation):
> Mathematically, forward-mode AD corresponds to a fully-right association of the chain rule of differentiation, and reverse-mode AD corresponds to a fully-left association. Different associations of the chain rule produce the same result but may differ in computational complexity†.
Top: fully-right association of chain rule, starting from partial derivative of input; "forward-mode".
Bottom: fully-left association of chain rule, starting from output; "reverse-mode".
Sorry this got so long, it condenses two years of my learnings - many unmentioned acknowledgements.
@dan-zheng asked me for some feedback on this thread.
The answers above seems to embody a unfortunate and popular perspective, namely that forward mode and reverse mode AD are different questions requiring different vocabulary and techniques. Instead, a single, simple notion of differentiation and a single API suffice, and a single simple AD algorithm can handle forward, reverse, and other mixed modes with ease and without complicated operational details like graphs, mutation, and “backpropagation”. Instead of changing the algorithm, choose a suitable representation of linear maps. A good choice for low-dimensional domains is functions that are linear, while a good choice for low-dimensional codomains is the dual of such functions, where the fundamental building blocks of functions are defined dually, with composition reversed, projections become injections and vice versa, duplication and combination (addition) trade places, and curried scalar multiplication becomes itself. See The simple essence of automatic differentiation for details, including proofs. The algorithm is calculated from a simple, precise specification by solving a standard collection of algebra problems. The Microsoft Research talk is probably the most accessible explanation.
Another unfortunate choice in the first formulation above is the type of jvp
, in which A × Tangent A → B × Tangent B
suggests that the (a) the result primal value can depend on the input vector, which it must not, and (b) the derivative of a function at a point might not be a linear map (which it must be by definition). Both shortcomings are easily fixed as follows:
A → B × (Tangent A → Tangent B)
.A → B × (Tangent A ⊸ Tangent B)
.With these two changes, your API would become more precise, i.e., you’ve statically eliminated many invalid representations. The remaining invalid representations can be eliminated via dependent types.
These changes also lead to fixing the first serious problem I mentioned above of treating various “modes” of differentiation as if they were different questions (specifications), rather than different answers (algorithms). The key is in realizing that there are many valid linear map representations you can use for Tangent A ⊸ Tangent B
. Two very simple choices give you correct-by-construction (see the paper link above), intrinsically parallel-friendly AD good for different dimensionalities. Another choice when input and output dimensions are similar is matrices, preferably modernized to be safe and compositional (not arrays/“tensors”) for the post-Fortran era.
Thanks @dan-zheng and @conal this is really interesting. Since there seem to be a bunch of people following along, going to make a study thread here to discuss this paper. https://github.com/google-research/dex-lang/discussions/494
High-level: it sounds like many of these ideas may be out-of-scope for the type system of Dex? And there is a more practical question of "how to auto-define and name simple tangent types". However, it still feels really important.
@conal Thanks for the feedback. Before I get to the technical part of my answer, I wanted to ask you to limit yourself from expressing judgements on anyone discussing any topic on our issue tracker. We’re trying to build an inclusive community and welcome people from many backgrounds. In particular we don’t care if they want to understand the process of differentiation in terms of graph traversals, backpropagation, or category theory. I’m sure you too could learn a lot from them, if only you open yourself to their perspective. I’m aware that it’s quite easy to get misread as you post things online and I’m sure that you’re writing your comments in good faith, but please be careful about how your message can be understood by others.
Moving on to technical material. We completely agree with your suggestions (1. and 2.) and it is in fact how AD is implemented in Dex. The builtin functions we expose for that purpose are:
linearize : (a -> b) -> a -> (b & Tangent a -o Tangent b)
transpose : (a -o b) -> b -o a
Note that our type system even features a linear arrow that can verify that user-defined functions are truly linear and transposable. jvp
is just a little helper defined in lib/prelude.dx
, because it’s a well-known function with a pretty convenient signature.
Finally, I think it is worth noting that this approach is no silver bullet, which is likely the reason why many AD systems that do care a lot about forward-mode performance cannot take the path you’re outlining. In particular, just like one can prove theorems that forward- and reverse-mode AD can be carried out in the same order of computational complexity as the input program, forward-mode has the additional benefit of being able to preserve the same order of memory complexity. But, this is conditional on being able to produce a program where the evaluation of the non-linear function is interleaved with the linear part, which is far from easy when linearization is considered fundamental (it would require whole program optimization and very aggressive code motion in many program representations). See our LAFI abstract for an outline of how the ideas you’re suggesting can be pushed even further to alleviate this issue (the gist of it is that we actually do make jvp
the fundamental operation, and recover linearize
from it).
@apaszke Thanks for this response. Message received about tone. I originally wrote these notes just for @dan-zheng as a response to his inquiry and a follow-on to some of our past conversations. I regret sending the notes as they were to a group with whom I don’t have such a shared context.
We completely agree with your suggestions (1. and 2.) and it is in fact how AD is implemented in Dex. The builtin functions we expose for that purpose are:
linearize : a -> (b & Tangent a -o Tangent b) transpose : (a -o b) -> b -o a
Great. It sounds like we’re closer than I thought. Seeing this explicit transpose
, however, I suspect that we are not quite talking about the same thing here, so I’ll ask some clarifying questions. Are you assuming a particular representation of linear maps (T a -o T b
), and if so what? Correspondingly, by transposition, do you mean an operation on a particular representation or something more abstract/mathematical/algebraic, say in terms of linear maps rather than matrices?
I think I’m suggesting something different, which is to have only what you call “linearize
”—which I call constructing a computably differentiable function—but parametrized over the representation for linear maps (-o
). (The effect of transposition is instead achieved when desired by choice of linear map representation.) Forward mode can then represent T a -o T b
as the linear subset of T a -> T b
, while reverse mode can represent it as the linear subset of T b -> T a
, which itself really represents the linear subset of (T b -o s) -> (T a -o s)
where s
is the scalar field shared by T a
and T b
. Of course matrices can also suitable for some domain+codomain types. All linear map representations use the exact same algebraic vocabulary for building linear maps (which happens to be exactly the general vocabulary of biproduct categories plus curried scalar multiplication). The definitions of that vocabulary for forward and reverse modes are dual to each other. All choices of representation & definitions must satisfy the associated laws, as necessary and sufficient for correctness. Then forward, reverse, and mixed modes are all the same algorithm, but with different linear map representations and no need for an explicit transposition step.
… forward-mode has the additional benefit of being able to preserve the same order of memory complexity. But, this is conditional on being able to produce a program where the evaluation of the non-linear function is interleaved with the linear part, which is far from easy when linearization is considered fundamental (it would require whole program optimization and very aggressive code motion in many program representations).
The reason I’m aware of for combining the primal function (of type a -> b
) and tangent function (of type a -> (T a -o T b)
) into a single function (of type a -> b × (T a -o T b)
, which is isomorphic to (a -> b) × (a -> (T a -o T b))
) is exactly to be able to share computation. I found this sharing to be very simple, without any optimization effort on my part (although I am piggybacking on a fairly good Haskell compiler). I note, however, that you’re specifically talking about memory complexity, and I wonder if you’ve noticed something important that I haven’t.
Thanks for explaining! I now see how what you are proposing is slightly different than what we do. I'll try to paraphrase your point and describe how it compares to our approach, but of course please do point out any inaccuracies in my comparison.
If I understand your point correctly, you say the vocabulary that transforms and composes the linear maps in both forward- and reverse-mode is the same, and I agree that it can be made so (as you carefully outline in your paper). I would be tempted to say that there is a type-class your linear map representation has to implement in order to be compatible with the process of differentiation.
Going down that path has the benefit of using a single program transformation for both modes, but the downside of using two sets of rules (as you have to implement the type-class twice, once for each of the two linear map types). But, this largely misses out on the close correspondence between the rules used to perform forward- and reverse-mode. In our own jargon, we like to say that forward-mode rules implement JVPs (jacobian-vector products), while reverse-mode rules implement VJPs (vector-jacobian products). I like those names much, because they highlight their relation: each one is a transposed version of the other. Because of that, in Dex the only AD mode we really support (and have to implement rules for) is forward-mode, while reverse-mode is obtained not via reusing the same differentiation process with a different rule set, but through a program transposition transformation that is always valid to perform on the functions produced by forward-mode AD.
So yes, for the purpose of differentiation we do assume a particular representation of linear maps, which in our case is encoded in what we call a structurally linear program (this also has some interesting connections to linear logic as @dougalm wonderfully explained in one of our issues). But this doesn't prevent us from getting reverse-mode in the end, because we've simply found a path that doesn't require us to reengage the AD machinery, with the added benefit of having a significantly smaller rule set than necessary for both AD modes.
About your second question, this is mostly not about sharing (which is critical too, but not precluded by the signature a -> b × (T a -o T b)
), but about the order in which the different operations are carried out. Consider the implementation of jvp
via linearization:
linearize :: (a -> b) -> (a -> (b, T a -o T b))
jvp :: (a -> b) -> (a, T a) -> (b, T b)
jvp f (x, t) =
let (y, f') = linearize f x
in (y, f' t)
If the function jvp
was to be executed in a language with eager evaluation, this would mean that the original mapping a -> b
) had to be computed in its entirety before the evaluation of f'
could even begin. Then, precisely because of the sharing that you've mentioned, many of intermediate values computed in the body of f
would have to be kept in memory until f'
consumes them. Because the number of those intermediates cannot be bounded (it depends on f
and can be arbitrarily high), we cannot derive any good bound on the memory consumption of a program differentiated in this way in such a language. But, we know that a memory bound can be proven (to be at most twice what f
requires), if only we interleave the evaluation of the original computation with the linearization of each step. A sufficiently smart compiler might be able to re-optimize this code to achieve that, but it is not guaranteed.
I'm curious what type of functions one should be able
grad
over. i.e. what is the implicit restriction ona
def grad (f:a->Float) (x:a) : a = snd (vjp f x) 1.0
Currently it seems to work for tables and tuples, but other things crash for me (for instance custom
data
types).Was playing around with something Flax-like for grouping params and functions, but I think this might be the wrong path given https://github.com/google-research/dex-lang/issues/331 and because I am not smart enough to figure out how to unpack a tuple a Params to a Param of a tuples.