Extracts the trainable params on __init__ using nnx.variables into a self.params attribute. This effectively fixes the set of trainable params, in the current implementation if the model adds new params then update will fail. In a future PR we can make the following changes if we want to continue in this direction:
Have __init__ accept a State object with the Variables directly.
Currently there is some benefit of having a reference to model inside Optimizer in that potentially you could just pass the optimizer to some of the functions and then use the model from there. On the other hand, having a pure Variables structure is inline with how Pytorch / MLX represent the optimizers.
I find myself often splitting via nnx.split(state.model, self.wrt) etc. If this is storing the differentiable variables, would saving the static graphdef also make workflows cleaner with less splitting required?
What does this PR do?
Extracts the trainable params on
__init__
usingnnx.variables
into aself.params
attribute. This effectively fixes the set of trainable params, in the current implementation if themodel
adds new params thenupdate
will fail. In a future PR we can make the following changes if we want to continue in this direction:__init__
accept aState
object with the Variables directly.wrt
argument.model
attribute.Example:
Discussion
Currently there is some benefit of having a reference to
model
insideOptimizer
in that potentially you could just pass theoptimizer
to some of the functions and then use the model from there. On the other hand, having a pureVariable
s structure is inline with how Pytorch / MLX represent the optimizers.