jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Caching data best practice? #5344

Open tawe141 opened 3 years ago

tawe141 commented 3 years ago

I'm currently making my own implementation of a gaussian process, and I'd like to cache certain results like the correlation matrix (K) and the inverse of the correlation matrix multiplied by the training output (K^-1 @ y). As I understand it, JAX JIT requires everything to be nicely packaged into one function without any operations done on variables outside the function scope. This is likely what's causing my error; example code is pasted below (let me know if you want something runnable).

class GP:
    def __init__(self, alpha: float = 1e-8, kernel: Kernel = RBF(), debug: bool = True):
        """
        Constructor for base Gaussian process class.

        :param alpha: (float) nugget parameter, increase if matrix inverses fail due to ill-conditioning. Default: 1e-8
        :param kernel: (Kernel)
        :param debug: (bool) if False, JIT compiles `self.predict_mu` and `self.predict_var`
        TODO: figure out why you can't JIT `self.fit`
        """
        self.alpha = alpha
        self.kernel = kernel
        self.x = None
        self.y = None
        self.K = None
        self.U = None
        self.predict_mu = vmap(self.predict_)
        self.predict_var = vmap(self.predict_var_)

        if debug is False:
            self.fit = jit(self.fit, static_argnums=0)  # JITing this results in a tracer error
            self.predict_mu = jit(self.predict_mu, static_argnums=0)
            self.predict_var = jit(self.predict_var, static_argnums=0)

    def fit(self, x: np.ndarray, y: np.ndarray, *args):
        """
        Pre-calculates the correlation matrix K and K^-1 @ y (denoted here as U)

        :param x: array shape (Nxd), independent variables
        :param y: array shape (Nx1), dependent variable
        :return:
        """
        self.x = x
        self.y = y
        self.K = self.kernel(x, x)
        self.K = self.K + self.alpha * np.eye(len(self.K))
        self.U = solve(self.K, self.y)

    def predict_(self, x: np.ndarray) -> float:
        """
        Returns prediction of a *single sample* `x`. See `self.predict_mu` for the corresponding vectorized function

        :param x: vector of shape (d,), where d is the dimensionality of the input
        :return: float
        """
        return self.kernel(x, self.x) @ self.U

    def predict_var_(self, x: np.ndarray) -> float:
        """
        Returns covariance of a *single sample* `x`. See `self.predict_var` for the corresponding vectorized function

        :param x: vector of shape (d,), where d is the dimensionality of the input
        :return: float
        """
        K_train = self.kernel(x, self.x)
        return self.kernel(x, x) - K_train.T @ solve(self.K, K_train)

    def predict(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Returns prediction and variance of the GP given inputs `x`. Simple wrapper calling `self.predict_mu` and `self.predict_var`

        :param x: vector of shape(N, d) inputs
        :return:
        """
        return self.predict_mu(x), self.predict_var(x)

Setting debug=True (in other words, not JITing self.fit or self.predict_* functions), this works as expected. JITing self.fit, however, results in the following error:

Traceback (most recent call last):
  File "/Users/erictaw/gpgrad/mvp.py", line 170, in <module>
    mu, var = gp.predict(x)
  File "/Users/erictaw/gpgrad/mvp.py", line 131, in predict
    return self.predict_mu(x), self.predict_var(x)
  File "/Users/erictaw/gpgrad/mvp.py", line 112, in predict_
    return self.kernel(x, self.x) @ self.U
  File "/Users/erictaw/miniconda3/envs/gpgrad/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4960, in deferring_binary_op
    return binary_op(self, other)
  File "/Users/erictaw/miniconda3/envs/gpgrad/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3423, in matmul
    out = lax.dot_general(
  File "/Users/erictaw/miniconda3/envs/gpgrad/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 665, in dot_general
    return dot_general_p.bind(lhs, rhs,
jax._src.traceback_util.FilteredStackTrace: jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line /Users/erictaw/gpgrad/mvp.py:103 (fit).
...

My question is: what's the best way to cache something like self.U so I don't have to potentially compute it multiple times?

zhangqiaorjc commented 3 years ago

https://jax.readthedocs.io/en/latest/jax.html#jax.jit

jax.jit requires function to be pure, but your fit and predict modifies self, and is thus not a pure function. That should explain the error message "The functions being transformed should not save traced values to global state"

On first look, you can change your fit function to return the values you want to save and cache them however you want. Just don't update states inside a function you want to jit.

jonas-eschle commented 2 years ago

While it is clear why the error is raised, I would still be curious about

My question is: what's the best way to cache something like self.U so I don't have to potentially compute it multiple times?

I assume that one would need to construct something with conditionals? Or what would you suggest?

jakevdp commented 2 years ago

Probably the best approach would be to pass the cached values around explicitly, e.g something like this (leaving object-oriented stuff aside for simplicity):

def fit(data, cache=None):
  if cache is None:
    cache = _compute_cached_quantites(data)
  ...
  return outputs, cache

Then you can do something like this:

results1, cache = fit(data)
results2, cache = fit(data, cache)

Or more directly:

cache = _compute_cached_quantities(data)
results1, _ = fit(data, cache)
results2, _ = fit(data, cache)
jonas-eschle commented 2 years ago

Wouldn't that

and therefore need quite a lot of additional thoughts?

jakevdp commented 2 years ago

The second version would not lead to retracing (which is why I mentioned it).

The gradient is fine as long as you're explicit about which arguments you want the gradient to be taken with respect to (via closures or partials).

Agreed there are other considerations - this was just a simple skeletal example of the general approach, where you make the procedure pure by explicitly passing any global state that the functions need. An example of this can be seen in jax.random functions, where the random state (often treated as a mutable global in Python) is instead passed explicitly.

jonas-eschle commented 2 years ago

The second version would not lead to retracing (which is why I mentioned it).

Indeed, it just requires more boilerplate, a pre-jit-hook basically, but that is probably the tradeoff

The gradient is fine as long as you're explicit about which arguments you want the gradient to be taken with respect to (via closures or partials).

It depends. The cached value can still have a gradient, we may just don't minimize (now) with respect to it. But then we may just need to recompute anyway, keeping track of this is probably simply a necessity.

Agreed there are other considerations - this was just a simple skeletal example of the general approach, where you make the procedure pure by explicitly passing any global state that the functions need. An example of this can be seen in jax.random functions, where the random state (often treated as a mutable global in Python) is instead passed explicitly.

Sure, the explicit passing makes it all nice, functional. My main worries is how to deal with the jit (tracing) of functions.

But solution two and some smart bookkeeping together with a hook should do the trick indeed! Thanks!

Gattocrucco commented 2 years ago

I encountered this old issue now by chance. I had practically the same problem (object-oriented caching of the matrix decomposition for Gaussian process regression) and I patched something which mostly works, though there are still some rough corners. Code, in case it might be useful: https://github.com/Gattocrucco/lsqfitgp/blob/a40240388dc8ea436f31fd8aabda7ec9417aeb3a/lsqfitgp/_linalg/_decomp.py#L193

jonas-eschle commented 2 years ago

Thanks a lot for this! It looks quite complicated to me, especially messing with the JAX internals? Or Am I misunderstanding that? Isn't there an easier way, what's your experience with this?

Gattocrucco commented 2 years ago

About the internals: in the last commits I removed a bit of the messing (it's a work in progress). By "messing" I mean taking jax.stop_gradient in the custom JVPs, now that's gone and they are honest custom JVPs.

Anyway, the code maybe seems a bit complicated at first sight because there are various things implemented, but the core idea is simple. JAX actually does not care that all functions it traverses be pure, it cares only about the functions it is told to do something about. So if I do something like

class Decomp:
    ...

@jax.jit
@jax.grad
def gp_marginal_likelihood(params, residuals):
    covmat = ... # some function of `params`
    decomp = Decomp(covmat) # computes and caches decomposition of `covmat`
    logdet = decomp.logdet() # log(det(M))
    chi2 = decomp.quad(residuals) # r.T @ M^-1 @ r
    return -1/2 * logdet - 1/2 * chi2

from the point of view of JAX it is equivalent to something like

@jax.jit
@jax.grad
def gp_marginal_likelihood(params, residuals):
    covmat = ... # some function of `params`
    L = cholesky(covmat)
    logdet = 2 * sum(log(diag(L)))
    lr = solve_triangular(L, residuals)
    chi2 = lr.T @ lr
    return -1/2 * logdet - 1/2 * chi2

which is traceable.

The complications come from the fact that I also want to define custom derivatives for the methods of Decomp, logdet and quad. I do this by:

1) Applying stop_gradient to the matrix that I feed into the decomposition routine in Decomp.__init__, and saving the unstopped matrix as an object attribute 2) Wrapping each method with a function where the matrix appears explicitly as argument, such that I can define custom derivatives w.r.t. the matrix 3) Wrapping again such function into a method that passes to the function the matrix saved in self.

Mind that this is not working perfectly, there are some problems with reverse derivatives (which mostly work, still). This code was originally written for autograd, and instead of stop_gradient I used to hard-strip all the autograd tracers. I stopped doing that with JAX out of fear of breaking vmap, jit, etc., but I suspect using stop_gradient is not equivalent because JAX will see the products of the decomposition as involved in the derivatives somehow when going in reverse and do a mess because they popped out "impurely" in the innermost method.

Gattocrucco commented 2 years ago

Now that I have written it down, it comes to my mind that maybe the remaining problems would be solved as well by mock-passing all the decomposition products as variadic arguments to the intermediate wrapper. Thanks for making me think about it!

Gattocrucco commented 2 years ago

Another solution: declare the class to be a jax pytree, such that self can pass across jit and derivative boundaries:

import jax
import functools
import jax.test_util

@jax.tree_util.register_pytree_node_class
class Decomp:

    def tree_flatten(self):
        return (self.l, self.a), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        self = super().__new__(cls)
        self.l, self.a = children
        return self

    # use __new__ instead of __init__ because __init__ does not return anything
    @functools.partial(jax.jit, static_argnums=0)
    def __new__(cls, a):
        self = super().__new__(cls)
        # stops a's gradient now since we are going to define custom gradients
        # w.r.t. a anyway
        self.l = jax.scipy.linalg.cholesky(jax.lax.stop_gradient(a), lower=True)
        self.a = a
        return self

    def solve(self, b):
        return self._solve(self, b)

    @jax.custom_jvp
    @staticmethod # staticmethod otherwise custom_jvp won't see self
    @jax.jit
    def _solve(self, b):
        lb = jax.scipy.linalg.solve_triangular(self.l, b, lower=True)
        llb = jax.scipy.linalg.solve_triangular(self.l.T, lb, lower=False)
        return llb

    @_solve.defjvp
    @jax.jit
    def _solve_jvp(primals, tangents):
        self, b = primals
        self_dot, b_dot = tangents
        a_dot = self_dot.a
        primal = self.solve(b)
        tangent_a = -self.solve(a_dot) @ primal # -a^-1 ∂a a^-1 b
        tangent_b = self.solve(b_dot)
        return primal, tangent_a + tangent_b

key = jax.random.PRNGKey(0)
m = jax.random.normal(key, (5, 5))
_, key = jax.random.split(key)
b = jax.random.normal(key, (5,))
a = m @ m.T
dec = Decomp(a)
x = dec.solve(b)
print(jax.numpy.linalg.norm(a @ x - b, 2)) # -> 3.095427e-06

# rev mode does not work, dunno why
jax.test_util.check_grads(lambda b: dec.solve(b), (b,), order=2, modes='fwd')
jax.test_util.check_grads(lambda m: Decomp(m @ m.T).solve(b), (m,), order=2, modes='fwd')