google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

nnx.pmap documentation seems to be for vmap #4330

Open hrbigelow opened 1 month ago

hrbigelow commented 1 month ago

Hi,

The source code documentation for nnx.pmap appears to be instead referring to nnx.vmap.

cgarciae commented 1 month ago

This is correct. Sorry it was adapted very quickly from vmap for an experiment. Will update soon.

hrbigelow commented 1 month ago

No problem, thank you @cgarciae