Open Joshuaalbert opened 4 months ago
If I understand the discussion in https://github.com/NIFTy-PPL/JAXbind/issues/10 correcty, JAXbind
's functionality goes beyond what JAX's external callbacks provide.
Being able to specify your own output buffers is certainly an attractive feature; if it can be made to work with JAXbind
, I hope we can support this soon. For the moment, the focus is on finishing the code paper, however.
JAXbind already passes on the pre-allocated buffer from XLA to the user. IMO this is a non-issue here. Please see the paper for a discussion of the differences between JAXbind and JAX's external callback API (most importantly the support for customizing both JVP and VJP in JAXbind).
I agree that https://github.com/google/jax/issues/20701 would be neat to have.
JAX's external callbacks also allow specifying a custom jvp and vjp using jax.custom_jvp
/jax.custom_vjp
. See here for example specifying a jvp.
Ah, I see, so it's related to allowing both JVP and VJP https://github.com/NIFTy-PPL/JAXbind/issues/10#issuecomment-2027414580. Nice. My only comment is then that it would be better to follow the API of JAX's external callbacks, as much as possible since there is then a better chance they'll incorperate it into JAX's source, which I would argue should be the goal of JAXbind. Otherwise, there is just competing offers and redundant efforts.
We reached out to the JAX developers regarding this here https://github.com/google/jax/issues/19179 .
I am not sure whether a unified interface is easily possible given that JAXbind needs to combine both pure_callback
, custom_jvp
and custom_vjp
in one. I can imagine a future where JAXbind will simply be superseded by tools in JAX itself (e.g. if custom transpositions would be added and issues like the one you raised are solved). However, we are not there yet otherwise we wouldn't have felt the need to code JAXbind.
This package is very similar to JAX's provided, and improving, external callback functionality.
I also urge you to show support for https://github.com/google/jax/issues/20701 to get JAX provide output buffers.