jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.01k stars 2.75k forks source link

Deferred pushforward evaluation in jvp #526

Closed Ajoo closed 5 years ago

Ajoo commented 5 years ago

Hi. I'm sorry if this question is due to a basic misunderstanding of how the AD in JAX works. I tried to work through the code but found it pretty tough to follow.

Basically, for vjps we get back a function to evaluate the pullback of a cotangent vector without re-evaluating the original function (and, I imagine, recomputing intermediate values). I'm wondering if there is something similar for jvps? I'd like to get back a function to evaluate the pushforward of a tangent vector at a specific point, ideally making use of intermediate values computed when evaluating the function at the chosen point. My goal is to evaluate the pushforward at the same point but for multiple tangent vectors but where these tangent vectors are not known a priori as they are in jacfwd, so vmap isn't an option. Is there a better way to do this than simply calling jvp every time?

I see that there is a linearize function that looks like it gives me the pushforward function I want but it's not documented so I'm wondering if it's meant for internal use?

Thank you in advance

mattjj commented 5 years ago

Thanks for asking this!

You're right, that's exactly what linearize is for. It composes forward-mode autodiff with partial evaluation, so that all the linearization points are stored (which costs memory, but means you don't have to re-do FLOPs for future JVP evaluations).

Everything in api.py (where the name doesn't start with an underscore) is public and meant to be used. We just haven't gotten to documenting some things.

@sschoenholz anything to add?

Maybe we can use this issue to add a docstring to linearize.

mattjj commented 5 years ago

Here's some quick example usage:

from __future__ import print_function

import jax.numpy as np
from jax import jvp, linearize

def f(x):
  return 3. * np.sin(x) + np.cos(x / 2.)

x = 2.
t1 = 1.
t2 = 3.

print(jvp(f, (x,), (t1,)))
print(jvp(f, (x,), (t2,)))  # nothing saved from first evaluation

y, f_jvp = linearize(f, x)
print(y)
print(f_jvp(t1))  # all the linearization work is already done!
print(f_jvp(t2))
mattjj commented 5 years ago

Here's the real fun: composing partial-evaluation and forward-mode is two out of the three steps for how JAX does reverse-mode. The last step is to transpose the (necessarily linear) computation.

In other words, linearize is just a composition of jvp with JAX's partial evaluation machinery, specifically partial_eval.trace_to_jaxpr. Then we have a general mechanism for transposing linear jaxprs. Reverse-mode vjp emerges for free as a combination of all three!

One thing that lets us do is compute Gauss-Newton products, of the form (x, v) -> J(x)' J(x) v, only doing the linearization work once, since we can transpose without re-linearizing. That's one of the things we plan to explain a bit more in our forthcoming "Autodiff Cookbook Part 2".

mattjj commented 5 years ago

EDIT: just noticed that the OP includes "My goal is to evaluate the pushforward at the same point but for multiple tangent vectors but where these tangent vectors are not known a priori as they are in jacfwd, so vmap isn't an option.", so it's clear that @Ajoo already understood this! I'm leaving this comment here for posterity.

In writing the docstring in #527 I realized an important point to make: if you know all the tangent vectors ahead of time, you should use jvp + vmap instead of linearize (and calling the result multiple times). The reasons are twofold:

  1. With linearize the memory cost scales with the depth of the computation, since it saves all linearization points (just like reverse-mode's vjp does), whereas for jvp (and hence also vmap'd jvp) the memory cost does not scale with the depth of the computation.
  2. Using vmap rather than having an outer-loop means you get the benefits of vectorization, e.g. matrix-vector multiplies turn into matrix-matrix multiplies, matrix-matrix multiplies turn into more general tensor dots; that means much faster code.

The place to use linearize is if you don't know all the tangent vectors ahead of time. Maybe you want to apply it iteratively.

Does that make sense?

EDIT Here's an example of vmap + jvp, which I included in the docstring for linearize to be submitted in #527:

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
Ajoo commented 5 years ago

Wow. Thank you for the quick and detailed reply! It was very helpful.

Funny you should mention Gauss-Newton products as that was just what I was trying to do (to optimize a function with a conjugate gradients method). I originally thought about including the question about whether it was possible to compute both the vjp and jvp without linearizing twice but I thought it might be asking too much :p Could you maybe point me in the right direction?

I actually hadn't noticed that the AutoDiff cookbook was out already. Really looking forward to that second part!

Thank you for the great work

mattjj commented 5 years ago

Thanks for the positive words!

The linearize-once GNVP is possible but it might take some mucking about in ad.py, which isn't easy to jump into. That said, in the internal vjp function (which is different from the wrapper exposed in api.py) you can see that it calls (the internal version of) linearize and then calls backward_pass, which does the transposing. One way to do it would be to return both vjp_ and a Python callable for the linearized function. Code like lift_linearized might show how to do the latter step.

That sounded a bit arcane, but let's discuss further! It would be so cool to get this.

mattjj commented 5 years ago

Let's follow up in #529.

Ajoo commented 5 years ago

Thank you for the pointers. I'll definitely look deeper into how to do this over the weekend!