jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

Support __cuda_array_interface__ on GPU #1100

Closed hawkinsp closed 8 months ago

hawkinsp commented 5 years ago

https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html

It would not be hard to make DeviceArray implement this interface on GPU.

It would be slightly harder to support wrapping a DeviceArray around an existing CUDA array, but not that hard.

jonasrauber commented 5 years ago

I think even just supporting one of the directions (i.e. making DeviceArray implement this interface on GPU) would already be a great addition. I would be happy to help, but I am not sure where to find the pointer to GPU memory / what else to pay attention to.

jonasrauber commented 4 years ago

TensorFlow now supports dlpack: https://github.com/VoVAllen/tf-dlpack/issues/3

hawkinsp commented 4 years ago

PR #2133 added __cuda_array_interface__ export. You'll need a jaxlib built from GitHub head or you'll need to wait for us to make another jaxlib wheel release.

Because of https://github.com/pytorch/pytorch/issues/32868 you can't directly import the resulting arrays to PyTorch. But because of https://github.com/cupy/cupy/issues/2616 you can "launder" the array via CuPy and into PyTorch if you want.

(Another option for interoperability is DLPack, which JAX supports at Github head, in both directions.)

kkraus14 commented 4 years ago

Could this be reopened until import support is added as well?

hawkinsp commented 4 years ago

I don't follow. We support both directions, I believe?

Edit: I am wrong, apparently we don't support imports.

jakirkham commented 4 years ago

Edit: I am wrong, apparently we don't support imports.

Yeah this need came up again recently ( cc @leofang @quasiben ).

hawkinsp commented 4 years ago

Although note that DLPack imports should work, so that's an option if the exporter supports DLPack.

leofang commented 4 years ago

Thanks John! Yeah we just finished a GPU Hackathon, and a few of our teams evaluating JAX asked us why JAX can't work with other libraries like CuPy and PyTorch bidirectionally. It'd be very useful, say, to do autograd in JAX, postprocess in CuPy, then bring it back to JAX.

hawkinsp commented 4 years ago

Also: I haven't tried this, but since CuPy supports both __cuda_array_interface__ and DLPack, you can most likely "launder" an array via CuPy into JAX:

(Obviously this isn't ideal, but it might unblock you.)

leofang commented 3 years ago

Hi @hawkinsp I recently pm'd @apaszke in an occasion where this support was mentioned. It'd be nice if JAX can prioritize the bi-directional support for the CUDA Array Interface (and update to the latest v3 protocol, in which the synchronization semantics is specified).

As you pointed out in a DLPack issue (https://github.com/dmlc/dlpack/issues/50), DLPack lacks the support for complex numbers and it's unlikely to be resolved in the foreseeable future. For array libraries this is simply not an acceptable workaround and is actually a blocker for several applications that I am aware.

Thanks, and happy new year!

miguelusque commented 3 years ago

Could this be reopened until import support is added as well?

Hi!

I was wondering if there is any update on this. Thanks!

Miguel

dmenig commented 3 years ago

I'm very interested in this too

mjsML commented 1 year ago

@hawkinsp this is seeing internal traction, given how some of the JAX internals have evolved (i.e. arrays, shmaps etc., in the context of MGMN) is there work being done here?

hypnopump commented 1 year ago

i would be interested in this

pearu commented 10 months ago

Just FYI, a new Python package pydlpack has been released that supports bidirectional data exchange of many array providers and consumers.

Here follows an example of zero-copy data exchange in between jax and torch:

>>> from dlpack import asdlpack
>>> a1 = jax.numpy.array([[1, 2], [3, 4]])
>>> t1 = torch.from_dlpack(asdlpack(a1))
>>> t1
tensor([[1, 2],
        [3, 4]], device='cuda:0', dtype=torch.int32)
>>> t2 = torch.tensor([[5, 6], [7, 8]]).cuda()
>>> a2 = jax.numpy.from_dlpack(asdlpack(t2))
>>> a2
Array([[5, 6],
       [7, 8]], dtype=int32)

An other example is exchanging cuda buffers between jax and numba:

>>> from dlpack import asdlpack
>>> import numba.cuda, numpy, jax
>>> a = numba.cuda.to_device(numpy.array([[1, 2], [3, 4]]))
>>> arr = jax.numpy.from_dlpack(asdlpack(a))
>>> arr
Array([[1, 2],
       [3, 4]], dtype=int32)