Open patrick-kidger opened 2 years ago
Why don't you just use single dispatch? That's what I do to pretty print pytrees in tjax?
Of course one can design separate systems for this. I'm just commenting that it would be nice to have some broader extensible way of doing this in JAX; i.e. make it a JAX-notion or a PyTrees(-beyond-JAX)-notion. Rather than just a library-specific thing we each reinvent.
I really like the idea of being able to pretty print pytrees; my fingers get RSI from typing jax.tree_map(lambda x: x.shape, some_large_pytree) everywhere, and yet the string is typically not nice to read in a log
cc @mattjj and @hawkinsp for the extensibility aspect of this proposal
What does serialization have to do with Jax? Do you serialize the static fields differently than the dynamic ones?
As for pretty-printing, the only advantage that I can see to having registration within Jax is that Jax errors can make use of it. That would be useful since whenever I have an error, I added a call before it to my pretty-printer, and re-run the program to hit the same error.
Our usual starting position is that anything supporting more than some basic tree manipulation is out of scope for jax's pytree library. Namely, jax.tree_util
is scoped narrowly to correspond to the jit
(etc.) API contract and not much else.
I do think it would be nice for another tree library to do this. (I'd likely try and reach for such a thing in my work!) What are the main disadvantages to that? I can imagine wanting to check/ensure some correspondence with jax's pytree node registry.
In JAX a PyTree node defines two things: flattening and unflattening. There's some other things that I think would make sense to (optionally) define here. The ones I've come across so far are how to pretty print it (analogous to
pprint.pprint
) and how to serialise and deserialise it to disk.As you can see I've linked to the existing Equinox implementations, which are basically defined for all built-in PyTrees + dataclasses. It would be nice if there was a nice hook to have these integrate with arbitrary custom PyTree nodes as well.
This is a bit speculative and probably not a priority. More of an RFC / just something I'm thinking about keeping in mind going forward.