jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

Allow jax.pure_callback to infer result_shape_dtypes #16613

Open carlosgmartin opened 1 year ago

carlosgmartin commented 1 year ago

I was wondering if it would be possible to allow jax.pure_callback to infer result_shape_dtypes by calling the specified function once at trace time and using the result's treedef. This saves users the trouble of having to manually specify result_shape_dtypes. Possible implementation:

import jax
import numpy as np

def to_numpy(tree):
    return jax.tree_map(lambda x: np.zeros(x.shape, x.dtype), tree)

def pure_callback_auto(f, *args, **kwargs):
    tree = f(*to_numpy(args), **to_numpy(kwargs))
    return jax.pure_callback(f, tree, *args, **kwargs)
carlosgmartin commented 1 year ago

@jakevdp Would you like me to create a PR for this?

jakevdp commented 1 year ago

I'm interested to hear what @sharadmv thinks of the idea

carlosgmartin commented 7 months ago

@jakevdp Now that pure_callback and io_callback are switching from NumPy arrays to JAX arrays, this seems even easier to achieve: Just call jax.eval_shape(callback, *args, **kwargs) inside pure_callback and io_callback to infer the result_shape_dtypes automatically. Is my understanding correct?

jakevdp commented 7 months ago

@carlosgmartin I think that would work in the simplest cases where the callback functions are implemented in terms of JAX operations. But I think this will be a rare case: why do a callback if your operations are entirely expressible in terms of JAX operations? More frequently the callback will likely call out to some other non-JAX API, for which eval_shape will fail.

carlosgmartin commented 7 months ago

I've deleted my previous comment, since I was confused.

I think jax.eval_shape can work even if the function is not entirely expressible in terms of JAX operations. For example, consider timeit.default_timer:

import timeit
import jax

print(jax.eval_shape(timeit.default_timer))
# ShapeDtypeStruct(shape=(), dtype=float32)

This allows one to do things like the following:

import timeit
import jax

def io_callback_auto(callback, *args, **kwargs):
    shape = jax.eval_shape(callback, *args, **kwargs)
    return jax.experimental.io_callback(callback, shape, *args, **kwargs)

def f(h, _):
    return h, io_callback_auto(timeit.default_timer)

_, xs = jax.lax.scan(f, None, None, length=10**6)

print(xs)
# [881338.3 881338.3 881338.3 ... 881340.  881340.  881340. ]

Could you give a concrete example where this wouldn't work, if you had one in mind?

jakevdp commented 7 months ago

I don’t disagree that some useful callbacks are compatible with eval_shape, I’m just pointing out that it’s likely the majority will not be. It’s not clear to me that it’s helpful to have a default that will fail in the majority of cases.