patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

`jaxtyped` Annotation fails #187

Open dxm447 opened 4 months ago

dxm447 commented 4 months ago

I am building a code using cupy, and jaxtyping for type-hinting to calculate the Laplacian of Gaussian of a function. Here is my code:

import cupy as cp
import numpy as np
import cupyx.scipy.ndimage as csnd
import cucim.skimage.exposure as cexpose
import cupyx.scipy.signal as csisig
from typing import Tuple
from jaxtyping import Float, jaxtyped
from beartype import beartype as typechecker

@jaxtyped(typechecker=typechecker)
def laplacian_gaussian(
    image: Float[cp.ndarray, "dim1 dim2"],
    standard_deviation: int = 3,
    hist_stretch: bool = True,
    sampling: Float = 1,
) -> Tuple[
    Float[cp.ndarray, "dim3 dim4"],
    Float[cp.ndarray, "dim3 dim4"],
]:
    image: Float[cp.ndarray, "dim1 dim2"] = cp.asarray(image.astype(cp.float64))
    if sampling != 1:
        sampled_image: Float[cp.ndarray, "dim3 dim4"] = csnd.zoom(image, sampling)
    else:
        sampled_image: Float[cp.ndarray, "dim3 dim4"] = cp.copy(image)
    if hist_stretch:
        sampled_image: Float[cp.ndarray, "dim3 dim4"] = cexpose.equalize_hist(
            sampled_image
        )
    gauss_image: Float[cp.ndarray, "dim3 dim4"] = csnd.gaussian_filter(
        sampled_image, standard_deviation
    )
    positive_laplacian: Float[cp.ndarray, "3 3"] = cp.asarray(
        (
            (0.0, 1.0, 0.0),
            (1.0, -4.0, 1.0),
            (0.0, 1.0, 0.0),
        ),
        dtype=np.float64,
    )
    negative_laplacian: Float[cp.ndarray, "3 3"] = cp.asarray(
        (
            (0.0, -1.0, 0.0),
            (-1.0, 4.0, -1.0),
            (0.0, -1.0, 0.0),
        ),
        dtype=np.float64,
    )
    positive_filtered: Float[cp.ndarray, "dim3 dim4"] = csisig.convolve2d(
        gauss_image, positive_laplacian, mode="same", boundary="symm", fillvalue=0
    )
    negative_filtered: Float[cp.ndarray, "dim3 dim4"] = csisig.convolve2d(
        gauss_image, negative_laplacian, mode="same", boundary="symm", fillvalue=0
    )
    return (positive_filtered, negative_filtered)

Calling this raises the following error: AnnotationError: Do not use isinstance(x, jaxtyping.Float). If you want to check just the dtype of an array, then use jaxtyping.Float[jnp.ndarray, "..."].

The error is from:

File ~/anaconda3/envs/arm/lib/python3.10/site-packages/jaxtyping/_array_types.py:561, in _MetaAbstractDtype.instancecheck(cls, obj)

    560 def __instancecheck__(cls, obj: Any) -> NoReturn:
--> 561     raise AnnotationError(
    562         f"Do not use `isinstance(x, jaxtyping.{cls.__name__})`. If you want to "
    563         "check just the dtype of an array, then use "
    564         f'`jaxtyping.{cls.__name__}[jnp.ndarray, "..."]`.'
    565     )