google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.3k stars 2.68k forks source link

Efficient ways of implementing the Divergence operator (avoid computing the jacobian) #3022

Open franciscovargas opened 4 years ago

franciscovargas commented 4 years ago

Hi , first not very sure that the issues section is the best place for this so I apologise before hand. Theres not much of a forum community online for JAX (no mailing list ?) yet, and this can almost be seen as a feature request :

I am trying to implment the div operator (trace of the jacobian). For a given parametric map (vector field) of the form:

$$ f_{\theta} : \mathbb{R}^n \rightarrow \mathbb{R}^{n} $$

I want to apply the div operator to it which is :

$$ \nabla \cdot f{\theta} = \sum{i=1}^n \partial_{xi} f\theta $$

Note the derivatives are with respect to the inputs of the function $\partial_{xi} f{\theta}(x_1, ... x_n)$ rather than theta (theta is constant).

I spent some time reading the documentation and thinking about the problem . Using 'grad' and 'vmap' does not seem feasible since by definition this function requires an input of size $n$ and produces an output of the same size, theres no way of making it a scalar nicely, you would have to create n functions whic each return a index $i$ of $f$, which is quite tedious.

As a jacobian vector product I cant see how theres a vector v that would produce the trace. of the jacobian (a single pullback wont give you the trace).

So the best I could come up with was:

import jax
import jax.numpy as np 

X = np.arange(500).reshape(50, 10) 

theta = np.eye(10,10) 

def f(theta, X):
    out = X.dot(theta)
    return out

def divergence(f, theta_, X_):

    def my_div(f_):
        jac = jax.jacrev(f_, 1)
        return lambda t, x_: np.trace(jac(t, x_))

    div  = ((jax.vmap(my_div(f), in_axes=(None, 0)) (theta_, X_ )))
    return div

div = divergence(f, theta, X)
print(div.shape)  # 50, a scalar div per datapoint(row) in X

This is fairly inneficient as it $n(n+1)/2$ operations when it could be doing just $n$ roughly.

Any thoughts ? (new to jax so maybe I am missing an obvious way of using jvp's $

shoyer commented 4 years ago

I don’t think there exists any shortcut for computing the diagonal of a Jacobian for a general functions in a single autodiff pass. You might find this paper on designing neural nets that satisfy this property interesting: https://papers.nips.cc/paper/9187-neural-networks-with-cheap-differential-operators.pdf

It’s not an asymptotic improvement, but one minor improvement to your approach would be to use forward mode (JVP) over reverse mode, because it requires much less memory.

franciscovargas commented 4 years ago

@shoyer why would forward JVP help in the case of a square jacobian ? (just basing myself on the docs where it state that either has an advantage over the other depending on the jacobian being tall or fat)

My other query is why is there no shortcut ? is it unnatural (unnatural as in reverse/forward mode diffs dont lend themselves to an element-wise application to a function ?). This might be a very silly analogy but if I were to sit down and comptue derivatives manually it would be much easier for me to compute the div (diagonal of jac) than to compute the full jacobian.

I understand autodiff frameworks have JVP as their back-bones so everything they do relies on computing the jacobian but it carrying out elementwise derivatives of a function should be something that can be adapted to i.e. for reverse mode start with $\partial_{xi} f$ as the root and then the chain rule can be applied going backwards from there ( i.e. $\partial{xi} a \cdot \partial{a} f(a)$ ) .

Divs seem to pop up a lot in the natural sciences and also in stochastic processes so it could be a useful feature to have.

Thanks for the reference btw ! its very useful.

shoyer commented 4 years ago

why would forward JVP help in the case of a square jacobian ? (just basing myself on the docs where it state that either has an advantage over the other depending on the jacobian being tall or fat)

All things being equal, forward mode is simpler than reverse mode. Reverse mode requires building up a computation graph and storing all intermediate outputs in memory. For square Jacobians, it's the same number of floating point operations but reducing memory access makes forward mode cheaper.

I understand autodiff frameworks have JVP as their back-bones so everything they do relies on computing the jacobian but it carrying out elementwise derivatives of a function should be something that can be adapted to i.e. for reverse mode start with $\partial_{xi} f$ as the root and then the chain rule can be applied going backwards from there ( i.e. $\partial{xi} a \cdot \partial{a} f(a)$ ) .

I would not be surprised if XLA can already performs exactly these sorts of operations under a jit. For example, f(x)[i] and f(x[i]) for an expensive element-wise function f() run in exactly the same time.

That said, if you can come up with an efficient way to diagonal of the Jacobian (unless this technique or any others), we could consider including it in JAX, or maybe it would make sense in the auto-diff cookbook in the docs.

act65 commented 2 years ago

I also want a div op.

It seems the authors of local_kinetic_energy in here implement a div operator, kinda (to compute the laplacian).

Also. On the topic of shortcuts for computing jacobians. There is this?

cisprague commented 2 months ago

Hi! I noticed this thread is one of the first Google search results for "jax divergence" but found limited explicit implementations directly addressing this. I'd like to share my approach for implementing the divergence operator in JAX, hoping it might assist others looking for a solution or improve upon existing ones scattered across various repositories.

Here’s the implementation, which includes support for different modes. @jakevdp @shoyer, are there some ways it could be improved?

import jax
import jax.numpy as jnp
from typing import Callable

def divergence(f: Callable, n: int, gaussian: bool):

    """
    Compute the divergence of a vector field using JAX.

    Args:
    f : Callable
        The vector field function R^n -> R^n.
    n : int
        Mode of divergence computation. -1 for exact trace, 0 for efficient exact, 
        and positive integers for stochastic estimation using Hutchinson's trace estimator.
    gaussian : bool
        Flag to use Gaussian (True) or Rademacher (False) vectors for stochastic estimation.

    Returns:
    Callable
        A function that computes the divergence at a point.
    """

    # Exact calculation using the trace of the Jacobian
    if n == -1:
        return jax.jit(lambda x, key: jnp.trace(jax.jacobian(f)(x)))

    # Efficient exact calculation using gradients
    if n == 0:
        def div(x, key):
            fi = lambda i, *y: f(jnp.stack(y))[i]
            dfidxi = lambda i, y: jax.grad(fi, argnums=i+1)(i, *y)
            return sum(dfidxi(i, x) for i in range(x.shape[0]))
            # Not sure why vmap doesn't work here.
            # return jax.vmap(dfidxi, in_axes=(0, None))(jnp.arange(x.shape[0]), x)
        return jax.jit(div)

    # Hutchinson's trace estimator for stochastic estimation
    if n > 0:
        def div(x, key):
            def vJv(key):
                _, vjp = jax.vjp(f, x)
                v = jax.random.normal(key, x.shape, dtype=x.dtype) if gaussian else jax.random.rademacher(key, x.shape, dtype=x.dtype)
                return jnp.dot(vjp(v)[0], v)
            return jax.vmap(vJv)(jax.random.split(key, n)).mean()
        return jax.jit(div)
>>> f = lambda x: jnp.dot(jnp.diag(jnp.array([1.0, 2.0, 3.0])), x)
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (3,))
>>> divergence(f, -1, True)(x, key)
Array(6., dtype=float32)
>>> divergence(f, 0, True)(x, key)
Array(6., dtype=float32)
>>> divergence(f, 1000, True)(x, key)
Array(6.1945624, dtype=float32)
>>> divergence(f, 1000, False)(x, key)
Array(6.0000005, dtype=float32)