araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

Support for MultiDiscrete and MultiBinary action spaces in PPO #30

Closed jan1854 closed 7 months ago

jan1854 commented 7 months ago

Description

closes #19

Addresses #19. Adds support for MultiDiscrete and MultiBinary action spaces to PPO.

Constructs a multivariate categorical distribution through Tensorflow Probability's Independent and Categorical. Note that the Categorical distribution requires every variable to have the same number of categories. Therefore, I pad the logits to the largest shape across the dimensions (pad by -inf to ensure that these invalid actions have zero probability).

MultiBinary is handled as a special case of MultiDiscrete with two choices per categorical variable.

Only one-dimensional action spaces are supported, so using, e.g., MultiDiscrete([[2],[3]]) or MultiBinary([2, 3]) will result in an exception (as in stable-baselines3).

Testing

I added some tests (tests/test_space, similar to the tests in stable-baselines3) that check if there are errors during learning and that the correct exceptions are raised if PPO is used with multi-dimensional MultiDiscrete and MultiBinary action spaces.

To check whether there are issues with the learning performance, I compared the performance to stable-baselines3's PPO on MultiDiscrete and MultiBinary action space environments. Since there are no environments with these action spaces in the classic Gym benchmarks, I used a discretized action version of Reacher and a binary action version of Acrobot for testing purposes (see the wrappers below).

Test script for MultiDiscrete action spaces:

from datetime import datetime
from typing import Sequence

import gymnasium as gym
import numpy as np

from sbx import PPO

class ActionDiscretizationWrapper(gym.ActionWrapper):
    def __init__(self, env, bins: Sequence[int]):
        super().__init__(env)
        assert isinstance(self.env.action_space, gym.spaces.Box)
        self.action_space = gym.spaces.MultiDiscrete(bins)

    def action(self, action: np.ndarray) -> np.ndarray:
        assert np.all(action < self.action_space.nvec)
        range = self.env.action_space.high - self.env.action_space.low
        cont_action = range * action / (self.action_space.nvec - 1) + self.env.action_space.low
        return cont_action

if __name__ == "__main__":
    env = ActionDiscretizationWrapper(gym.make("Reacher-v4"), [15, 17])

    date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    agent = PPO("MlpPolicy", env, tensorboard_log=f"out/reacher_discrete_{date_time}")
    agent.learn(1000000, progress_bar=True)

Test script for MultiBinary action spaces:

from datetime import datetime

import numpy as np

import gymnasium as gym
from sbx import PPO

class BinaryAcrobotWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        # One action for applying torque -1 (original action: 0), one action for applying torque 1 (original action: 2).
        # If both bits (or none) are set, the torque is 0 (original action: 1).
        self.action_space = gym.spaces.MultiBinary(2)

    def action(self, action: np.ndarray) -> np.ndarray:
        return int(action[1] - action[0] + 1)

if __name__ == "__main__":
    env = BinaryAcrobotWrapper(gym.make("Acrobot-v1"))

    date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    agent = PPO("MlpPolicy", env, tensorboard_log=f"out/binary_acrobot_{date_time}")
    agent.learn(1000000, progress_bar=True)

Results: sbx's and stable-baselines3's PPO have the same learning performance.

reacher_discrete_sbx_vs_sb3

acrobot_binary_sbx_vs_sb3

Motivation and Context

Types of changes

Checklist:

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

araffin commented 7 months ago

Hello, thanks again for the PR =) I'll try to have a look in the coming days.

Btw, because of your good contributions, would you be interested in becoming a SBX maintainer? (so you won't have to fork the repo for fixing a bug/adding a feature)

jan1854 commented 7 months ago

Sounds awesome, I'd be happy to become an SBX maintainer :)

araffin commented 7 months ago

For built-in multi discrete, I think there are the Atari games? Although we would need to use the ram version at first until CNN are supported by SBX.