NIFTy-PPL / JAXbind

Bind any function written in another language to JAX with support for JVP/VJP/batching/jit compilation
BSD 2-Clause "Simplified" License
61 stars 0 forks source link

JAX's external callbacks #31

Open Joshuaalbert opened 4 months ago

Joshuaalbert commented 4 months ago

This package is very similar to JAX's provided, and improving, external callback functionality.

I also urge you to show support for https://github.com/google/jax/issues/20701 to get JAX provide output buffers.

mreineck commented 4 months ago

If I understand the discussion in https://github.com/NIFTy-PPL/JAXbind/issues/10 correcty, JAXbind's functionality goes beyond what JAX's external callbacks provide.

Being able to specify your own output buffers is certainly an attractive feature; if it can be made to work with JAXbind, I hope we can support this soon. For the moment, the focus is on finishing the code paper, however.

Edenhofer commented 4 months ago

JAXbind already passes on the pre-allocated buffer from XLA to the user. IMO this is a non-issue here. Please see the paper for a discussion of the differences between JAXbind and JAX's external callback API (most importantly the support for customizing both JVP and VJP in JAXbind).

I agree that https://github.com/google/jax/issues/20701 would be neat to have.

Joshuaalbert commented 4 months ago

JAX's external callbacks also allow specifying a custom jvp and vjp using jax.custom_jvp/jax.custom_vjp. See here for example specifying a jvp.

Joshuaalbert commented 4 months ago

Ah, I see, so it's related to allowing both JVP and VJP https://github.com/NIFTy-PPL/JAXbind/issues/10#issuecomment-2027414580. Nice. My only comment is then that it would be better to follow the API of JAX's external callbacks, as much as possible since there is then a better chance they'll incorperate it into JAX's source, which I would argue should be the goal of JAXbind. Otherwise, there is just competing offers and redundant efforts.

Edenhofer commented 4 months ago

We reached out to the JAX developers regarding this here https://github.com/google/jax/issues/19179 .

I am not sure whether a unified interface is easily possible given that JAXbind needs to combine both pure_callback, custom_jvp and custom_vjp in one. I can imagine a future where JAXbind will simply be superseded by tools in JAX itself (e.g. if custom transpositions would be added and issues like the one you raised are solved). However, we are not there yet otherwise we wouldn't have felt the need to code JAXbind.