Closed WreckItTim closed 1 year ago
I had a similar problem, and wrote a wrapper subclass to stable_baselines3.common.policies.ActorCriticPolicy to achieve this:
Now, I am pretty new to stable baselines 3 and do not claim that this is a good way to do it, but from some limited experiments I did this at least achieves a comparable performance to the sb3-native solution when given the same networks and hyperparameters.
With this code, you can setup your policy and value networks as torch MLPs and then pass them like this (example for PPO):
model = PPO(env=train_env, policy=TorchWrapperActorCriticPolicy, policy_kwargs={ "torch_policy_network": <put your custom policy network here>, "torch_value_network": <put your custom policy network here> })
You should then be able to freely re-assign model.policy._policy_network and model.policy._value_network. Feedback to this solution is welcome!
I have read the documentation
TL;DR: you can use set_parameters
or load_state_dict()
(you will probably need to rename some keys)
I have checked that there is no similar issue in the repo
Probably a duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1411
I have read the documentation
TL;DR: you can use
set_parameters
orload_state_dict()
(you will probably need to rename some keys)I have checked that there is no similar issue in the repo
Probably a duplicate of #1411
Slightly a duplicate, the difference being I am also trying to collect a replay buffer using the already trained model before online training starts. As opposed to randomly sampling the action space before online training starts.
The code that update the replay buffer is there: https://github.com/DLR-RM/stable-baselines3/blob/5a70af8abddfc96eac3911e69b20f19f5220947c/stable_baselines3/common/off_policy_algorithm.py#L416
In the examples, we have one about saving/loading a replay buffer.
@araffin Perfect! This might be better than the hacky solution I was using in original post. Thank you for the responses =)
@BeFranke Yes this is similar to what I want to do! Thank you. I was trying to avoid making a custom class, because you never know for sure if there are any potential logic errors, or other steps, you are skipping/overwriting when working with a large, external repo. I think it's a better solution than overriding/hacking the built-in classes though.
❓ Question
Hello,
I saw the previous post here https://github.com/DLR-RM/stable-baselines3/issues/543 with the corresponding paper and google collab notebooks. These helped for sure, thank you!
I am doing something slightly different, and rather than replacing the entire policy with a custom policy object, I want to just replace the nn model in a new SB3 model object - after doing some supervised learning with a surrogate model. To be specific...
Solution (psudeo)-code:
Note that I set the actual network here rather than using load_state_dict() because I will be using custom network layers as well. So I do not want to just set weights. I want to replace the modules with my own sequential model, which also has pretrained weights.
I am a little nervous when playing with such a big repo as SB3, and hoping I did not miss anything that may break training. Everything compiles and looks right, but are there any logic errors or other issues I am missing? Thank you so much!
Checklist