DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.98k stars 1.68k forks source link

model.learn() does not properly handle gym.spaces.Discrete spaces where start !=0 #1797

Closed Jaage closed 8 months ago

Jaage commented 9 months ago

šŸ› Bug

I have a spaces.Discrete(19, 2) observation as part of my observation. The documentation for the Discrete space lists the space of possible values as {a, a+1, ..., a+n-1}, thus my observation space for this value will be {2, 3, ..., 20}. However, model.learn() does not seem to account for this non-zero start, as it errors during the one hot encoding because torch.nn.functional.one_hot creates a tensor from 0 to observation_space.n: return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float(). So torch creates a tensor of length 19, and thus errors when trying to one-hot encode the value 20.

Am I doing something wrong, is there an automatic way to handle this somewhere that I have missed, or do I need to manually map my Discrete observation to start from 0 before?

Thank you.

Code example

import polars as pl
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.vec_env import DummyVecEnv
import os

class simple_env(gym.Env):
    """Custom Environment that follows gym interface."""

    def __init__(self, df: pl.DataFrame):
        super().__init__()
        self.df = df
        self.cur_row = 0
        self.action_space = spaces.Box(low=-1, high=1, shape=(1,))
        self.observation_space = spaces.Dict({
            "a": spaces.Box(0, 1, (1,), np.float64),
            "b": spaces.Discrete(19, 2)
        })
        self.metadata = {"render_modes": ["human"], "render_fps": 30}
        self.render_mode = "human"

    def step(self, action):
        terminated = False
        info = {}
        reward = float(action)
        if self.cur_row == self.df.shape[0]-1:
            self.cur_row = -1
        self.cur_row+=1
        observation = {key:
                       self.df[self.cur_row].select(pl.col(f"{key}")).to_numpy().flatten()[0]
                       if isinstance(value, gym.spaces.Discrete)
                       else self.df[self.cur_row].select(pl.col(f"{key}")).to_numpy().flatten()
                       for key, value in self.observation_space.items()}
        return observation, reward, terminated, False, info

    def reset(self, seed=None):
        super().reset(seed=seed)
        self.cur_row=0
        observation = {key:
                       self.df[self.cur_row].select(pl.col(f"{key}")).to_numpy().flatten()[0]
                       if isinstance(value, gym.spaces.Discrete)
                       else self.df[self.cur_row].select(pl.col(f"{key}")).to_numpy().flatten()
                       for key, value in self.observation_space.items()}
        info = {}
        return observation, info

    def render(self, mode="human"):
        pass

    def close(self):
        pass

def make_env(env_config, rank, seed=0):
    """
    Utility function for multiprocessed env.
    :param env_config: (dict) the dictionary of environment parameters
    :param seed: (int) the initial seed for RNG
    :param rank: (int) index of the subprocess
    """
    def _init():
        env = simple_env(**env_config)
        env.reset(seed=(seed + rank))
        return env
    set_random_seed(seed)
    pl.set_random_seed(rank+seed)
    return _init

def make_env(env_config, rank, seed=0):
    """
    Utility function for multiprocessed env.
    :param env_config: (dict) the dictionary of environment parameters
    :param seed: (int) the initial seed for RNG
    :param rank: (int) index of the subprocess
    """
    def _init():
        env = simple_env(**env_config)
        env.reset(seed=(seed + rank))
        return env
    set_random_seed(seed)
    pl.set_random_seed(rank+seed)
    return _init

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
df = pl.DataFrame({"a": [0.0, 0.12, 0.17, 0.99, 0.86], "b": [7, 8, 5, 15, 20]})
env = simple_env(df)
check_env(env)
env_config = {"df": df}
num_cpu = 1
vec_env = DummyVecEnv([make_env(env_config, i) for i in range(num_cpu)])
model = PPO("MultiInputPolicy", vec_env, device="cpu")
obs = vec_env.reset()
model.learn(10)

Relevant log output / Error message

File "c:test\tests.py", line 231, in <module>
    action, _states = model.predict(obs, deterministic=True)
  File "C:\test\.venv\lib\site-packages\stable_baselines3\common\base_class.py", line 553, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
  File "C:\test\.venv\lib\site-packages\stable_baselines3\common\policies.py", line 366, in predict
    actions = self._predict(obs_tensor, deterministic=deterministic)
  File "C:\test\.venv\lib\site-packages\stable_baselines3\common\policies.py", line 715, in _predict
    return self.get_distribution(observation).get_actions(deterministic=deterministic)
  File "C:\test\.venv\lib\site-packages\stable_baselines3\common\policies.py", line 748, in get_distribution
    features = super().extract_features(obs, self.pi_features_extractor)
  File "C:\test\.venv\lib\site-packages\stable_baselines3\common\policies.py", line 130, in extract_features
    preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
  File "C:\test\.venv\lib\site-packages\stable_baselines3\common\preprocessing.py", line 113, in preprocess_obs
    preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
  File "C:\test\.venv\lib\site-packages\stable_baselines3\common\preprocessing.py", line 125, in preprocess_obs
    return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float()
RuntimeError: Class values must be smaller than num_classes.

System Info

Checklist

Jaage commented 9 months ago

Hi @araffin, if this is a duplicate, can you please provide the link to the original? And what checkboxes need more clarification?

araffin commented 9 months ago

I have checked that there is no similar issue in the repo

Try harder next time =)

Duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1509, #1295 and #913

but it seems we need to update the env checker to warn when users are using dict obs space.

Am I doing something wrong, is there an automatic way to handle this somewhere that I have missed, or do I need to manually map my Discrete observation to start from 0 before?

yes or you can use a wrapper.

araffin commented 8 months ago

but it seems we need to update the env checker to warn when users are using dict obs space.

The env checker was up to date, you were using the Discrete constructor the wrong way (second argument is not start)