Closed kundan-kumarr closed 1 year ago
@kundan7kumar Did you ever find a solution to this? I'm trying as well to implement more layers in my ReccurentPPO policy, as the problem I am trying to solve is quite complex. I'd love to know if you found a straightforward way of customizing the hidden/numlayers/lstmlayers and pass that to the PPO
Probably duplicate of https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/183
The parameters are also documented.
I am trying to create a custom lstm policy. It seems that BasePolicy is missing. How can we create custom LSTM policy that we can pass to PPO or A2C. Also, if not, can be modify the layer of lstm in the current setting to customize the results.
import gym import torch.nn as nn from sb3_contrib.common.policies import BasePolicy from sb3_contrib.ppo_recurrent import RecurrentPPO
class CustomPolicy(BasePolicy): def init(self, *args, *kwargs): super(CustomPolicy, self).init(args, **kwargs)
env_name = "CartPole-v1" env = gym.make(env_name) model = RecurrentPPO(CustomPolicy, env, verbose=1)
Please help