Open jackd opened 3 years ago
Agreed this is strange... I'd expect to see an error in both cases! I'm not sure why the first one works, but you can fix both by marking the function argument as static:
@partial(jax.jit, static_argnums=(1,))
def f(x, fn):
return fn(fn(x))
@jakevdp thanks for the quick response. What about if the function contains both static and non-static args? My actual use case is sparse implementations - simplified version below. I suppose this might be easier once your sparse support PR goes through with primitives, but in the meantime is there any way to achieve this?
from typing import Callable
from functools import partial
import jax
import jax.numpy as jnp
@jax.jit
def power_iteration(A_fun: Callable, x0, iterations):
def cond_fun(state):
return state[2] < iterations
def body_fun(state):
value, vec, it = state
vec = A_fun(vec)
value = jnp.linalg.norm(vec)
vec /= value
return value, vec, it + 1
value = jnp.linalg.norm(x0)
vec = x0 / value
value, vec, _ = jax.lax.while_loop(cond_fun, body_fun, (value, vec, 0))
return value, vec
@jax.jit
def dense_matvec(A, x):
return A @ x
@partial(jax.jit, static_argnums=(3,))
def coo_matvec(data, row, col, nrows, v):
assert v.ndim == 1
dv = data * v[col]
return jnp.zeros(nrows, dtype=dv.dtype).at[row].add(dv)
n = 16
iters = 100
sparsity = 0.1
dtype = jnp.float32
key = jax.random.PRNGKey(0)
vals_key, mask_key, x0_key = jax.random.split(key, 3)
x0 = jax.random.normal(x0_key, shape=(n,), dtype=dtype)
a = jax.random.normal(vals_key, shape=(n, n), dtype=dtype)
mask = jax.random.uniform(mask_key, shape=(n, n), dtype=dtype) < sparsity
# strengthen diagonal so eigvals are more real
a = a + n * jnp.eye(n, dtype=dtype)
mask = jax.ops.index_update(
mask, jax.ops.index[jnp.arange(n), jnp.arange(n)], jnp.ones((n,), dtype=bool)
)
# get coo data
row, col = jnp.where(mask)
data = a[row, col]
# create masked a
a = jax.ops.index_update(jnp.zeros((n, n), dtype=dtype), jax.ops.index[row, col], data)
w, v = jax.jit(jnp.linalg.eig, backend="cpu")(a)
wi = jnp.argmax(jnp.abs(w))
true_value = w[wi]
true_vec = v[:, wi]
print("True:")
print(true_value)
print(true_vec)
# our dense implemenetation
dense_fun = jax.tree_util.Partial(dense_matvec, a)
dense_value, dense_vec = power_iteration(dense_fun, x0, iters)
print("Dense:")
print(dense_value)
print(dense_vec)
# our coo implementation
coo_fun = jax.tree_util.Partial(coo_matvec, data, row, col, n)
coo_value, coo_vec = power_iteration(coo_fun, x0, iters)
print(coo_value, coo_vec)
print("COO:")
print(coo_value)
print(coo_vec)
I think my fix works there as well. Change the first function definition to this:
@partial(jax.jit, static_argnums=(0,))
def power_iteration(A_fun: Callable, x0, iterations):
In other words, a callable passed to a jitted function should always be marked static in that jitted function. I'm surprised it would ever work otherwise.
Digging a bit, I think I see what's going on here. With jax.tree_util.Partial
, bound arguments become part of the pytree, and so they are traced in a jitted context: https://github.com/google/jax/blob/cd4138b83d75c20b287345c434017b09c99a9cc6/jax/tree_util.py#L301-L315
We could fix the issue you're seeing by leaving any static arguments out of the pytree produced by Partial
; that would make things like this more consistent.
Hmm... I would have thought you could set things up such that the outer function doesn't need to be recompiled so long as the passed function doesn't need to be - e.g. changing the data in a
shouldn't require a recompile in either the sparse or dense case, though it would if the shape / dtype / nnz (for the sparse case) changed.
Yeah - I spent some time looking into this. JAX's current method of tracking static arguments makes it difficult to do things at this level of granularity. If you pass a jax.tree_util.Partial
function as a non-static argument, all arguments to that function will be traced, and I don't think there's any way around this currently.
@jakevdp thanks for your investigation. I've got a work-around based on passing in a vector of size nrows
with arbitrary values (sized
in the code below). It's dirty, but it works for the moment...
@jax.jit
def coo_matvec(data, row, col, sized, v):
assert v.ndim == 1
dv = data * v[col]
return jnp.zeros(sized.size, dtype=dv.dtype).at[row].add(dv)
I'm looking to pass a function to a
jit
ed function. The argument is itself ajit
ed function with a static argument set usingjax.tree_util.Partial
. This is similar to 1443. Am I missing something obvious? Is this intended behaviour? If not, is there a work-around?Error: