google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6k stars 633 forks source link

[nnx] add Using Filters guide #4028

Closed cgarciae closed 3 months ago

cgarciae commented 3 months ago

What does this PR do?

review-notebook-app[bot] commented 3 months ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

cgarciae commented 3 months ago

Hey @IvyZX, something weird happened with the other PR so I am continuing it here.

Thanks for reviewing the PR, sadly I hadn't pushed all my changes and the guide was half way through. Check out the new content. I focused the guide on making it clear how filters work and how state is filtered. Let me know if you still think I should add more examples.

jlperla commented 3 months ago

@cgarciae this is beautiful! Makes a big difference in forming a mental model on NNX.

Is a new release planned anytime the next month or so which encapsulates this sort of stuff? I would love to prepare some of these notes fall teaching to show the NNX workflow

cgarciae commented 3 months ago

@jlperla coincidentally I just created a new release (needed it for a bug fix), I can create another one soon. However, all the features shown here are stable so it should not impact the information in the guide a lot (except for some types being exposed at the top-level).

jlperla commented 3 months ago

Thanks @cgarciae If the interfaces are stable then that is good for me to experiment. I will try some of this out in the next month and then update my lecture notes august (at which point I can point to new docs) so no time pressure on my side.

Looking forward to trying out the library. Also keen to see in august how the "training loop" pattern with metrics, optax, etc. that I can teach as a boilerplate - which I noticed you were working on. Lots of movement!