jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.78k forks source link

Provide output buffer for pure_callback result #20701

Open Joshuaalbert opened 6 months ago

Joshuaalbert commented 6 months ago

Is there any value in supplying a pure_callback (one with a return value), a buffer into which to place it's output? Something like the code below. I have some science applications where I need to use bindings to underlying C++ code that require taking an output buffer into which to place the results. I currently need to create the array inside the callback function, but I suspect that JAX might already have some pre-allocated buffer waiting for the result anyways, and could provide that to the function.

def _fn(x, _output_buffer):
    # Instead of _output_buffer=np.zeros_like(x)
    np.square(x, out=_output_buffer)
    return _output_buffer

result_shape_dtype = jax.ShapeDtypeStruct(
    shape=np.shape(x),
    dtype=x.dtype
)

return jax.pure_callback(_fn, result_shape_dtype, x, vectorized=True)
Joshuaalbert commented 6 months ago

Also, is it safe to reuse the buffer of input values, e.g. like

def _fn(x):
    np.square(x, out=x)
    return x

result_shape_dtype = jax.ShapeDtypeStruct(
    shape=np.shape(x),
    dtype=x.dtype
)

return jax.pure_callback(_fn, result_shape_dtype, x, vectorized=True)

EDIT: That would be a nope as it gives jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: output array is read-only

superbobry commented 6 months ago

There is a preallocated buffer on the XLA side, but it is not currently passed to the pure_callback.

@hawkinsp should we have an API for this, wdyt?

Joshuaalbert commented 6 months ago

Also, the main reason this would be helpful is because the output buffers for the science applications I'm working on are really large so if there is already one allocated by XLA it would save lots of memory to use that one.

Joshuaalbert commented 6 months ago

@superbobry @hawkinsp any update available for this?

superbobry commented 6 months ago

Hey @Joshuaalbert, sorry for the silence. Here is a quick update

superbobry commented 5 months ago

Another quick update: I prototyped dlpack.callback, but after discussing it with a few JAX team members, I decided not to move forward with it as JAX has too many callback APIs already.

Instead, the plan is to change existing callback APIs to support mutable_results=. I am waiting on a few changes in XLA FFI, but once they land, it should be fairly straightforward to implement this.

Joshuaalbert commented 4 months ago

Hi @superbobry, any news?

superbobry commented 4 months ago

No news yet in a sense that none of my attempts landed. Hopefully, in the coming weeks :)

Joshuaalbert commented 4 months ago

@superbobry can you check out JAXbind which offers a way to specify both JVP and VJP for external callbacks? I argue that it should be also possible within JAX to specify both.

dfm commented 4 months ago

@Joshuaalbert — JAX doesn't currently offer a public API for this (customizing multiple transforms for a single callable), but it's on our radar. I'd say that this is off topic for this issue thread, but feel free to open another with more info about your use cases for using both JVP and VJP for one callback!

superbobry commented 3 months ago

Re original topic: @yashk2810 suggested passing the output buffer as an argument and donating it to the callback, i.e.

jax.jit(partial(jax.pure_callback, ...), donate_argnums=...)

this will ensure that the runtime doesn't allocate a separate buffer for the output.

Unfortunately, though, there isn't a great way for creating a mutable view of a jax.Array. I was hoping we can do jax.Array -> DLPack -> NumPy, but NumPy imports DLPack capsules as readonly.

Joshuaalbert commented 3 months ago

Would donating the input avoid double allocating memory? I can imagine if it's something like: invoke pure callback > free input > allocate inside the pure callback > use that output array in JAX without reallocation. We're memory limited in this use case, ducc.wgridder.