google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

[nnx] support pure dicts #4352

Closed cgarciae closed 2 weeks ago

cgarciae commented 3 weeks ago

What does this PR do?

Example:

m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
graphdef, state = nnx.split(m)

pure_state = state.to_pure_dict() # remove leaf metdata

m2 = nnx.merge(graphdef, pure_state) # merge from pure state