patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

vscode/pylance/pyright don't consider a Tensor to be compatible with TensorType #26

Closed andyljones closed 3 years ago

andyljones commented 3 years ago

Using torchtyping in vscode, I've found that passing a Tensor to a TorchType generates an error in the type checker:

image

Tagging the TensorType import with type: ignore as recommended in the FAQ for mypy compatibility doesn't help. Is there any other way to suppress these errors short of tagging every use of a tensor with a tensortype'd sig with type: ignore?

Reproduction

vscode's Pylance language server backs onto the pyright project, and so we can get an easier to examine reproduction by using pyright directly.

Here's a quick script to set up an empty conda env with just torch and torchtyping

mkdir tmp
cd tmp
conda create -p ./.env 
conda activate ./.env
pip install torch==1.9.0 torchtyping==0.1.3

and one more command to install pyright

sudo npm install -g pyright

Then create two files, pyrightconfig.json with contents

{
    "useLibraryCodeForTypes": true,
    "exclude": [".env"]
}

and test.py with contents

import torch
from torchtyping import TensorType

def f(a: TensorType):
    pass

f(torch.zeros())

With that all done, running pyright test.py will give the error:

Loading configuration file at /Users/andy/code/tmp/pyrightconfig.json
Assuming Python version 3.9
Assuming Python platform Darwin
stubPath /Users/andy/code/tmp/typings is not a valid directory.
Searching for source files
Found 1 source file
/Users/andy/code/tmp/test.py
  /Users/andy/code/tmp/test.py:7:3 - error: Argument of type "Tensor" cannot be assigned to parameter "a" of type "TensorType" in function "f"
    "Tensor" is incompatible with "TensorType" (reportGeneralTypeIssues)
1 error, 0 warnings, 0 infos 
Completed in 0.715sec
patrick-kidger commented 3 years ago

Thanks for the thorough repro, that's really helpful. (Especially as I use neither VSCode nor pyright myself.)

The good news is that I think this can be fixed, and moreover it can be fixed in either of two different ways.

Does that help?

andyljones commented 3 years ago

Yes, yes it does! TensorType[...] fixed it in the repro, and together with 'upgrading to 0.1.4' it's also fixed in prod. Thanks very much!

lainisourgod commented 2 years ago

@patrick-kidger not working for me with nested types

  1. define nested structure with TensorType: works OK for variable a as it's gets concrete type Tensor from torch.zeros() part, but doesn't work for type ModelInput: TensorType gets replaced by Unknown image image

torch 1.9.0 torchtyping 0.1.4 pylance 2021.10.1 python 3.7.10

lainisourgod commented 2 years ago

OK pyright doesn't understand TensorTypes at all...

image
patrick-kidger commented 2 years ago

Hmm. So I see to recall that things were fixed in the original issue precisely because pyright sees anything of the form CustomType[...] and just bails -- it doesn't even try to understand what's going on with a [...]'d custom type. c.f. also example 18 of https://github.com/microsoft/pyright/issues/1537.

Ultimately I think this is an issue with pyright: ideally it should inspect the return value from classes with a custom __class_getitem__. But it doesn't; static type checkers are full of limitations like this one. (It's the reason a runtime type checker is recommended in the torchtyping README instead.)

Some possible avenues forward:

1.

One thing you could try (although I'm not hopeful) is to adjust the return annotation on this line:

https://github.com/patrick-kidger/torchtyping/blob/2292ad604b9d40234dedb73f87cb6f1b4e84625d/torchtyping/tensor_type.py#L81

to torch.Tensor.

If that works I'd be happy to accept a PR changing it.

2.

More broadly if you're able to identify a work-around that makes pyright happy then I'd be happy to accept a PR on that too.

3.

One final possibility that may-or-may-not work (up to pyright, really) is to write your own function wrapper to the effect of

def tensor_type(*args) -> torch.Tensor:
    return TensorType[args]

Input = tensor_type("seq_size", "emb_dim")

(or possibly some variation of the above) The use of the additional torch.Tensor annotation may-or-may-not convince pyright.

lainisourgod commented 2 years ago

@patrick-kidger thanks for your help but unfortunately none of this worked for me :(

patrick-kidger commented 2 years ago

Ach, that's a shame. Anyway, I'm not sure what more can be done on the torchtyping end, although I'm open to suggestions. Fundamentally I think this is something on pyright's end, in supporting __class_getitem__.