Closed wjakob closed 9 months ago
Thanks for the report. I think when I implemented the original DLPack type conversion bool support didn't exist in the DLPack spec. It should be easy to add.
I have a pending PR to do this, although it will take a day or two to land.
Excellent, thank you!
A random note: you passed x.__dlpack__()
to from_dlpack
. Per the spec, you should pass x
itself, not x.__dlpack__()
(https://dmlc.github.io/dlpack/latest/python_spec.html#syntax-for-data-interchange-with-dlpack), because we need to be able to call __dlpack_device
on the same object.
We might drop support for passing a raw DLPack
capsule at some point because it's not what the spec says we should accept.
Right, my bad! (fortunately that's something I just did in the repro but not the original application)
This was fixed by https://github.com/google/jax/commit/c4368351d2377c36b75a38dc40266f9e31ea0830 and https://github.com/openxla/xla/commit/19dc13a5e54ccaecf1a8700d08a7a8726376b59d !
(Needs a fresh jaxlib build, obviously.)
Awesome, thank you!
Description
While developing a program that exchanges tensors with another framework, I noticed that this data exchange fails when some of the arrays are boolean-valued. Here is a program that demonstrates this failure in a CPU/NumPy session. However, the same issue also occurs when using the CUDA backend.
Incidentally, that reported ID is the DLPack encoding of
kDLBool
(see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L157)What jax/jaxlib version are you using?
0.4.23 0.4.23
Which accelerator(s) are you using?
CPU and GPU (CUDA)
Additional system info?
macOS and Linux
NVIDIA GPU info