google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

Calculate `jvp` for intermediate variables in a function #12782

Open exenGT opened 1 year ago

exenGT commented 1 year ago

Hi JAX developers,

I'm wondering if it's possible to calculate jvp for intermediate variables defined in a function. Here's an example:

def f(x, params_1, params_2, params_3):

    W1, b1 = params_1
    z1 = jnp.tanh(jnp.dot(x, W1) + b1)

    W2, b2 = params_2
    z2 = jnp.tanh(jnp.dot(z1, W2) + b2)

    W3, b3 = params_3
    y = jnp.dot(z2, W3) + b3

    return y

The quantity I would like to calculate is the jvp (∂z2/∂z1)v. I guess that this may be available as some intermediate step in calculating the jvp of f(), but am not sure how to retrieve it.

JW

mattjj commented 1 year ago

Thanks for the question!

It's not exactly computed as an intermediate, but how about this?

def f(x, zero_like_v, params_1, params_2, params_3):

    W1, b1 = params_1
    z1 = jnp.tanh(jnp.dot(x, W1) + b1) + zero_like_v

    W2, b2 = params_2
    z2 = jnp.tanh(jnp.dot(z1, W2) + b2)

    W3, b3 = params_3
    y = jnp.dot(z2, W3) + b3

    return y, z2

(y, ydot), (_, dz2_dz1_v) = jax.jvp(
    f,
    (x, jnp.zeros_like(v), params1, params2, params3),
    (x_dot, v, params1_dot, params2_dot, params3_dot))

If I understood your notation correctly, I think dz2_dz1_v is the quantity you're after, assuming you meant v to be an arbitrary vector with commensurate shape (i.e. shaped like z1).

WDYT? If this isn't right, maybe you could unpack your notation a bit more?

exenGT commented 1 year ago

Thank you so much @mattjj ! I have to say that I'm not fully understanding your code yet, so I'll give myself a bit of time to digest it. You are right that in my notation, v is an arbitrary vector with the same shape as z1. In fact, I'm trying to implement an algorithm which requires explicitly calculating the jvp for every adjacent pair of layers z_out and z_in in a neural network (different from the usual backpropagation, which only requires the jvp of f with respect to the arguments). I'm wondering if there is a general (i.e. less "ad hoc" way) to do this. The network model I have in mind is the simple MLP under the section "Example: ML model parameters" in this link. Hopefully this is a bit clearer than my previous post!

JW

patrick-kidger commented 1 year ago

Note that usual backpropagation is a vjp, not a jvp. Mathematically you get the derivative d(loss)/d(parameters) either way, but for typical neural network problems the vjp is much more efficient.

exenGT commented 1 year ago

Thank you @patrick-kidger ! Yeah, I realized that if you want to propagate backwards, then vjp is preferred over jvp. Either way, I'm trying to find a way to calculate the quantity for any two adjacent layers in the neural network.

mattjj commented 1 year ago

Well, whether you want to go forward or backward may just come down to whether for your algorithm you need to evaluate things of the form v -> v'J or v -> J v, where J represents the Jacobian between some z_in and z_out. The former is naturally adapted to reverse-mode (i.e. it's by definition a VJP) and the latter to forward-mode (a JVP). We only have the intuition that VJPs are more efficient for neural network problems because we most often use autodiff to optimize a scalar-valued loss function, so we're naturally pulling back a covector from T*R to T*R^n to get the locally linearized level sets of the loss function in parameter space.

In other words, the algorithm you want to implement probably determines whether you want JVPs or VJPs. Or both! For example, if you're doing anything Gauss-Newton like, you probably want both.

If you could say more about the algorithm, we might be able to figure out more together.

patrick-kidger commented 1 year ago

Gauss-Newton like

Tangential nit that the least-squares subproblems in GN are usually better solved via e.g. a QR solver, than via the normal equations. :)

mattjj commented 1 year ago

Sure if the problem is small enough for direct methods to be possible, but usually for neural net optimizations we have to rely on implicit access to these Jacobians, e.g. for a Krylov solve or other iterative scheme, or to build tractable approximations which then are amenable to direct methods, or to choose step sizes in a 1D subspace. As a concrete example of the latter, consider the technique described in Sec 6.4 of the K-FAC paper.

My point is just: I'm not making any claims about how to solve any linear systems, but rather observing that both JVPs and VJPs may be useful, for example in Gauss-Newton ish schemes like K-FAC.

exenGT commented 1 year ago

Thanks for your detailed explanation @mattjj ! I think it would be great if we could discuss the algorithm in more detail. What I have in mind is an algorithm to optimize the output, which involves calculating the quantity in question. As this would lead to a more problem-specific discussion, is it okay if we communicate via email or private messaging? Really appreciate your help!

jbuckman commented 6 months ago

What solution did you end up settling on?

I'm interested in doing this in order to gain more visibility into my model. I find it useful to occasionally measure gradient magnitudes, activations, etc., during my training process, in order to confirm that everything is healthy. It's easy enough to measure dl/dw for any given weight (where l is the loss, computed from y). But I also want to track things like dz1/dw, dl/dz1, dz2/dz1. This is quite difficult currently, especially since these terms might be scattered across multiple functions at different levels of abstraction. Is there a good solution?

(Since this monitoring is only occasional, and not part of the core update loop, it doesn't really matter to me how fast it is; I just want to be able to see the answer, regardless of whether it is a jvp or vjp or something else under the hood.)

exenGT commented 6 months ago

Hi @jbuckman,

Thanks for your question! I think a most straightforward way that I can think of is to write the "big" function (L(w)) as a collection of predefined small functions relying on the intermediate variables (z1(w), z2(z1), L(z2)); this way, you can let the big function return the Jacobians (or jvps) of each "small" function. Would that work in your situation?

jbuckman commented 6 months ago

I do not believe so. Let's say I want to compute the Jacobean of an intermediate variable wrt the output. Here is my current understanding of how I would go about it:


# These are my small functions
def f1(...): ...
def f2(...): ...
def f3(...): ...

# This is my big function
def g(x0):
  x1 = f1(x0)
  x2 = f2(x1)
  x3 = f3(x2)
  return x3

# This is my concrete input
concrete_x0 = jnp.arange([1., 2., 4.])

# I can straightforwardly compute dx3/dx0
dx3_dx0 = jacfwd(g)(concrete_x0)

# I can compute dx{i+1}/dxi for any i, but it is awkward, because
# in order to construct the input I need to partially duplicate the logic 
# from g.
dx1_dx0 = jacfwd(f1)(concrete_x0)
dx2_dx1 = jacfwd(f2)(f1(concrete_x0)) 
dx3_dx2 = jacfwd(f3)(f2(f1(concrete_x0)))

# I can compute dx3/dxi for any i, but it is even more awkward, because
# now I need to duplicate the logic from g even further in order to construct
# the portion of the function that "remains".

dx3_dx0 = jacfwd(g)(concrete_x0)
dx3_dx1 = jacfwd(lambda x1: f3(f2(x1)))(f1(concrete_x0))
dx3_dx2 = jacfwd(f3)(f2(f1(concrete_x0)))

You can see how, in order to compute e.g. dx3_dx1, I essentially need to entirely re-implement g. This is true for each intermediate-variable Jacobean I want to look at. This example has just 4 sets of activations x0 to x3, but a real NN might have hundreds.

In practice, this pattern quickly becomes unmaintainable. If g is complex, implementing each Jacobeans like this is a ton of work, and every time I want to make a change to g I need to go around and manually update every single Jacobean computation to match. It is a big opportunity for bugs -- easy to accidentally have a small difference between the true g and the part-of-g-after-x1 that I needed to construct to apply jacrev.

A few other things that would further add complexity:

But ultimately, I feel that it should be fairly easy to compute these partials dx3/dx0, dx3/dx1, dx3/dx2 that I am after: after all, a function like jacrev is already computing them the moment I invoke jacrev(g), as intermediate steps.

Is there a cleaner way to do this?