jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

Type checking for pytrees #3340

Open shoyer opened 4 years ago

shoyer commented 4 years ago

Glad to hear JAX will eventually have annotations! I was following this issue for updates. I actually started trying to add annotations myself before I realized how much work it was going to be.

One thing it would be nice to expose as soon as possible though are base types for tensors and pytrees:

Tensor = Union[np.ndarray, jnp.ndarray]  # probably needs more things in the union like tracers and DeviceArray?
PyTree = Union[Tensor,
               'PyTreeLike',
               Tuple['PyTree', ...],
               List['PyTree'],
               Dict[Hashable, 'PyTree'],
               None]

Originally posted by @NeilGirdhar in https://github.com/google/jax/issues/1555#issuecomment-639554300

NeilGirdhar commented 4 years ago

Thanks a lot for adding this 😄

gnecula commented 4 years ago

There are two goals for typing hints: to catch errors and to improve readability by documenting intended usage. I think that readability is the most important and easiest to achieve, e.g., just use PyTree everywhere a pytree is expected.

The title of this issue is about type checking, i.e., preventing type errors. I feel that this is much harder especially in the parts of the the code that was written to make heavy use of dynamic typing. The type checkers try hard, but this is a very murky area. We have recently filed a bug with pytype in presence of Union.

My thinking is that for PyTree we get most of the benefit from defining a type alias PyTree = Any, along with a comment explaining what pytrees are. Or course, for other cases that are closer to static typing we should give proper type definitions.

NeilGirdhar commented 4 years ago

Yes, you're right about the benefit of documentation. That's why I'm already using PyTree everywhere in my code.

Could you link me to your bug with pytype? I'm curious?

My thinking is that for PyTree we get most of the benefit from defining a type alias PyTree = Any, along with a comment explaining what pytrees are. Or course, for other cases that are closer to static typing we should give proper type definitions.

I agree for now, but pretty soon np.ndarray and as I understand it jnp.ndarray are going to be type-checked, which means that my PyTree will "wake up". At that point, we will get static typing? We just need an easy way to get PyTreeLike to accept all of the user-defined classes. Someone mentioned Protocols, but I've never used them.

shoyer commented 4 years ago

I can imagine two ways in which it might make sense to type check pytrees:

In theory, the first type could be checked with Generic and the second type could be checked (at least partially) with TypeVar.

billmark commented 3 years ago

Has there been any progress on this bug? What's the current best-known-practice? Is it still to define PyTree=Any?

hawkinsp commented 3 years ago

Yes, in the sense that I think we're fairly confident it cannot be done without a Python type checker that supports recursive types, which mypy and pytype (our current type checkers) do not. So, you should write Any.

XuehaiPan commented 2 years ago

Yes, in the sense that I think we're fairly confident it cannot be done without a Python type checker that supports recursive types, which mypy and pytype (our current type checkers) do not. So, you should write Any.

mypy now supports recursive types since v0.981, and will be enabled by default since v0.990. E.g.:

JSON = Union[Dict[str, 'JSON'], List['JSON'], str, int, float, bool, None]
carlosgmartin commented 1 year ago

Given the above, are there any plans to add a standard jax.PyTree type soon?

jakevdp commented 1 year ago

How would you do better than PyTree = Any, given the fact that arbitrary types can be registered as pytrees at runtime?

NeilGirdhar commented 1 year ago

given the fact that arbitrary types can be registered as pytrees at runtime?

FYI, if ABCMeta.register is ever supported by type checkers (e.g., by MyPy), then you could make PyTree a class that inherits from abc.ABC and register all of your pytree types, which would be visible to type checkers.

jakevdp commented 1 year ago

That sounds like it could be a good solution someday

shoyer commented 1 year ago

At this point, I would suggest defining project specific PyTree types, e.g.,

PyTree = dict[str, 'PyTree'] | list['PyTree'] | jax.Array

Project specific types could (in principle) be handled with generics, e.g.,

from typing import Generic, TypeVar
import dataclasses

T = TypeVar('T')

@dataclasses.dataclass
class MyStruct(Generic[T]):
    x: T
    y: T

PyTree = dict[str, 'PyTree'] | list['PyTree'] | MyStruct['PyTree'] | jax.Array

Potentially there's some room for libraries like Flax to define struct types that are compatible with this sort of type checking, but otherwise I don't think there's much be to done in JAX.

cgarciae commented 1 year ago

I think the only real useful definition of Pytree is:

Pytree = Any

because tree_map accepts both pytrees and leafs (which are all non registered types) so the following is valid:

mul2 = lambda x: x * 2

jax.tree_map(mul2, 3)                        # 6
jax.tree_map(mul2, "hi ")                    # "hi hi "
jax.tree_map(mul2, TypeThatImplementsMul())  # ???
# and so on...
jakevdp commented 1 year ago

Note that jaxtyping has type annotations for pytrees that get around the above issues by only checking at runtime: that could be a good solution depending on your use-case.

But for the reasons above, it appears that Python's static type-checking spec doesn't have much to say about PyTrees as currently implemented.

shoyer commented 1 year ago

I think the only real useful definition of Pytree is:

Pytree = Any

because tree_map accepts both pytrees and leafs (which are all non registered types) so the following is valid:

Right, but individual projects can probably guarantee that they are only going to use a restricted set of types in PyTree leaves. For example, every leaf that is a neural net parameter needs to be a (float) array.

cgarciae commented 1 year ago

In Flax we use something like this but its not informative of the leaf types at all:

Collection = Mapping[str, Any]
FrozenVariableDict = FrozenDict[str, Collection]