Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
504 stars 175 forks source link

Create custom recurrent policy #163

Closed kundan-kumarr closed 1 year ago

kundan-kumarr commented 1 year ago

I am trying to create a custom lstm policy. It seems that BasePolicy is missing. How can we create custom LSTM policy that we can pass to PPO or A2C. Also, if not, can be modify the layer of lstm in the current setting to customize the results.

import gym import torch.nn as nn from sb3_contrib.common.policies import BasePolicy from sb3_contrib.ppo_recurrent import RecurrentPPO

class CustomPolicy(BasePolicy): def init(self, *args, *kwargs): super(CustomPolicy, self).init(args, **kwargs)

def make_lstm_layer(self, n_lstm_layers: int) -> nn.Module:
    # Create your custom LSTM layer here
    # For example:
    lstm_layer = nn.LSTM(input_size=self.features_dim, hidden_size=64, num_layers=n_lstm_layers)
    return lstm_layer

env_name = "CartPole-v1" env = gym.make(env_name) model = RecurrentPPO(CustomPolicy, env, verbose=1)

Please help

sasdfsaad commented 1 year ago

@kundan7kumar Did you ever find a solution to this? I'm trying as well to implement more layers in my ReccurentPPO policy, as the problem I am trying to solve is quite complex. I'd love to know if you found a straightforward way of customizing the hidden/numlayers/lstmlayers and pass that to the PPO

araffin commented 1 year ago

Probably duplicate of https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/183

The parameters are also documented.