coax-dev / coax

Modular framework for Reinforcement Learning in python
https://coax.readthedocs.io
MIT License
165 stars 17 forks source link

Sharing parameters between actor and critic? #23

Closed mhr closed 2 years ago

mhr commented 2 years ago

I'd like to share some parameters between my actor and critic networks. During training, when the critic's weights get updated, so do the shared parameters, and likewise when the actor's weights get updated.

It appeared at first that https://coax.readthedocs.io/en/latest/examples/getting_started/third_agent.html?highlight=shared#ppo-on-pong was relevant because it has a shared function between func_pi and func_v, but I realized this function only enables copying of architecture, not synced parameters during training.

In Haiku generally, I found that https://dm-haiku.readthedocs.io/en/latest/api.html#multi-transform is how one is supposed to share parameters between two networks, but that requires doing the transform of the module functions yourself, and Coax's code requires that we let it transform for us. See https://colab.research.google.com/gist/tomhennigan/a9f434bf64d132f1734310bdc36d7281/example-of-multi_transform-for-actor-critic-with-shared-layer.ipynb#scrollTo=gInsTcK9TN_x for how this is used in an Actor-Critic framework.

Another use-case I'm thinking of is if I want to train the parameters of a separate module with a distinct Optax optimizer in a different training loop such that the module's parameters are updated using both RL and another loss. This might be useful for e.g. model-based RL where we want to train a dynamics module separately from the policy gradient training loop and then use the dynamics module within an actor within the policy gradient training loop.

What is an easy way to allow for parameter-sharing between actor and critic networks in Coax?

KristianHolsheimer commented 2 years ago

Hi there! Thanks for your interest in coax!

You're right, the shared module is somewhat misleading as the params aren't actually shared between the policy and value function.

There wasn't a clean way to do this, so I added a utility for that: coax.utils.sync_shared_params. This function looks for shared top-level keys in the provided params and then takes (weighted) averages across multiple occurrences of those keys. Because of this, it's important that you're careful about parameter name scopes. The easiest way to do this is to wrap your network definitions in a hk.experimental.name_scope.

Here's an example of how to use it: A2C stub. The relevant line is:

pi.params, v.params = coax.utils.sync_shared_params(pi.params, v.params)

To be more explicit, here's a simple example:

import coax
import haiku as hk

params1 = hk.data_structures.to_haiku_dict({'torso': {'w': 1.0}, 'head1': {'w': 10.0}})
params2 = hk.data_structures.to_haiku_dict({'torso': {'w': 2.0}, 'head2': {'w': 20.0}})

We sync the params using relative weights, e.g.

params1, params2 = coax.utils.sync_shared_params(params1, params2, weights=[0.2, 0.8])

Then, after we sync the shared params, we have:

print(params1)  # {'torso': {'w': 1.8}, 'head1': {'w': 10.0}}
print(params2)  # {'torso': {'w': 1.8}, 'head2': {'w': 20.0}}

Let me know if this works for you.