jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.41k stars 2.79k forks source link

Allow custom PyTree nodes to define pretty-prints, serialisation, deserialisation #11210

Open patrick-kidger opened 2 years ago

patrick-kidger commented 2 years ago

In JAX a PyTree node defines two things: flattening and unflattening. There's some other things that I think would make sense to (optionally) define here. The ones I've come across so far are how to pretty print it (analogous to pprint.pprint) and how to serialise and deserialise it to disk.

As you can see I've linked to the existing Equinox implementations, which are basically defined for all built-in PyTrees + dataclasses. It would be nice if there was a nice hook to have these integrate with arbitrary custom PyTree nodes as well.

This is a bit speculative and probably not a priority. More of an RFC / just something I'm thinking about keeping in mind going forward.

NeilGirdhar commented 2 years ago

Why don't you just use single dispatch? That's what I do to pretty print pytrees in tjax?

patrick-kidger commented 2 years ago

Of course one can design separate systems for this. I'm just commenting that it would be nice to have some broader extensible way of doing this in JAX; i.e. make it a JAX-notion or a PyTrees(-beyond-JAX)-notion. Rather than just a library-specific thing we each reinvent.

zhangqiaorjc commented 2 years ago

I really like the idea of being able to pretty print pytrees; my fingers get RSI from typing jax.tree_map(lambda x: x.shape, some_large_pytree) everywhere, and yet the string is typically not nice to read in a log

cc @mattjj and @hawkinsp for the extensibility aspect of this proposal

NeilGirdhar commented 2 years ago

What does serialization have to do with Jax? Do you serialize the static fields differently than the dynamic ones?

As for pretty-printing, the only advantage that I can see to having registration within Jax is that Jax errors can make use of it. That would be useful since whenever I have an error, I added a call before it to my pretty-printer, and re-run the program to hit the same error.

froystig commented 2 years ago

Our usual starting position is that anything supporting more than some basic tree manipulation is out of scope for jax's pytree library. Namely, jax.tree_util is scoped narrowly to correspond to the jit (etc.) API contract and not much else.

I do think it would be nice for another tree library to do this. (I'd likely try and reach for such a thing in my work!) What are the main disadvantages to that? I can imagine wanting to check/ensure some correspondence with jax's pytree node registry.