Open carlosgmartin opened 1 year ago
@jakevdp Would you like me to create a PR for this?
I'm interested to hear what @sharadmv thinks of the idea
@jakevdp Now that pure_callback
and io_callback
are switching from NumPy arrays to JAX arrays, this seems even easier to achieve: Just call jax.eval_shape(callback, *args, **kwargs)
inside pure_callback
and io_callback
to infer the result_shape_dtypes
automatically. Is my understanding correct?
@carlosgmartin I think that would work in the simplest cases where the callback functions are implemented in terms of JAX operations. But I think this will be a rare case: why do a callback if your operations are entirely expressible in terms of JAX operations? More frequently the callback will likely call out to some other non-JAX API, for which eval_shape
will fail.
I've deleted my previous comment, since I was confused.
I think jax.eval_shape
can work even if the function is not entirely expressible in terms of JAX operations. For example, consider timeit.default_timer
:
import timeit
import jax
print(jax.eval_shape(timeit.default_timer))
# ShapeDtypeStruct(shape=(), dtype=float32)
This allows one to do things like the following:
import timeit
import jax
def io_callback_auto(callback, *args, **kwargs):
shape = jax.eval_shape(callback, *args, **kwargs)
return jax.experimental.io_callback(callback, shape, *args, **kwargs)
def f(h, _):
return h, io_callback_auto(timeit.default_timer)
_, xs = jax.lax.scan(f, None, None, length=10**6)
print(xs)
# [881338.3 881338.3 881338.3 ... 881340. 881340. 881340. ]
Could you give a concrete example where this wouldn't work, if you had one in mind?
I don’t disagree that some useful callbacks are compatible with eval_shape
, I’m just pointing out that it’s likely the majority will not be. It’s not clear to me that it’s helpful to have a default that will fail in the majority of cases.
I was wondering if it would be possible to allow
jax.pure_callback
to inferresult_shape_dtypes
by calling the specified function once at trace time and using the result's treedef. This saves users the trouble of having to manually specifyresult_shape_dtypes
. Possible implementation: