google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

Add support for storing arbitrary PyTrees with `Module.perturb()` #4348

Closed copybara-service[bot] closed 3 weeks ago

copybara-service[bot] commented 3 weeks ago

Add support for storing arbitrary PyTrees with Module.perturb()