Open tawe141 opened 3 years ago
https://jax.readthedocs.io/en/latest/jax.html#jax.jit
jax.jit requires function to be pure, but your fit and predict modifies self, and is thus not a pure function. That should explain the error message "The functions being transformed should not save traced values to global state"
On first look, you can change your fit function to return the values you want to save and cache them however you want. Just don't update states inside a function you want to jit.
While it is clear why the error is raised, I would still be curious about
My question is: what's the best way to cache something like self.U so I don't have to potentially compute it multiple times?
I assume that one would need to construct something with conditionals? Or what would you suggest?
Probably the best approach would be to pass the cached values around explicitly, e.g something like this (leaving object-oriented stuff aside for simplicity):
def fit(data, cache=None):
if cache is None:
cache = _compute_cached_quantites(data)
...
return outputs, cache
Then you can do something like this:
results1, cache = fit(data)
results2, cache = fit(data, cache)
Or more directly:
cache = _compute_cached_quantities(data)
results1, _ = fit(data, cache)
results2, _ = fit(data, cache)
Wouldn't that
and therefore need quite a lot of additional thoughts?
The second version would not lead to retracing (which is why I mentioned it).
The gradient is fine as long as you're explicit about which arguments you want the gradient to be taken with respect to (via closures or partials).
Agreed there are other considerations - this was just a simple skeletal example of the general approach, where you make the procedure pure by explicitly passing any global state that the functions need. An example of this can be seen in jax.random
functions, where the random state (often treated as a mutable global in Python) is instead passed explicitly.
The second version would not lead to retracing (which is why I mentioned it).
Indeed, it just requires more boilerplate, a pre-jit-hook basically, but that is probably the tradeoff
The gradient is fine as long as you're explicit about which arguments you want the gradient to be taken with respect to (via closures or partials).
It depends. The cached value can still have a gradient, we may just don't minimize (now) with respect to it. But then we may just need to recompute anyway, keeping track of this is probably simply a necessity.
Agreed there are other considerations - this was just a simple skeletal example of the general approach, where you make the procedure pure by explicitly passing any global state that the functions need. An example of this can be seen in
jax.random
functions, where the random state (often treated as a mutable global in Python) is instead passed explicitly.
Sure, the explicit passing makes it all nice, functional. My main worries is how to deal with the jit (tracing) of functions.
But solution two and some smart bookkeeping together with a hook should do the trick indeed! Thanks!
I encountered this old issue now by chance. I had practically the same problem (object-oriented caching of the matrix decomposition for Gaussian process regression) and I patched something which mostly works, though there are still some rough corners. Code, in case it might be useful: https://github.com/Gattocrucco/lsqfitgp/blob/a40240388dc8ea436f31fd8aabda7ec9417aeb3a/lsqfitgp/_linalg/_decomp.py#L193
Thanks a lot for this! It looks quite complicated to me, especially messing with the JAX internals? Or Am I misunderstanding that? Isn't there an easier way, what's your experience with this?
About the internals: in the last commits I removed a bit of the messing (it's a work in progress). By "messing" I mean taking jax.stop_gradient
in the custom JVPs, now that's gone and they are honest custom JVPs.
Anyway, the code maybe seems a bit complicated at first sight because there are various things implemented, but the core idea is simple. JAX actually does not care that all functions it traverses be pure, it cares only about the functions it is told to do something about. So if I do something like
class Decomp:
...
@jax.jit
@jax.grad
def gp_marginal_likelihood(params, residuals):
covmat = ... # some function of `params`
decomp = Decomp(covmat) # computes and caches decomposition of `covmat`
logdet = decomp.logdet() # log(det(M))
chi2 = decomp.quad(residuals) # r.T @ M^-1 @ r
return -1/2 * logdet - 1/2 * chi2
from the point of view of JAX it is equivalent to something like
@jax.jit
@jax.grad
def gp_marginal_likelihood(params, residuals):
covmat = ... # some function of `params`
L = cholesky(covmat)
logdet = 2 * sum(log(diag(L)))
lr = solve_triangular(L, residuals)
chi2 = lr.T @ lr
return -1/2 * logdet - 1/2 * chi2
which is traceable.
The complications come from the fact that I also want to define custom derivatives for the methods of Decomp, logdet and quad. I do this by:
1) Applying stop_gradient
to the matrix that I feed into the decomposition routine in Decomp.__init__
, and saving the unstopped matrix as an object attribute
2) Wrapping each method with a function where the matrix appears explicitly as argument, such that I can define custom derivatives w.r.t. the matrix
3) Wrapping again such function into a method that passes to the function the matrix saved in self
.
Mind that this is not working perfectly, there are some problems with reverse derivatives (which mostly work, still). This code was originally written for autograd, and instead of stop_gradient
I used to hard-strip all the autograd tracers. I stopped doing that with JAX out of fear of breaking vmap
, jit
, etc., but I suspect using stop_gradient
is not equivalent because JAX will see the products of the decomposition as involved in the derivatives somehow when going in reverse and do a mess because they popped out "impurely" in the innermost method.
Now that I have written it down, it comes to my mind that maybe the remaining problems would be solved as well by mock-passing all the decomposition products as variadic arguments to the intermediate wrapper. Thanks for making me think about it!
Another solution: declare the class to be a jax pytree, such that self
can pass across jit and derivative boundaries:
import jax
import functools
import jax.test_util
@jax.tree_util.register_pytree_node_class
class Decomp:
def tree_flatten(self):
return (self.l, self.a), None
@classmethod
def tree_unflatten(cls, aux_data, children):
self = super().__new__(cls)
self.l, self.a = children
return self
# use __new__ instead of __init__ because __init__ does not return anything
@functools.partial(jax.jit, static_argnums=0)
def __new__(cls, a):
self = super().__new__(cls)
# stops a's gradient now since we are going to define custom gradients
# w.r.t. a anyway
self.l = jax.scipy.linalg.cholesky(jax.lax.stop_gradient(a), lower=True)
self.a = a
return self
def solve(self, b):
return self._solve(self, b)
@jax.custom_jvp
@staticmethod # staticmethod otherwise custom_jvp won't see self
@jax.jit
def _solve(self, b):
lb = jax.scipy.linalg.solve_triangular(self.l, b, lower=True)
llb = jax.scipy.linalg.solve_triangular(self.l.T, lb, lower=False)
return llb
@_solve.defjvp
@jax.jit
def _solve_jvp(primals, tangents):
self, b = primals
self_dot, b_dot = tangents
a_dot = self_dot.a
primal = self.solve(b)
tangent_a = -self.solve(a_dot) @ primal # -a^-1 ∂a a^-1 b
tangent_b = self.solve(b_dot)
return primal, tangent_a + tangent_b
key = jax.random.PRNGKey(0)
m = jax.random.normal(key, (5, 5))
_, key = jax.random.split(key)
b = jax.random.normal(key, (5,))
a = m @ m.T
dec = Decomp(a)
x = dec.solve(b)
print(jax.numpy.linalg.norm(a @ x - b, 2)) # -> 3.095427e-06
# rev mode does not work, dunno why
jax.test_util.check_grads(lambda b: dec.solve(b), (b,), order=2, modes='fwd')
jax.test_util.check_grads(lambda m: Decomp(m @ m.T).solve(b), (m,), order=2, modes='fwd')
I'm currently making my own implementation of a gaussian process, and I'd like to cache certain results like the correlation matrix (
K
) and the inverse of the correlation matrix multiplied by the training output (K^-1 @ y
). As I understand it, JAX JIT requires everything to be nicely packaged into one function without any operations done on variables outside the function scope. This is likely what's causing my error; example code is pasted below (let me know if you want something runnable).Setting
debug=True
(in other words, not JITingself.fit
orself.predict_*
functions), this works as expected. JITingself.fit
, however, results in the following error:My question is: what's the best way to cache something like
self.U
so I don't have to potentially compute it multiple times?