google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

Fixes for batch norm docs #3982

Closed jkarwowski closed 3 weeks ago

jkarwowski commented 3 weeks ago

Changes:

The documentation for batch norm is written in such a way, that directly copying from it results in an unpleasant bug where the params used in eval_step refer to the initial parameters of the network. This causes metrics to show that the model is not generalising to the test set at all. This change hopefully saves someone else learning flax a similar scratch on the head.

Checklist

google-cla[bot] commented 3 weeks ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.