patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.11k stars 56 forks source link

mypy type checking seems to break in strict mode -- a mypy bug? #52

Open bluenote10 opened 1 year ago

bluenote10 commented 1 year ago

Following up on https://github.com/patrick-kidger/torchtyping/issues/41 I'm trying the same things here. However I'm not really having a big success so far with mypy. Am I doing anything wrong?

import torch
from jaxtyping import Float

dim1 = "dim1"

# Expected to work but fails with:
# error: Returning Any from function declared to return "Tensor"
def simple_test_a(x: Float[torch.Tensor, "dim1"]) -> torch.Tensor:
    return x

# Expected to work but fails with:
# error: Returning Any from function declared to return "float"
def simple_test_b(x: Float[torch.Tensor, "dim1"]) -> float:
    return x.item()

# Expected to error, but passes type checking
def simple_test_c(x: Float[torch.Tensor, "dim1"]) -> None:
    x.asdfasdfasdf()

VSCode (pyright) seems to do a little better, but apparently doesn't like the import:

image

patrick-kidger commented 1 year ago

I can't replicate your issue I'm afraid. Running:

import torch
from jaxtyping import Float

def simple_test_a(x: Float[torch.Tensor, "dim1"]) -> torch.Tensor:
    reveal_type(x)
    return x

def simple_test_b(x: Float[torch.Tensor, "dim1"]) -> float:
    reveal_type(x)
    return x.item()

def simple_test_c(x: Float[torch.Tensor, "dim1"]) -> None:
    reveal_type(x)
    x.asdfasdfasdf()

prints:

tmp.py:5: note: Revealed type is "torch._tensor.Tensor"
tmp.py:9: note: Revealed type is "torch._tensor.Tensor"
tmp.py:13: note: Revealed type is "torch._tensor.Tensor"
tmp.py:14: error: "Tensor" has no attribute "asdfasdfasdf"  [attr-defined]

This is with versions:

torch: 1.13.1
jaxtyping: 0.2.9
mypy: 0.991

As for VSCode, this issue is due to a now-resolved bug in pyright: https://github.com/microsoft/pyright/issues/4287 . Try updating your pyright version.

bluenote10 commented 1 year ago

Interesting, it seems to be related with strict mode. I can replicate your output when I run mypy in non-strict mode:

test.py:6: note: Revealed type is "torch._tensor.Tensor"
test.py:11: note: Revealed type is "torch._tensor.Tensor"
test.py:16: note: Revealed type is "torch._tensor.Tensor"
test.py:17: error: "Tensor" has no attribute "asdfasdfasdf"  [attr-defined]

As soon as I add a mypy.ini containing

[mypy]
strict = True

the output becomes:

test.py:5: error: Name "dim1" is not defined  [name-defined]
test.py:6: note: Revealed type is "Any"
test.py:7: error: Returning Any from function declared to return "Tensor"  [no-any-return]
test.py:10: error: Name "dim1" is not defined  [name-defined]
test.py:11: note: Revealed type is "Any"
test.py:12: error: Returning Any from function declared to return "float"  [no-any-return]
test.py:15: error: Name "dim1" is not defined  [name-defined]
test.py:16: note: Revealed type is "Any"

Note that even the revealed types change.

Using exactly the same package versions.

Looks like a mypy bug?