Open shoyer opened 4 years ago
Thanks a lot for adding this 😄
There are two goals for typing hints: to catch errors and to improve readability by documenting intended usage. I think that readability is the most important and easiest to achieve, e.g., just use PyTree
everywhere a pytree is expected.
The title of this issue is about type checking, i.e., preventing type errors. I feel that this is much harder especially in the parts of the the code that was written to make heavy use of dynamic typing. The type checkers try hard, but this is a very murky area. We have recently filed a bug with pytype in presence of Union
.
My thinking is that for PyTree we get most of the benefit from defining a type alias PyTree = Any
, along with a comment explaining what pytrees are. Or course, for other cases that are closer to static typing we should give proper type definitions.
Yes, you're right about the benefit of documentation. That's why I'm already using PyTree
everywhere in my code.
Could you link me to your bug with pytype? I'm curious?
My thinking is that for PyTree we get most of the benefit from defining a type alias PyTree = Any, along with a comment explaining what pytrees are. Or course, for other cases that are closer to static typing we should give proper type definitions.
I agree for now, but pretty soon np.ndarray
and as I understand it jnp.ndarray
are going to be type-checked, which means that my PyTree
will "wake up". At that point, we will get static typing? We just need an easy way to get PyTreeLike
to accept all of the user-defined classes. Someone mentioned Protocols, but I've never used them.
I can imagine two ways in which it might make sense to type check pytrees:
In theory, the first type could be checked with Generic
and the second type could be checked (at least partially) with TypeVar
.
Has there been any progress on this bug? What's the current best-known-practice? Is it still to define PyTree=Any?
Yes, in the sense that I think we're fairly confident it cannot be done without a Python type checker that supports recursive types, which mypy and pytype (our current type checkers) do not. So, you should write Any
.
Yes, in the sense that I think we're fairly confident it cannot be done without a Python type checker that supports recursive types, which
mypy
andpytype
(our current type checkers) do not. So, you should writeAny
.
mypy
now supports recursive types since v0.981, and will be enabled by default since v0.990. E.g.:
JSON = Union[Dict[str, 'JSON'], List['JSON'], str, int, float, bool, None]
Given the above, are there any plans to add a standard jax.PyTree
type soon?
How would you do better than PyTree = Any
, given the fact that arbitrary types can be registered as pytrees at runtime?
given the fact that arbitrary types can be registered as pytrees at runtime?
FYI, if ABCMeta.register
is ever supported by type checkers (e.g., by MyPy), then you could make PyTree
a class that inherits from abc.ABC
and register all of your pytree types, which would be visible to type checkers.
That sounds like it could be a good solution someday
At this point, I would suggest defining project specific PyTree types, e.g.,
PyTree = dict[str, 'PyTree'] | list['PyTree'] | jax.Array
Project specific types could (in principle) be handled with generics, e.g.,
from typing import Generic, TypeVar
import dataclasses
T = TypeVar('T')
@dataclasses.dataclass
class MyStruct(Generic[T]):
x: T
y: T
PyTree = dict[str, 'PyTree'] | list['PyTree'] | MyStruct['PyTree'] | jax.Array
Potentially there's some room for libraries like Flax to define struct types that are compatible with this sort of type checking, but otherwise I don't think there's much be to done in JAX.
I think the only real useful definition of Pytree is:
Pytree = Any
because tree_map
accepts both pytrees and leafs (which are all non registered types) so the following is valid:
mul2 = lambda x: x * 2
jax.tree_map(mul2, 3) # 6
jax.tree_map(mul2, "hi ") # "hi hi "
jax.tree_map(mul2, TypeThatImplementsMul()) # ???
# and so on...
Note that jaxtyping
has type annotations for pytrees that get around the above issues by only checking at runtime: that could be a good solution depending on your use-case.
But for the reasons above, it appears that Python's static type-checking spec doesn't have much to say about PyTrees as currently implemented.
I think the only real useful definition of Pytree is:
Pytree = Any
because
tree_map
accepts both pytrees and leafs (which are all non registered types) so the following is valid:
Right, but individual projects can probably guarantee that they are only going to use a restricted set of types in PyTree leaves. For example, every leaf that is a neural net parameter needs to be a (float) array.
In Flax we use something like this but its not informative of the leaf types at all:
Collection = Mapping[str, Any]
FrozenVariableDict = FrozenDict[str, Collection]
Glad to hear JAX will eventually have annotations! I was following this issue for updates. I actually started trying to add annotations myself before I realized how much work it was going to be.
One thing it would be nice to expose as soon as possible though are base types for tensors and pytrees:
Originally posted by @NeilGirdhar in https://github.com/google/jax/issues/1555#issuecomment-639554300