patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

Pyright reports an error with named axis #38

Open Luceurre opened 2 years ago

Luceurre commented 2 years ago

Setup

Code Example

from torchtyping import TensorType

def example(foo: TensorType["batch"]):
    pass

Problem

Pyright reports the following error: "batch" is not defined

Related issue

The same error is reported by mypy when -1 is omitted: https://github.com/patrick-kidger/torchtyping/issues/35

patrick-kidger commented 2 years ago

This is expected. Static type checking is (unfortunately) fundamentally incompatible with annotating arrays.

You should either add the appropriate flag to disable pyright's checking here, or define batch = None elsewhere in the file so that pyright thinks this is a forward reference.

wookayin commented 2 years ago

Can we use something like TypeVar (instead of string literal for the type annotation) instead?

e.g.

Batch = typing.TypeVar

def example(foo: TensorType[Batch]):
    pass

or

Batch = torchtyping.AxisVar("Batch")   # a hypothetical API

def example(foo: TensorType[Batch]):
    pass

UPDATE: seems relevant to #37 (PEP-646)