google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

Add direct penzai.treescope support for NNX objects. #3948

Closed copybara-service[bot] closed 1 month ago

copybara-service[bot] commented 1 month ago

Add direct penzai.treescope support for NNX objects.

This change implements the __penzai_repr__ protocol on most NNX objects, making it possible to directly visualize them using the standard penzai.treescope configuration without an extra conversion step. Modules, GraphDefs, and States are all visualizable.

The nnx.display function is no longer needed if Penzai is installed, since pz.ts.basic_interactive_setup() followed by IPython.display.display or pz.show (or just returning an object from an IPython cell) is now sufficient to visualize NNX objects.

Also fixes GraphDef repr to use "leaves" instead of "variables".