patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.12k stars 56 forks source link

Inconsistent shape checking for lists of tensors #234

Open seanroelofs opened 2 months ago

seanroelofs commented 2 months ago

I am trying to type check the shape of Tensors inside a list. It seems like the shape wildcards are not enforced consistently though the list. I wrote an example to help describe the issue.

from beartype import beartype
from beartype.typing import List, Optional
from jaxtyping import Float, jaxtyped
import torch
from torch import Tensor
import unittest

@jaxtyped(typechecker=beartype)
def foo(x: Float[Tensor, "B C"], feat_list: List[Float[Tensor, "B C N"]], y: Optional[Float[Tensor, "N"]] = None):
    pass

class TestJaxtyping(unittest.TestCase):
    def test_variable_list(self):
        B = 8
        C = 32
        N_1 = 3
        N_2 = 4
        x = torch.zeros((B, C)) 
        feats_a = [torch.zeros((B, C, N_1)), torch.zeros((B, C, N_2))]
        feats_b = [torch.zeros((B, C, N_1)), torch.zeros((C, B, N_2))]
        y = torch.zeros(N_1)

        # I feel like this should error, but it doesn't
        foo(x, feats_a)

        try:
            foo(x, feats_b)
        except:
            print("Jaxtyping caught the B C switch here")

        # most concerning of all, this works sometimes and not others depending on if N matches to N_1 or N_2
        foo(x, feats_a, y)

if __name__ == "__main__":
    unittest.main()

Is there any way to enforce tensor shapes inside a list correctly?

patrick-kidger commented 2 months ago

I think this is a beartype thing: it only checks one element of a list.

CC @leycec

leycec commented 2 months ago

...heh. @beartype woes, huh? It's all true. @beartype guarantees constant-time O(1) complexity by only pseudo-randomly type-checking one item of each list. This is both a bad thing and a good thing. On the bright side, @beartype scales to arbitrarily large lists (and all other kinds of containers); @beartype is guaranteed to never Denial-of-Service (DoS) your workflow when a disturbingly large list (or other kind of container) inevitably gets passed in. On the dark side, non-deterministic type-checking kinda sucks. I get that and sympathize with your pain.

You are now thinking: "I hate @beartype." You're not wrong, @seanroelofs. But... fear not! An upcoming release of @beartype will provide the sort of linear-time O(n) type-checking you want and need. If your use case can't wait until then, no judgement, bro typeguard is a valid alternative to @beartype that might be a better fit here in the meantime: e.g.,

from typeguard import typechecked  # <-- this fills me with sadness
from beartype.typing import List, Optional
from jaxtyping import Float, jaxtyped
import torch
from torch import Tensor
import unittest

@jaxtyped(typechecker=typechecked)  # <-- sadness intensifies
def foo(x: Float[Tensor, "B C"], feat_list: List[Float[Tensor, "B C N"]], y: Optional[Float[Tensor, "N"]] = None):
    pass

Of course, typeguard comes with the opposite tradeoff. It type-checks everything and thus fails to scale to large problem domains. But... maybe that's not a problem here?

And thanks so much to @patrick-kidger for pinging me on. Hope you're having an amazing summer! It's been so long since we've GitHub chatted. How about that CrowdStrike fiasco, huh? Yikes. Oh – and this issue can (probably) be safely closed. As everyone has surmised, this is almost certainly @beartype's fault. jaxtyping is blameless in this and all things.