Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
442 stars 166 forks source link

EvalCallback crashes Maskable PPO without error #226

Closed icheered closed 5 months ago

icheered commented 5 months ago

🐛 Bug

I'm still investigating, but EvalCallback seems to crash MaskablePPO. Whenever I add EvalCallback(eval_env=env, eval_freq=10) as a callback to model.learn it will work until the 10th iteration, after which the environment's action_masks is no longer called.

Code example


import numpy as np
from gymnasium import Env
from gymnasium.spaces import Box, Discrete
from sb3_contrib import MaskablePPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
import src.grid.utils as utils

class GridEnv(Env):
    def __init__(self, max_episode_steps = 1000):
        super().__init__()
        self.gridsize = 10

        # 0 = north, 1 = east, 2 = south, 3 = west
        self.action_space = Discrete(4)
        self.observation_space = Box(low=0, high=255, shape=(1, self.gridsize, self.gridsize), dtype=np.uint8)

        # Reset to initialize the state 
        self.max_episode_steps = max_episode_steps
        self.episode_step = 0
        self.reset()

    def reset(self, seed=None, options=None):
        # Create square
        self.grid = utils.generate_ones_grid(size=self.gridsize)

        # Add current position in the center
        self.grid[self.gridsize//2, self.gridsize//2] = 3
        self.episode_step = 0

        # Convert the state to a numpy array with dtype float32
        obs = np.expand_dims(self.grid.astype(np.uint8), axis=0) # Shape becomes [1, 20, 20], type uint8
        info = {}
        return obs, info

    def step(self, action):

        # Increment the simulation duration
        self.episode_step += 1
        print(f"Received action: {action}. On step: {self.episode_step}")

        # Apply movement
        self._update_state(action)
        done = False

        # Check if won or lose
        if np.sum(self.grid == 1) == 0:
            done = True
            reward = (self.max_episode_steps - self.episode_step)*3 + 100
        elif self.episode_step >= self.max_episode_steps:
            # Lost, took too long to cover the grid. Reward is -1 for every missed spot
            done = True
            reward = -1 * np.sum(self.grid == 1)
        else:
            done = False
            reward = 1 # One point for surviving

        # Score is number of covered spots
        info = {"episode_step": self.episode_step, "score": np.sum(self.grid == 2), "done": True} if done else {}
        truncated = False

        # Convert state to numpy array
        obs = np.expand_dims(self.grid.astype(np.uint8), axis=0)
        return obs, reward, done, truncated, info

    def action_masks(self):
        # Get the current position of the head, position of the grid where value is 3
        head_position = np.argwhere(self.grid == 3)[0]

        # Get if the each possible action is out of bounds
        up_border = head_position[0] == 0
        right_border =  head_position[1] == self.grid.shape[0] - 1
        down_border =  head_position[0] == self.grid.shape[1] - 1
        left_border =  head_position[1] == 0

        borders = [not up_border, not right_border, not down_border, not left_border]
        print(f"Action mask: {borders}")
        return borders

    def _update_state(self, action: np.int64):
        # Get the current position of the head, position of the grid where value is 3
        head_position = np.argwhere(self.grid == 3)[0]

        # Get the current step
        movement = [(-1, 0), (0, 1), (1, 0), (0, -1)]
        new_head_position = head_position + movement[action]

        # Check if the new head position is out of bounds
        if new_head_position[0] < 0 or new_head_position[0] >= self.grid.shape[0] or new_head_position[1] < 0 or new_head_position[1] >= self.grid.shape[1]:
            raise ValueError("New head position is out of bounds")

        # Update the grid. Set the previous head to 2, and the new head to 3
        self.grid[head_position[0], head_position[1]] = 2
        self.grid[new_head_position[0], new_head_position[1]] = 3

        return 1

env = Monitor(GridEnv(max_episode_steps=200))
check_env(env, warn=True)  # Check if the environment is valid

eval_callback = EvalCallback(
                            eval_env = env, 
                            eval_freq=10, 
                        )

model = MaskablePPO("MlpPolicy", env, verbose=2)
model.learn(
        total_timesteps=1_000_000, 
        callback=[eval_callback],
    )

Relevant log output / Error message

Action mask: [True, True, True, True]
Received action: 2. On step: 1
Action mask: [True, True, True, True]
Received action: 0. On step: 2
Action mask: [True, True, True, True]
Received action: 1. On step: 3
Action mask: [True, True, True, True]
Received action: 3. On step: 4
Action mask: [True, True, True, True]
Received action: 2. On step: 5
Action mask: [True, True, True, True]
Received action: 2. On step: 6
Action mask: [True, True, True, True]
Received action: 3. On step: 7
Action mask: [True, True, True, True]
Received action: 2. On step: 8
Action mask: [True, True, True, True]
Received action: 1. On step: 9
Action mask: [True, True, True, True]
Received action: 1. On step: 10
Received action: 3. On step: 1
Received action: 3. On step: 2
Received action: 3. On step: 3
Received action: 3. On step: 4
Received action: 3. On step: 5
Received action: 3. On step: 6
Traceback (most recent call last):
  File "../../min_working_reproduction.py", line 112, in <module>
    model.learn(
  File "../../.venv/lib/python3.10/site-packages/sb3_contrib/ppo_mask/ppo_mask.py", line 542, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, self.n_steps, use_masking)
  File "../../.venv/lib/python3.10/site-packages/sb3_contrib/ppo_mask/ppo_mask.py", line 318, in collect_rollouts
    if not callback.on_step():
  File "../../.venv/lib/python3.10/site-packages/stable_baselines3/common/callbacks.py", line 114, in on_step
    return self._on_step()
  File "../../.venv/lib/python3.10/site-packages/stable_baselines3/common/callbacks.py", line 221, in _on_step
    continue_training = callback.on_step() and continue_training
  File "../../.venv/lib/python3.10/site-packages/stable_baselines3/common/callbacks.py", line 114, in on_step
    return self._on_step()
  File "../../.venv/lib/python3.10/site-packages/stable_baselines3/common/callbacks.py", line 462, in _on_step
    episode_rewards, episode_lengths = evaluate_policy(
  File "../../.venv/lib/python3.10/site-packages/stable_baselines3/common/evaluation.py", line 94, in evaluate_policy
    new_observations, rewards, dones, infos = env.step(actions)
  File "../../.venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 206, in step
    return self.step_wait()
  File "../../.venv/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 58, in step_wait
    obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step(
  File "../../.venv/lib/python3.10/site-packages/stable_baselines3/common/monitor.py", line 94, in step
    observation, reward, terminated, truncated, info = self.env.step(action)
  File "../../min_working_reproduction.py", line 45, in step
    self._update_state(action)
  File "../../min_working_reproduction.py", line 93, in _update_state
    raise ValueError("New head position is out of bounds")
ValueError: New head position is out of bounds

System Info

Checklist

icheered commented 5 months ago

Thanks to https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/178 I found from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback. Replacing my EvalCallback with MaskableEvalCallback fixed my issue. However this was nowhere in the docs, perhaps this should be added.

araffin commented 5 months ago

However this was nowhere in the docs, perhaps this should be added.

yes, it is only shown in the example. I would appreciate a PR that update the doc and add a note/warning about that ;)

icheered commented 5 months ago

Created a PR: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/227