google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

[nnx] add pmap #3969

Closed cgarciae closed 3 weeks ago

cgarciae commented 1 month ago

What does this PR do?

cgarciae commented 3 weeks ago

Integrated in https://github.com/google/flax/commit/78d85af67e5ac6dd1dc40398a3d451ef9f7fcc76