taichi-dev / taichi

Productive, portable, and performant GPU programming in Python.
https://taichi-lang.org
Apache License 2.0
25.5k stars 2.28k forks source link

Fusing Taichi with JAX #6367

Open chaoming0625 opened 2 years ago

chaoming0625 commented 2 years ago

We have already seen some examples which can use Taichi as a part of the PyTorch program. For example,

However, is it possible to integrate Taichi into JAX?

Taichi is able to generate highly optimized operators, and it is very suitable to implement operators involving sparse computations. If Taichi kernels can be used in a JAX program, it will be interesting for broad programmers.

I think the key to the integration is the address of the compiled kernel in Taichi. There are examples that launch a GPU kernel compiled by Triton in JAX. Maybe it is straightforward for Taichi too.

ailzhang commented 2 years ago

@chaoming0625 I'm by no means a JAX expert so my guess could be wrong. IIUC JAX device arrays don't give a raw ptr to storage in memory as PyTorch does, which making a torch-like integration (zero-copy) with Taichi for JAX kinda hard. Then if you have to copy the device array from JAX, copying it to numpy arrays or torch tensors so that Taichi can operate on those pretty efficiently, this could be a possible way to workaround?

Note taichi's sparse computation requires a specific datalayout (depending on your snode structure) in a root buffer managed by Taichi, dense numpy arrays/torch tensors are still the recommended way to interact with other librarys for those sparse fields.

chaoming0625 commented 2 years ago

Dear @ailzhang , one way to interoperate JAX data with Taichi is using dlpack:

import jax.dlpack
import torch

def j2t(x_jax):
  x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
  return x_torch

def t2j(x_torch):
  x_torch = x_torch.contiguous()
  x_jax = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch))
  return x_jax

This could make a zero-copy from JAX data to PyTorch tensor. This PyTorch tensor can then be used in Taichi kenerls. Finally, the tensors returned from the Taichi kernel can also be zero-copied to JAX.

I think this may be one possible solution.

chaoming0625 commented 2 years ago

We are just wondering where can get the address of the Taichi compiled kernels. Thanks.

ailzhang commented 2 years ago

@chaoming0625 Sounds good! Taichi ndarrays are just contiguous memory so it should be pretty straightforward to support dl_pack format (although it doesn't yet). Taichi compiled kernels are https://github.com/taichi-dev/taichi/blob/master/python/taichi/lang/kernel_impl.py#L574.

chaoming0625 commented 2 years ago

Dear @ailzhang , that's wonderful. Thanks very much!

salykova commented 1 year ago

Hi @ailzhang

I just wanted to ask if there is any update on this issue or an alternative solution to @chaoming0625's. Do you plan to implement support for jax arrays via taichi.ndarray as it was done for pytorch?

maedoc commented 1 year ago

Also curious about this, since I'd like to use some packages written for Jax (numpyro specifically) and try out the taichi ad system.

jarmitage commented 11 months ago

As further motivation, I would love to be able to tap into these JAX projects with Taichi:

chaoming0625 commented 10 months ago

See examples in https://github.com/brainpy/BrainPy/pull/553