Closed andyljones closed 3 years ago
Thanks for the thorough repro, that's really helpful. (Especially as I use neither VSCode nor pyright myself.)
The good news is that I think this can be fixed, and moreover it can be fixed in either of two different ways.
TensorType
to TensorType[...]
, with a literal ...
. This is equivalent as far as torchtyping
is concerned, and this seems to make pyright happy.TensorType
to torch.Tensor
, just for those cases where you don't need the explicit TensorType[...stuff here...]
. The explicit TensorType[...stuff here...]
should already be fine with pyright.Does that help?
Yes, yes it does! TensorType[...]
fixed it in the repro, and together with 'upgrading to 0.1.4' it's also fixed in prod. Thanks very much!
@patrick-kidger not working for me with nested types
TensorType
: works OK for variable a
as it's gets concrete type Tensor
from torch.zeros()
part, but doesn't work for type ModelInput
: TensorType
gets replaced by Unknown
torch 1.9.0 torchtyping 0.1.4 pylance 2021.10.1 python 3.7.10
OK pyright doesn't understand TensorTypes at all...
Hmm. So I see to recall that things were fixed in the original issue precisely because pyright sees anything of the form CustomType[...]
and just bails -- it doesn't even try to understand what's going on with a [...]
'd custom type. c.f. also example 18 of https://github.com/microsoft/pyright/issues/1537.
Ultimately I think this is an issue with pyright: ideally it should inspect the return value from classes with a custom __class_getitem__
. But it doesn't; static type checkers are full of limitations like this one. (It's the reason a runtime type checker is recommended in the torchtyping README instead.)
Some possible avenues forward:
1.
One thing you could try (although I'm not hopeful) is to adjust the return annotation on this line:
to torch.Tensor
.
If that works I'd be happy to accept a PR changing it.
2.
More broadly if you're able to identify a work-around that makes pyright happy then I'd be happy to accept a PR on that too.
3.
One final possibility that may-or-may-not work (up to pyright, really) is to write your own function wrapper to the effect of
def tensor_type(*args) -> torch.Tensor:
return TensorType[args]
Input = tensor_type("seq_size", "emb_dim")
(or possibly some variation of the above)
The use of the additional torch.Tensor
annotation may-or-may-not convince pyright.
@patrick-kidger thanks for your help but unfortunately none of this worked for me :(
Ach, that's a shame.
Anyway, I'm not sure what more can be done on the torchtyping end, although I'm open to suggestions. Fundamentally I think this is something on pyright's end, in supporting __class_getitem__
.
Using torchtyping in vscode, I've found that passing a
Tensor
to aTorchType
generates an error in the type checker:Tagging the
TensorType
import withtype: ignore
as recommended in the FAQ for mypy compatibility doesn't help. Is there any other way to suppress these errors short of tagging every use of a tensor with a tensortype'd sig withtype: ignore
?Reproduction
vscode's Pylance language server backs onto the pyright project, and so we can get an easier to examine reproduction by using pyright directly.
Here's a quick script to set up an empty conda env with just torch and torchtyping
and one more command to install pyright
Then create two files,
pyrightconfig.json
with contentsand
test.py
with contentsWith that all done, running
pyright test.py
will give the error: