google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

feature_request: support for tabulate/summary in the NNX API #3962

Open AshishKumar4 opened 1 month ago

AshishKumar4 commented 1 month ago

Hey! I see that the Flax NNX APIs have visualization support via Penzai, but is there Module.tabulate or summary or something similar for nnx.Module objects? If not, is something already in the works?

Such functionalities are really helpful in quickly checking the number of parameters and distributions as well as the structure of the network and in debugging.

cgarciae commented 1 month ago

Hey! Having tabulate API is definitely on the horizon as having a top-tier debugging experience is very important for NNX. Its not in the works but something that we would add sooner than later.

Thanks for bringing it up, we can track the issue here and hopefully we can spend some cycles on it soon.