jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.37k stars 2.78k forks source link

dlpack conversion fails #15686

Closed mehdiataei closed 1 year ago

mehdiataei commented 1 year ago

Description

repro:

import jax
import jax.dlpack
import jax.numpy as jnp
import numpy as np

x = jnp.arange(10)
y = jax.dlpack.to_dlpack(x)
yy = np.from_dlpack(y)

Gives:

AttributeError: type object 'PyCapsule' has no attribute '__dlpack__'

jax 0.4.8 jaxlib 0.4.7+cuda12.cudnn88 numpy 1.24.2

The only reference to this error that I found is this: https://github.com/triton-inference-server/server/issues/3944

which refers to a change in dlpack protocol in numpy

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

GPU

Additional system info

No response

NVIDIA GPU info

No response

hawkinsp commented 1 year ago

I think you just wanted:

import jax
import jax.dlpack
import jax.numpy as jnp
import numpy as np

x = jnp.arange(10)
yy = np.from_dlpack(x)

np.from_dlpack wants an object on which it can call __dlpack__, not the dlpack capsule itself.

mehdiataei commented 1 year ago

Thanks @hawkinsp. I think my confusion was because

yy = jax.dlpack.from_dlpack(y)

works on the capsule itself.