DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.69k stars 1.65k forks source link

[Feature Request] Temporal Convolutional network #1984

Open tty666 opened 1 month ago

tty666 commented 1 month ago

🚀 Feature

Hello guys, After watching this video : https://www.youtube.com/watch?v=WoLlZLdoEQk I had the idea to extend the NatureCNN to NatureCTN1D this way :

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()
class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, activation_fn, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.activation1 = activation_fn()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.activation2 = activation_fn()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.activation1, self.dropout1,
                                    self.conv2, self.chomp2, self.activation2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.activation = activation_fn()

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.activation(out + res)
class NatureTCN1D(BaseFeaturesExtractor):
    def __init__(
        self,
        observation_space: gym.spaces.Box,
        features_dim: int = 256,
        dataset_dim: int = 21,
        activation_fn=nn.SiLU,
        dropout=0.2,
    ) -> None:
        super().__init__(observation_space, features_dim)
        self.dataset_dim = dataset_dim
        n_input_channels = self.dataset_dim

        self.tcn = nn.Sequential(
            TemporalBlock(n_input_channels, 32, kernel_size=5, stride=1, dilation=1, padding=(5-1) * 1, activation_fn=activation_fn, dropout=dropout),
            TemporalBlock(32, 64, kernel_size=7, stride=1, dilation=2, padding=(7-1) * 2, activation_fn=activation_fn, dropout=dropout),
            TemporalBlock(64, 128, kernel_size=3, stride=1, dilation=4, padding=(3-1) * 4, activation_fn=activation_fn, dropout=dropout),
            nn.Flatten(),
        )

        with th.no_grad():
            sample_observation = th.as_tensor(observation_space.sample()[None, :, :self.dataset_dim]).float()
            sample_observation = sample_observation.permute(0, 2, 1)
            n_flatten = self.tcn(sample_observation).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.SiLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        observations = observations[:, :, :self.dataset_dim]
        observations = observations.permute(0, 2, 1)
        return self.linear(self.tcn(observations))
class CombinedExtractorT1D(BaseFeaturesExtractor):
    """
    Combined features extractor for Box observation spaces.
    The input is fed through two separate submodules (CNN1D and Flatten),
    the output features are concatenated and fed through additional MLP network ("combined").

    :param observation_space: The observation space
    :param cnn_output_dim: Number of features to output from the CNN1D submodule(s). Defaults to
        256 to avoid exploding network sizes.
    """

    def __init__(
        self,
        observation_space: spaces.Box,
        cnn_output_dim: int = 256,
        dataset_dim: int = 5,
    ) -> None:
        assert isinstance(observation_space, spaces.Box), (
            "NatureCNN1D must be used with a gym.spaces.Box ",
            f"observation space, not {observation_space}",
        )
        super().__init__(observation_space, features_dim=1)
        # We assume CxL inputs (channels first)
        assert len(observation_space.shape) == 2, (
            "You should use NatureCNN1D only with 2D inputs (channels, length)"
        )
        self.cnn_extractor = NatureTCN1D(observation_space, features_dim=cnn_output_dim, dataset_dim=dataset_dim)
        self.raw_extractor = nn.Flatten()

        cnn_output_size = cnn_output_dim
        raw_output_size = get_flattened_obs_dim(observation_space)

        # Update the features dim manually
        self._features_dim = cnn_output_size + raw_output_size

    def forward(self, observations: th.Tensor) -> th.Tensor:
        cnn_encoded = self.cnn_extractor(observations)
        raw_encoded = self.raw_extractor(observations)

        return th.cat([raw_encoded, cnn_encoded], dim=1)
class ActorCriticTCN1DPolicy(ActorCriticPolicy):
    """
    CNN policy class for actor-critic algorithms (has both policy and value prediction).
    Used by A2C, PPO and the likes.

    :param observation_space: Observation space
    :param action_space: Action space
    :param lr_schedule: Learning rate schedule (could be constant)
    :param net_arch: The specification of the policy and value networks.
    :param activation_fn: Activation function
    :param ortho_init: Whether to use or not orthogonal initialization
    :param use_sde: Whether to use State Dependent Exploration or not
    :param log_std_init: Initial value for the log standard deviation
    :param full_std: Whether to use (n_features x n_actions) parameters
        for the std instead of only (n_features,) when using gSDE
    :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
        a positive standard deviation (cf paper). It allows to keep variance
        above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
    :param squash_output: Whether to squash the output using a tanh function,
        this allows to ensure boundaries when using gSDE.
    :param features_extractor_class: Features extractor to use.
    :param features_extractor_kwargs: Keyword arguments
        to pass to the features extractor.
    :param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
    :param normalize_images: Whether to normalize images or not,
        dividing by 255.0 (True by default)
    :param optimizer_class: The optimizer to use,
        ``th.optim.Adam`` by default
    :param optimizer_kwargs: Additional keyword arguments,
        excluding the learning rate, to pass to the optimizer
    """

    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
        activation_fn: Type[nn.Module] = nn.Tanh,
        ortho_init: bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = NatureTCN1D,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ):
        # Ajout d'une valeur par défaut si optimizer_kwargs est None
        # if optimizer_kwargs is None:
        #     optimizer_kwargs = {}
        #     optimizer_kwargs["eps"] = 1e-5
        #print(f"optimizer_kwargs: {optimizer_kwargs}")  # Debug message
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch,
            activation_fn,
            ortho_init,
            use_sde,
            log_std_init,
            full_std,
            use_expln,
            squash_output,
            features_extractor_class,
            features_extractor_kwargs,
            share_features_extractor,
            normalize_images,
            optimizer_class,
            optimizer_kwargs,
        )
class MultiInputActorCriticPolicyT1D(ActorCriticPolicy):
    """
    MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction).
    Used by A2C, PPO and the likes.

    :param observation_space: Observation space (Tuple)
    :param action_space: Action space
    :param lr_schedule: Learning rate schedule (could be constant)
    :param net_arch: The specification of the policy and value networks.
    :param activation_fn: Activation function
    :param ortho_init: Whether to use or not orthogonal initialization
    :param use_sde: Whether to use State Dependent Exploration or not
    :param log_std_init: Initial value for the log standard deviation
    :param full_std: Whether to use (n_features x n_actions) parameters
        for the std instead of only (n_features,) when using gSDE
    :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure
        a positive standard deviation (cf paper). It allows to keep variance
        above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
    :param squash_output: Whether to squash the output using a tanh function,
        this allows to ensure boundaries when using gSDE.
    :param features_extractor_class: Uses the CombinedExtractor
    :param features_extractor_kwargs: Keyword arguments
        to pass to the features extractor.
    :param share_features_extractor: If True, the features extractor is shared between the policy and value networks.
    :param normalize_images: Whether to normalize images or not,
        dividing by 255.0 (True by default)
    :param optimizer_class: The optimizer to use,
        ``th.optim.Adam`` by default
    :param optimizer_kwargs: Additional keyword arguments,
        excluding the learning rate, to pass to the optimizer
    """

    def __init__(
        self,
        observation_space: spaces.Box,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
        activation_fn: Type[nn.Module] = nn.Tanh,
        ortho_init: bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractorT1D,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch,
            activation_fn,
            ortho_init,
            use_sde,
            log_std_init,
            full_std,
            use_expln,
            squash_output,
            features_extractor_class,
            features_extractor_kwargs,
            share_features_extractor,
            normalize_images,
            optimizer_class,
            optimizer_kwargs,
        )

It's a quick addition I am pretty sure I could work more on it... But maybe it's a good addition and sometimes replacement for LSTM/RNN ? I am using SiLU in my context but in a more "general" way ReLU could be used as activation function ... What do you think Should I propose it as a pull request for the contrib repo ? Or it doesn't make any sense for you ?

Motivation

The ReccurentPPO wasn't fitting my need so I did some research about other possibilities ...

Pitch

No response

Alternatives

No response

Additional context

No response

Checklist

araffin commented 1 month ago

But maybe it's a good addition and sometimes replacement for LSTM/RNN ?

that would be more for SB3 contrib I guess. And without any benchmark, it's hard to say if it's a good addition. For instance, for recurrent PPO: https://wandb.ai/sb3/no-vel-envs/reports/PPO-vs-RecurrentPPO-aka-PPO-LSTM-on-environments-with-masked-velocity--VmlldzoxOTI4NjE4

(the gain is marginal with respect to frame stacking on several envs but it is substantial on others like lunar lander without velocity)