NIFTy-PPL / JAXbind

Bind any function written in another language to JAX with support for JVP/VJP/batching/jit compilation
BSD 2-Clause "Simplified" License
61 stars 0 forks source link

[JOSS 6532] Relationship to JAX's built in "external callback" functionality #10

Closed dfm closed 7 months ago

dfm commented 7 months ago

If I understand the implementation properly (although the docs are not very clear about the implementation details), this library works by acquiring the Python GIL and executing a Python callback from within the XLA custom call. As far as I can tell, this is the same behavior offered by JAX's built in external callback functionality, which can be combined with the jax.custom_jvp and/or jax.custom_vjp decorators to achieve the same results as JAXBind. I'd be interested to understand the difference between these two approaches, and a clear comparison should be added to the docs (and JOSS paper).

As a smaller side note, it might be also interesting to compare to the Pallas submodule, which takes a slightly different approach to similar ends on GPUs and TPUs.

ref: https://github.com/openjournals/joss-reviews/issues/6532

Edenhofer commented 7 months ago

Good point! We should definitely compare to both! I'll add a discuss of the differences to the documentation and the paper!

Yes, we briefly need to acquire the GIL to call back into python as to allow for a convenient python user-interface. For any reasonably expensive call, the time for which the C++ code will hold the GIL will in all likelihood be completely negligible. In the python call, the GIL will likely again be released very quickly assuming most codes in python will pretty quickly go back to C/C++/Rust/... (e.g. by calling scipy/numpy).

The going-back-to-python part is indeed very similar to the host-callback in JAX which can be combined with jax.custom_jvp OR jax.custom_vjp ("Forward-mode autodiff cannot be used on the jax.custom_vjp function and will raise an error"). The "or" is exactly what puts it apart from these functions. JAXbind not only registers a custom JVP but also a custom VJP meaning that we can e.g. access the Fisher metric of a model with a call bound using JAXbind without relying on JAX's transposition backend. For linear functions this means that we can differentiate them arbitrary many times and transpose them at will.

dfm commented 7 months ago

I'm not sure I totally follow the argument about how crucial it is to define both a JVP and VJP. I think that the "linear" point would also hold for an implementation using custom_jvp?

Regardless, I think it would be great for the docs to make a very clear case about this comparison!

Side note: as far as I can tell, it's not too hard to define a custom primitive that uses the external callback mechanism under the hood (e.g. lowering with mlir.lower_fun), so it's possible that the compiled component of JAXBind could be replaced by jax.pure_callback, while keeping the same front end. This could potentially reduce the maintenance burden.

Edenhofer commented 7 months ago

For the kind of optimizations we usually do with our models (arxiv:1901.11033 and arxiv:2105.10470), we need both the JVP and the VJP. Together they serve as a positive semidefinite estimator of the local "curvature" which we use in a second order optimization.

Unfortunately, custom_jvp is not enough. It allows you to define a custom JVP but you must rely on JAX automatically deriving the VJP for you then. For example, say you want to wrap a host-callback to the FFT, then you can define the JVP to be the FFT itself but you can not define the VJP anymore. JAX will try to derive it for you by transposing the JVP but will fail because it can not transpose a host-callback. Custom transpositions in JAX would solve exactly this problem but so far they have various limitations that make them unsuitable for the types of problems we wanted to solve (see e.g. https://github.com/google/jax/issues/13298 and https://github.com/google/jax/issues/13283).

Very interesting idea! We started with C++ as we were wrapping individual C++ functions before generalizing the code. Looking back at it, we could probably replace all of the C++ code with something along the lines of this https://github.com/google/jax/discussions/17875 for jax.pure_callback. Having said that, we haven't internally settled the debate yet whether we might want to expose a version of JAXbind free of any python GIL for C/C++ which would require our custom C++ backend.

dfm commented 7 months ago

That all makes sense, thanks! Wrt the JOSS review, I'd just say that all this should be discussed clearly in the docs. Otherwise all sounds good.

dfm commented 6 months ago

Another follow up on this. The paper has a paragraph that starts "To the best of our knowledge no other code currently exists for connecting generic functions to JAX. ..." I'd argue that this is no longer true given this discussion here. In particular, the JAX core library provides an interface for this! JAXbind does provide some extra features, but I think this should be clarified in the text.

roth-jakob commented 6 months ago

Thanks for the comment!

It might be that you looked at an old version of the paper. Unfortunately, I forgot to update the PDF in the JOSS review issue. Most of the changes were already incorporated into the paper. Additionally, I have now updated the summary section of the paper with PR #22.

Let me know if you believe the relation to the external callback functionally is still not clear in the paper.