Open ailzhang opened 2 years ago
Thanks for writing this up! My initial thought was triggered by a question that someone asked if they can use Taichi's GPU data with numpy withouth incurring a D2H copying. If we can interoperate with either numba or JAX, we can provide users with GPU numpy for free. Not sure whether numba or JAX is more similar to numpy, though :-)
@k-ye Yea sharing CPU data with numpy (or GPU data with numba) is possible since we have physical pointer address https://github.com/taichi-dev/taichi/blob/master/taichi/program/program.cpp#L567 for cpu and cuda backend. Implementing the above array interfaces (or dlpack) should do the work.
Side note: currently our to_numpy
is by default a deep-copy, note torch's to_numpy
shares the underlying storage if src and target device are both cpu. We can probably add a to_numpy(..., copy=False)
so that user can control whether it shares the underlying storage or not. And the implementation shouldn't be hard.
, note torch's to_numpy shares the underlying storage if src and target device are both cpu.
I wonder how the lifetime ownership problem is resolved in this case?
@k-ye I believe the ownership is shared in this case as when you do numpy()
it actually creates a new python tensor from the storage in pytorch and numpy steals its reference and set it as base. https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_numpy.cpp#L160-L165 In other words creating a numpy array out of existing storage from a tensor should increase the reference of the tensor.
I see, wasn't aware that torch used CPython layer API..
Hi @ailzhang @k-ye
I just wanted to ask if there is any update on this issue? I didn't find any information on dlpack in the taichi docs. In particular Im interested in using taichi with jax. Jax has already implemented dlpack support. Do you maybe plan to implement support for jax arrays via taichi.ndarrays or dlpack?
Is there any updates on this?
By common array interfaces, I mean
Note an alternative for this is https://dmlc.github.io/dlpack/latest/python_spec.html#syntax-for-data-interchange-with-dlpack and we already have an issue for it. https://github.com/taichi-dev/taichi/issues/4534 Tbh I haven't explored the pros and cons of these two interfaces myself but this is something to consider before implementing. Reference: https://data-apis.org/array-api/latest/design_topics/data_interchange.html
Implementation wise it shouldn't be too hard, one reference can be https://github.com/pytorch/pytorch/pull/11984/files.
I believe jax and pytorch doesn't support importing from cuda_array_interface yet. In other words, if you create a torch tensor/jax devicearray then use their cuda_array_interface in taichi that's totally fine. But if you create a taichi ndarray and want to use its cuda_array_interface in torch/jax I believe it's not yet supported. Numba does support that tho. https://github.com/google/jax/issues/1100
Also once we support these interfaces in taichi we should use them when we use numpy/torch/paddle tensors as external arrays for taichi kernels to clean things up.
cc: @k-ye