Open Joshuaalbert opened 6 months ago
Also, is it safe to reuse the buffer of input values, e.g. like
def _fn(x):
np.square(x, out=x)
return x
result_shape_dtype = jax.ShapeDtypeStruct(
shape=np.shape(x),
dtype=x.dtype
)
return jax.pure_callback(_fn, result_shape_dtype, x, vectorized=True)
EDIT: That would be a nope as it gives jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: output array is read-only
There is a preallocated buffer on the XLA side, but it is not currently passed to the pure_callback
.
@hawkinsp should we have an API for this, wdyt?
Also, the main reason this would be helpful is because the output buffers for the science applications I'm working on are really large so if there is already one allocated by XLA it would save lots of memory to use that one.
@superbobry @hawkinsp any update available for this?
Hey @Joshuaalbert, sorry for the silence. Here is a quick update
jax.Array
s which are currently immutable (and thus cannot be used for output buffers);dlpack.callback
, a new callback API using DLPack capsules for both inputs and outputs.Another quick update: I prototyped dlpack.callback
, but after discussing it with a few JAX team members, I decided not to move forward with it as JAX has too many callback APIs already.
Instead, the plan is to change existing callback APIs to support mutable_results=
. I am waiting on a few changes in XLA FFI, but once they land, it should be fairly straightforward to implement this.
Hi @superbobry, any news?
No news yet in a sense that none of my attempts landed. Hopefully, in the coming weeks :)
@superbobry can you check out JAXbind which offers a way to specify both JVP and VJP for external callbacks? I argue that it should be also possible within JAX to specify both.
@Joshuaalbert — JAX doesn't currently offer a public API for this (customizing multiple transforms for a single callable), but it's on our radar. I'd say that this is off topic for this issue thread, but feel free to open another with more info about your use cases for using both JVP and VJP for one callback!
Re original topic: @yashk2810 suggested passing the output buffer as an argument and donating it to the callback, i.e.
jax.jit(partial(jax.pure_callback, ...), donate_argnums=...)
this will ensure that the runtime doesn't allocate a separate buffer for the output.
Unfortunately, though, there isn't a great way for creating a mutable view of a jax.Array
. I was hoping we can do jax.Array
-> DLPack -> NumPy, but NumPy imports DLPack capsules as readonly.
Would donating the input avoid double allocating memory? I can imagine if it's something like: invoke pure callback > free input > allocate inside the pure callback > use that output array in JAX without reallocation. We're memory limited in this use case, ducc.wgridder
.
Is there any value in supplying a pure_callback (one with a return value), a buffer into which to place it's output? Something like the code below. I have some science applications where I need to use bindings to underlying C++ code that require taking an output buffer into which to place the results. I currently need to create the array inside the callback function, but I suspect that JAX might already have some pre-allocated buffer waiting for the result anyways, and could provide that to the function.