patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Question: manual assertion #192

Closed JuanFMontesinos closed 4 months ago

JuanFMontesinos commented 4 months ago

Hi, as far as I understood, it should be possible to assert manually any array/tensor.

from typing import Union, Dict, Sequence

from torch import Tensor
import numpy as np
from jaxtyping import Float, jaxtyped

Array = Union[Tensor, np.ndarray]

DiscreteTrajectoryType = Float[Array, "T 3"]

with jaxtyped("context"):
    w = np.random.rand(100,3).astype(np.float32)
    flag = isinstance(w, DiscreteTrajectoryType)
    print(flag)

However, this throws false. What am I missing?

patrick-kidger commented 4 months ago

Ah, this is a bug in Python itself.

To explain, jaxtyping objects like Float are transparent to Union. That means DiscreteTrajectoryType = Union[Float[Tensor, "T 3"], Float[np.ndarray, "T 3"]]. And unfortunately, isinstance(..., Union[...]) is bugged.

See also #73.

The fix is to do isinstance(x, typing.get_args(union_type)) when working with a Union.

JuanFMontesinos commented 4 months ago

I have no much clue about how python implements type/instance/subclass checking. After a bit of experimentation I've noted that DiscreteTrajectoryType = Float[np.ndarray, "T 3"] | Float[torch.Tensor, "T 3"] works. So seems union is working but fails if Union occurs within Float.

patrick-kidger commented 4 months ago

This is because Union and | are actually implemented as two different union types in Python. :/

JuanFMontesinos commented 4 months ago

Alright 🤣 Not gonna try to dig into the rabbit hole today. Thanks for the info!