Closed varunagrawal closed 2 months ago
NOTE: I got the idea for using the optax.multi_transform
from https://github.com/google/flax/discussions/1706.
Hi @varunagrawal , I'm not familiar with optax.multi_transform
, but optimizer.init
is expecting a dict, and PPONetworkParams
is not a dict. Maybe try optimizer.init(init_params.__dict__)
? Will move this to a discussion for now
I want to freeze parts of my network for training, so to do this, I modified the
ppo.train
function to accept an optimizer object.I then define the optimizer as
However, when the optimizer calls init as
optimizer.init(init_params)
when defining the TrainingState, it throws the following error:This does not happen when I define the optimizer as
op = optax.adam(3.0e-4)
. Since both variants areoptax.GradientTransformationExtraArgs
, can someone please explain what is happening and how can I resolve this?Here's the full stack trace: