google-research / dex-lang

Research language for array processing in the Haskell/ML family
BSD 3-Clause "New" or "Revised" License
1.59k stars 107 forks source link

Q: What types can be used with autodiff? #454

Closed srush closed 3 years ago

srush commented 3 years ago

I'm curious what type of functions one should be able grad over. i.e. what is the implicit restriction on a

def grad (f:a->Float) (x:a) : a = snd (vjp f x) 1.0

Currently it seems to work for tables and tuples, but other things crash for me (for instance custom data types).

Was playing around with something Flax-like for grouping params and functions, but I think this might be the wrong path given https://github.com/google-research/dex-lang/issues/331 and because I am not smart enough to figure out how to unpack a tuple a Params to a Param of a tuples.

data Param a =                                                                                            
  AsParam a                                                                                            
  Init ParamInit 

def Layer (inp:Type) (out:Type) (param:Type) : Type=                                                      
  {forward:({key: Key & params:param} -> inp -> out) &                                                    
   params: param &                                                                                        
   init: Bool }   
apaszke commented 3 years ago

Yeah, base types (Float), tables, tuples and records are the only things that will work at the moment I think. Adding support for general ADTs is on our roadmap, but it is a bit subtle, because the type of e.g. jvp really is

jvp :: (a -> b) -> (a, Tangent a) -> (b, Tangent b)

with Tangent being a function that takes a to its tangent type. Now, for floats and tables we have Tangent a == a. For Int, Bool we have Tangent a == Unit. For tuples and records, you apply the Tangent map to each member type. For example Tangent (Int & Float & n=>Int & m=>Float) = (Unit & Float & n=>Unit & m=>Float)

Supporting ADTs is on the roadmap, but is a bit more involved. Below are two reasons that we know of. If you have any thoughts or suggestions do let us know!

ADTs have named constructors

Unlike tuples and records, ADTs have named constructors, with names given by the users. This is a problem, because the Tangent we need for an ADT might no longer fit the same definition. Consider this definition:

data MyADT = MyConstructor Float Int

Now, what should its tangent type be? We could "downgrade it" into a tuple of (Float & Unit), but that seems quite arbitrary and would get very ugly with nested ADTs. Alternatively, we could make it so that each data definition additionally introduces another ADT (e.g. MyADTTangent), but that seems quite ugly.

In short, we'd like the tangent types of ADTs to have some correspondence to the names given by the user to the original ADT. On the other hand, we don't want to put the burden of writing another, extremely similar ADT by hand.

ADTs can be sum types

Differentiating through sum types is even more involved, because then the Tangent mapping actually becomes dependent. That is, the type of jvp would have to be

jvp :: (a -> b) -> (x:a ** Tangent x) -> (y:b ** Tangent y)

where ** is the dependent pair constructor. This is because the Tangent one provides for a sum-type argument has to use exactly the same constructor/case as that argument.

Of course an alternative here is to make jvp partial (i.e. allow any case, but throw an error when the wrong one is given), which we might do initially for implementation simplicity.

srush commented 3 years ago

As always this answer is so good, and taught me so much more than the question I asked. Your time is highly appreciated.

dan-zheng commented 3 years ago

I'd like to extend @apaszke's excellent answer about associated Tangent types!

To also address the question: what types work with differentiation?

An analogy

There's an analogy between types in programming languages and mathematical spaces. Consider the type of jvp from above:

jvp :: (a -> b) -> (x:a ** Tangent x) -> (y:b ** Tangent y)

I'll reuse the (x:a ** B x) syntax for dependent sum types from above. For those unfamiliar: a dependent sum type can be thought of as a "generalized pair type", where the second type (B x) can depend on a value in the first type (x:a).

Programming languages

Differential geometry (math)

def gradient_descent(parameters, gradient):
  """Performs gradient descent, updating parameters given their gradients.

  Gradient descent effectively has the type signature of exponential map.
  This implementation updates parameters in-place, as done in modern deep learning frameworks.
  You could imagine an equivalent functional implementation that returns new parameters.

  Args:
    parameters: bundle of arrays. Like `model.parameters()` from PyTorch.
      gradient: bundle of tangent arrays. Like `[p.grad for p in model.parameters()]` from PyTorch.
  """
  parameters -= learning_rate * gradient

Putting it together

Let's visualize this analogy.

"First-class Differentiable Programming" @ Probabilistic & Differentiable Programming Summit, June 2019

Visualization: differentiable manifolds, differentiable function and differential function, exponential map.

We can think of _types in programming languages_ as _mathematical spaces_. * A function from type `T` to type `U` can be drawn as a function between two manifolds. * Points on the manifolds have a tangent vector space, and a differential function can be drawn between tangent vector spaces. * We can see the types of the differential function and the exponential map operation in the slides below. ![Differentiable manifolds and differential function](https://user-images.githubusercontent.com/5590046/105540666-8e8eca00-5cc4-11eb-88e0-2d0635afeefd.png) ![Differentiable manifolds and exponential map](https://user-images.githubusercontent.com/5590046/105540674-90f12400-5cc4-11eb-8ba2-0fdf22a083e3.png)

"Intro to Differentiable Swift" @ Swift for TensorFlow Open Design Meeting, March 2020

Animated visualization ✨: what types work with differentiation? From differentiable manifolds to a Differentiable interface.

Simple script for the animation: * Manifolds have points, each point has an associated tangent vector space, we can move a point along a tangent vector to get a new point (via exponential map). * This helps us answer the question: what types work with differentiation? * We can design a `Differentiable` interface (a Swift _protocol_) with two requirements: a `TangentVector` associated type and an exponential map operation called `move(along:)`. ![differentiable-manifolds-to-differentiable-protocol](https://user-images.githubusercontent.com/5590046/105580878-0f99a000-5d5d-11eb-96c9-8102cba3be75.gif) If you liked this animation: they come from [these slides](http://bit.ly/swift-autodiff-intro), which provide more context and detail — including a [differentiable physics demo](https://colab.sandbox.google.com/github/tensorflow/swift/blob/main/notebooks/talk_demos/Differentiable_Physics.ipynb)!

Definition of "Pushforward (differential)" @ Wikipedia

Visualization: a differentiable map between two manifolds, and its pushforward.


Image showing two differentiable manifolds: a sphere and a spheroid, from https://en.wikipedia.org/wiki/Pushforward_(differential).
If a map, φ, carries every point on manifold M to manifold N, then the pushforward of φ carries vectors in the tangent space at every point in M to a tangent space at every point in N.

Reverse-mode differentiation?

A long aside

So far, we've talked about differentiable functions and differential functions, which are _pushforward linear approximation functions_. Differential functions implement forward-mode differentiation. What about reverse-mode differentiation? Reverse-mode seems to require _backpropagator_ [pullback functions](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)): starting with partial derivatives at outputs and ending by computing partial derivatives at inputs. There is a correspondence between forward-mode and reverse-mode differentiation. We can view the two modes as different _associations_ of the multiplications in the [chain rule of differentiation](https://en.wikipedia.org/wiki/Chain_rule):

Visualization: chain rule
From the ["Automatic differentiation" section of the Swift Differentiable Programming Manifesto](https://github.com/apple/swift/blob/main/docs/DifferentiableProgramming.md#automatic-differentiation): > Mathematically, forward-mode AD corresponds to a fully-right association of the chain rule of differentiation, and reverse-mode AD corresponds to a fully-left association. Different associations of the chain rule produce the same result but may differ in computational complexity†.


Top: fully-right association of chain rule, starting from partial derivative of input; "forward-mode".
Bottom: fully-left association of chain rule, starting from output; "reverse-mode".

> †: Finding the optimal association of the chain rule of differentiation is analogous to the [matrix chain multiplication](https://en.wikipedia.org/wiki/Matrix_chain_multiplication) problem and can be solved in `O(n^3)` time. More efficient algorithms also exist.
It is possible to write a program transformation that _transposes_ linear maps, turning differentials into pullbacks and pullbacks into differentials. This _transposition_ transformation has been explored, check out the resources below to learn more! * ["A new trick for calculating Jacobian vector products"](https://j-towns.github.io/2017/06/12/A-new-trick.html), @j-towns's 2017 blog post > Later in the thread we were discussing another very specific use case for forward mode, that of computing generalised Gauss Newton matrix-vector products, when we happened upon a new trick: _a method for calculating jvps by composing two reverse mode vjps_! This could render specialised code for forward mode redundant. The trick is simple. I’ll demonstrate it first mathematically and then with Theano code. * ["Decomposing reverse-mode automatic differentiation"](https://popl21.sigplan.org/details/lafi-2021-papers/9/Decomposing-reverse-mode-automatic-differentiation) @ [LAFI 2021](https://popl21.sigplan.org/home/lafi-2021): abstract and slides available. * JAX and Dex both implement reverse-mode autodiff via a transposition transformation on linear maps. * Are there formalisms relating "transposition the program transformation" to _"the denotation of linear maps"_? * This could pave the road for "The Simple Essence of Linear Maps", mentioning duals/transposes.


Sorry this got so long, it condenses two years of my learnings - many unmentioned acknowledgements.

Maybe I should post it elsewhere.

conal commented 3 years ago

@dan-zheng asked me for some feedback on this thread.

The answers above seems to embody a unfortunate and popular perspective, namely that forward mode and reverse mode AD are different questions requiring different vocabulary and techniques. Instead, a single, simple notion of differentiation and a single API suffice, and a single simple AD algorithm can handle forward, reverse, and other mixed modes with ease and without complicated operational details like graphs, mutation, and “backpropagation”. Instead of changing the algorithm, choose a suitable representation of linear maps. A good choice for low-dimensional domains is functions that are linear, while a good choice for low-dimensional codomains is the dual of such functions, where the fundamental building blocks of functions are defined dually, with composition reversed, projections become injections and vice versa, duplication and combination (addition) trade places, and curried scalar multiplication becomes itself. See The simple essence of automatic differentiation for details, including proofs. The algorithm is calculated from a simple, precise specification by solving a standard collection of algebra problems. The Microsoft Research talk is probably the most accessible explanation.

Another unfortunate choice in the first formulation above is the type of jvp, in which A × Tangent A → B × Tangent B suggests that the (a) the result primal value can depend on the input vector, which it must not, and (b) the derivative of a function at a point might not be a linear map (which it must be by definition). Both shortcomings are easily fixed as follows:

  1. Move the result vector out of dependency on input value: A → B × (Tangent A → Tangent B).
  2. Restrict the second general function dependency to be linear: A → B × (Tangent A ⊸ Tangent B).

With these two changes, your API would become more precise, i.e., you’ve statically eliminated many invalid representations. The remaining invalid representations can be eliminated via dependent types.

These changes also lead to fixing the first serious problem I mentioned above of treating various “modes” of differentiation as if they were different questions (specifications), rather than different answers (algorithms). The key is in realizing that there are many valid linear map representations you can use for Tangent A ⊸ Tangent B. Two very simple choices give you correct-by-construction (see the paper link above), intrinsically parallel-friendly AD good for different dimensionalities. Another choice when input and output dimensions are similar is matrices, preferably modernized to be safe and compositional (not arrays/“tensors”) for the post-Fortran era.

srush commented 3 years ago

Thanks @dan-zheng and @conal this is really interesting. Since there seem to be a bunch of people following along, going to make a study thread here to discuss this paper. https://github.com/google-research/dex-lang/discussions/494

High-level: it sounds like many of these ideas may be out-of-scope for the type system of Dex? And there is a more practical question of "how to auto-define and name simple tangent types". However, it still feels really important.

apaszke commented 3 years ago

@conal Thanks for the feedback. Before I get to the technical part of my answer, I wanted to ask you to limit yourself from expressing judgements on anyone discussing any topic on our issue tracker. We’re trying to build an inclusive community and welcome people from many backgrounds. In particular we don’t care if they want to understand the process of differentiation in terms of graph traversals, backpropagation, or category theory. I’m sure you too could learn a lot from them, if only you open yourself to their perspective. I’m aware that it’s quite easy to get misread as you post things online and I’m sure that you’re writing your comments in good faith, but please be careful about how your message can be understood by others.

Moving on to technical material. We completely agree with your suggestions (1. and 2.) and it is in fact how AD is implemented in Dex. The builtin functions we expose for that purpose are:

linearize : (a -> b) -> a -> (b & Tangent a -o Tangent b)
transpose  : (a -o b) -> b -o a

Note that our type system even features a linear arrow that can verify that user-defined functions are truly linear and transposable. jvp is just a little helper defined in lib/prelude.dx, because it’s a well-known function with a pretty convenient signature.

Finally, I think it is worth noting that this approach is no silver bullet, which is likely the reason why many AD systems that do care a lot about forward-mode performance cannot take the path you’re outlining. In particular, just like one can prove theorems that forward- and reverse-mode AD can be carried out in the same order of computational complexity as the input program, forward-mode has the additional benefit of being able to preserve the same order of memory complexity. But, this is conditional on being able to produce a program where the evaluation of the non-linear function is interleaved with the linear part, which is far from easy when linearization is considered fundamental (it would require whole program optimization and very aggressive code motion in many program representations). See our LAFI abstract for an outline of how the ideas you’re suggesting can be pushed even further to alleviate this issue (the gist of it is that we actually do make jvp the fundamental operation, and recover linearize from it).

conal commented 3 years ago

@apaszke Thanks for this response. Message received about tone. I originally wrote these notes just for @dan-zheng as a response to his inquiry and a follow-on to some of our past conversations. I regret sending the notes as they were to a group with whom I don’t have such a shared context.

We completely agree with your suggestions (1. and 2.) and it is in fact how AD is implemented in Dex. The builtin functions we expose for that purpose are:

linearize : a -> (b & Tangent a -o Tangent b)
transpose  : (a -o b) -> b -o a

Great. It sounds like we’re closer than I thought. Seeing this explicit transpose, however, I suspect that we are not quite talking about the same thing here, so I’ll ask some clarifying questions. Are you assuming a particular representation of linear maps (T a -o T b), and if so what? Correspondingly, by transposition, do you mean an operation on a particular representation or something more abstract/mathematical/algebraic, say in terms of linear maps rather than matrices?

I think I’m suggesting something different, which is to have only what you call “linearize”—which I call constructing a computably differentiable function—but parametrized over the representation for linear maps (-o). (The effect of transposition is instead achieved when desired by choice of linear map representation.) Forward mode can then represent T a -o T b as the linear subset of T a -> T b, while reverse mode can represent it as the linear subset of T b -> T a, which itself really represents the linear subset of (T b -o s) -> (T a -o s) where s is the scalar field shared by T a and T b. Of course matrices can also suitable for some domain+codomain types. All linear map representations use the exact same algebraic vocabulary for building linear maps (which happens to be exactly the general vocabulary of biproduct categories plus curried scalar multiplication). The definitions of that vocabulary for forward and reverse modes are dual to each other. All choices of representation & definitions must satisfy the associated laws, as necessary and sufficient for correctness. Then forward, reverse, and mixed modes are all the same algorithm, but with different linear map representations and no need for an explicit transposition step.

… forward-mode has the additional benefit of being able to preserve the same order of memory complexity. But, this is conditional on being able to produce a program where the evaluation of the non-linear function is interleaved with the linear part, which is far from easy when linearization is considered fundamental (it would require whole program optimization and very aggressive code motion in many program representations).

The reason I’m aware of for combining the primal function (of type a -> b) and tangent function (of type a -> (T a -o T b)) into a single function (of type a -> b × (T a -o T b), which is isomorphic to (a -> b) × (a -> (T a -o T b))) is exactly to be able to share computation. I found this sharing to be very simple, without any optimization effort on my part (although I am piggybacking on a fairly good Haskell compiler). I note, however, that you’re specifically talking about memory complexity, and I wonder if you’ve noticed something important that I haven’t.

apaszke commented 3 years ago

Thanks for explaining! I now see how what you are proposing is slightly different than what we do. I'll try to paraphrase your point and describe how it compares to our approach, but of course please do point out any inaccuracies in my comparison.

If I understand your point correctly, you say the vocabulary that transforms and composes the linear maps in both forward- and reverse-mode is the same, and I agree that it can be made so (as you carefully outline in your paper). I would be tempted to say that there is a type-class your linear map representation has to implement in order to be compatible with the process of differentiation.

Going down that path has the benefit of using a single program transformation for both modes, but the downside of using two sets of rules (as you have to implement the type-class twice, once for each of the two linear map types). But, this largely misses out on the close correspondence between the rules used to perform forward- and reverse-mode. In our own jargon, we like to say that forward-mode rules implement JVPs (jacobian-vector products), while reverse-mode rules implement VJPs (vector-jacobian products). I like those names much, because they highlight their relation: each one is a transposed version of the other. Because of that, in Dex the only AD mode we really support (and have to implement rules for) is forward-mode, while reverse-mode is obtained not via reusing the same differentiation process with a different rule set, but through a program transposition transformation that is always valid to perform on the functions produced by forward-mode AD.

So yes, for the purpose of differentiation we do assume a particular representation of linear maps, which in our case is encoded in what we call a structurally linear program (this also has some interesting connections to linear logic as @dougalm wonderfully explained in one of our issues). But this doesn't prevent us from getting reverse-mode in the end, because we've simply found a path that doesn't require us to reengage the AD machinery, with the added benefit of having a significantly smaller rule set than necessary for both AD modes.


About your second question, this is mostly not about sharing (which is critical too, but not precluded by the signature a -> b × (T a -o T b)), but about the order in which the different operations are carried out. Consider the implementation of jvp via linearization:

linearize :: (a -> b) -> (a -> (b, T a -o T b))

jvp :: (a -> b) -> (a, T a) -> (b, T b)
jvp f (x, t) =
  let (y, f') = linearize f x
   in (y, f' t)

If the function jvp was to be executed in a language with eager evaluation, this would mean that the original mapping a -> b) had to be computed in its entirety before the evaluation of f' could even begin. Then, precisely because of the sharing that you've mentioned, many of intermediate values computed in the body of f would have to be kept in memory until f' consumes them. Because the number of those intermediates cannot be bounded (it depends on f and can be arbitrarily high), we cannot derive any good bound on the memory consumption of a program differentiated in this way in such a language. But, we know that a memory bound can be proven (to be at most twice what f requires), if only we interleave the evaluation of the original computation with the linearization of each step. A sufficiently smart compiler might be able to re-optimize this code to achieve that, but it is not guaranteed.