google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

[nnx] Optimizer uses variables #4347

Open cgarciae opened 3 weeks ago

cgarciae commented 3 weeks ago

What does this PR do?

Extracts the trainable params on __init__ using nnx.variables into a self.params attribute. This effectively fixes the set of trainable params, in the current implementation if the model adds new params then update will fail. In a future PR we can make the following changes if we want to continue in this direction:

Example:

params = nnx.variables(model, nnx.Param)
optimizer = nnx.Optimizer(params, tx=optax.adam(0.01)

Discussion

Currently there is some benefit of having a reference to model inside Optimizer in that potentially you could just pass the optimizer to some of the functions and then use the model from there. On the other hand, having a pure Variables structure is inline with how Pytorch / MLX represent the optimizers.

jlperla commented 3 weeks ago

I find myself often splitting via nnx.split(state.model, self.wrt) etc. If this is storing the differentiable variables, would saving the static graphdef also make workflows cleaner with less splitting required?