jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.06k stars 2.75k forks source link

Add documentation on how to use PyType with Jax (and also common add-on libraries such as Flax) #8224

Open billmark opened 2 years ago

billmark commented 2 years ago

There are a bunch of tricks that one needs to know to use PyType with JAX (esp. also in combination with Flax). For example, a PyTree needs to be treated as "Any".

Since it's very common to want to use PyType with JAX, it would be useful to have a section of the Jax documentation summarizing these tricks and best practices. I'm not sure what the best way is to handle the Jax/Flax interactions but it's important for someone to figure out how to document those best practices too.

jakevdp commented 2 years ago

Thanks for raising this: I think at this point the best practices generally involve using Any for anything non-trivial (particularly pytrees and arrays), owing to some general missing features in Python's type-checking ecosystem. See for example #943 & #3340.

Additionally, type-checking of JAX code is hampered by the fact that mypy/pytype has poor support in general for decorators (see e.g. https://github.com/python/mypy/issues/1927), which JAX code tends to use extensively.

Beyond the recommendation to use Any for arrays and pytrees, and the caveat that type-checking is fundamentally broken as soon as you use a function transformation, do you have other thoughts about what would belong in a best practices guide?

mattjj commented 2 years ago

@superbobry any thoughts or references to point to? (@billmark is a Googler.)

superbobry commented 2 years ago

I think we can have a proper type for arrays in JAX, and I have a draft commit doing that internally.

Trees, however, are harder, because their type is fundamentally recursive and generic, and pytype does not currently support that. I think mypy does have some support for recursive generics, but I'm sure it too has its limits.

billmark commented 2 years ago

There are a bunch of confusing comments in the Flax documentation and code regarding what to do for type annotations on PyTrees. For example: https://flax.readthedocs.io/en/latest/flax.struct.html says "Note: Inherit from PyTreeNode instead to avoid type checking issues when using PyType". These comments imply that there is some best practice for type annotations, but I can't find any coherent explanation of what that best practice is. Maybe it really is to use "Any" everywhere because proper type checking is impossible, but if so that needs to be stated clearly instead of having confusing half-explanations that imply some other solution scattered elsewhere in the docs.

This seems to be an area of general confusion, as this discussion is already uncovering. For example, there are bugs like this one: https://github.com/google/flax/issues/620

patrick-kidger commented 2 years ago

I don't know if it's been considered, but one other option are run-time type-checkers. Personally I never use static type-checkers as I find jumping through their hoops more pain than it's worth. Instead my usual pattern is to enable a run-time type-checker just during tests. (And not otherwise, to avoid any performance penalty in user code.) Pro: No hoop-jumping. Pro: Can handles all kinds of complicated cases like recursive types, custom value-parameterised types, custom instance checks etc. Con: Only tests the code pathways executes during your tests.

The main two options I know about are typeguard and beartype. c.f. also torchtyping for PyTorch tensor annotations as an example of what you can do with them.

billmark commented 2 years ago

@jakevdp. I realize I didn't directly reply to your question:

Beyond the recommendation to use Any for arrays and pytrees, and the caveat that type-checking is fundamentally broken as soon as you use a function transformation, do you have other thoughts about what would belong in a best practices guide?

If those are the best practices, then that's mostly sufficient. Minor additions would include: (1) please remove or amend the other confusing comments about type checking in the Flax docs. (2) Explain why anything better than "Any" isn't possible with the current type checkers. (3) Possibly discuss the use of explicit type annotations every time you perform a functional transformation as a workaround for the fact that functional transformations break type checking. (i.e. my_var: actual_type = jax.jit(blah, blah)).

However, it seems from other comments (and also from the Flax docs) that there's not consensus even by Jax/Flax maintainers on what the best practices should be. I am not prepared to weigh in on that discussion.

leycec commented 2 years ago

@beartype bro maintainer here. Thanks so much for the gracious namedrop nearly a year ago, @billmark. Has it really been that long? :face_exhaling:

@beartype was originally gestated out of a multiphysics biology simulator, where runtime type-checking tamed the million-line code beast that nothing else could. We're still as devoted to big data science now as we were back then – and JAX is directly in that wheelhouse.

@beartype only currently provides explicit support for NumPy type hints like numpy.typing.NDArray[...]. We'd love to extend support to JAX types, however. Does JAX provide an equivalent API for expressing JAX constraints as type hints? Clearly, typing.Any is a poor substitute for properly constrained types. Static type checkers like PyType and mypy may fail to grok dynamic runtime semantics and the deep Pythonic magic that JAX internally performs, but @beartype is a different breed altogether. If it runs at runtime, we can type-check it.

Let @beartype know if we can do anything for JAX. Until then, thanks for all the efficient transforms, wonderful JAX team!

jakevdp commented 2 years ago

Hi - JAX does not currently do much with static typing, beyond some scattered uses of jnp.ndarray and frequent annotations with aliases like Array = Any. We're currently exploring doing something more systematic (discussion at https://github.com/google/jax/pull/11859) but the roadmap is not yet finalized. In particular, @patrick-kidger and I have been going back and forth about whether we should follow numpy's lead and define both jnp.ndarray for instance checks and jax.NDArray for annotations, or on the other hand push to use a single type like jax.Array for both instance checks and annotations. The latter unification is a nice idea, but ends up being somewhat challenging because of some fundamental limitations of Python's type system, along with the fact that JAX makes heavy use of duck-typing within transformations like jit.

patrick-kidger commented 2 years ago

FWIW, once the jaxtyping rewrite goes in ¹ then jaxtyping will actually be PEP-compliant. It shouldn't actually need any special support from either runtime type checkers or static type checkers.

'tis a thing of beauty, if I say so myself.

¹ Once @jakevdp and I have settled our differences regarding that pesky ND at the start of the annotation.

leycec commented 2 years ago

jaxtyping will actually be PEP-compliant.

:partying_face:

It shouldn't actually need any special support from either runtime type checkers...

__isinstancecheck__() in the metaclass, huh? Classic trick. @beartype gives a clawed thumbs up.

...we should follow numpy's lead and define both jnp.ndarray for instance checks and jax.NDArray for annotations...

Sadly, this is the way. Exactly as you suggest, @jakevdp, Python's typing ecosystem imposes hard fundamental limitations. Type hints were never intended to be used at runtime, really. Many older PEPs admit as much. It was a wild and lawless time back then. Let us never go back there.

NumPy circumvents this by cleverly piggybacking its numpy.typing.NDArray pure-Python type hint factory on top of CPython's typing.GenericAlias C-based superclass. Since typing.GenericAlias drives the entirety of PEP 585 (e.g., type hints like list[str] rather than typing.List[str]), the fact that numpy.typing.NDArray subclasses typing.GenericAlias makes NumPy type hints implicitly compatible with static type checkers – usually the main obstacle.

Cue the sign for victory. \o/