ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.11k stars 5.6k forks source link

RLLib: ExternalEnv concurrent episode leaking issue #31965

Closed AydinGokce closed 1 year ago

AydinGokce commented 1 year ago

What happened + What you expected to happen

I'm RLLib's external env with PPO to drive a physical system. For this reason, I am limited to one env with one worker. To ensure this, I build the algorithm (PPO) with num_rollout_workers=1, batch_mode="complete_episodes", num_envs_per_worker=1

The Bug And after a couple of episodes, I consistently get AssertionError: Too many concurrent episodes, were some leaked? This ExternalEnv was created with max_concurrent=2

The full stack trace:

Traceback (most recent call last):
  File "/home/aydin/Documents/TREC-RLT/scripts/pendulum/pendulum_external_replication.py", line 70, in <module>
    main()
  File "/home/aydin/Documents/TREC-RLT/scripts/pendulum/pendulum_external_replication.py", line 67, in main
    agent.train()
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 367, in train
    raise skipped from exception_cause(skipped)
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 364, in train
    result = self.step()
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 749, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 2623, in _run_one_training_iteration
    results = self.training_step()
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 318, in training_step
    train_batch = synchronous_parallel_sample(
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/execution/rollout_ops.py", line 85, in synchronous_parallel_sample
    sample_batches = worker_set.foreach_worker(
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 696, in foreach_worker
    handle_remote_call_result_errors(remote_results, self._ignore_worker_failures)
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 73, in handle_remote_call_result_errors
    raise r.get()
ray.exceptions.RayTaskError(AssertionError): ray::RolloutWorker.apply() (pid=195731, ip=192.168.2.49, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7f0288f98970>)
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/utils/actor_manager.py", line 183, in apply
    raise e
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/utils/actor_manager.py", line 174, in apply
    return func(self, *args, **kwargs)
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/execution/rollout_ops.py", line 86, in <lambda>
    lambda w: w.sample(), local_worker=False, healthy_only=True
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 900, in sample
    batches = [self.input_reader.next()]
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/evaluation/sampler.py", line 92, in next
    batches = [self.get_data()]
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/evaluation/sampler.py", line 285, in get_data
    item = next(self._env_runner)
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/evaluation/sampler.py", line 660, in _env_runner
    unfiltered_obs, rewards, dones, infos, off_policy_actions = base_env.poll()
  File "/home/aydin/anaconda3/envs/trec-rlt/lib/python3.9/site-packages/ray/rllib/env/external_env.py", line 372, in poll
    assert len(results[0]) <= limit, (
AssertionError: Too many concurrent episodes, were some leaked? This ExternalEnv was created with max_concurrent=1

(Note than max_concurrent is set to 2, because RLLib's polling algorithm checks whether the number of concurrent episodes is < limit, rather than <= limit.

Versions / Dependencies

OS = Ubuntu 20.04 Python = 3.9 Ray = 2.2.0 Full output of conda list is here

Reproduction script

Here is the env class and the driver script. Place the env in a file pendulum_env_replication, and run the driver script from the same directory.

Driver Script


from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.external_env import ExternalEnv
from pendulum_env_replication import PendulumExternalEnv

def main():

    register_env("pendulum-v0", lambda c: PendulumExternalEnv(c))
    agent = (
        PPOConfig()
        .resources(num_gpus=1)
        .rollouts(num_rollout_workers=1, batch_mode="complete_episodes", num_envs_per_worker=1)
        .environment(env="pendulum-v0")
        .build()
    )       

    while True:
        agent.train()

if __name__ == "__main__":
    main()

Env

from ray.rllib.env.external_env import ExternalEnv
import time
import numpy as np
import gym

class PendulumExternalEnv(ExternalEnv):

    def __init__(self, config):

        # action space [pwm]
        current_limit: float = 1
        self.action_space = gym.spaces.Box(low=-current_limit, high=current_limit, shape=(1,), dtype=np.float32)

        # observation space [current angle (rad), target angle (rad), current angular velocity (rad/s)]
        self.angle_limit: float = 0.2
        observation_limit = np.array([self.angle_limit, self.angle_limit, 2 * np.pi])
        self.observation_space = gym.spaces.Box(low=-observation_limit, high=observation_limit, dtype=np.float32)
        self.observation = np.zeros(3, dtype=np.float32)

        ExternalEnv.__init__(self, self.action_space, self.observation_space, max_concurrent = 1)

        # environment variables
        self.episode_duration: float = 5
        self.episode_num: int = 0
        self.start_time = time.time()

    def run(self):
        while True:
            episode_id: str = self.start_episode()

            while not self._check_done():
                print("loopin")

                action = self.get_action(episode_id, self.observation)[0]
                self.log_returns(episode_id, 0)

            print("resetting")
            self.start_time = time.time()
            self.end_episode(episode_id, [0,0,0]) #FIXME: add actual obs to end_episode 

    def _check_done(self):
        if time.time() - self.start_time > self.episode_duration:
            return True

        if abs(self.observation[1]) > self.angle_limit:
            return True

        return False

Issue Severity

High: It blocks me from completing my task.

Rohan138 commented 1 year ago

Thank you for raising the issue! Fixed by #34374