google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

Add nnx.shard_map #4261

Open cgarciae opened 1 month ago

8bitmp3 commented 1 month ago

Thanks @cgarciae :rocket:

TODO Then we can update the Transforms guide