Open chaoming0625 opened 2 years ago
@chaoming0625 I'm by no means a JAX expert so my guess could be wrong. IIUC JAX device arrays don't give a raw ptr to storage in memory as PyTorch does, which making a torch-like integration (zero-copy) with Taichi for JAX kinda hard. Then if you have to copy the device array from JAX, copying it to numpy arrays or torch tensors so that Taichi can operate on those pretty efficiently, this could be a possible way to workaround?
Note taichi's sparse computation requires a specific datalayout (depending on your snode structure) in a root buffer managed by Taichi, dense numpy arrays/torch tensors are still the recommended way to interact with other librarys for those sparse fields.
Dear @ailzhang , one way to interoperate JAX data with Taichi is using dlpack
:
import jax.dlpack
import torch
def j2t(x_jax):
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
return x_torch
def t2j(x_torch):
x_torch = x_torch.contiguous()
x_jax = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch))
return x_jax
This could make a zero-copy from JAX data to PyTorch tensor. This PyTorch tensor can then be used in Taichi kenerls. Finally, the tensors returned from the Taichi kernel can also be zero-copied to JAX.
I think this may be one possible solution.
We are just wondering where can get the address of the Taichi compiled kernels. Thanks.
@chaoming0625 Sounds good! Taichi ndarrays are just contiguous memory so it should be pretty straightforward to support dl_pack format (although it doesn't yet). Taichi compiled kernels are https://github.com/taichi-dev/taichi/blob/master/python/taichi/lang/kernel_impl.py#L574.
Dear @ailzhang , that's wonderful. Thanks very much!
Hi @ailzhang
I just wanted to ask if there is any update on this issue or an alternative solution to @chaoming0625's. Do you plan to implement support for jax arrays via taichi.ndarray as it was done for pytorch?
Also curious about this, since I'd like to use some packages written for Jax (numpyro specifically) and try out the taichi ad system.
As further motivation, I would love to be able to tap into these JAX projects with Taichi:
See examples in https://github.com/brainpy/BrainPy/pull/553
We have already seen some examples which can use Taichi as a part of the PyTorch program. For example,
However, is it possible to integrate Taichi into JAX?
Taichi is able to generate highly optimized operators, and it is very suitable to implement operators involving sparse computations. If Taichi kernels can be used in a JAX program, it will be interesting for broad programmers.
I think the key to the integration is the address of the compiled kernel in Taichi. There are examples that launch a GPU kernel compiled by Triton in JAX. Maybe it is straightforward for Taichi too.