jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

Mark jax typing Protocols as `@runtime_checkable` #22144

Closed bionicles closed 4 months ago

bionicles commented 4 months ago

Please:

Goal: Ensure "isinstance" works to check if a type (not a value) is a jax Array type. Problem: When deserializing data containing jax arrays, one must check if the types are jax Array types. However, "jaxlib.xla_extension.ArrayImpl is jax.Array" evaluates False (normal, but annoying) Opportunity: Could we make these typing protocols runtime checkable to faciliate serialization/deserialization logic?

TLDR Fix: from typing import runtime_checkable & use @runtime_checkable decorator on Protocol subclasses

Reproduction

import jax, jax.numpy as jnp
from jax._src.typing import DuckTypedArray
from jaxlib.xla_extension import ArrayImpl

def test_jaxlib__extensions_arraylike_is_array():
    try:
        assert ArrayImpl is jax.Array, f"{ArrayImpl=} is not {jax.Array=}"
    except Exception as e:
        print(f"{e=}")
    print("This is normal.")

def test_duck_typed_array_is_runtime_checkable():
    arr = jnp.ones(3, int)
    print("arr", arr)
    assert isinstance(
        arr,
        DuckTypedArray,
    ), "jax._src.typing.DuckTypedArray is not runtime_checkable"

test_jaxlib__extensions_arraylike_is_array()
test_duck_typed_array_is_runtime_checkable()

Result

image

System:

../jax (main) $
> python -c 'import jax; jax.print_environment_info()'
jax:    0.4.31.dev20240627+4c70a94c8
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.12.0 | packaged by conda-forge | (main, Oct  3 2023, 08:43:22) [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='WIN-QVRBL09D89C', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')

$ nvidia-smi
Thu Jun 27 11:25:54 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.52.01              Driver Version: 555.99         CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:01:00.0  On |                  Off |
|  0%   50C    P2             29W /  450W |   19984MiB /  24564MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      4192      C   /python3.12                                 N/A      |
|    0   N/A  N/A    198001      C   /python3.12                                 N/A      |
+-----------------------------------------------------------------------------------------+

References / Links (not strictly needed)

[1] from a Python array API standard, but it's really long to paste here.

[2] Another resource for protocols is DLPack C API, seems clear on data type sizes of things

Here's the short notes on those, I'll add class attributes to the tree_plus eventually, but it can get noisy. image

Workaround

from typing import TypeGuard

from jaxlib.xla_extension import ArrayImpl

import jax, jax.numpy as jnp

# test_duck_typed_array_is_runtime_checkable()
def is_jax_array_type(type_x: type, *, debug: bool = DEBUG) -> bool:
    if type_x is ArrayImpl:
        decision = True
    elif type_x is jax.Array:
        decision = True
    else:
        decision = False

    if debug:
        print(f"is {type_x=} a jax.Array? {decision=}")

    return decision

def is_jax_array_value(x: Any, *, debug: bool = DEBUG) -> TypeGuard[jax.Array]:
    type_x = type(x)
    return is_jax_array_type(type_x, debug=debug)

def test_typeguard_workaround():
    arr = jnp.ones(3, int)
    print("arr", arr)
    arr_is_jax_array = is_jax_array_value(arr)
    print(f"{arr_is_jax_array=}")
    assert (
        arr_is_jax_array
    ), "the is_jax_array TypeGuard[jax.Array] failed to classify a jax Array"
    assert not is_jax_array_value(1)
    assert not is_jax_array_value(False)
    assert not is_jax_array_value((1, 2, 3))

test_typeguard_workaround()

image

Cheers, love me some jax!

jakevdp commented 4 months ago

Assigning @superbobry because he has thought a lot about this kind of typing issue, and may have good feedback!

superbobry commented 4 months ago

I don't mind making DuckTypedArray runtime-checkable, but the guarantees these runtime checks give are pretty weak. In particular, they do not do any type checking:

>>> from typing import Protocol, runtime_checkable
>>> @runtime_checkable
... class P(Protocol):
...   def f(self) -> int: ...
...
>>> class A:
...   f = 42
...
>>> isinstance(A(), P)
True

Note also, that DuckTypedArray only has two methods and thus the runtime check is equivalent to hasattr(..., "shape") and hasattr(..., "dtype"), which means that e.g. a NumPy array will be considered an instance of DuckTypedArray. Is that the behavior you expect?

bionicles commented 4 months ago

Oh, you're right, that wouldn't work, since it would make code think it sees jax arrays when it sees numpy arrays.

I just want to round-trip serialize/deserialize all the things, and hit a situation where I needed to check if a type is jax.Array or not, but T is jax.Array didn't work for ArrayImpl. The typeguard function solved my problem, could be fine to just stick with that.

Seems like a lot of python typing is still Work in Progress, if there's a solution which python's type system supports better than protocols, like dataclasses or abstract base classes, that's great.

Mainly my objective was to ensure we can support both type and value level checks, for the abstract case to check if type T is jax Array, and the concrete case to check if value V is an instance of a jax Array

I'm agnostic about the implementation details, so whatever you think is best with the state of Python, I'm down to test it

jakevdp commented 4 months ago

jax.Array is a base class, so to check if x is a JAX array you can use isinstance(x, jax.Array), or to check if T = type(x) is an array type, you can use issubclass(T, jax.Array).

bionicles commented 4 months ago

sounds good! only thing i'd add would be, make sure to use try/except with issubclass, it's pretty flaky, crashes with annotations (oof)

image

given that, here's a working solution with issubclass, thank you ! closing

from typing import TypeGuard, Any

from jaxlib.xla_extension import ArrayImpl
import jax, jax.numpy as jnp

def is_jax_array_subclass(x: type) -> bool:
    "check if a type is a subclass of jax.Array"
    try:
        return issubclass(x, jax.Array)
    except Exception as _e:
        return False

def is_jax_array(x: Any) -> TypeGuard[jax.Array]:
    "check if a value is an instance of jax.Array"
    try:
        return is_jax_array_subclass(x.__class__)
    except Exception as _e:
        return False

def test_is_jax_array_subclass():
    assert is_jax_array_subclass(jax.Array)
    assert is_jax_array_subclass(ArrayImpl)
    assert not is_jax_array_subclass(tuple[int, ...])
    assert not is_jax_array_subclass(jnp.ones(3))

def test_is_jax_array():
    assert not is_jax_array(jax.Array)
    assert not is_jax_array(ArrayImpl)
    assert not is_jax_array(tuple[int, ...])
    assert is_jax_array(jnp.ones(3))

test_is_jax_array_subclass()
test_is_jax_array()

image