Open EdwardRaff opened 2 years ago
Hi @EdwardRaff,
Thanks for your interest in the library!
To be honest, I'm not too optimistic about a fully-fledged interface. We'd like to provide JAX binders - but as far as I can tell, the JAX documentation still doesn't include a clear documentation for external contributions at the C++ level. I have only found this unofficial repo that relies on an undocumented inner API. Without a strong commitment from the JAX devs on a stable API for C++ extensions, it is hard to prioritize the development of an interface that could be broken without any warning.
The PyTorch environment is much more welcoming for community projects like KeOps. My understanding is that whereas PyTorch is designed as an open research framework, JAX is still very much designed as an internal Google tool. In the short term, our priority is to support the common array API. This may be enough to provide some compatibility with JAX and TensorFlow?
Needless to say, if you (or anyone else!) has thoughts on this topic, we'd be interested to hear about it :-) Best regards, Jean
OpenAI's Trition language (a language compiler to write highly optimized CUDA kernels) has been hooked into JAX now. https://github.com/jax-ml/jax-triton Perhaps this might provide insight on how to achieve this?
There is also now a callback function built into JAX. https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
Hi @adam-hartshorne ,
Thanks a lot, I think that the callback mechanism is exactly what we need to provide a minimal wrapper: we will look into it and keep you updated!
Best regards, Jean
There is also this in the docs for lower-level interactions (which is essentially an official take on "extending jax" repo link mentioned above). https://jax--14158.org.readthedocs.build/en/14158/Custom_Operation_for_GPUs.html
This appears to be the up-to-date way of connecting CUDA to JAX.
Toy package containing boilerplate for writing custom CUDA kernels for JAX. https://github.com/brentyi/jax_cuda_boilerplate/tree/master
FYI, this might provide the basis of a solution in order to support Keops with JAX,
Love KeOps, just wondering if yall were planning JAX support given the 2.0 release.