Closed hawkinsp closed 8 months ago
I think even just supporting one of the directions (i.e. making DeviceArray
implement this interface on GPU) would already be a great addition.
I would be happy to help, but I am not sure where to find the pointer to GPU memory / what else to pay attention to.
TensorFlow now supports dlpack: https://github.com/VoVAllen/tf-dlpack/issues/3
PR #2133 added __cuda_array_interface__
export. You'll need a jaxlib built from GitHub head or you'll need to wait for us to make another jaxlib wheel release.
Because of https://github.com/pytorch/pytorch/issues/32868 you can't directly import the resulting arrays to PyTorch. But because of https://github.com/cupy/cupy/issues/2616 you can "launder" the array via CuPy and into PyTorch if you want.
(Another option for interoperability is DLPack, which JAX supports at Github head, in both directions.)
Could this be reopened until import support is added as well?
I don't follow. We support both directions, I believe?
Edit: I am wrong, apparently we don't support imports.
Edit: I am wrong, apparently we don't support imports.
Yeah this need came up again recently ( cc @leofang @quasiben ).
Although note that DLPack imports should work, so that's an option if the exporter supports DLPack.
Thanks John! Yeah we just finished a GPU Hackathon, and a few of our teams evaluating JAX asked us why JAX can't work with other libraries like CuPy and PyTorch bidirectionally. It'd be very useful, say, to do autograd in JAX, postprocess in CuPy, then bring it back to JAX.
Also: I haven't tried this, but since CuPy supports both __cuda_array_interface__
and DLPack, you can most likely "launder" an array via CuPy into JAX:
__cuda_array_interface__
to CuPy.(Obviously this isn't ideal, but it might unblock you.)
Hi @hawkinsp I recently pm'd @apaszke in an occasion where this support was mentioned. It'd be nice if JAX can prioritize the bi-directional support for the CUDA Array Interface (and update to the latest v3 protocol, in which the synchronization semantics is specified).
As you pointed out in a DLPack issue (https://github.com/dmlc/dlpack/issues/50), DLPack lacks the support for complex numbers and it's unlikely to be resolved in the foreseeable future. For array libraries this is simply not an acceptable workaround and is actually a blocker for several applications that I am aware.
Thanks, and happy new year!
Could this be reopened until import support is added as well?
Hi!
I was wondering if there is any update on this. Thanks!
Miguel
I'm very interested in this too
@hawkinsp this is seeing internal traction, given how some of the JAX internals have evolved (i.e. arrays, shmaps etc., in the context of MGMN) is there work being done here?
i would be interested in this
Just FYI, a new Python package pydlpack has been released that supports bidirectional data exchange of many array providers and consumers.
Here follows an example of zero-copy data exchange in between jax and torch:
>>> from dlpack import asdlpack
>>> a1 = jax.numpy.array([[1, 2], [3, 4]])
>>> t1 = torch.from_dlpack(asdlpack(a1))
>>> t1
tensor([[1, 2],
[3, 4]], device='cuda:0', dtype=torch.int32)
>>> t2 = torch.tensor([[5, 6], [7, 8]]).cuda()
>>> a2 = jax.numpy.from_dlpack(asdlpack(t2))
>>> a2
Array([[5, 6],
[7, 8]], dtype=int32)
An other example is exchanging cuda buffers between jax and numba:
>>> from dlpack import asdlpack
>>> import numba.cuda, numpy, jax
>>> a = numba.cuda.to_device(numpy.array([[1, 2], [3, 4]]))
>>> arr = jax.numpy.from_dlpack(asdlpack(a))
>>> arr
Array([[1, 2],
[3, 4]], dtype=int32)
https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
It would not be hard to make
DeviceArray
implement this interface on GPU.It would be slightly harder to support wrapping a
DeviceArray
around an existing CUDA array, but not that hard.