jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.79k forks source link

Integrating TVM into JAX #10277

Open chaoming0625 opened 2 years ago

chaoming0625 commented 2 years ago

Please:

TVM can generate highly optimized operators. Is it possible to integrate the optimized operator of TVM into jax's jit system?

hawkinsp commented 2 years ago

Do you have a particular use case in mind?

In general: sure, we could either use TVM in place of XLA, or we could use TVM to generate individual kernels inside an otherwise XLA-compiled program. For the first approach JAX has some early support for plugging in alternate compilers and runtimes, and for the second case, it's possible to mix XLA-compiled code and other code via mechanisms like dlpack and XLA CustomCall operators. But it would probably be a reasonably large project.

I suspect this is probably in the "contributions welcome" category, but I'd be interested to know if there is a particular program or use case that motivates the question.

BDHU commented 2 years ago

My understanding is TVM supported more hardware accelerators than XLA. It will be interesting to run compiled JAX programs on something like FPGA.

kaiyang-code commented 7 months ago

Hi, I would like to work on this. However, I'm new to JAX. Could you give me some guidance? Thanks!