jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

grad of grad fails when ravelling is involved #9412

Closed gianlucadetommaso closed 2 years ago

gianlucadetommaso commented 2 years ago

Consider the following toy function:

def fun(d):
    a, b = d['a'], d['b']
    return a + 2 * b

Suppose that for compatibility reasons we are interested in working with ravelled arguments. We then set some values for d and define an equivalent function that takes as input ravelled arguments:

d = dict(a=1., b=2.)
rav_d, unravel = ravel_pytree(d)
rav_fun = lambda r: fun(unravel(r))

Let's check function value and gradient:

rav_fun(rav_d), grad(rav_fun)(rav_d)
>>> (DeviceArray(5., dtype=float32), DeviceArray([1., 2.], dtype=float32))

All good so far. If we now want to compute an element-wise gradient of the gradient function (i.e. the diagonal of the Hessian), generally we could do as follows:

vmap(grad(grad(rav_fun)))(rav_d)

If there was no ravelling involved, this would work just fine. However, in this case we get the error >>> IndexError: tuple index out of range

Why is this the case? For functions without ravelling it does appear to work.

To reproduce the previous code, import the following:

from jax import vmap, grad
from jax.flatten_util import ravel_pytree
jakevdp commented 2 years ago

This is expected: the issue is that grad only works for scalar-valued functions, but grad(rav_fun) is a vector-valued function.

If you want autodiff of a vector-valued function f, you could use vmap(grad(f)) or jacobian(f), depending on what flavor of vector gradient you're interested in.

gianlucadetommaso commented 2 years ago

Thanks for the reply. I am try to obtain second-order derivatives, specifically the diagonal of the Hessian of some real-valued function, without computing nor storing the full Hessian matrix.

Any simple way to do this?

jakevdp commented 2 years ago

Funny, I just answered a very similar question on StackOverflow this morning: https://stackoverflow.com/q/70956578

Do you mean the diagonal of the Hessian? The Jacobian is a first-order gradient.

jakevdp commented 2 years ago

You can compute the diagonal of the Hessian using the first flavor of vector gradient I mentioned: vmap of grad. For example:

import jax.numpy as jnp
from jax import vmap, hessian, grad

def f(x):
  return jnp.dot(x, jnp.sin(x))

x = jnp.arange(3.0)

print(jnp.diagonal(hessian(f)(x)))
# [ 2.         0.2391336 -2.6508884]
print(vmap(grad(grad(f)))(x))
# [ 2.         0.2391336 -2.6508884]
gianlucadetommaso commented 2 years ago

vmap(grad(grad(f)) is exactly what I was doing in my original example, right? :)

This indeed does work for normal functions, but when I have unravel in between something goes wrong. Let me tweak a little your example to show this.

def f(c):
    a, b = c.values()
    return np.dot(a, np.sin(b))

c = dict(a=1., b=2.)
rav_c, unravel = ravel_pytree(c)
rav_f = lambda r: f(unravel(r))
vmap(grad(grad(rav_f)))(rav_c)

This gives the following error:

IndexError: tuple index out of range

I don't see why it does not work here.

jakevdp commented 2 years ago

Ah, thanks for the clarification.

I think it does not work here exactly because you've ravelled the pytree, and so your function no longer returns a scalar (or a pytree of scalars), so grad is not a valid transform.

Can you say more about what your expected output is? It's not clear to me from your code what exactly you are trying to compute (I think my mental interpreter raises the same error that JAX does 😁 )

gianlucadetommaso commented 2 years ago

Given a function f:\mathbb{R}^n \to \mathbb{R}, I'm trying to compute \text{diag}(\nabla^2 f), that is the diagonal of its Hessian matrix.

However, I define f(x) = g(unravel(x)), where g:\mathcal{X} \to \mathbb{R} is some original function that takes as input any pytree.

Since x\in\mathbb{R}^n and f(x)\in\mathbb{R}, I would expect \text{diag}(\nabla^2 f)\in\mathbb{R}^n. In principle, I don't see why this should not work?

I'm going through all of this ravelling business because when I implement algorithms that involve algebraic operations, in principle these should know nothing about pytrees; they should take arrays in inputs and spit out arrays in output.

Hence, even if I have available a function g which takes a pytree as input (e.g. a neural network, where inputs are the model parameters), I need to transform this into a function f in order to use it in combination with these algebraic algorithms.

I find the ravelling/unravelling actually the main difficulty using JAX. Perhaps you have indications here?

jakevdp commented 2 years ago

So it sounds like what you're looking for is this:

jnp.diagonal(jacobian(grad(rav_f))(rav_c))

but without computing the full Jacobian?

If so, I'm not aware of any general way to do this. But given a particular function, you could probably find a way to construct the calculation you need (see the stackoverflow question linked above for one example).

jakevdp commented 2 years ago

My first inclination is that it may be easier to do if you avoid the raveling, since the main problem here is that you're trying to construct gradients with respect to elements of arrays (which JAX does not support, except by generating full Jacobian/Hessian matrices and then extracting the desired portion). But I admit I don't understand your initial motivation for ravelling well enough to know if that's viable.

gianlucadetommaso commented 2 years ago

So it sounds like what you're looking for is this:

jnp.diagonal(jacobian(grad(rav_f))(rav_c))

but without computing the full Jacobian?

Exactly, I would like to avoid full Jacobian and Hessian computation. If no ravelling is involved, you showed above that something like vmap(grad(grad(f)))(x) does the job. I would expect this to work also when ravelling is involved.

About the motivation, I'm not sure if I'm touching familiar ground here, but you could think of MCMC, or variational inference, or Laplace approximation, and so on. All of these algorithms involve algebraic operations over model parameters. If for model we take a neural network, its parameters are expressed as dictionaries. If I want, say, MCMC to work in this case, I have two choices: (a) I tree-map all algebraic operations in MCMC, so that no ravelling is needed; (b) I leave the MCMC algorithm as it is, but pass functions that take ravelled parameters as inputs.

While (a) is viable, (b) is less intrusive. In general, I would prefer taking a (b) approach.

jakevdp commented 2 years ago

vmap(grad(grad(f))) does the job for functions that map a scalar to a scalar. Your function maps a vector to a scalar, so grad(f) maps one vector space to another vector space, in which case neither grad nor vmap apply (grad because it's only for scalar-valued functions, vmap because it can only map vector spaces onto themselves).

I think taking the diagonal of the jacobian will probably be your best bet.

jakevdp commented 2 years ago

(xref similar issues with similar outcomes: #1563 #3022 #3801)

gianlucadetommaso commented 2 years ago

I see. Thanks a lot!

hawkinsp commented 2 years ago

I think there's no further action for us to take here; closing. Feel free to reopen if there's something further for us to look into!