patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Move equinox "tree_pformat" into jaxtyping or allow users to configure their own #194

Open vasiliykarasev opened 3 months ago

vasiliykarasev commented 3 months ago

By default, jaxtyping errors will directly print pytree contents (usually making errors long). If one depends on equinox (or have it installed), they can opt in for pretty printing, and there is a TODO for cleaning up this dependency: https://github.com/patrick-kidger/jaxtyping/blob/main/jaxtyping/_decorator.py#L767-L770

This issue is meant to “+1” the TODO – our tensors are large, errors are spammy, and I frequently find it necessary to drop into pdb to figure out what’s happening. I think pretty-printing can help a lot, and for now we opted into making it a dependency.

If moving equinox’s tree_pformat into jaxtyping is difficult, I think one alternative could be to allow the user to register a pprint function (e.g. through a global jaxtyping config). For my use cases either approach is fine - I didn't see any shortcomings in the equinox’s way of printing pytrees.

patrick-kidger commented 3 months ago

Right! So moving this over was deemed nontrivial as it depends on JAX's own pretty-printing, and I didn't really want to duplicate all of JAX pp + Equinox pp into jaxtyping. (Although granted that's not that hard either!)

FWIW I have recently started a PyTorch project, for which the dependency on JAX (through Equinox) is undesirable to me as well, so I am actually hoping to fix this up in the next month or two.

For now I'm going to mark this as a feature request, and please feel free to ping this thread if nothing gets fixed after a couple of months!