DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.22k stars 1.71k forks source link

[Question] loading pretrained model weights into a new model #1414

Closed WreckItTim closed 1 year ago

WreckItTim commented 1 year ago

❓ Question

Hello,

I saw the previous post here https://github.com/DLR-RM/stable-baselines3/issues/543 with the corresponding paper and google collab notebooks. These helped for sure, thank you!

I am doing something slightly different, and rather than replacing the entire policy with a custom policy object, I want to just replace the nn model in a new SB3 model object - after doing some supervised learning with a surrogate model. To be specific...

  1. I have a custom torch.nn model with pretrained weights for both the actor and critic, and I want to create a new TD3 policy to use these models.
  2. As the previously mentioned paper suggests, I want to collect some data into the replay buffer with the pretrained actor/critic before any online training occurs when fine-tuning.

Solution (psudeo)-code:

n_episodes_b4 = 100 # first collect 100 episodes using pretrained networks
n_episodes_after = 1 # train every n_episodes_after episodes after collection
sb3model = stable_baselines3.TD3(..., learning_starts=0, train_freq=(n_episodes_b4, 'episode'))

# fetch pretrained actor and critic torch.nn models
pretrained_actor, pretrained_critics = load -> torch.nn.Sequential(...), [torch.nn.Sequential(...), ...]

# replace actor network with pretrained actor
sb3model.actor.mu = copy.deepcopy(pretrained_actor)
sb3model.actor_target.mu = copy.deepcopy(pretrained_actor)

# replace all critic networks with pretrained critics
for module in sb3model.critic.modules():
    del module
for module in sb3model.critic_target.modules():
    del module
for idx in range(len(sb3model.critic.q_networks)):
    q_net = copy.deepcopy(pretrained_critics[idx])
    sb3model.critic.q_networks[idx] = q_net
    sb3model.critic.add_module(f"qf{idx}", q_net)
    q_net_target = copy.deepcopy(pretrained_critics[idx])
    sb3model.critic_target.q_networks[idx] = q_net_target
    sb3model.critic_target.add_module(f"qf{idx}", q_net_target)

# use pretrained model to collect n_collection steps before training, not rando
sb3model.learn(env, ...)
#  then start training on normal schedule
env -> after 100 reset() calls then set:
    sb3model.train_freq = (n_episodes_after, 'episode')
    sb3model._convert_train_freq()

Note that I set the actual network here rather than using load_state_dict() because I will be using custom network layers as well. So I do not want to just set weights. I want to replace the modules with my own sequential model, which also has pretrained weights.

I am a little nervous when playing with such a big repo as SB3, and hoping I did not miss anything that may break training. Everything compiles and looks right, but are there any logic errors or other issues I am missing? Thank you so much!

Checklist

BeFranke commented 1 year ago

I had a similar problem, and wrote a wrapper subclass to stable_baselines3.common.policies.ActorCriticPolicy to achieve this:

Gist

Now, I am pretty new to stable baselines 3 and do not claim that this is a good way to do it, but from some limited experiments I did this at least achieves a comparable performance to the sb3-native solution when given the same networks and hyperparameters.

With this code, you can setup your policy and value networks as torch MLPs and then pass them like this (example for PPO):

model = PPO(env=train_env, policy=TorchWrapperActorCriticPolicy, policy_kwargs={ "torch_policy_network": <put your custom policy network here>, "torch_value_network": <put your custom policy network here> })

You should then be able to freely re-assign model.policy._policy_network and model.policy._value_network. Feedback to this solution is welcome!

araffin commented 1 year ago

I have read the documentation

https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#accessing-and-modifying-model-parameters

TL;DR: you can use set_parameters or load_state_dict() (you will probably need to rename some keys)

I have checked that there is no similar issue in the repo

Probably a duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1411

WreckItTim commented 1 year ago

I have read the documentation

https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#accessing-and-modifying-model-parameters

TL;DR: you can use set_parameters or load_state_dict() (you will probably need to rename some keys)

I have checked that there is no similar issue in the repo

Probably a duplicate of #1411

Slightly a duplicate, the difference being I am also trying to collect a replay buffer using the already trained model before online training starts. As opposed to randomly sampling the action space before online training starts.

araffin commented 1 year ago

The code that update the replay buffer is there: https://github.com/DLR-RM/stable-baselines3/blob/5a70af8abddfc96eac3911e69b20f19f5220947c/stable_baselines3/common/off_policy_algorithm.py#L416

In the examples, we have one about saving/loading a replay buffer.

WreckItTim commented 1 year ago

@araffin Perfect! This might be better than the hacky solution I was using in original post. Thank you for the responses =)

@BeFranke Yes this is similar to what I want to do! Thank you. I was trying to avoid making a custom class, because you never know for sure if there are any potential logic errors, or other steps, you are skipping/overwriting when working with a large, external repo. I think it's a better solution than overriding/hacking the built-in classes though.