getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

JAX support on the horizon? #254

Open EdwardRaff opened 2 years ago

EdwardRaff commented 2 years ago

Love KeOps, just wondering if yall were planning JAX support given the 2.0 release.

jeanfeydy commented 2 years ago

Hi @EdwardRaff,

Thanks for your interest in the library!

To be honest, I'm not too optimistic about a fully-fledged interface. We'd like to provide JAX binders - but as far as I can tell, the JAX documentation still doesn't include a clear documentation for external contributions at the C++ level. I have only found this unofficial repo that relies on an undocumented inner API. Without a strong commitment from the JAX devs on a stable API for C++ extensions, it is hard to prioritize the development of an interface that could be broken without any warning.

The PyTorch environment is much more welcoming for community projects like KeOps. My understanding is that whereas PyTorch is designed as an open research framework, JAX is still very much designed as an internal Google tool. In the short term, our priority is to support the common array API. This may be enough to provide some compatibility with JAX and TensorFlow?

Needless to say, if you (or anyone else!) has thoughts on this topic, we'd be interested to hear about it :-) Best regards, Jean

adam-hartshorne commented 1 year ago

OpenAI's Trition language (a language compiler to write highly optimized CUDA kernels) has been hooked into JAX now. https://github.com/jax-ml/jax-triton Perhaps this might provide insight on how to achieve this?

There is also now a callback function built into JAX. https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html

jeanfeydy commented 1 year ago

Hi @adam-hartshorne ,

Thanks a lot, I think that the callback mechanism is exactly what we need to provide a minimal wrapper: we will look into it and keep you updated!

Best regards, Jean

adam-hartshorne commented 1 year ago

There is also this in the docs for lower-level interactions (which is essentially an official take on "extending jax" repo link mentioned above). https://jax--14158.org.readthedocs.build/en/14158/Custom_Operation_for_GPUs.html

adam-hartshorne commented 1 year ago

This appears to be the up-to-date way of connecting CUDA to JAX.

Toy package containing boilerplate for writing custom CUDA kernels for JAX. https://github.com/brentyi/jax_cuda_boilerplate/tree/master

adam-hartshorne commented 1 year ago

FYI, this might provide the basis of a solution in order to support Keops with JAX,

https://github.com/rdyro/torch2jax