google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Optimizer with MultiTransform throws ValueError #480

Closed varunagrawal closed 2 months ago

varunagrawal commented 2 months ago

I want to freeze parts of my network for training, so to do this, I modified the ppo.train function to accept an optimizer object.

I then define the optimizer as

def zero_grads():
      def init_fn(_):
          return ()
      def update_fn(updates, state, params=None):
          return jax.tree_map(jp.zeros_like, updates), ()
      return optax.GradientTransformation(init_fn, update_fn)

op = optax.multi_transform(
    {
        "adam": optax.adam(3.0e-4),
        "zero": zero_grads()
    }, {
        "Encoder_0": "zero",
        "BCModel_0": "zero",
        "ResidualPolicy_0": "adam"
    })

However, when the optimizer calls init as optimizer.init(init_params) when defining the TrainingState, it throws the following error:

ValueError: Expected dict, got PPONetworkParams(policy={'params': ...

This does not happen when I define the optimizer as op = optax.adam(3.0e-4). Since both variants are optax.GradientTransformationExtraArgs, can someone please explain what is happening and how can I resolve this?

Here's the full stack trace:

Traceback (most recent call last):
  File "/Users/scripts/train.py", line 181, in <module>
    main()
  File "/Users/scripts/train.py", line 79, in main
    tx = optimizer.init(init_params)
  File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/optax/_src/combine.py", line 214, in init_fn
    inner_states = {
  File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/optax/_src/combine.py", line 217, in <dictcomp>
    mask_compatible_extra_args=mask_compatible_extra_args).init(params)
  File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/optax/_src/wrappers.py", line 545, in init_fn
    masked_params = mask_pytree(params, mask_tree)
  File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/optax/_src/wrappers.py", line 509, in mask_pytree
    return jax.tree_util.tree_map(
  File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/jax/_src/tree_util.py", line 311, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/jax/_src/tree_util.py", line 311, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Expected dict, got PPONetworkParams(policy={'params': {'ResidualPolicy_0': {'Dense_0': {'kernel': Array([[-0.081,  0.018, -0.062, ...,  0.081,  0.015,  0.01 ],
       [ 0.018,  0.075,  0.039, ...,  0.035,  0.072, -0.018],
       [-0.046,  0.078,  0.027, ...,  0.038,  0.04 , -0.08 ],
       ...,
       [ 0.012,  0.029, -0.009, ...,  0.01 , -0.017,  0.022],
       [ 0.007,  0.034,  0.002, ..., -0.087,  0.059,  0.08 ],
       [ 0.084,  0.024, -0.017, ..., -0.061, -0.087,  0.007]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'Dense_1': {'kernel': Array([[ 0.079,  0.092,  0.13 , ...,  0.01 , -0.024, -0.124],
       [ 0.037, -0.038,  0.072, ...,  0.13 ,  0.051, -0.032],
       [-0.087, -0.007, -0.026, ...,  0.051,  0.071, -0.025],
       ...,
       [ 0.033, -0.073, -0.038, ...,  0.154,  0.025,  0.062],
       [-0.182, -0.127,  0.011, ..., -0.003, -0.08 ,  0.029],
       [-0.115, -0.125,  0.094, ..., -0.041, -0.002, -0.139]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'Dense_2': {'kernel': Array([[-0.008,  0.126, -0.005, ..., -0.137, -0.017,  0.149],
       [ 0.07 , -0.039, -0.172, ..., -0.043, -0.088,  0.138],
       [ 0.146, -0.02 ,  0.142, ...,  0.044,  0.013,  0.092],
       ...,
       [-0.132, -0.006,  0.021, ...,  0.028, -0.068, -0.086],
       [-0.065,  0.124,  0.024, ..., -0.073, -0.087,  0.068],
       [ 0.154, -0.026, -0.121, ...,  0.008, -0.054,  0.035]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0.], dtype=float32)}}}}, value={'params': {'hidden_0': {'kernel': Array([[ 0.061, -0.025,  0.026, ..., -0.067,  0.02 , -0.025],
       [ 0.027, -0.046, -0.015, ..., -0.003,  0.04 ,  0.057],
       [-0.04 ,  0.069, -0.004, ..., -0.06 ,  0.067,  0.046],
       ...,
       [ 0.032,  0.055,  0.014, ..., -0.068,  0.02 , -0.005],
       [-0.016,  0.048, -0.002, ..., -0.07 , -0.062, -0.077],
       [-0.004, -0.064,  0.032, ...,  0.004,  0.001,  0.002]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.], dtype=float32)}, 'hidden_1': {'kernel': Array([[ 0.003,  0.021, -0.014, ..., -0.06 ,  0.055,  0.096],
       [-0.028, -0.01 , -0.013, ...,  0.025,  0.016,  0.019],
       [ 0.02 ,  0.026,  0.029, ...,  0.078,  0.099, -0.078],
       ...,
       [ 0.071, -0.071,  0.03 , ...,  0.104, -0.084,  0.103],
       [ 0.011,  0.095, -0.06 , ..., -0.   ,  0.071, -0.007],
       [-0.039,  0.024, -0.053, ...,  0.051, -0.02 ,  0.008]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.], dtype=float32)}, 'hidden_2': {'kernel': Array([[-0.047,  0.041, -0.093, ...,  0.084,  0.103,  0.034],
       [-0.018, -0.038,  0.016, ...,  0.066,  0.052,  0.105],
       [ 0.002,  0.044,  0.059, ..., -0.008, -0.037, -0.102],
       ...,
       [-0.045, -0.085,  0.017, ...,  0.024, -0.07 , -0.034],
       [ 0.1  , -0.042,  0.026, ...,  0.051, -0.103,  0.056],
       [ 0.002,  0.095, -0.073, ..., -0.049,  0.018, -0.077]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.], dtype=float32)}, 'hidden_3': {'kernel': Array([[ 0.08 , -0.093, -0.103, ..., -0.04 ,  0.063, -0.052],
       [-0.028, -0.087, -0.056, ..., -0.017, -0.072, -0.107],
       [ 0.039, -0.024, -0.021, ..., -0.052,  0.028, -0.012],
       ...,
       [ 0.036,  0.087, -0.033, ...,  0.073, -0.05 ,  0.044],
       [-0.102,  0.069,  0.002, ...,  0.075,  0.095, -0.098],
       [-0.066, -0.019,  0.053, ...,  0.043,  0.025, -0.092]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.], dtype=float32)}, 'hidden_4': {'kernel': Array([[ 0.08 , -0.102, -0.031, ...,  0.009, -0.012,  0.017],
       [ 0.038,  0.081,  0.033, ...,  0.059,  0.056,  0.018],
       [ 0.06 , -0.024,  0.034, ..., -0.046,  0.044,  0.021],
       ...,
       [-0.048, -0.052, -0.099, ..., -0.05 ,  0.016,  0.003],
       [-0.101,  0.042, -0.023, ..., -0.027, -0.016,  0.061],
       [ 0.014, -0.084, -0.028, ...,  0.079,  0.049,  0.104]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.], dtype=float32)}, 'hidden_5': {'kernel': Array([[ 0.091],
       [-0.073],
       [ 0.103],
       [ 0.062],
       [-0.028],
       [ 0.078],
       [-0.066],
       [ 0.047],
       [ 0.066],
       [-0.095],
       [ 0.004],
       [ 0.017],
       [ 0.038],
       [-0.031],
       [ 0.108],
       [-0.048],
       [ 0.065],
       [ 0.099],
       [ 0.031],
       [ 0.023],
       [ 0.014],
       [ 0.025],
       [ 0.086],
       [ 0.087],
       [ 0.052],
       [ 0.029],
       [-0.079],
       [ 0.087],
       [-0.016],
       [-0.02 ],
       [-0.1  ],
       [ 0.086],
       [ 0.018],
       [-0.021],
       [-0.025],
       [ 0.068],
       [-0.068],
       [ 0.071],
       [ 0.013],
       [ 0.073],
       [ 0.076],
       [ 0.098],
       [ 0.035],
       [ 0.074],
       [-0.052],
       [ 0.044],
       [-0.073],
       [-0.006],
       [-0.033],
       [-0.091],
       [-0.037],
       [-0.006],
       [ 0.003],
       [ 0.082],
       [ 0.001],
       [-0.001],
       [ 0.083],
       [ 0.003],
       [ 0.088],
       [ 0.023],
       [-0.031],
       [ 0.027],
       [-0.016],
       [ 0.046],
       [ 0.038],
       [-0.062],
       [-0.011],
       [-0.037],
       [-0.07 ],
       [-0.065],
       [ 0.033],
       [ 0.1  ],
       [-0.084],
       [ 0.055],
       [ 0.024],
       [-0.083],
       [ 0.031],
       [ 0.048],
       [-0.1  ],
       [ 0.098],
       [-0.037],
       [ 0.084],
       [-0.004],
       [ 0.004],
       [-0.029],
       [ 0.071],
       [ 0.031],
       [ 0.069],
       [ 0.079],
       [-0.064],
       [ 0.068],
       [-0.055],
       [-0.085],
       [ 0.102],
       [-0.032],
       [-0.08 ],
       [-0.094],
       [-0.098],
       [-0.097],
       [ 0.041],
       [-0.015],
       [ 0.032],
       [ 0.048],
       [-0.073],
       [ 0.071],
       [-0.098],
       [ 0.072],
       [ 0.051],
       [-0.031],
       [ 0.078],
       [-0.001],
       [-0.052],
       [-0.011],
       [-0.003],
       [-0.003],
       [-0.01 ],
       [ 0.013],
       [-0.058],
       [-0.076],
       [ 0.107],
       [-0.014],
       [ 0.102],
       [ 0.054],
       [-0.047],
       [-0.095],
       [-0.041],
       [-0.049],
       [ 0.043],
       [-0.092],
       [ 0.016],
       [ 0.026],
       [ 0.098],
       [-0.101],
       [ 0.065],
       [-0.027],
       [ 0.085],
       [ 0.093],
       [-0.105],
       [ 0.079],
       [-0.036],
       [ 0.089],
       [-0.008],
       [-0.05 ],
       [-0.072],
       [ 0.094],
       [-0.001],
       [-0.042],
       [ 0.049],
       [-0.065],
       [-0.011],
       [-0.083],
       [-0.008],
       [ 0.011],
       [-0.032],
       [-0.052],
       [ 0.052],
       [ 0.026],
       [-0.069],
       [-0.01 ],
       [ 0.059],
       [-0.079],
       [-0.071],
       [-0.019],
       [-0.041],
       [-0.052],
       [-0.053],
       [-0.072],
       [-0.083],
       [ 0.017],
       [ 0.071],
       [ 0.067],
       [ 0.002],
       [-0.042],
       [-0.085],
       [-0.006],
       [-0.016],
       [-0.086],
       [-0.101],
       [ 0.06 ],
       [-0.067],
       [-0.052],
       [ 0.004],
       [ 0.076],
       [ 0.075],
       [-0.106],
       [-0.044],
       [-0.066],
       [ 0.086],
       [ 0.05 ],
       [-0.083],
       [ 0.105],
       [ 0.08 ],
       [ 0.103],
       [ 0.072],
       [ 0.024],
       [ 0.   ],
       [ 0.065],
       [ 0.025],
       [ 0.047],
       [-0.083],
       [ 0.014],
       [ 0.059],
       [ 0.072],
       [-0.058],
       [ 0.091],
       [-0.033],
       [-0.011],
       [-0.097],
       [-0.077],
       [ 0.049],
       [-0.058],
       [-0.053],
       [-0.081],
       [-0.032],
       [ 0.06 ],
       [-0.093],
       [-0.087],
       [-0.064],
       [-0.008],
       [-0.052],
       [-0.058],
       [ 0.072],
       [-0.031],
       [ 0.07 ],
       [-0.068],
       [-0.102],
       [ 0.045],
       [-0.104],
       [ 0.097],
       [-0.054],
       [-0.035],
       [ 0.021],
       [ 0.015],
       [ 0.064],
       [-0.006],
       [-0.008],
       [ 0.068],
       [-0.014],
       [ 0.065],
       [ 0.034],
       [ 0.002],
       [ 0.014],
       [-0.103],
       [ 0.082],
       [-0.088],
       [-0.029],
       [-0.006],
       [ 0.083],
       [ 0.063],
       [-0.049],
       [ 0.099],
       [-0.012],
       [-0.018],
       [ 0.077],
       [ 0.054],
       [-0.095]], dtype=float32), 'bias': Array([0.], dtype=float32)}}}).
varunagrawal commented 2 months ago

NOTE: I got the idea for using the optax.multi_transform from https://github.com/google/flax/discussions/1706.

btaba commented 2 months ago

Hi @varunagrawal , I'm not familiar with optax.multi_transform, but optimizer.init is expecting a dict, and PPONetworkParams is not a dict. Maybe try optimizer.init(init_params.__dict__)? Will move this to a discussion for now