hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

Dimension mismatch with when using custom Feature Extractor #1184

Closed yassinetb closed 1 year ago

yassinetb commented 1 year ago

Hi!

There seems to be an issue with the matching of dimensions between the fully connected MLP after a custom feature extractor. It seems that the MLP dimensions are not dynamically computed based on the dimensions of the feature extractor output. Also, for some reason, it seems like the batching is not functioning as expected, as what gets sent to the custom Feature Extractor are single observations rather than batches of observations - I assume this might also cause one of the dimensionality issues I am experiencing. Below is a Minimal Working Example with a dummy custom environment (checked with sb3.check_env) where the observation space matches the one for my use case, a custom Features Extractor that resembles a lot the one from the SB3 documentation (for the sake of simplicity).

I am not sure what is wrong here (I hope it's not me being completely stupid, in which case I sincerely apologize), I imagine it may be related to the "features_dim" parameter or something similar. If so, it probably is necessary to update the documentation on Custom Features Extractor to clarify how to avoid dimensionality issues between the custom feature extractor and the following MLP.

Dummy Environment

import numpy as np
import gymnasium as gym
from gymnasium import spaces

class MyDummyEnvironment(gym.Env):
    """
    Dummy environment with 2 grids in the observation space, and 4 actions. The goal is to have both central 
    value on the grid being equal to 5 in a given observation. Stupid, I know, but doesn't matter.
    """
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None):
        super(MyDummyEnvironment, self).__init__()
        self.action_space = spaces.Discrete(4)  # Two possible actions: 0 or 1
        self.observation_space = spaces.Dict(
            {"my_first_frame": spaces.Box(-10, 10, shape=(30,30), dtype=int),
             "my_second_frame": spaces.Box(-10, 10, shape=(30,30), dtype=int)})
        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

    def reset(self, seed = None, options = None):
        self.my_first_frame = np.zeros((30,30), dtype=int)
        self.my_second_frame = np.zeros((30,30), dtype=int)
        return self.get_obs(), self.get_info()

    def step(self, action): 
        if action == 0:
            self.my_first_frame[14,14] += 1
        elif action == 1:
            self.my_second_frame[14,14] += 1
        elif action == 2:
            self.my_first_frame[14,14] -= 1
        elif action == 3:
            self.my_second_frame[14,14] -= 1
        reward = 0
        done = False
        if self.my_first_frame[14,14] == 5 and self.my_second_frame[14,14] == 5:
            reward = 10
            done = True
        if np.abs(self.my_first_frame[14,14]) >= 10 or np.abs(self.my_second_frame[14,14] == 10):
            reward = -10
            done = True
        return self.get_obs(), reward, done, False, self.get_info()

    def get_obs(self):
        return {"my_first_frame": self.my_first_frame, "my_second_frame": self.my_second_frame}

    def get_info(self):
        return {}

    def render(self):
        print("Not much to render here")

Custom Feature Extractor

import torch as th
from torch import nn
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Dict, features_dim):
        # We do not know features-dim here before going over all the items,
        # so put something dummy for now. PyTorch requires calling
        # nn.Module.__init__ before adding modules
        super().__init__(observation_space, features_dim=features_dim)
        extractors = {}
        total_concat_size = 0
        # We need to know size of the output of this extractor,
        # so go over all the spaces and compute output feature sizes
        for key, subspace in observation_space.spaces.items():
            if key == "my_first_frame":
                # We will just downsample one channel of the image by 4x4 and flatten.
                # Assume the image is single-channel (subspace.shape[0] == 0)
                extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten())
                total_concat_size += subspace.shape[0] // 4 * subspace.shape[1] // 4
            elif key == "my_second_frame":
                # Run through a simple MLP
                extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten())
                total_concat_size += 16

        self.extractors = nn.ModuleDict(extractors)
        # Update the features dim manually
        self._features_dim = total_concat_size

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []
        # self.extractors contain nn.Modules that do all the processing.
        for key, extractor in self.extractors.items():
            encoded_tensor_list.append(extractor(observations[key]))
        # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
        return th.cat(encoded_tensor_list, dim=1)

Code snippet to instantiate (and check) the env and start the model.

from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
import os

from stable_baselines3.common.env_checker import check_env
env = MyDummyEnvironment(render_mode = 'rgb_array')
check_env(env)

log_path = os.path.join('RL training') 

env = Monitor(env, log_path)
env = DummyVecEnv([lambda: env])

policy_kwargs = dict(
    features_extractor_class=CustomCombinedExtractor,
    features_extractor_kwargs=dict(features_dim=68),
)
model = PPO("MultiInputPolicy", env, batch_size = 64, policy_kwargs=policy_kwargs, verbose=1)
model.learn(1000)

System Info