araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

Custom env with FrameStack wrapper causes invalid actions to be passed to `env.step` #17

Closed capnspacehook closed 11 months ago

capnspacehook commented 11 months ago

🤖 Custom Gym Environment

Describe the bug

When using gymnasium.wrappers.frame_stack.FrameStack with a simple custom env, I get an exception when an action is being chosen in step.

Code example

import itertools
from typing import Any, List, Tuple

import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers.frame_stack import FrameStack
from sbx import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import DummyVecEnv

class MyEnv(gym.Env):
    def __init__(self) -> None:
        self.actions, self.action_space = self.actionSpace()
        self.observation_space = Box(0, 1, shape=(1,))

        super().__init__()

    def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
        chosenAction = self.actions[action]

        return self.obs(), 0.0, False, False, {}

    def reset(
        self, *, seed: int | None = None, options: dict | None = None
    ) -> Tuple[Any, dict]:
        super().reset(seed=seed, options=options)
        return self.obs(), {}

    def obs(self):
        return np.array([0.5], dtype=np.float32)

    def render(self) -> Any | List[Any] | None:
        pass

    def actionSpace(self):
        baseActions = [0, 1, 2, 3, 4]

        totalActionsWithRepeats = list(itertools.permutations(baseActions, 2))
        withoutRepeats = []

        for combination in totalActionsWithRepeats:
            reversedCombination = combination[::-1]
            if reversedCombination not in withoutRepeats:
                withoutRepeats.append(combination)

        filteredActions = [[action] for action in baseActions] + withoutRepeats

        return filteredActions, Discrete(len(filteredActions))

if __name__ == "__main__":
    env = MyEnv()
    check_env(env)

    env = FrameStack(env, 4)
    env = DummyVecEnv([lambda: env])

    algo = PPO("MlpPolicy", env)
    algo.learn(total_timesteps=1000)
Traceback (most recent call last):
  File "/home/user/sbx_ppo_repro.py", line 61, in <module>
    algo.learn(total_timesteps=1000)
  File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/ppo/ppo.py", line 315, in learn
    return super().learn(
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 259, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/home/user/jax-venv/lib/python3.10/site-packages/sbx/common/on_policy_algorithm.py", line 152, in collect_rollouts
    new_obs, rewards, dones, infos = env.step(clipped_actions)
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 197, in step
    return self.step_wait()
  File "/home/user/jax-venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 58, in step_wait
    obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
  File "/home/user/jax-venv/lib/python3.10/site-packages/gymnasium/wrappers/frame_stack.py", line 179, in step
    observation, reward, terminated, truncated, info = self.env.step(action)
  File "/home/user/sbx_ppo_repro.py", line 21, in step
    chosenAction = self.actions[action]
TypeError: only integer scalar arrays can be converted to a scalar index

 System Info

sbx at the latest commit was installed using pip: pip install git+https://github.com/araffin/sbx

 Checklist

araffin commented 11 months ago

Hello, thanks for the bug report. I guess the issue comes from a flatten layer which is not applied in SBX.

A quick fix is to use a VecFrameStack instead (it stacks on the last axis instead of the first):

from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

vec_env = DummyVecEnv([lambda: env])
vec_env = VecFrameStack(vec_env, 4)

To reproduce with a even more minimal code:

from typing import Any, List, Tuple

import gymnasium as gym
from gymnasium.spaces import Box, Discrete
from sbx import PPO

class MyEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = Box(0, 1, shape=(2, 1), dtype="float32")
        self.action_space = Discrete(15)

    def step(self, action: Any) -> Tuple[Any, float, bool, bool, dict]:
        return self.observation_space.sample(), 0.0, False, False, {}

    def reset(
        self, *, seed: int | None = None, options: dict | None = None
    ) -> Tuple[Any, dict]:
        super().reset(seed=seed, options=options)
        return self.observation_space.sample(), {}

    def render(self) -> Any | List[Any] | None:
        pass

PPO("MlpPolicy", MyEnv()).learn(total_timesteps=1000)
araffin commented 11 months ago

I've pushed a fix in https://github.com/araffin/sbx/pull/18, you should be able to upgrade to sbx 0.9.0 soon =)