patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.41k stars 34 forks source link

NameError encountered in tutorial #47

Closed ShashankSule closed 1 year ago

ShashankSule commented 1 year ago

I'm trying out the example in the readme; in particular, I am running

from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

@typechecked
def func(x: TensorType["batch"],
         y: TensorType["batch"]) -> TensorType["batch"]:
    return x + y

func(rand(3), rand(3))  # works
func(rand(3), rand(1))

However, after executing func(rand(3), rand(3)) (which is supposed to work), I get

Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/shashanksule/miniforge3/envs/pr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/Users/shashanksule/miniforge3/envs/pr/lib/python3.9/site-packages/torch_pesq/loss.py", line 320, in forward d_symm, d_asymm = self.raw(ref, deg) File "/Users/shashanksule/miniforge3/envs/pr/lib/python3.9/site-packages/torch_pesq/loss.py", line 174, in raw ) -> Tuple[TensorType["batch", "sample"], TensorType["batch", "sample"]]: NameError: name 'batch' is not defined

I get the same error even if I define named tensors with names = ("batch",) and enter them into func. What's going wrong here?

patrick-kidger commented 1 year ago

Are you using typeguard v2.*? (Later versions of typeguard have broken things in a variety of odd ways.)

In any case, as mentioned at the top of the README, I now strongly recommend using jaxtyping instead of this package.

ShashankSule commented 1 year ago

Thank you! Works now after the fix, at typeguard==2.13.2. For reference, the reason I am using torchtyping is that this package for computing perceptual metrics on audio uses torchtyping, so perhaps that needs to be updated to jaxtyping.

qmpzzpmq commented 1 year ago

hi, @ShashankSule I am facing same issue with you and also same with using torch_pesq as well. I tried your method of typeguard==2.13.2. Is three any other method ?