google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.01k stars 633 forks source link

[nnx] refactor vmap #3949

Closed cgarciae closed 3 months ago

cgarciae commented 4 months ago

What does this PR do?

Continues the refactor process from #3927 to vmap.

cgarciae commented 3 months ago

continued in #3969