Open exenGT opened 1 year ago
Thanks for the question!
It's not exactly computed as an intermediate, but how about this?
def f(x, zero_like_v, params_1, params_2, params_3):
W1, b1 = params_1
z1 = jnp.tanh(jnp.dot(x, W1) + b1) + zero_like_v
W2, b2 = params_2
z2 = jnp.tanh(jnp.dot(z1, W2) + b2)
W3, b3 = params_3
y = jnp.dot(z2, W3) + b3
return y, z2
(y, ydot), (_, dz2_dz1_v) = jax.jvp(
f,
(x, jnp.zeros_like(v), params1, params2, params3),
(x_dot, v, params1_dot, params2_dot, params3_dot))
If I understood your notation correctly, I think dz2_dz1_v
is the quantity you're after, assuming you meant v
to be an arbitrary vector with commensurate shape (i.e. shaped like z1).
WDYT? If this isn't right, maybe you could unpack your notation a bit more?
Thank you so much @mattjj ! I have to say that I'm not fully understanding your code yet, so I'll give myself a bit of time to digest it. You are right that in my notation, v
is an arbitrary vector with the same shape as z1
. In fact, I'm trying to implement an algorithm which requires explicitly calculating the jvp
for every adjacent pair of layers z_out
and z_in
in a neural network (different from the usual backpropagation, which only requires the jvp
of f
with respect to the arguments). I'm wondering if there is a general (i.e. less "ad hoc" way) to do this. The network model I have in mind is the simple MLP under the section "Example: ML model parameters" in this link. Hopefully this is a bit clearer than my previous post!
JW
Note that usual backpropagation is a vjp
, not a jvp
. Mathematically you get the derivative d(loss)/d(parameters)
either way, but for typical neural network problems the vjp
is much more efficient.
Thank you @patrick-kidger ! Yeah, I realized that if you want to propagate backwards, then vjp
is preferred over jvp
. Either way, I'm trying to find a way to calculate the quantity for any two adjacent layers in the neural network.
Well, whether you want to go forward or backward may just come down to whether for your algorithm you need to evaluate things of the form v -> v'J
or v -> J v
, where J
represents the Jacobian between some z_in
and z_out
. The former is naturally adapted to reverse-mode (i.e. it's by definition a VJP) and the latter to forward-mode (a JVP). We only have the intuition that VJPs are more efficient for neural network problems because we most often use autodiff to optimize a scalar-valued loss function, so we're naturally pulling back a covector from T*R
to T*R^n
to get the locally linearized level sets of the loss function in parameter space.
In other words, the algorithm you want to implement probably determines whether you want JVPs or VJPs. Or both! For example, if you're doing anything Gauss-Newton like, you probably want both.
If you could say more about the algorithm, we might be able to figure out more together.
Gauss-Newton like
Tangential nit that the least-squares subproblems in GN are usually better solved via e.g. a QR solver, than via the normal equations. :)
Sure if the problem is small enough for direct methods to be possible, but usually for neural net optimizations we have to rely on implicit access to these Jacobians, e.g. for a Krylov solve or other iterative scheme, or to build tractable approximations which then are amenable to direct methods, or to choose step sizes in a 1D subspace. As a concrete example of the latter, consider the technique described in Sec 6.4 of the K-FAC paper.
My point is just: I'm not making any claims about how to solve any linear systems, but rather observing that both JVPs and VJPs may be useful, for example in Gauss-Newton ish schemes like K-FAC.
Thanks for your detailed explanation @mattjj ! I think it would be great if we could discuss the algorithm in more detail. What I have in mind is an algorithm to optimize the output, which involves calculating the quantity in question. As this would lead to a more problem-specific discussion, is it okay if we communicate via email or private messaging? Really appreciate your help!
What solution did you end up settling on?
I'm interested in doing this in order to gain more visibility into my model. I find it useful to occasionally measure gradient magnitudes, activations, etc., during my training process, in order to confirm that everything is healthy. It's easy enough to measure dl/dw
for any given weight (where l
is the loss, computed from y
). But I also want to track things like dz1/dw
, dl/dz1
, dz2/dz1
. This is quite difficult currently, especially since these terms might be scattered across multiple functions at different levels of abstraction. Is there a good solution?
(Since this monitoring is only occasional, and not part of the core update loop, it doesn't really matter to me how fast it is; I just want to be able to see the answer, regardless of whether it is a jvp or vjp or something else under the hood.)
Hi @jbuckman,
Thanks for your question! I think a most straightforward way that I can think of is to write the "big" function (L(w)) as a collection of predefined small functions relying on the intermediate variables (z1(w), z2(z1), L(z2)); this way, you can let the big function return the Jacobians (or jvp
s) of each "small" function. Would that work in your situation?
I do not believe so. Let's say I want to compute the Jacobean of an intermediate variable wrt the output. Here is my current understanding of how I would go about it:
# These are my small functions
def f1(...): ...
def f2(...): ...
def f3(...): ...
# This is my big function
def g(x0):
x1 = f1(x0)
x2 = f2(x1)
x3 = f3(x2)
return x3
# This is my concrete input
concrete_x0 = jnp.arange([1., 2., 4.])
# I can straightforwardly compute dx3/dx0
dx3_dx0 = jacfwd(g)(concrete_x0)
# I can compute dx{i+1}/dxi for any i, but it is awkward, because
# in order to construct the input I need to partially duplicate the logic
# from g.
dx1_dx0 = jacfwd(f1)(concrete_x0)
dx2_dx1 = jacfwd(f2)(f1(concrete_x0))
dx3_dx2 = jacfwd(f3)(f2(f1(concrete_x0)))
# I can compute dx3/dxi for any i, but it is even more awkward, because
# now I need to duplicate the logic from g even further in order to construct
# the portion of the function that "remains".
dx3_dx0 = jacfwd(g)(concrete_x0)
dx3_dx1 = jacfwd(lambda x1: f3(f2(x1)))(f1(concrete_x0))
dx3_dx2 = jacfwd(f3)(f2(f1(concrete_x0)))
You can see how, in order to compute e.g. dx3_dx1
, I essentially need to entirely re-implement g
.
This is true for each intermediate-variable Jacobean I want to look at. This example has just 4 sets of activations x0 to x3, but a real NN might have hundreds.
In practice, this pattern quickly becomes unmaintainable. If g
is complex, implementing each Jacobeans like this is a ton of work, and every time I want to make a change to g
I need to go around and manually update every single Jacobean computation to match. It is a big opportunity for bugs -- easy to accidentally have a small difference between the true g
and the part-of-g
-after-x1
that I needed to construct to apply jacrev.
A few other things that would further add complexity:
f1, f2, f3
, but in practice, each of these would themselves be a composition of functions (e.g. a Transformer is the composition of many TransformerBlock layers, a TransformerBlock is a composition of an Attention function and an MLP function, an MLP function is the composition of many Dense layers...). If f1
itself has intermediate activations, say z1
, I might also want to take dx3/dz1
, which would require taking jacrev
of a function constructed from the composition of the "second half" of f1 together with f2 and f3.But ultimately, I feel that it should be fairly easy to compute these partials dx3/dx0, dx3/dx1, dx3/dx2
that I am after: after all, a function like jacrev is already computing them the moment I invoke jacrev(g)
, as intermediate steps.
Is there a cleaner way to do this?
Hi JAX developers,
I'm wondering if it's possible to calculate
jvp
for intermediate variables defined in a function. Here's an example:The quantity I would like to calculate is the jvp (∂z2/∂z1)v. I guess that this may be available as some intermediate step in calculating the
jvp
off()
, but am not sure how to retrieve it.JW