DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.68k stars 1.65k forks source link

Support VecEnv for gymnasium.vector.VectorEnv and Brax #1745

Open vyeevani opened 10 months ago

vyeevani commented 10 months ago

🚀 Feature

It would be nice to have a wrapper that ingested gymnasium.vector.VectorEnv and gave back a VecEnv.

Motivation

I want to do highly parallelized hardware accelerated simulation. This pretty much leaves Isaac or Brax. Brax has a lighter weight setup plus runs on TPUs. Stable baselines has well documented and tested implementations of most of the algorithms that I'm interested in using, as well as deep integration with the imitate library. I'd like to use both of these libraries.

Pitch

Brax currently provides a wrapper for legacy OpenAI gym vectorized environments. I have a request up to support Gymnasium vectorized API (pretty much just change the imports to Gymnasium instead of Gym). Stable baselines requires vectorized environments to be implemented against it's specific VecEnv specification. As far as I can tell, it's pretty simple to migrate between gymnasium vectorized env API and sb3's representation.

I'd like a wrapper class to be provided that implements VecEnv with an underlying gymnasium vectorized env.

Alternatives

Given the public API allows users to extend the library to write this themselves, that would be the chief alternative.

Additional context

No response

Checklist

vyeevani commented 10 months ago

Where is the duplicate? I searched for it but couldn't find it. Would appreciate a pointer

araffin commented 10 months ago

Partial duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1568#issuecomment-1600595147 and https://github.com/DLR-RM/stable-baselines3/issues/229

For short: a VecEnvWrapper would be indeed a good idea but only after gymnasium 1.0 is released and fully tested. Would you be willing to contribute such wrapper?

Related doc: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#sb3-with-envpool-or-isaac-gym

Related issues: https://github.com/DLR-RM/stable-baselines3/issues/1712 and https://github.com/DLR-RM/stable-baselines3/issues/772#issuecomment-1048657002

vyeevani commented 9 months ago

I happened to step away from using Gymnasium APIs. I was focused on Brax.

from typing import ClassVar, Optional

from brax.envs.base import PipelineEnv
from brax.io import image
# import gym
# from gym import spaces
import gymnasium
import gymnasium as gym
from gymnasium import spaces
from gymnasium.vector import utils
import jax
import jax.numpy as jp
import numpy as np

from stable_baselines3.common.vec_env.base_vec_env import VecEnvIndices

class SB3Wrapper(VecEnv):
  def __init__(self,
               env: PipelineEnv,
               seed: int = 0,
               info_keys: Optional[Sequence[str]] = None,
               backend: Optional[str] = None):
    self._env = env
    self.info_keys = info_keys
    self.metadata = {
        'render.modes': ['human', 'rgb_array'],
        'video.frames_per_second': 1 / self._env.dt
    }
    if not hasattr(self._env, 'batch_size'):
      raise ValueError('underlying env must be batched')

    if not hasattr(self._env, 'episode_length'):
      raise ValueError('underlying env must be wrapped with an episode wrapper')

    obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
    obs_space = spaces.Box(-obs, obs, dtype='float32')

    action = jax.tree_map(np.array, self._env.sys.actuator.ctrl_range)
    action_space = spaces.Box(action[:, 0], action[:, 1], dtype='float32')

    self.num_envs = self._env.batch_size
    self.observation_space = obs_space
    self.action_space = action_space

    # self.batch_observation_space = utils.batch_space(obs_space, self.num_envs)
    # self.batch_action_space = utils.batch_space(action_space, self.num_envs)

    self.seed(seed)
    self.backend = backend
    self._state = None

    def reset(key):
      key1, key2 = jax.random.split(key)
      state = self._env.reset(key2)
      return state, state.obs, key1

    self._reset = jax.jit(reset, backend=self.backend)

    def step(state, action):
      state = self._env.step(state, action)
      info = {**state.metrics, **state.info}
      return state, state.obs, state.reward, state.done, state.info['truncation'], info

    self._step = jax.jit(step, backend=self.backend)

  def reset(self, **kwargs):
    self._state, obs, self._key = self._reset(self._key)
    return np.array(obs)

  def step_async(self, action):
    self.action = jp.array(action)

  def step_wait(self):
    self._state, obs, reward, done, truncation, info = self._step(self._state, self.action)
    def batch_dict_to_list_dict(batched_dict, keys_to_process):
      return [{} for i in range(self.num_envs)]
      # if keys_to_process is None:
        # return [{} for i in range(self.num_envs)]
      # # Filter the dictionary to only include specified keys that are JAX arrays
      # filtered_dict = {key: batched_dict[key] for key in keys_to_process if key in batched_dict and isinstance(batched_dict[key], jnp.ndarray)}

      # # Find the batch size from the first item in the filtered dictionary
      # batch_size = filtered_dict[next(iter(filtered_dict))].shape[0] if filtered_dict else 0

      # # Create a list of dictionaries for each batch index
      # return [{key: filtered_dict[key][i] for key in filtered_dict} for i in range(batch_size)]
    info = batch_dict_to_list_dict(info, self.info_keys)
    # print(reward)
    return np.array(obs), np.array(reward), np.array(done), info

  def seed(self, seed: int = 0):
    self._key = jax.random.PRNGKey(seed)

  def render(self, mode='human'):
    if mode == 'rgb_array':
      sys, state = self._env.sys, self._state
      if state is None:
        raise RuntimeError('must call reset or step before rendering')
      return image.render_array(sys, state.pipeline_state.take(0), 256, 256)
    else:
      return super().render(mode=mode)  # just raise an exception

  def close(self):
    pass

  def env_is_wrapped(self, wrapper_class):
    return [False] * self.num_envs

  def step(self, actions):
    self.step_async(actions)
    return self.step_wait()

  def get_attr(self, attr_name, indicies):
    return getattr(self, attr_name)

  def set_attr(self, attr_name, value, indicies):
    return setattr(self, attr_name, value)

  def env_method(self, method_name, *method_args, indicies, **method_kwargs):
    return self.get_attr(method_name)(method_args, method_kwargs)

class AutoResetWrapper2(Wrapper):
  """Automatically resets Brax envs that are done."""
  def reset(self, rng: jax.Array) -> State:
    base_state = self.env.reset(rng)
    info = base_state.info.copy()
    info.update({
        'initial_base_state': base_state,
        'current_base_state': base_state
    })

    return State(
        pipeline_state=base_state.pipeline_state,
        obs=base_state.obs,
        reward=base_state.reward,
        done=base_state.done,
        metrics=base_state.metrics,
        info=info
    )

  def step(self, state: State, action: jax.Array) -> State:
    initial_base_state = state.info['initial_base_state']
    current_base_state = state.info['current_base_state']
    next_base_state = self.env.step(current_base_state, action)

    done = next_base_state.done
    def where_done(x, y):
      return jp.where(done, x, y)

    info = jax.tree_map(where_done, initial_base_state.info, next_base_state.info).copy()
    info.update ({
        'initial_base_state': initial_base_state,
        'current_base_state': jax.tree_map(where_done, initial_base_state, next_base_state),
    })

    return State(
        pipeline_state=jax.tree_map(where_done, initial_base_state.pipeline_state, next_base_state.pipeline_state),
        obs=jax.tree_map(where_done, initial_base_state.obs, next_base_state.obs),
        reward=jax.tree_map(where_done, initial_base_state.reward, next_base_state.reward),
        done=next_base_state.done,
        metrics=jax.tree_map(where_done, initial_base_state.metrics, next_base_state.metrics),
        info=info
    )

from brax.envs.wrappers.training import VmapWrapper, EpisodeWrapper, AutoResetWrapper
from brax.envs.ant import Ant
from brax.envs.humanoid import Humanoid

episode_length = 1000
backend = 'spring'
batch_size = 1024
action_repeat = 1

env = Ant(backend='spring')
env = EpisodeWrapper(env, episode_length, action_repeat=action_repeat)
env = AutoResetWrapper2(env)
env = VmapWrapper(env, batch_size)
vyeevani commented 9 months ago

^ This is really hacky stuff and there's tons that's terrible about it. This is a high level sketch of everything that would be needed to get this to work.

araffin commented 9 months ago

Hello, thanks for providing the code =) Do you need any help to get it to work? I would be happy to link it in our doc (and maybe integrate it in the zoo or sb3 contrib) as it should be similar to envpool/isaac gym: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#sb3-with-envpool-or-isaac-gym

jamesheald commented 1 month ago

@vyeevani did you finalise this into a working version?