patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.41k stars 34 forks source link

mypy not compatible with any named axes? #35

Open zplizzi opened 2 years ago

zplizzi commented 2 years ago

When I specify a type like TensorType["batch_size", "num_channels", "x", "y"], I get a mypy error like error: Name "batch_size" is not defined for each of the named axes. Is this expected? Am I doing something wrong? This is with the most recent mypy, 0.950.

patrick-kidger commented 2 years ago

This is expected - mypy thinks that the string is being used as part of a forward reference, rather than as a literal string. Python's typing system can be a bit of a mess in edge cases like this.

One solution is to actually define some objects with the name of these strings. Another is to use the appropriate annotations to have mypy ignore the error.

zplizzi commented 2 years ago

Got it - no worries, I understand the constraints here. It might be helpful to update the section of the documentation discussing mypy to explain this more clearly, though. When I read that mypy was "mostly" supported I would have expected that this core feature would work without hacks.

stvhuang commented 2 years ago

Latest version of Pyright (1.1.262) starts to throw the similar errors. ("batch_size" is not defined)

bluenote10 commented 1 year ago

Another is to use the appropriate annotations to have mypy ignore the error.

How exactly is this supposed to work? Even with the following

from torchtyping import TensorType  # type: ignore

def batch_outer_product(
    x: TensorType[
        "batch",  # type: ignore
        "x_channels",  # type: ignore
    ],
    y: TensorType[
        "batch",  # type: ignore
        "y_channels",  # type: ignore
    ],
) -> TensorType[
    "batch",  # type: ignore
    "x_channels",  # type: ignore
    "y_channels",  # type: ignore
]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

I'm getting:

test.py:6: error: Name "batch" is not defined  [name-defined]
test.py:6: error: Name "x_channels" is not defined  [name-defined]
test.py:7: error: Unused "type: ignore" comment
test.py:8: error: Unused "type: ignore" comment
test.py:10: error: Name "batch" is not defined  [name-defined]
test.py:10: error: Name "y_channels" is not defined  [name-defined]
test.py:11: error: Unused "type: ignore" comment
test.py:12: error: Unused "type: ignore" comment
test.py:14: error: Name "batch" is not defined  [name-defined]
test.py:14: error: Name "x_channels" is not defined  [name-defined]
test.py:14: error: Name "y_channels" is not defined  [name-defined]
test.py:15: error: Unused "type: ignore" comment
test.py:16: error: Unused "type: ignore" comment
test.py:17: error: Unused "type: ignore" comment

I assume this line in the documentation is no longer valid, right?

Additionally mypy has a bug which causes it crash on any file using the str: int or str: ... notation, as in TensorType["batch": 10].

The underlying issue (https://github.com/python/mypy/issues/10266) has been closed.