patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.41k stars 34 forks source link

pycharm shows an incorrect dtype when assigning torch.randn to a variable #29

Open jamesdsmith99 opened 3 years ago

jamesdsmith99 commented 3 years ago

Hi,

I have recently started using this library, so i might be using it incorrectly, but linting seems to fail in pycharm when assigning the result of torch.randn to TensorType with a float dtype.

Here is an example:

Matrix = TensorType['h', 'w', float]
x: Matrix = torch.randn(5, 3)

The second line gets underlined with the following error:

Expected type 'TensorType[Any, Any, float]', got 'Tensor' instead

If i modify the second line to:

x: Matrix = torch.randn(5, 3).float()

The error goes away, but I would rather not do that as one of the plus sides of this library is to remove extra typing related code from my main logic. Having to add an implicit .float defeats the purpose of this library IMO.

From reading the docs this should work, torch.randn returns a tensor the the deafult dtype, and TensorTypes that have float in them should be of the default type.

patrick-kidger commented 3 years ago

Without having things set up in PyCharm myself it'll be a fair bit of work to diagnose this.

If you remove the float specifier and have only TensorType['h', 'w'] does that still produce an error? I'm just trying to gather some data on what raises an error and what doesn't. More broadly if you can track down what's causing the issue then I'd be happy to accept a PR.

spietras commented 2 years ago

If you remove the float specifier and have only TensorType['h', 'w'] does that still produce an error?

I checked it and yes, there is still an error. Doing torch.randn(5, 3).float() works only because float() is untyped so PyCharm can't assume anything about the return type and doesn't emit any warnings.

Seems that torchtyping doesn't work with PyCharm's type checker at all, because no matter what I do there is always a warning when assigning Tensor to anything with TensorType type hint.

And I guess it's not surprising because TensorType is a subclass of Tensor so it complains when we try to assign an instance of the parent class to something expecting a subclass.

fzyzcjy commented 1 year ago

Hi, is there any updates?

patrick-kidger commented 1 year ago

Yes! I'd recommend trying jaxtyping. Despite the name, it actually works equally well for PyTorch.

In particular, it's designed to play much better with static type checkers.

fzyzcjy commented 1 year ago

@patrick-kidger Interesting, thanks for the quick reply! (Never thought "jax" typing would work for "pytorch" before ;) )