DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.69k stars 1.65k forks source link

Custom actor and critic network #1985

Closed krishdotn1 closed 3 weeks ago

krishdotn1 commented 1 month ago

❓ Question

can anyone explain to me how can I change the default actor and critic network to my network? I have done this. Step by step of implementation:

  1. created a custom network
  2. def _build_mlp_extractor(self) -> None: self.mlp_extractor = CustomNetwork(self.features_dim) with my custom network. is it enough to run stable-baselines3 as default run ?

Checklist

fracapuano commented 1 month ago

Hey @krishdotn1 👋

Check out this ⬇️ -- it directly contains the answer to your question 😊

from typing import Callable, Dict, List, Optional, Tuple, Type, Union

from gymnasium import spaces
import torch as th
from torch import nn

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy

class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the features extractor.

    :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
    :param last_layer_dim_pi: (int) number of units for the last layer of the policy network
    :param last_layer_dim_vf: (int) number of units for the last layer of the value network
    """

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 64,
        last_layer_dim_vf: int = 64,
    ):
        super().__init__()

        # IMPORTANT:
        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # Policy network
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
        )
        # Value network
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
        )

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
            If all layers are shared, then ``latent_policy == latent_value``
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)

class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        # Disable orthogonal initialization
        kwargs["ortho_init"] = False
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )

    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)

model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
model.learn(5000)

You should just subclass ActorCriticPolicy (or any other policy you would need based on your problem--hard to tell without the necessary context) and add your custom network as mlp_extractor. Feel free to share more details if not clear, happy to help further 🤗

krishdotn1 commented 1 month ago

Thank you @fracapuano. I have done same thing it works well but when I use SubprocVecEnv it never start training. just show Using Cuda device and loading. P.S I'm using LNN model for critic and actor.

fracapuano commented 1 month ago

Great to hear that you are following the doc! Would you be able to upload a minimal example to reproduce your issue here?

Thank you!

araffin commented 3 weeks ago

@fracapuano thanks for helping out =)

I have never used SubprocVecEnv with a custom model, but my understanding of this is that it should not change much

yes, SubprocVecEnv should not influence the result, it only spawns multiple process to collect data in parallel. If something goes wrong as that step, it's probably a limitation from the env.