DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.92k stars 1.68k forks source link

[Question] recommended way to convert gymnasium/gym v26 envs to SB3 VecEnvs #1356

Closed elliottower closed 1 year ago

elliottower commented 1 year ago

❓ Question

Posting this here to not spam the Gymnasium integration PR (#1327) as afaik it’s just a use case question rather than an issue with the PR. Will edit with example code to make things more clear but I mainly just want to know the best practices for converting envs with step() functions returning truncated and terminated bools into SB3’s API using done signals.

I would like to make vector envs but I run into issues due to the differing number of return types (5 vs 4). My initial thought was to ignore truncation and set done to equal termination, but reading discussions and documentation it seems like it’s best to set done equal to truncated or terminated. PR comments here say to use a TimeLimit wrapper as well, to capture the truncation signal. Is this then the best practice?

done                        = terminated or truncated
info["TimeLimit.truncated"] = not terminated and truncated

Example code of wrapping the env with this TimeLimit wrapper and doing this conversion would be greatly appreciated.

Relevant references: https://github.com/DLR-RM/stable-baselines3/blob/feat/gymnasium-support/docs/guide/vec_envs.rst https://github.com/openai/gym/issues/3102#issuecomment-1275909754 https://gymnasium.farama.org/content/migration-guide/ https://github.com/DLR-RM/stable-baselines3/pull/780#discussion_r1116365773

Edit: a bit more context for what my issue was (converting the step function): https://github.com/DLR-RM/stable-baselines3/pull/1327#issuecomment-1451232543

Full code below: sb3_train.py (updating older training script with older pettingzoo using gym rather than gymnasium):

"""Binary to run Stable Baselines 3 agents on meltingpot substrates."""

import gymnasium
import stable_baselines3
from stable_baselines3.common import callbacks
from stable_baselines3.common import torch_layers
from stable_baselines3.common import vec_env
from rl_zoo3.gym_patches import PatchedTimeLimit
# from sb3_contrib.common import vec_env # only has async env

import supersuit as ss
import torch
from torch import nn
import torch.nn.functional as F

from examples.pettingzoo import utils
from meltingpot.python import substrate

device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

# Use this with lambda wrapper returning observations only
class CustomCNN(torch_layers.BaseFeaturesExtractor):
  """Class describing a custom feature extractor."""

  def __init__(
      self,
      observation_space: gymnasium.spaces.Box,
      features_dim=128,
      num_frames=6,
      fcnet_hiddens=(1024, 128),
  ):
    """Construct a custom CNN feature extractor.

    Args:
      observation_space: the observation space as a gymnasium.Space
      features_dim: Number of features extracted. This corresponds to the number
        of unit for the last layer.
      num_frames: The number of (consecutive) frames to feed into the network.
      fcnet_hiddens: Sizes of hidden layers.
    """
    super(CustomCNN, self).__init__(observation_space, features_dim)
    # We assume CxHxW images (channels first)
    # Re-ordering will be done by pre-preprocessing or wrapper

    self.conv = nn.Sequential(
        nn.Conv2d(
            num_frames * 3, num_frames * 3, kernel_size=8, stride=4, padding=0),
        nn.ReLU(),  # 18 * 21 * 21
        nn.Conv2d(
            num_frames * 3, num_frames * 6, kernel_size=5, stride=2, padding=0),
        nn.ReLU(),  # 36 * 9 * 9
        nn.Conv2d(
            num_frames * 6, num_frames * 6, kernel_size=3, stride=1, padding=0),
        nn.ReLU(),  # 36 * 7 * 7
        nn.Flatten(),
    )
    flat_out = num_frames * 6 * 7 * 7
    self.fc1 = nn.Linear(in_features=flat_out, out_features=fcnet_hiddens[0])
    self.fc2 = nn.Linear(
        in_features=fcnet_hiddens[0], out_features=fcnet_hiddens[1])

  def forward(self, observations) -> torch.Tensor:
    # Convert to tensor, rescale to [0, 1], and convert from
    #   B x H x W x C to B x C x H x W
    observations = observations.permute(0, 3, 1, 2)
    features = self.conv(observations)
    features = F.relu(self.fc1(features))
    features = F.relu(self.fc2(features))
    return features

def main():
  # Config
  substrate_name = "commons_harvest__open"
  player_roles = substrate.get_config(substrate_name).default_player_roles
  env_config = {"substrate": substrate_name, "roles": player_roles}

  env = utils.parallel_env(env_config)
  rollout_len = 1000
  total_timesteps = 2000000
  num_agents = env.max_num_agents

  max_steps = 1000
  max_steps_eval = 1000

  # Training
  num_cpus = 1  # number of cpus
  num_envs = 1  # number of parallel multi-agent environments
  # number of frames to stack together; use >4 to avoid automatic
  # VecTransposeImage
  num_frames = 4
  # output layer of cnn extractor AND shared layer for policy and value
  # functions
  features_dim = 128
  fcnet_hiddens = [1024, 128]  # Two hidden layers for cnn extractor
  ent_coef = 0.001  # entropy coefficient in loss
  batch_size = (rollout_len * num_envs // 2
               )  # This is from the rllib baseline implementation
  lr = 0.0001
  n_epochs = 30
  gae_lambda = 1.0
  gamma = 0.99
  target_kl = 0.01
  grad_clip = 40
  verbose = 3
  model_path = None  # Replace this with a saved model

  env = utils.parallel_env(
      max_cycles=rollout_len,
      env_config=env_config,
  )
  env = ss.observation_lambda_v0(env, lambda x, _: x["RGB"], lambda s: s["RGB"])
  env = ss.dtype_v0(env, "uint8")
  env = ss.pettingzoo_env_to_vec_env_v1(env)
  env = ss.concat_vec_envs_v1(
      env,
      num_vec_envs=num_envs,
      num_cpus=num_cpus,
      base_class="stable_baselines3")
  env = PatchedTimeLimit(env, max_steps)
  env = vec_env.VecTransposeImage(env, True)
  env = vec_env.VecFrameStack(env, num_frames)

  eval_env = utils.parallel_env(
      max_cycles=rollout_len,
      env_config=env_config,
  )
  eval_env = ss.observation_lambda_v0(eval_env, lambda x, _: x["RGB"],
                                      lambda s: s["RGB"])
  eval_env = ss.dtype_v0(eval_env, "uint8")
  eval_env = ss.pettingzoo_env_to_vec_env_v1(eval_env)
  eval_env = ss.concat_vec_envs_v1(
      eval_env,
      num_vec_envs=1,
      num_cpus=1,
      base_class="stable_baselines3")
  eval_env = PatchedTimeLimit(eval_env, max_steps_eval)
  eval_env = vec_env.VecTransposeImage(eval_env, True)
  eval_env = vec_env.VecFrameStack(eval_env, num_frames)
  eval_freq = 100000 // (num_envs * num_agents)

  policy_kwargs = dict(
      features_extractor_class=CustomCNN,
      features_extractor_kwargs=dict(
          features_dim=features_dim,
          num_frames=num_frames,
          fcnet_hiddens=fcnet_hiddens,
      ),
      net_arch=[features_dim],
  )

  tensorboard_log = "./results/sb3/harvest_open_ppo_paramsharing"

  model = stable_baselines3.PPO(
      "CnnPolicy",
      env=env,
      learning_rate=lr,
      n_steps=rollout_len,
      batch_size=batch_size,
      n_epochs=n_epochs,
      gamma=gamma,
      gae_lambda=gae_lambda,
      ent_coef=ent_coef,
      max_grad_norm=grad_clip,
      target_kl=target_kl,
      policy_kwargs=policy_kwargs,
      tensorboard_log=tensorboard_log,
      verbose=verbose,
  )
  if model_path is not None:
    model = stable_baselines3.PPO.load(model_path, env=env)
  eval_callback = callbacks.EvalCallback(
      eval_env, eval_freq=eval_freq, best_model_save_path=tensorboard_log)
  model.learn(total_timesteps=total_timesteps, callback=eval_callback)

  logdir = model.logger.dir
  model.save(logdir + "/model")
  del model
  model = stable_baselines3.PPO.load(logdir + "/model")  # noqa: F841

if __name__ == "__main__":
  main()

Utils helper file (also updating original script with old pettingzoo/gym rather than gymnasium):

"""PettingZoo interface to meltingpot environments."""

import functools

from gymnasium import utils as gym_utils
import matplotlib.pyplot as plt
from ml_collections import config_dict
from pettingzoo import utils as pettingzoo_utils
from pettingzoo.utils import wrappers

from examples import utils
from meltingpot.python import substrate

PLAYER_STR_FORMAT = 'player_{index}'
MAX_CYCLES = 1000

def parallel_env(env_config, max_cycles=MAX_CYCLES):
  return _ParallelEnv(env_config, max_cycles)

def raw_env(env_config, max_cycles=MAX_CYCLES):
  return pettingzoo_utils.parallel_to_aec_wrapper(
      parallel_env(env_config, max_cycles))

def env(env_config, max_cycles=MAX_CYCLES):
  aec_env = raw_env(env_config, max_cycles)
  aec_env = wrappers.AssertOutOfBoundsWrapper(aec_env)
  aec_env = wrappers.OrderEnforcingWrapper(aec_env)
  return aec_env

class _MeltingPotPettingZooEnv(pettingzoo_utils.ParallelEnv):
  """An adapter between Melting Pot substrates and PettingZoo's ParallelEnv."""

  def __init__(self, env_config, max_cycles):
    self.env_config = config_dict.ConfigDict(env_config)
    self.max_cycles = max_cycles
    self._env = substrate.build(env_config['substrate'], roles=env_config['roles'])
    self._num_players = len(self._env.observation_spec())
    self.possible_agents = [
        PLAYER_STR_FORMAT.format(index=index)
        for index in range(self._num_players)
    ]
    self.agents = [agent for agent in self.possible_agents]
    observation_space = utils.remove_world_observations_from_space(
        utils.spec_to_space(self._env.observation_spec()[0]))
    self.observation_space = functools.lru_cache(
        maxsize=None)(lambda agent_id: observation_space)
    action_space = utils.spec_to_space(self._env.action_spec()[0])
    self.action_space = functools.lru_cache(maxsize=None)(
        lambda agent_id: action_space)
    self.state_space = utils.spec_to_space(
        self._env.observation_spec()[0]['WORLD.RGB'])

  def state(self):
    return self._env.observation()

  def reset(self, seed=None, **kwargs):
    """See base class."""
    timestep = self._env.reset()
    self.agents = self.possible_agents[:]
    self.num_cycles = 0
    return utils.timestep_to_observations(timestep)

  def step(self, action):
    """See base class."""
    actions = [action[agent] for agent in self.agents]
    timestep = self._env.step(actions)
    rewards = {
        agent: timestep.reward[index] for index, agent in enumerate(self.agents)
    }
    self.num_cycles += 1
    termination = timestep.last()
    terminations = {agent: termination for agent in self.agents}
    truncation = self.num_cycles >= self.max_cycles
    truncations = {agent: truncation for agent in self.agents}
    infos = {agent: {} for agent in self.agents}
    if termination:
      self.agents = []

    observations = utils.timestep_to_observations(timestep)
    return observations, rewards, terminations, truncations, infos

  def close(self):
    """See base class."""
    self._env.close()

  def render(self, mode='human', filename=None):
    rgb_arr = self.state()['WORLD.RGB']
    if mode == 'human':
      plt.cla()
      plt.imshow(rgb_arr, interpolation='nearest')
      if filename is None:
        plt.show(block=False)
      else:
        plt.savefig(filename)
      return None
    return rgb_arr

class _ParallelEnv(_MeltingPotPettingZooEnv, gym_utils.EzPickle):
  metadata = {'render_modes': ['human', 'rgb_array']}

  def __init__(self, env_config, max_cycles):
    gym_utils.EzPickle.__init__(self, env_config, max_cycles)
    _MeltingPotPettingZooEnv.__init__(self, env_config, max_cycles)

Error:

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1496, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/elliottower/Documents/GitHub/meltingpot/examples/pettingzoo/sb3_train.py", line 197, in <module>
    if __name__ == "__main__":
  File "/Users/elliottower/Documents/GitHub/meltingpot/examples/pettingzoo/sb3_train.py", line 188, in main
    eval_env, eval_freq=eval_freq, best_model_save_path=tensorboard_log)
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/ppo/ppo.py", line 304, in learn
    return super().learn(
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 246, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 174, in collect_rollouts
    new_obs, rewards, dones, infos = env.step(clipped_actions)
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/base_vec_env.py", line 171, in step
    return self.step_wait()
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 33, in step_wait
    observations, rewards, dones, infos = self.venv.step_wait()
  File "/Users/elliottower/Documents/GitHub/meltingpot/venv/lib/python3.9/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 95, in step_wait
    observations, rewards, dones, infos = self.venv.step_wait()
ValueError: too many values to unpack (expected 4)

Checklist

araffin commented 1 year ago

Hello, if you want to know the conversion from gym api to VecEnv api, it is there: https://github.com/DLR-RM/stable-baselines3/blob/feat/gymnasium-support/stable_baselines3/common/vec_env/dummy_vec_env.py#L60-L71

Also relevant: https://github.com/DLR-RM/rl-baselines3-zoo/blob/feat/gymnasium-support/rl_zoo3/gym_patches.py#L27-L50

But gym API is for single env only normally. If you want to define a VecEnv directly, you can take a look at https://github.com/DLR-RM/rl-baselines3-zoo/pull/355 where we define a VecEnv for envpool envs.

elliottower commented 1 year ago

Thanks for the links, the patched time limit works perfectly, but the problem is I can't get that to work with other wrappers like VecTransposeImage and VecFrameStack, which have the 4 return types with done instead of 5 with terminated and truncated. I'll post some example code here so it's easier to understand the use case.

I found that ss actually has a ss.concat_vec_envs_v1() wrapper which works with base_class='stable_baselines3', and wraps it into a vector env, so I think (?) I can use that instead of writing my own vector env wrapper like https://github.com/DLR-RM/rl-baselines3-zoo/pull/355. Although I guess in this case maybe the best option is to write a wrapper which turns the pettingzoo env into vector envs that have the right return types to work with VecFrameStack and VecTransposeImage, as you said that API using done was going to continue to be the standard for SB3 internally.

araffin commented 1 year ago

I can't get that to work with other wrappers like VecTransposeImage and VecFrameStack,

it looks like you are mixing gym wrappers and VecEnv wrappers, as you notices they don't work together.

a wrapper which turns the pettingzoo env into vector envs that have the right return types to work with VecFrameStack and VecTransposeImage,

yes, probably the best option.

elliottower commented 1 year ago

Just as an update I got this working by modifying ss.sb3_vector_wrapper (used in ss.concat_vec_envs_v1, opened a PR for it) but I was thinking it would probably make the most sense if there was support for creating gymnasium/pettingzoo vector envs directly with stable-baselines3 (@araffin ) As said earlier in this thread the differing number of return values for step() prevents the existing SB3 vec env functions from working.

Working code using ss:

env = utils.parallel_env(render_mode="rgb_array", env_config=env_config, max_cycles=rollout_len) # load from meltingpot into a PettingZoo env
  env = ss.observation_lambda_v0(env, lambda x, _: x["RGB"], lambda s: s["RGB"]) 
  env = ss.pettingzoo_env_to_vec_env_v1(env)
  env = ss.concat_vec_envs_v1(
      env,
      num_vec_envs=num_envs,
      num_cpus=num_cpus,
      base_class="stable_baselines3")
  env = vec_env.VecMonitor(env)
  env = vec_env.VecTransposeImage(env, True)
  env = vec_env.VecFrameStack(env, num_frames)
araffin commented 1 year ago

I got this working by modifying

Good to hear =)

I was thinking it would probably make the most sense if there was support for creating gymnasium/pettingzoo vector envs directly with stable-baselines3

As explained in its paper/blog posrt, SB3 is focused on single agent model free RL. Support for more should be done in external repositories (like imitation/offline RL). We also should not add additional dependencies (like petting zoo or super suit), so I would disagree with that statement.

The only custom VecEnv we are considering adding now are envpool and isaac gym (both will probably be implemented in the zoo as they don't cover full VecEnv features).

Working code using ss:

Closing as the original question is solved.