DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.68k stars 1.65k forks source link

[Feature Request] Multi-Cluster Support for SubprocVecEnv #1345

Open snowyday opened 1 year ago

snowyday commented 1 year ago

🚀 Feature

I propose enhancing the SubprocVecEnv to support multiple clusters. To achieve this, I have created a new class called DistVecEnv that is fully compatible with SubprocVecEnv as the following:

class DistVecEnv(VecEnv):
    """
    Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
    process, allowing significant speed up when the environment is computationally complex.

    For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
    number of logical cores on your CPU.

    :param env_fns: Environments to run in subprocesses

    Notes
    ------
    `Actor` is a support class for `DistVecEnv` that controls remote environments, similar to `_worker` for
    `SubprocEnv`.
    """

    def __init__(self, env_fns: List[Callable[[], gym.Env]]):
        self.waiting = False
        self.closed = False
        self.ref_steps = []
        self.actors: List[Actor] = [Actor.remote(env=env_fn) for env_fn in env_fns]
        observation_space, action_space = ray.get(self.actors[0].get_spaces.remote())
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)

    def step_async(self, actions: np.ndarray) -> None:
        self.ref_results = [actor.step.remote(action) for actor, action in zip(self.actors, actions)]
        self.waiting = True

    def step_wait(self) -> VecEnvStepReturn:
        results = ray.get(self.ref_results)
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos

    def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
        if seed is None:
            seed = np.random.randint(0, 2**32 - 1)
        return ray.get([actor.seed.remote(seed + idx) for idx, actor in enumerate(self.actors)])

    def reset(self) -> VecEnvObs:
        obs = ray.get([actor.reset.remote() for actor in self.actors])
        return _flatten_obs(obs, self.observation_space)

    def close(self) -> None:
        if self.closed:
            return
        ray.get([actor.close.remote() for actor in self.actors])
        self.closed = True

    def get_images(self) -> Sequence[np.ndarray]:
        return ray.get([actor.render.remote("rgb_array") for actor in self.actors])

    def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
        """Return attribute from vectorized environment (see base class)."""
        indices = self._get_indices(indices)
        return ray.get([self.actors[index].get_attr.remote(attr_name) for index in indices])

    def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
        """Set attribute inside vectorized environments (see base class)."""
        indices = self._get_indices(indices)
        ray.get([self.actors[index].set_attr.remote(attr_name, value) for index in indices])

    def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
        """Call instance methods of vectorized environments."""
        indices = self._get_indices(indices)
        refs = [self.actors[index].env_method.remote(method_name, (method_args, method_kwargs)) for index in indices]
        return ray.get(refs)

    def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
        """Check if worker environments are wrapped with a given wrapper"""
        indices = self._get_indices(indices)
        return ray.get([self.actors[index].is_wrapped.remote(wrapper_class) for index in indices])

Train a PPO agent on CartPole-v1 using 8 environments.

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DistVecEnv

if __name__=="__main__":
    env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=DistVecEnv)
    model = PPO("MlpPolicy", env, device="cpu")
    model.learn(total_timesteps=25_000)

To see the implementation of DistVecEnv, please refer to my gist.

Motivation

The current SubporcVecEnv implementation using multiprocessing can be limited by hardware and may not be suitable for scaling to larger cluster environments. Utilizing Ray for distributed computing can provide a more scalable solution.

Pitch

This feature request aims to enhance the SubprocVecEnv class by creating a new class called DistVecEnv that uses the Ray library to distribute multiple environments to their own processes.

Alternatives

Supporting MPI would be a complex and challenging task (please see MPIVecEnv), and may not provide the scalability and ease of use that Ray offers.

Additional context

Issues related to this feature request have been discussed for several years:

Checklist

araffin commented 1 year ago

Hello, thanks for the proposal, I will try to take a look at it soon (my stack was quite last week). Could you add the missing imports so I could run the code?

snowyday commented 1 year ago

Hi @araffin, Thank you for your response! I've attached the complete code, as requested, and I hope it has been helpful. The missing imports have been added, and Ray is the only additional library required (installable with 'pip install -U "ray[default]"'). My version of Ray is '2.3.0', but I have only used basic functions, so it should work with any ray version installed via pip.

I tried test_vec_envs.py with VEC_ENV_CLASSES = [DistVecEnv], and it passed. I believe that at least the PPO in the following main should work.

Show code ```python import ray from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import gym import numpy as np from stable_baselines3.common.vec_env.base_vec_env import ( VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn, ) from stable_baselines3.common.vec_env.subproc_vec_env import _flatten_obs @ray.remote class Actor: def __init__(self, env: gym.Env) -> None: from stable_baselines3.common.env_util import is_wrapped self.env = env() self.observation_space = self.env.observation_space self.action_space = self.env.action_space def get_attr(self, attr_name: str) -> Any: return getattr(self.env, attr_name) def set_attr(self, attr_name: str, value: Any) -> None: setattr(self.env, attr_name, value) def get_spaces(self) -> Tuple[Any, Any]: return self.observation_space, self.action_space def step(self, action: Any) -> Any: observation, reward, done, info = self.env.step(action) if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation observation = self.env.reset() return observation, reward, done, info def seed(self, value) -> Union[None, int]: return self.env.seed(value) def reset(self) -> Any: return self.env.reset() def close(self) -> None: self.env.close() def render(self, mode: str) -> np.ndarray: return self.env.render(mode=mode) def env_method(self, method_name, args) -> Any: return getattr(self.env, method_name)(*args[0], **args[1]) def is_wrapped(self, wrapper_class: Type[gym.Wrapper]) -> bool: return is_wrapped(self.env, wrapper_class) class DistVecEnv(VecEnv): """ Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own process, allowing significant speed up when the environment is computationally complex. For performance reasons, if your environment is not IO bound, the number of environments should not exceed the number of logical cores on your CPU. :param env_fns: Environments to run in subprocesses Notes ------ `Actor` is a support class for `DistVecEnv` that controls remote environments, similar to `_worker` for `SubprocEnv`. """ def __init__(self, env_fns: List[Callable[[], gym.Env]]): self.waiting = False self.closed = False self.ref_steps = [] self.actors: List[Actor] = [Actor.remote(env=env_fn) for env_fn in env_fns] observation_space, action_space = ray.get(self.actors[0].get_spaces.remote()) VecEnv.__init__(self, len(env_fns), observation_space, action_space) def step_async(self, actions: np.ndarray) -> None: self.ref_results = [actor.step.remote(action) for actor, action in zip(self.actors, actions)] self.waiting = True def step_wait(self) -> VecEnvStepReturn: results = ray.get(self.ref_results) self.waiting = False obs, rews, dones, infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: if seed is None: seed = np.random.randint(0, 2**32 - 1) return ray.get([actor.seed.remote(seed + idx) for idx, actor in enumerate(self.actors)]) def reset(self) -> VecEnvObs: obs = ray.get([actor.reset.remote() for actor in self.actors]) return _flatten_obs(obs, self.observation_space) def close(self) -> None: if self.closed: return ray.get([actor.close.remote() for actor in self.actors]) self.closed = True def get_images(self) -> Sequence[np.ndarray]: return ray.get([actor.render.remote("rgb_array") for actor in self.actors]) def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" indices = self._get_indices(indices) return ray.get([self.actors[index].get_attr.remote(attr_name) for index in indices]) def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" indices = self._get_indices(indices) ray.get([self.actors[index].set_attr.remote(attr_name, value) for index in indices]) def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: """Call instance methods of vectorized environments.""" indices = self._get_indices(indices) refs = [self.actors[index].env_method.remote(method_name, (method_args, method_kwargs)) for index in indices] return ray.get(refs) def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: """Check if worker environments are wrapped with a given wrapper""" indices = self._get_indices(indices) return ray.get([self.actors[index].is_wrapped.remote(wrapper_class) for index in indices]) if __name__ == "__main__": from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=DistVecEnv) model = PPO("MlpPolicy", env, device="cpu") model.learn(total_timesteps=25_000) ``` Please feel free to let me know if you have any further questions or concerns. (By the way, just a quick quip, I'm not tied in any way with Ray - if I were, I'd surely use it for RLlib.)
araffin commented 1 year ago

Thanks I could give it a quick try and it does work =) There seems to be a big overhead compared to SubprocVecEnv or DummyVecEnv (both at startup and during data collection), do you know where it comes from? The only use-case that would be interesting with this is the multi-node support of ray, so its place is probably the RL Zoo (https://github.com/DLR-RM/rl-baselines3-zoo).

I updated slightly the code and fixed some minor bugs:

EDIT: related but looks more complex: https://github.com/ingambe/RayEnvWrapper/blob/0a4fc91807297db620dec4c763ca48d5c2479fc4/RayEnvWrapper/CustomRayRemoteEnv.py

Show code ```python import ray from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union import gym import numpy as np from stable_baselines3.common.vec_env.base_vec_env import ( VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn, ) from stable_baselines3.common.vec_env.subproc_vec_env import _flatten_obs, SubprocVecEnv @ray.remote class RayWorker: def __init__(self, env: Callable[[], gym.Env]) -> None: self.env = env() self.observation_space = self.env.observation_space self.action_space = self.env.action_space def get_attr(self, attr_name: str) -> Any: return getattr(self.env, attr_name) def set_attr(self, attr_name: str, value: Any) -> None: setattr(self.env, attr_name, value) def get_spaces(self) -> Tuple[Any, Any]: return self.observation_space, self.action_space def step(self, action: Any) -> Any: observation, reward, done, info = self.env.step(action) if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation observation = self.env.reset() return observation, reward, done, info def seed(self, value) -> Union[None, int]: return self.env.seed(value) def reset(self) -> Any: return self.env.reset() def close(self) -> None: self.env.close() def render(self, mode: str) -> np.ndarray: return self.env.render(mode=mode) def env_method(self, method_name, args) -> Any: return getattr(self.env, method_name)(*args[0], **args[1]) def is_wrapped(self, wrapper_class: Type[gym.Wrapper]) -> bool: from stable_baselines3.common.env_util import is_wrapped return is_wrapped(self.env, wrapper_class) class DistVecEnv(VecEnv): """ Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own process, allowing significant speed up when the environment is computationally complex. For performance reasons, if your environment is not IO bound, the number of environments should not exceed the number of logical cores on your CPU. :param env_fns: Environments to run in subprocesses Notes ------ `RayWorker` is a support class for `DistVecEnv` that controls remote environments, similar to `_worker` for `SubprocEnv`. """ def __init__(self, env_fns: List[Callable[[], gym.Env]]): self.waiting = False self.closed = False self.ref_steps = [] self.workers: List[RayWorker] = [RayWorker.remote(env=env_fn) for env_fn in env_fns] observation_space, action_space = ray.get(self.workers[0].get_spaces.remote()) VecEnv.__init__(self, len(env_fns), observation_space, action_space) def step_async(self, actions: np.ndarray) -> None: self.ref_results = [actor.step.remote(action) for actor, action in zip(self.workers, actions)] self.waiting = True def step_wait(self) -> VecEnvStepReturn: results = ray.get(self.ref_results) self.waiting = False obs, rews, dones, infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: if seed is None: seed = np.random.randint(0, 2**32 - 1) return ray.get([actor.seed.remote(seed + idx) for idx, actor in enumerate(self.workers)]) def reset(self) -> VecEnvObs: obs = ray.get([actor.reset.remote() for actor in self.workers]) return _flatten_obs(obs, self.observation_space) def close(self) -> None: if self.closed: return ray.get([actor.close.remote() for actor in self.workers]) self.closed = True def get_images(self) -> Sequence[np.ndarray]: return ray.get([actor.render.remote("rgb_array") for actor in self.workers]) def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" indices = self._get_indices(indices) return ray.get([self.workers[index].get_attr.remote(attr_name) for index in indices]) def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" indices = self._get_indices(indices) ray.get([self.workers[index].set_attr.remote(attr_name, value) for index in indices]) def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: """Call instance methods of vectorized environments.""" indices = self._get_indices(indices) refs = [self.workers[index].env_method.remote(method_name, (method_args, method_kwargs)) for index in indices] return ray.get(refs) def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: """Check if worker environments are wrapped with a given wrapper""" indices = self._get_indices(indices) return ray.get([self.workers[index].is_wrapped.remote(wrapper_class) for index in indices]) if __name__ == "__main__": from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy env = make_vec_env("CartPole-v1", n_envs=6, vec_env_cls=DistVecEnv) # env = make_vec_env("CartPole-v1", n_envs=6, vec_env_cls=SubprocVecEnv) # dummy vec env: # env = make_vec_env("CartPole-v1", n_envs=6) model = PPO("MlpPolicy", env, seed=0, device="cpu", n_steps=256, n_epochs=1) model.learn(total_timesteps=10_000, progress_bar=True) print(evaluate_policy(model, model.get_env())) ```
snowyday commented 1 year ago

@araffin Thank you for your feedback! I'm glad the sample code worked well for you. Thank you for fixing the bug!

There seems to be a big overhead compared to SubprocVecEnv or DummyVecEnv (both at startup and during data collection), do you know where it comes from?

Please note that the current version of Ray (1.5 and above) will automatically call ray. init() on the first use of a Ray remote API, which can cause the overhead you mentioned.

during data collection

Regarding the overhead during data collection, I have found no significant difference in speed compared to SubprocVecEnv using Multiprocessing in a single-node environment. Nishihara's benchmarking, Ray is faster than Multiprocessing. However, please note that the startup method of Ray can affect the speed, as mentioned.

Before running the training, I recommend using Learner 1 process + Actor N process CPUs on the Ray worker nodes, and starting Ray with the CPU of Actor N processes. This should solve the problem. In my environment, I run it with 83 actors (< 28 x 3 nodes).

In a multi-node environment, the bottleneck will be communication between nodes, so overhead compared to Multiprocessing is inevitable. It is a trade-off between the number of actors and depends on the user to determine which is more efficient for learning.

The only use-case that would be interesting with this is the multi-node support of ray, so its place is probably the RL Zoo (https://github.com/DLR-RM/rl-baselines3-zoo).

I agree with your observation that the exciting thing about this method is Ray's multi-node support! Though using Ray can be more complicated than Multiprocessing, I hope people who want to use it will find it in a place like RL Zoo.

Lastly, I appreciate that you noticed my related work. Similar to my motivation for the development! If I were not to start from SubprocVecEnv with Ray, I would probably write like this. I think the design of SubprocVecEnv is really good!

araffin commented 1 year ago

Please note that the current version of Ray (1.5 and above) will automatically call ray. init() on the first use of a Ray remote API, which can cause the overhead you mentioned.

yes, I noticed, but I was talking about a runtime overhead, after the init (see my code).

Regarding the overhead during data collection, I have found no significant difference in speed compared to SubprocVecEnv using Multiprocessing in a single-node environment.

In a multi-node environment, the bottleneck will be communication between nodes, so overhead compared to Multiprocessing is inevitable. It is a trade-off between the number of actors and depends on the user to determine which is more efficient for learning.

Have you tried my code snippet? I was actually surprised by how much overhead there is (the slow down was significant), we are probably mis-using Ray in some way.

snowyday commented 1 year ago

Thank you for your comment!

I have noticed that in def step_wait, ray.get is around 10 times slower (mp: O(1.0e-5) sec. < ray: O(1.0e-4) sec.) in comparison to utilizing recv (mp.connection.Connection). This can likely be attributed to object de-serialization.

The current implementation of ray within the code should not cause any issues.

Ray is best utilized in multi-node setups. The speed difference, spanning 1.0-4 sec orders, is not a primary concern as the bottleneck may arise from inter-node communication. DistVecEnv is particularly well-suited for users seeking additional nodes.

george-adams1 commented 1 year ago

@snowyday did this ever make it into main?

snowyday commented 1 year ago

@george-adams1 No, it hasn't been integrated. It seems that the overhead caused by using Ray is not acceptable for some use cases when compared to multiprocessing. If you take a look at my gist, you'll see that it's easy to use, so if it's useful for your work, I would appreciate your comments.

1-Bart-1 commented 3 months ago

I would love to see this as a feature in SB3. Really useful for anyone that has access to a cluster. In RL, the simulation speed can often be the bottleneck rather than the neural net training speed. Anyways, here is an updated version that is compatible with the current version of subproc_vec_env and gymnasium.

import ray
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import gymnasium as gym
import numpy as np

from stable_baselines3.common.vec_env.base_vec_env import (
    VecEnv,
    VecEnvIndices,
    VecEnvObs,
    VecEnvStepReturn,
)

from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.vec_env.subproc_vec_env import _flatten_obs

@ray.remote
class Actor:
    def __init__(self, env: gym.Env) -> None:
        self.env = env()
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        self.reset_info: Optional[Dict[str, Any]] = {}

    def step(self, action: Any) -> Any:
        observation, reward, terminated, truncated, info = self.env.step(action)
        # convert to SB3 VecEnv api
        done = terminated or truncated
        info["TimeLimit.truncated"] = truncated and not terminated
        if done:
            # save final observation where user can get it, then reset
            info["terminal_observation"] = observation
            observation, self.reset_info = self.env.reset()
        return (observation, reward, done, info, self.reset_info)

    def reset(self, seed, option) -> Any:
        maybe_options = {"options": option} if option else {}
        observation, self.reset_info = self.env.reset(seed=seed, **maybe_options)
        return (observation, self.reset_info)

    def render(self) -> np.ndarray:
        return self.env.render()

    def close(self) -> None:
        self.env.close()

    def get_spaces(self) -> Tuple[Any, Any]:
        return (self.observation_space, self.action_space)

    def env_method(self, method_name, args) -> Any:
        return getattr(self.env, method_name)(*args[0], **args[1])

    def get_attr(self, attr_name: str) -> Any:
        return getattr(self.env, attr_name)

    def set_attr(self, attr_name: str, value: Any) -> None:
        setattr(self.env, attr_name, value)

    def is_wrapped(self, wrapper_class: Type[gym.Wrapper]) -> bool:
        return is_wrapped(self.env, wrapper_class)

class DistVecEnv(VecEnv):
    """
    Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
    process, allowing significant speed up when the environment is computationally complex.

    For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
    number of logical cores on your CPU.

    :param env_fns: Environments to run in subprocesses

    Notes
    ------
    `Actor` is a support class for `DistVecEnv` that controls remote environments, similar to `_worker` for
    `SubprocEnv`.
    """

    def __init__(self, env_fns: List[Callable[[], gym.Env]]):
        self.waiting = False
        self.closed = False
        self.ref_steps = []
        self.actors: List[Actor] = [Actor.remote(env=env_fn) for env_fn in env_fns]
        observation_space, action_space = ray.get(self.actors[0].get_spaces.remote())
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)

    def step_async(self, actions: np.ndarray) -> None:
        self.ref_results = [actor.step.remote(action) for actor, action in zip(self.actors, actions)]
        self.waiting = True

    def step_wait(self) -> VecEnvStepReturn:
        results = ray.get(self.ref_results)
        self.waiting = False
        obs, rews, dones, infos, self.reset_infos = zip(*results)
        return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos

    def reset(self) -> VecEnvObs:
        self.ref_results = []
        for env_idx, actor in enumerate(self.actors):
            self.ref_results.append(actor.reset.remote(self._seeds[env_idx], self._options[env_idx]))
        results = ray.get(self.ref_results)

        obs, self.reset_infos = zip(*results)
        self._reset_seeds()
        self._reset_options()
        return _flatten_obs(obs, self.observation_space)

    def close(self) -> None:
        if self.closed:
            return
        if self.waiting:
            ray.get(self.ref_results)
        ray.get([actor.close.remote() for actor in self.actors])
        self.closed = True

    def get_images(self) -> Sequence[np.ndarray]:
        return ray.get([actor.render.remote("rgb_array") for actor in self.actors])

    def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
        """Return attribute from vectorized environment (see base class)."""
        indices = self._get_indices(indices)
        return ray.get([self.actors[index].get_attr.remote(attr_name) for index in indices])

    def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
        """Set attribute inside vectorized environments (see base class)."""
        indices = self._get_indices(indices)
        ray.get([self.actors[index].set_attr.remote(attr_name, value) for index in indices])

    def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
        """Call instance methods of vectorized environments."""
        indices = self._get_indices(indices)
        refs = [self.actors[index].env_method.remote(method_name, (method_args, method_kwargs)) for index in indices]
        return ray.get(refs)

    def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
        """Check if worker environments are wrapped with a given wrapper"""
        indices = self._get_indices(indices)
        return ray.get([self.actors[index].is_wrapped.remote(wrapper_class) for index in indices])