Open XuehaiPan opened 1 year ago
Here's another possible improvement to the typing. Consider tree_map
:
tree_map(Callable[..., U], tree: PyTree[T], ...) -> PyTree[U]
Type-checking:
a: int = 1
a = tree_map(lambda x: x+1, 1)
will fail, even though a
evaluates to 2
.
Fixing this is achievable (although quite annoying) with overloads. E.g.:
@overload
tree_map(Callable[..., U], tree: T, ...) -> U:
...
@overload
tree_map(Callable[..., U], tree: Tuple[T, ...], ...) -> Tuple[U, ...]:
...
@overload
tree_map(Callable[..., U], tree: Dict[Any, T], ...) -> Dict[Any, U]:
...
and so on. Probably going one level deep is enough, for extra thoroughness two levels.
To expand on this, type hinting return types is usually too loose; I would appreciate an improvement here as well:
def f(x: PyTree[T]) -> PyTree[T]:
return x
Here, we know that the return value of f(x)
necessarily has the same type and layout as x
itself. However, while the argument PyTree[T]
is permissive, the return value PyTree[T]
is unspecific. This means we necessarily lose type information through the invocation of f
. Subsequent code may need to rely on the concrete type of f(x)
, which is no longer possible:
x: list[int] = [1, 2, 3]
y = f(x) # static type checker warns here due to OP's issue
s = sum(y) # static type checker warns here due to my issue, but we know that y is a list[int]
Normally, one would avoid this by using TypeVar
:
T = TypeVar("T")
def f(x: T) -> T:
return x
However, no such concept exists for the layout and type of PyTree
. A potential fix could look like this:
T = TypeVar("T")
S = SpecVar("S") # new
def f(x: PyTree[T, S]) -> PyTree[T, S]:
return x
This way, we can annotate that the returned value follows the same spec as the input argument. Unfortunately, we can't use keyword arguments here to avoid locking us to a single positional argument:
PyTree[T, spec=S] # SyntaxError
A possible work-around as shown in PEP-472 would be to use slices or dictionaries:
PyTree[T, "spec": S] # passes slice('spec', S, None) to __class_getitem__
PyTree[T, {"spec": S}] # passes the dict to __class_getitem__
Motivation
OpTree uses a generic custom class that returns a
Union
alias in the__class_getitem__
method.For example:
will expend to:
at runtime.
The typing linter
mypy
is a static analyzer, which does not actually run the code.In addition, Python does not support generic recursive type annotations yet. For function type annotations, the generic version of the typing substitution
PyTree[T]
will be substituted to:while the
ForwardRef('PyTree[T]')
will never be evaluated. This will causemypy
to fail to infer the arg/return type. It either raisesarg-type
orassignment
error. Usingtyping.cast
alleviate this issue butcast
has a non-zero overhead at runtime.Function signature:
cast
:mypy
infersT = int
and requires the input is an exactPyTree[int]
object rather than aUnion[...]
type.mypy
refuse to add type assignments forPyTree[int]
.Related issues:
typing.ForwardRef
to support generic recursive typesChecklist