jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.66k stars 2.83k forks source link

Function with `static_argnums` as `jit`ed function argument. #5609

Open jackd opened 3 years ago

jackd commented 3 years ago

I'm looking to pass a function to a jited function. The argument is itself a jited function with a static argument set using jax.tree_util.Partial. This is similar to 1443. Am I missing something obvious? Is this intended behaviour? If not, is there a work-around?

from functools import partial
import jax
import jax.numpy as jnp

@jax.jit
def f(x, fn):
    return fn(fn(x))

@partial(jax.jit, static_argnums=(0,))
def fn_with_static_arg(p, x):
    return jnp.tile(x, (p,))

@jax.jit
def fn_simple(p, x):
    return x ** p

x = jnp.arange(3)
p = 2
print(f(x, jax.tree_util.Partial(fn_simple, p)))  # works fine
print(f(x, jax.tree_util.Partial(fn_with_static_arg, p)))  # ValueError

Error:

  File ".../jax/api_util.py", line 101, in argnums_partial_except
    "Non-hashable static arguments are not supported, as this can lead "
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses.
Static argument (index 0) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function
fn_with_static_arg is non-hashable.
jakevdp commented 3 years ago

Agreed this is strange... I'd expect to see an error in both cases! I'm not sure why the first one works, but you can fix both by marking the function argument as static:

@partial(jax.jit, static_argnums=(1,))
def f(x, fn):
    return fn(fn(x))
jackd commented 3 years ago

@jakevdp thanks for the quick response. What about if the function contains both static and non-static args? My actual use case is sparse implementations - simplified version below. I suppose this might be easier once your sparse support PR goes through with primitives, but in the meantime is there any way to achieve this?

from typing import Callable
from functools import partial
import jax
import jax.numpy as jnp

@jax.jit
def power_iteration(A_fun: Callable, x0, iterations):
    def cond_fun(state):
        return state[2] < iterations

    def body_fun(state):
        value, vec, it = state
        vec = A_fun(vec)
        value = jnp.linalg.norm(vec)
        vec /= value
        return value, vec, it + 1

    value = jnp.linalg.norm(x0)
    vec = x0 / value
    value, vec, _ = jax.lax.while_loop(cond_fun, body_fun, (value, vec, 0))
    return value, vec

@jax.jit
def dense_matvec(A, x):
    return A @ x

@partial(jax.jit, static_argnums=(3,))
def coo_matvec(data, row, col, nrows, v):
    assert v.ndim == 1
    dv = data * v[col]
    return jnp.zeros(nrows, dtype=dv.dtype).at[row].add(dv)

n = 16
iters = 100
sparsity = 0.1
dtype = jnp.float32
key = jax.random.PRNGKey(0)
vals_key, mask_key, x0_key = jax.random.split(key, 3)
x0 = jax.random.normal(x0_key, shape=(n,), dtype=dtype)
a = jax.random.normal(vals_key, shape=(n, n), dtype=dtype)
mask = jax.random.uniform(mask_key, shape=(n, n), dtype=dtype) < sparsity
# strengthen diagonal so eigvals are more real
a = a + n * jnp.eye(n, dtype=dtype)
mask = jax.ops.index_update(
    mask, jax.ops.index[jnp.arange(n), jnp.arange(n)], jnp.ones((n,), dtype=bool)
)
# get coo data
row, col = jnp.where(mask)
data = a[row, col]

# create masked a
a = jax.ops.index_update(jnp.zeros((n, n), dtype=dtype), jax.ops.index[row, col], data)

w, v = jax.jit(jnp.linalg.eig, backend="cpu")(a)
wi = jnp.argmax(jnp.abs(w))
true_value = w[wi]
true_vec = v[:, wi]
print("True:")
print(true_value)
print(true_vec)

# our dense implemenetation
dense_fun = jax.tree_util.Partial(dense_matvec, a)
dense_value, dense_vec = power_iteration(dense_fun, x0, iters)
print("Dense:")
print(dense_value)
print(dense_vec)

# our coo implementation
coo_fun = jax.tree_util.Partial(coo_matvec, data, row, col, n)
coo_value, coo_vec = power_iteration(coo_fun, x0, iters)
print(coo_value, coo_vec)
print("COO:")
print(coo_value)
print(coo_vec)
jakevdp commented 3 years ago

I think my fix works there as well. Change the first function definition to this:

@partial(jax.jit, static_argnums=(0,))
def power_iteration(A_fun: Callable, x0, iterations):

In other words, a callable passed to a jitted function should always be marked static in that jitted function. I'm surprised it would ever work otherwise.

jakevdp commented 3 years ago

Digging a bit, I think I see what's going on here. With jax.tree_util.Partial, bound arguments become part of the pytree, and so they are traced in a jitted context: https://github.com/google/jax/blob/cd4138b83d75c20b287345c434017b09c99a9cc6/jax/tree_util.py#L301-L315

We could fix the issue you're seeing by leaving any static arguments out of the pytree produced by Partial; that would make things like this more consistent.

jackd commented 3 years ago

Hmm... I would have thought you could set things up such that the outer function doesn't need to be recompiled so long as the passed function doesn't need to be - e.g. changing the data in a shouldn't require a recompile in either the sparse or dense case, though it would if the shape / dtype / nnz (for the sparse case) changed.

jakevdp commented 3 years ago

Yeah - I spent some time looking into this. JAX's current method of tracking static arguments makes it difficult to do things at this level of granularity. If you pass a jax.tree_util.Partial function as a non-static argument, all arguments to that function will be traced, and I don't think there's any way around this currently.

jackd commented 3 years ago

@jakevdp thanks for your investigation. I've got a work-around based on passing in a vector of size nrows with arbitrary values (sized in the code below). It's dirty, but it works for the moment...

@jax.jit
def coo_matvec(data, row, col, sized, v):
    assert v.ndim == 1
    dv = data * v[col]
    return jnp.zeros(sized.size, dtype=dv.dtype).at[row].add(dv)