patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.12k stars 56 forks source link

Support runtime type-checking of generic functions #130

Open davnn opened 11 months ago

davnn commented 11 months ago

Hi,

do you think it is conceivable to implement type checks for generics, e.g. generic array types or generic data types, see below.

from jaxtyping import jaxtyped
from beartype import beartype
from jax import Array as JaxArray
from torch import Tensor as TorchArray
from numpy import ndarray as NumpyArray

GenericArray = TypeVar("GenericArray", JaxArray, NumpyArray, TorchArray)
GenericFloat = TypeVar("GenericFloat", Float16, Float32, Float64)

@jaxtyped
@beartype
def f(a: Shaped[GenericArray, "n"]) -> Shaped[GenericArray, "n"]:
    return a

@jaxtyped
@beartype
def f(a: GenericFloat[NumpyArray, "n"]) -> GenericFloat[NumpyArray, "n"]:
    return a

I would be happy to contribute, but I am unsure if there is even a possiblity of success.

patrick-kidger commented 11 months ago

Yup, I think this should be possible! I'd be happy to take a PR on this.