jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.22k stars 2.77k forks source link

Computing the diagonal elements of a Hessian #3801

Closed Marius1311 closed 4 years ago

Marius1311 commented 4 years ago

Hi all,

I would like to use Jax to compute the diagonal elelments of a Hessian matrix, i.e second partial derivatives \partial y^2 / \partial x_j^2. What's the most efficient way to do this? I know that for columns of the Hessian, I could use Hessian-vector products, but what can I do in this case to avoid computing full Hessians?

jekbradbury commented 4 years ago

Unfortunately, computing the diagonal elements of a Hessian is fundamentally just as expensive as computing the full Hessian (i.e., there aren't any tricks that JAX or any other library could use). See also https://github.com/google/jax/issues/564 and https://github.com/HIPS/autograd/issues/445 for more discussion.

jekbradbury commented 4 years ago

(Assuming there isn't any additional structure to your problem that would make it easier, such as having a component with a diagonal or otherwise sparse Jacobian)

clemisch commented 4 years ago

If my jacobian is diagonal, what is the most elegant way to get the diagonal of the hessian? In autograd it was elementwise_grad twice, but for some reason I can't wrap my head around vmap and jacobian for a general form in jax...

Marius1311 commented 4 years ago

Thanks @jekbradbury! My Jacobian does not have any sparsity structure I can exploit, unfortunetely.

jekbradbury commented 4 years ago

@clemisch if you have a function with a diagonal Jacobian, I believe that means it must act elementwise (or act elementwise up to an additive vector constant). For such a function, the equivalent of Autograd elementwise_grad is vmap(grad(f)) where f is the version that acts on a scalar.

mattjj commented 4 years ago

If my jacobian is diagonal, what is the most elegant way to get the diagonal of the hessian?

Whether you can use vmap may depend on whether your function is rank-polymorphic. Let's assume it's not, so that we have a function f modeling a function f : R^n -> R which we're promised has a diagonal Hessian.

Mathematically, if we can compute a Hessian-vector product (HVP), then we can reveal the diagonal entries of a diagonal Hessian by applying an HVP to an all-ones vector. Here's one way to do it:

from jax import jvp, grad, hessian
import jax.numpy as jnp
import numpy.random as npr

rng = npr.RandomState(0)
a = rng.randn(4)
x = rng.randn(4)

# function with diagonal Hessian that isn't rank-polymorphic
def f(x):
  assert x.ndim == 1
  return jnp.sum(jnp.tanh(a * x))

def hvp(f, x, v):
  return jvp(grad(f), (x,), (v,))[1]

print(hessian(f)(x))
print(jnp.diag(hessian(f)(x)))
print(hvp(f, x, jnp.ones_like(x)))
$ python issue3801.py
[[-0.03405464  0.          0.          0.        ]
 [ 0.          0.10269941  0.          0.        ]
 [ 0.          0.         -0.65265197  0.        ]
 [ 0.          0.          0.          2.9311912 ]]
[-0.03405464  0.10269941 -0.65265197  2.9311912 ]
[-0.03405464  0.10269941 -0.65265197  2.9311912 ]

That hvp implementation is in the autodiff cookbook.

clemisch commented 4 years ago

Thank you @jekbradbury and @mattjj for the explanation!

ibulu commented 4 years ago

computing hessian diagonal using hessian vector products: eqn.11 ?

Marius1311 commented 4 years ago

thanks @ibulu!

yaroslavvb commented 2 years ago

It is cheaper to compute diagonal than full Hessian for the common special case of network consisting of linear layers and pointwise non-linearities. For instance, if you have ReLU activation and softmax final layer with 10 classes, you can get exact Hessian diagonal using 10 backprop passes. If the network is fairly confident in one class, then you can do a single backprop for a good approximation. I use this trick in PyTorch autograd-lib package -- https://github.com/cybertronai/autograd-lib#example-2-hessian-quantities

mattjj commented 2 years ago

@yaroslavvb interesting! Do you mean the true Hessian or the Gauss-Newton matrix as an approximation to the Hessian?

yaroslavvb commented 2 years ago

When you use linear/ReLU layers, second derivatives are zero, so chain rule for Hessian reduces to f(g(x))=g' f'' g, hence Hessian diagonal = Gauss-Newton diagonal.

In that situation, computation graph of the Hessian looks like below. Unlabeled edges are indices to sum over.

Screen Shot 2021-10-26 at 11 31 29 AM

.

In a standard reverse AD system you then do two sets of backprop passes (from top to bottom) for the left branch + the right branch.

Or, use symmetry by noting that we can factor f using AA' factorization (Cholesky for instance) and the two passes are the same

Screen Shot 2021-10-26 at 11 45 27 AM

Finally, the diagonal is the hadamard product of two resulting vectors. For one output class it is the `diag(uu')=u\odot u' transformation

Screen Shot 2021-10-26 at 11 37 57 AM

For K output classes, you square your final backprop vectors pointwise and add them up over K backward passes (starting with one row of cholesky factorization for each pass) Screen Shot 2021-10-26 at 12 19 23 PM

Doing K passes is useful for checking correctness against old-fashioned autograd, but incidentally, I did not see a difference between doing K passes and doing 1 pass (fix random output class) in some MNIST experiments. I suppose Hess diagonal converges at a similar rate as the gradient, so doing 10x more passes gets diminishing returns in accuracy.

Now suppose you have more complicated non-linearities like sigmoid. Your Hessian computation graph looks more like this

Screen Shot 2021-10-26 at 12 20 30 PM

You can compute diagonals of 3 individual terms above (3=number of layers with nontrivial Hessian) and add them up, which may be cheaper for shallow networks. Diagonal cost will be the equivalent of K+w*d backprop passes where K is the number of classes, d is the number of pointwise nonlinearity layers in the network and w their width

tillahoffmann commented 4 months ago

I've also been trying to evaluate diagonal Hessians in an effort to find good diagonal preconditioners for optimizing deterministic loss functions using L-BFGS in jaxopt. The following code has worked well for me. If jit-compiled, it seems to function even if computing the full Hessian would exceed the available memory (jax optimization magic!).

import jax
from jax import numpy as jnp
from jax import random
import numpy as np

def normalize_path(path):
    """
    Normalize keys returned by `jax.tree_util.tree_flatten_with_path`.
    """
    return tuple(
        key.key if isinstance(key, (jax.tree_util.DictKey, jax.tree_util.SequenceKey)) 
        else key for key in path
    )

def tree_set(tree, path, value, strict=False):
    """
    Set the `value` at `path` in `tree`. Raise a `KeyError` if the path does not exist
    and `strict`.
    """
    path = normalize_path(path)
    items, treedef = jax.tree_util.tree_flatten_with_path(tree)
    values = []
    replaced = False
    for keys, original_value in items:
        if normalize_path(keys) == path:
            values.append(value)
            replaced = True
        else:
            values.append(original_value)
    if strict and not replaced:
        raise KeyError(path)
    return jax.tree.unflatten(treedef, values)

def tree_get(tree, path):
    """
    Get the value at `path` in `tree`. Raise an error if the path does not exist.
    """
    for key in normalize_path(path):
        tree = tree[key]
    return tree

def hessdiag(func):
    """
    Compute diagonal elements of the Hessian. Should be `jax.jit`ed to avoid large memory use.
    """
    def _hessdiag_wrapper(x, *args, **kwargs):

        def _leaf_hessdiag(path, value):
            """
            Compute the diagonal Hessian for `path` in `x`.
            """
            hessian = jax.jacfwd(jax.jacrev(lambda y: func(tree_set(x, path, y), *args, **kwargs)))(value)
            assert value.size ** 2 == hessian.size
            # Ravel, extract diagonal, reshape to the target shape.
            return jnp.diagonal(hessian.reshape((value.size, value.size))).reshape(value.shape)

        # Map the diagonal Hessian helper function over all leaves of the pytree.
        return jax.tree_util.tree_map_with_path(_leaf_hessdiag, x)

    return _hessdiag_wrapper

And two examples.

# Example: Compute scaled quadratic form.
key1, key2, key3, key4 = random.split(random.key(17), 4)

arg = {
    "params": {
        "loc": random.normal(key1, (30,)), 
        "scale": random.gamma(key2, 3, (50, 1)),
    }, 
    "x": random.normal(key3, (50, 30)),
}

def func(arg):
    z = (arg["x"] - arg["params"]["loc"]) / arg["params"]["scale"]
    return jnp.square(z).sum()

# Compute the diagonal and full Hessians.
hd = jax.jit(hessdiag(func))(arg)
h = jax.jacfwd(jax.jacrev(func))(arg)

# Compare the results.
for path, value in jax.tree_util.tree_flatten_with_path(hd)[0]:
    # Get the corresponding full Hessian and extract diagonal elements.
    reference = tree_get(h, path * 2)
    reference = jnp.diagonal(reference.reshape((value.size, value.size))).reshape(value.shape)
    np.testing.assert_allclose(value, reference, rtol=1e-6)

# Example 2: Large Hessian diagonal (this breaks if not `jit`-compiled).
jax.jit(hessdiag(lambda x: jnp.square(x).sum()))(random.normal(key4, (1_000_000,)))
mattjj commented 4 months ago

Thanks for sharing that!

tillahoffmann commented 1 month ago

I've been pondering about this a little more because out-of-memory errors arise for slightly more complex models. The following code snippets works for larger models because it iterates over elements of each leaf node (rather than taking the diagonal of the full Hessian of each leaf). It is however rather slow. Do you have an idea how one might be able to accelerate it, @mattj?

def elementwise_hessian(func):
    """
    Return a function to compute the elementwise Hessian.
    """

    def _inner(value):
        def _leaf_hessdiag(path, x):
            """
            Evaluate the Hessian of `func` with respect to the leaf at `path` with value
            `x`.
            """
            # These functions evaluates `func` and its gradient at `x` but replaces the 
            # value at `path` by the single argument `y`.
            func_at_path = lambda y: func(replace_at_path(value, {path: y}))
            grad_at_path = jax.grad(func_at_path)

            def _element_hessdiag(i):
                """
                Evaluate the Hessian of `func` with respect to the element `i` of the
                leaf at `path`.
                """
                # This function evaluates `grad[i]` at `x` but replaces the value at 
                # `path[i]` by its single argument `y`.
                i = jnp.unravel_index(i, x.shape)
                grad_at_index = lambda y: grad_at_path(x.at[i].set(y))[i]
                return jax.grad(grad_at_index)(x[i])

            indices = jnp.arange(x.size)
            # Can't use vmap because it allocates too much memory (cf. https://github.com/google/jax/issues/23358).
            return jax.lax.map(_element_hessdiag, indices).reshape(x.shape)

        return jax.tree_util.tree_map_with_path(_leaf_hessdiag, value)

    return _inner