DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.76k stars 1.67k forks source link

[Question] Creating a Custom Network #1644

Closed T4KUT00 closed 1 year ago

T4KUT00 commented 1 year ago

❓ Question

I am currently trying to do research using a custom environment. So I wanted to use the PPO algorithm to create a custom network with one image and two numbers as inputs, and I have looked at the documentation to create the network, but it is not working. Specifically, I combined "Multiple Inputs and Dictionary Observations" and "Advanced Example" in On-Policy Algorithms based on the policies page of the document, but it did not work. Can someone please help me?

Checklist

araffin commented 1 year ago

Hello, what have you tried so far and what were the errors? (please keep the code minimal and working, see links in the checklist)

T4KUT00 commented 1 year ago
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import gym
from gym import spaces
import airgym
import time
import torch
import torch.nn as nn

from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import MultiInputActorCriticPolicy

env = DummyVecEnv(
    [
        lambda: Monitor(
            gym.make(
                "airgym:airsim-drone-sample-v0",
                ip_address="127.0.0.1",
                image_shape =(256,144,1),
                step_length=3,                
            )
        )
    ]
)

env = VecTransposeImage(env)

class CustomCombinedExtractors(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict):
        super().__init__(observation_space, features_dim=1)

        extractors = {}
        total_concat_size = 0

        for key, subspace in observation_space.items():

            if key == "image":

                print(subspace.shape[0],subspace.shape[1],subspace.shape[2])
                image_channels = subspace.shape[0]
                extractors[key] = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten()
        )
                total_concat_size += (subspace.shape[1] // 2  ) * (subspace.shape[2] // 2  ) * 32
            elif key == "goal_dis":
                print("subspace.shape[0]:",subspace.shape[0])
                print("subspace:",subspace)
                print(subspace.shape[0])
                extractors[key] = nn.Sequential(
                    nn.Flatten()
                )

                total_concat_size += subspace.shape[0]

            elif key == "angle":
                print("subspace.shape[0]:",subspace.shape[0])
                print("subspace:",subspace)
                extractors[key] = nn.Sequential(
                    nn.Flatten()
                )

                total_concat_size += subspace.shape[0]

        self.extractors = nn.ModuleDict(extractors)

        self._features_dim = total_concat_size

    def forward(self, observations) -> torch.Tensor:

        encoded_tensor_list = []

        for key, extractor in self.extractors.items():

            encoded_tensor_list.append(extractor(observations[key]))

        return torch.cat(encoded_tensor_list, dim = 1)

class CustomNetwork(nn.Module):

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 9, 
        last_layer_dim_vf: int = 1,
    ):
        super().__init__()

        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        self.policy_net = nn.Sequential(
            nn.Linear(294914, last_layer_dim_pi), 
            nn.ReLU()
        )

        self.value_net = nn.Sequential(
            nn.Linear(294914, last_layer_dim_vf)
        )

    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
        return self.value_net(features)

class CustomActorCriticPolicy(MultiInputActorCriticPolicy):

    def __init__(
        self,
        observation_space: spaces.Dict,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        kwargs["ortho_init"] = False,
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            *args,
            **kwargs,
        )

    def _build_mlp_extractor(self) -> None:
        self.features_extractor = CustomCombinedExtractors(self.observation_space.spaces)
        self.mlp_extractor = CustomNetwork(self.features_dim)

model = PPO(
    CustomActorCriticPolicy,
    env,
    gamma = 0.99,
    learning_rate=0.00025,
    gae_lambda = 0.95,
    n_steps = 128,
    batch_size =128,
    n_epochs = 10,
    clip_range = 0.2,
    normalize_advantage=True,
    ent_coef = 0.05,
    vf_coef = 0.5, 
    max_grad_norm = 0.5,
    tensorboard_log="./tb_logs/",
    seed = 100,
    device="cuda",
    use_sde=False,
    verbose=1,
)

callbacks = []
eval_callback = EvalCallback(
    env,
    callback_on_new_best=None,
    n_eval_episodes=5,
    best_model_save_path=".",
    log_path="./tb_logs/",
    eval_freq=10000,
)
callbacks.append(eval_callback)

kwargs = {}
kwargs["callback"] = callbacks

model.learn(
    total_timesteps=5e5,
    tb_log_name="ppo_airsim_drone_run_" + str(time.time()),
    **kwargs
)

model.save("ppo_airsim_drone_policy")

This is code to run a PPO on a drone simulator called AirSim. The input is a single image and two numbers. The following code was used to create this code. https://github.com/microsoft/AirSim/tree/main/PythonClient/reinforcement_learning

When I ran it, it works at first, but when I reach n_steps, I get the following error

Traceback (most recent call last):
  File "C:\PPO\ppo_drone2.py", line 227, in <module>
    model.learn(
  File "C:\Users\name\anaconda3\envs\air\lib\site-packages\stable_baselines3\ppo\ppo.py", line 308, in learn
    return super().learn(
  File "C:\Users\name\anaconda3\envs\air\lib\site-packages\stable_baselines3\common\on_policy_algorithm.py", line 250, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "C:\Users\name\anaconda3\envs\air\lib\site-packages\stable_baselines3\common\on_policy_algorithm.py", line 213, in collect_rollouts
    values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))
  File "C:\Users\name\anaconda3\envs\air\lib\site-packages\stable_baselines3\common\policies.py", line 721, in predict_values
    latent_vf = self.mlp_extractor.forward_critic(features)
  File "C:\PPO\ppo_drone2.py", line 142, in forward_critic
    return self.value_net(features)
  File "C:\Users\name\anaconda3\envs\air\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\name\anaconda3\envs\air\lib\site-packages\torch\nn\modules\container.py", line 217, in forward
    input = module(input)
  File "C:\Users\name\anaconda3\envs\air\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\MaejimaTakuto\anaconda3\envs\air\lib\site-packages\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x258 and 294914x1)

The code that I believe is problematic is the network definition part, which is the following code.

class CustomCombinedExtractors(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict):
        super().__init__(observation_space, features_dim=1)

        extractors = {}
        total_concat_size = 0

        for key, subspace in observation_space.items():

            if key == "image":

                print(subspace.shape[0],subspace.shape[1],subspace.shape[2])
                image_channels = subspace.shape[0]
                extractors[key] = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten()
        )
                total_concat_size += (subspace.shape[1] // 2  ) * (subspace.shape[2] // 2  ) * 32
            elif key == "goal_dis":
                print("subspace.shape[0]:",subspace.shape[0])
                print("subspace:",subspace)
                print(subspace.shape[0])
                extractors[key] = nn.Sequential(
                    nn.Flatten()
                )

                total_concat_size += subspace.shape[0]

            elif key == "angle":
                print("subspace.shape[0]:",subspace.shape[0])
                print("subspace:",subspace)
                extractors[key] = nn.Sequential(
                    nn.Flatten()
                )

                total_concat_size += subspace.shape[0]

        self.extractors = nn.ModuleDict(extractors)

        self._features_dim = total_concat_size

    def forward(self, observations) -> torch.Tensor:

        encoded_tensor_list = []

        for key, extractor in self.extractors.items():

            encoded_tensor_list.append(extractor(observations[key]))

        return torch.cat(encoded_tensor_list, dim = 1)

class CustomNetwork(nn.Module):

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 9, 
        last_layer_dim_vf: int = 1,
    ):
        super().__init__()

        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        self.policy_net = nn.Sequential(
            nn.Linear(294914, last_layer_dim_pi), 
            nn.ReLU()
        )

        self.value_net = nn.Sequential(
            nn.Linear(294914, last_layer_dim_vf)
        )

    def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
        return self.value_net(features)

class CustomActorCriticPolicy(MultiInputActorCriticPolicy):

    def __init__(
        self,
        observation_space: spaces.Dict,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        kwargs["ortho_init"] = False,
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            *args,
            **kwargs,
        )

    def _build_mlp_extractor(self) -> None:
        self.features_extractor = CustomCombinedExtractors(self.observation_space.spaces)
        self.mlp_extractor = CustomNetwork(self.features_dim)
araffin commented 1 year ago

The error is pretty clear:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x258 and 294914x1)

    self.value_net = nn.Sequential(
       nn.Linear(294914, last_layer_dim_vf)

you need to double check those values, you should be using feature_dim (see doc). You should also be using gymnasium with the latest version of SB3.

This is not a SB3 issue but rather a PyTorch issue where you have a mismatch between shapes. I recommend you to debug things using breakpoints or ipdb.