metaopt / optree

OpTree: Optimized PyTree Utilities
https://optree.readthedocs.io
Apache License 2.0
136 stars 6 forks source link

[Feature Request] Better Type Annotation Support for PyTrees #6

Open XuehaiPan opened 1 year ago

XuehaiPan commented 1 year ago

Motivation

OpTree uses a generic custom class that returns a Union alias in the __class_getitem__ method.

For example:

from optree.typing import PyTree

TreeOfInts = PyTree[int]

will expend to:

TreeOfInts = typing.Union[
    int, 
    typing.Tuple[ForwardRef('PyTree[int]'), ...], 
    typing.List[ForwardRef('PyTree[int]')], 
    typing.Dict[typing.Any, ForwardRef('PyTree[int]')], 
    typing.Deque[ForwardRef('PyTree[int]')],
    optree.typing.CustomTreeNode[ForwardRef('PyTree[int]')]
]

at runtime.

The typing linter mypy is a static analyzer, which does not actually run the code.

In addition, Python does not support generic recursive type annotations yet. For function type annotations, the generic version of the typing substitution PyTree[T] will be substituted to:

typing.Union[~T,
    typing.Tuple[ForwardRef('PyTree[T]'), ...],
    typing.List[ForwardRef('PyTree[T]')],
    typing.Dict[typing.Any, ForwardRef('PyTree[T]')],
    typing.Deque[ForwardRef('PyTree[T]')],
    optree.typing.CustomTreeNode[ForwardRef('PyTree[T]')]
]

while the ForwardRef('PyTree[T]') will never be evaluated. This will cause mypy to fail to infer the arg/return type. It either raises arg-type or assignment error. Using typing.cast alleviate this issue but cast has a non-zero overhead at runtime.

Function signature:

def tree_leaves(
    tree: PyTree[T],
    is_leaf: Optional[Callable[[T], bool]] = None,
    *,
    none_is_leaf: bool = False,
) -> List[T]: ...
from typing import List, cast

import optree
from optree.typing import PyTree

list_of_ints: List[int]

tree_of_ints = cast(PyTree[int], (0, {'a': 1, 'b': (2, [3, 4])}))
list_of_ints = optree.tree_leaves(tree_of_ints)
$ mypy test.py
Success: no issues found in 1 source file
# test1.py

from typing import List

import optree

list_of_ints: List[int]

tree_of_ints = (0, {'a': 1, 'b': (2, [3, 4])})
list_of_ints = optree.tree_leaves(tree_of_ints)
$ mypy test1.py
test1.py:10: error: Argument 1 to "tree_leaves" has incompatible type "Tuple[int, Dict[str, object]]"; expected "PyTree[int]"  [arg-type]
    list_of_ints = optree.tree_leaves(tree_of_ints)
                                      ^~~~~~~~~~~~
Found 1 error in 1 file (checked 1 source file)
# test2.py

from typing import List

import optree
from optree.typing import PyTree

list_of_ints: List[int]

tree_of_ints: PyTree[int] = (0, {'a': 1, 'b': (2, [3, 4])})
list_of_ints = optree.tree_leaves(tree_of_ints)
$ mypy test2.py
test2.py:10: error: Incompatible types in assignment (expression has type "Tuple[int, Dict[str, object]]", variable has type "PyTree[int]")  [assignment]
    tree_of_ints: PyTree[int] = (0, {'a': 1, 'b': (2, [3, 4])})
                                ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Found 1 error in 1 file (checked 1 source file)

Related issues:

Checklist

rhaps0dy commented 1 year ago

Here's another possible improvement to the typing. Consider tree_map:

tree_map(Callable[..., U],  tree: PyTree[T], ...) -> PyTree[U]

Type-checking:

a: int = 1
a = tree_map(lambda x: x+1, 1)

will fail, even though a evaluates to 2.

Fixing this is achievable (although quite annoying) with overloads. E.g.:

@overload
tree_map(Callable[..., U],  tree: T, ...) -> U:
    ...
@overload
tree_map(Callable[..., U],  tree: Tuple[T, ...], ...) -> Tuple[U, ...]:
    ...

@overload
tree_map(Callable[..., U],  tree: Dict[Any, T], ...) -> Dict[Any, U]:
    ...

and so on. Probably going one level deep is enough, for extra thoroughness two levels.

LarsKue commented 1 month ago

To expand on this, type hinting return types is usually too loose; I would appreciate an improvement here as well:

def f(x: PyTree[T]) -> PyTree[T]:
    return x

Here, we know that the return value of f(x) necessarily has the same type and layout as x itself. However, while the argument PyTree[T] is permissive, the return value PyTree[T] is unspecific. This means we necessarily lose type information through the invocation of f. Subsequent code may need to rely on the concrete type of f(x), which is no longer possible:

x: list[int] = [1, 2, 3]
y = f(x)  # static type checker warns here due to OP's issue
s = sum(y)  # static type checker warns here due to my issue, but we know that y is a list[int]

Normally, one would avoid this by using TypeVar:

T = TypeVar("T")

def f(x: T) -> T:
    return x

However, no such concept exists for the layout and type of PyTree. A potential fix could look like this:

T = TypeVar("T")
S = SpecVar("S")  # new

def f(x: PyTree[T, S]) -> PyTree[T, S]:
    return x

This way, we can annotate that the returned value follows the same spec as the input argument. Unfortunately, we can't use keyword arguments here to avoid locking us to a single positional argument:

PyTree[T, spec=S]  # SyntaxError

A possible work-around as shown in PEP-472 would be to use slices or dictionaries:

PyTree[T, "spec": S]  # passes slice('spec', S, None) to __class_getitem__
PyTree[T, {"spec": S}]  # passes the dict to __class_getitem__