jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

JAX DLPack tensor exchange fails for boolean-valued tensors #19352

Closed wjakob closed 9 months ago

wjakob commented 9 months ago

Description

While developing a program that exchanges tensors with another framework, I noticed that this data exchange fails when some of the arrays are boolean-valued. Here is a program that demonstrates this failure in a CPU/NumPy session. However, the same issue also occurs when using the CUDA backend.

>>> import numpy as np
>>> import jax
>>> a = np.array([True])
>>> jax.dlpack.from_dlpack(a.__dlpack__())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/homebrew/lib/python3.12/site-packages/jax/_src/dlpack.py", line 123, in from_dlpack
    return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: Unknown or invalid DLPack type code 6

Incidentally, that reported ID is the DLPack encoding of kDLBool (see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L157)

What jax/jaxlib version are you using?

0.4.23 0.4.23

Which accelerator(s) are you using?

CPU and GPU (CUDA)

Additional system info?

macOS and Linux

NVIDIA GPU info


---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:01:00.0 Off |                  Off |
|  0%   46C    P8              18W / 450W |      2MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+```
hawkinsp commented 9 months ago

Thanks for the report. I think when I implemented the original DLPack type conversion bool support didn't exist in the DLPack spec. It should be easy to add.

hawkinsp commented 9 months ago

I have a pending PR to do this, although it will take a day or two to land.

wjakob commented 9 months ago

Excellent, thank you!

hawkinsp commented 9 months ago

A random note: you passed x.__dlpack__() to from_dlpack. Per the spec, you should pass x itself, not x.__dlpack__() (https://dmlc.github.io/dlpack/latest/python_spec.html#syntax-for-data-interchange-with-dlpack), because we need to be able to call __dlpack_device on the same object.

We might drop support for passing a raw DLPack capsule at some point because it's not what the spec says we should accept.

wjakob commented 9 months ago

Right, my bad! (fortunately that's something I just did in the repro but not the original application)

hawkinsp commented 9 months ago

This was fixed by https://github.com/google/jax/commit/c4368351d2377c36b75a38dc40266f9e31ea0830 and https://github.com/openxla/xla/commit/19dc13a5e54ccaecf1a8700d08a7a8726376b59d !

(Needs a fresh jaxlib build, obviously.)

wjakob commented 9 months ago

Awesome, thank you!