Closed mehdiataei closed 1 year ago
I think you just wanted:
import jax
import jax.dlpack
import jax.numpy as jnp
import numpy as np
x = jnp.arange(10)
yy = np.from_dlpack(x)
np.from_dlpack
wants an object on which it can call __dlpack__
, not the dlpack capsule itself.
Thanks @hawkinsp. I think my confusion was because
yy = jax.dlpack.from_dlpack(y)
works on the capsule itself.
Description
repro:
Gives:
jax 0.4.8 jaxlib 0.4.7+cuda12.cudnn88 numpy 1.24.2
The only reference to this error that I found is this: https://github.com/triton-inference-server/server/issues/3944
which refers to a change in dlpack protocol in numpy
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
GPU
Additional system info
No response
NVIDIA GPU info
No response