Closed bionicles closed 4 months ago
Assigning @superbobry because he has thought a lot about this kind of typing issue, and may have good feedback!
I don't mind making DuckTypedArray
runtime-checkable, but the guarantees these runtime checks give are pretty weak. In particular, they do not do any type checking:
>>> from typing import Protocol, runtime_checkable
>>> @runtime_checkable
... class P(Protocol):
... def f(self) -> int: ...
...
>>> class A:
... f = 42
...
>>> isinstance(A(), P)
True
Note also, that DuckTypedArray
only has two methods and thus the runtime check is equivalent to hasattr(..., "shape") and hasattr(..., "dtype")
, which means that e.g. a NumPy array will be considered an instance of DuckTypedArray
. Is that the behavior you expect?
Oh, you're right, that wouldn't work, since it would make code think it sees jax arrays when it sees numpy arrays.
I just want to round-trip serialize/deserialize all the things, and hit a situation where I needed to check if a type is jax.Array or not, but T is jax.Array
didn't work for ArrayImpl. The typeguard function solved my problem, could be fine to just stick with that.
Seems like a lot of python typing is still Work in Progress, if there's a solution which python's type system supports better than protocols, like dataclasses or abstract base classes, that's great.
Mainly my objective was to ensure we can support both type and value level checks, for the abstract case to check if type T is jax Array, and the concrete case to check if value V is an instance of a jax Array
I'm agnostic about the implementation details, so whatever you think is best with the state of Python, I'm down to test it
jax.Array
is a base class, so to check if x
is a JAX array you can use isinstance(x, jax.Array)
, or to check if T = type(x)
is an array type, you can use issubclass(T, jax.Array)
.
sounds good! only thing i'd add would be, make sure to use try/except with issubclass, it's pretty flaky, crashes with annotations (oof)
given that, here's a working solution with issubclass, thank you ! closing
from typing import TypeGuard, Any
from jaxlib.xla_extension import ArrayImpl
import jax, jax.numpy as jnp
def is_jax_array_subclass(x: type) -> bool:
"check if a type is a subclass of jax.Array"
try:
return issubclass(x, jax.Array)
except Exception as _e:
return False
def is_jax_array(x: Any) -> TypeGuard[jax.Array]:
"check if a value is an instance of jax.Array"
try:
return is_jax_array_subclass(x.__class__)
except Exception as _e:
return False
def test_is_jax_array_subclass():
assert is_jax_array_subclass(jax.Array)
assert is_jax_array_subclass(ArrayImpl)
assert not is_jax_array_subclass(tuple[int, ...])
assert not is_jax_array_subclass(jnp.ones(3))
def test_is_jax_array():
assert not is_jax_array(jax.Array)
assert not is_jax_array(ArrayImpl)
assert not is_jax_array(tuple[int, ...])
assert is_jax_array(jnp.ones(3))
test_is_jax_array_subclass()
test_is_jax_array()
Please:
Goal: Ensure "isinstance" works to check if a type (not a value) is a jax Array type. Problem: When deserializing data containing jax arrays, one must check if the types are jax Array types. However, "jaxlib.xla_extension.ArrayImpl is jax.Array" evaluates False (normal, but annoying) Opportunity: Could we make these typing protocols runtime checkable to faciliate serialization/deserialization logic?
TLDR Fix:
from typing import runtime_checkable
& use@runtime_checkable
decorator on Protocol subclassesReproduction
Result
System:
References / Links (not strictly needed)
[1] from a Python array API standard, but it's really long to paste here.
[2] Another resource for protocols is DLPack C API, seems clear on data type sizes of things
Here's the short notes on those, I'll add class attributes to the tree_plus eventually, but it can get noisy.
Workaround
Cheers, love me some jax!