[RLlib] Sampler takes first step before next batch is requested #37893

Open simonsays1980 opened 1 year ago

simonsays1980 commented 1 year ago

What happened + What you expected to happen

What happened

Possible workarounds

Maybe the dev-team has a different approach to it, momentarily I run different env settings either via diffrent workers or via different sub-envs - it is a bit cumbersome and wastes resources, but works.

I assume also that this effects any task-setting of environments via the env_task_fn in Algorithm (e.g. for curriculum learning or policy generalization) as the first num_envs_per_worker*num_rollout_workers trajectories will be biased. This might even happen in "self-play" or "league-training" scenarios where the policies are changed against which an agent has to play.

What you expected to happen

That the sampler returns first the batch and then if another batch is requested makes the first step in environment. With this setup any switching of env settings does not result in faulty trajectories.

Use cases

Versions / Dependencies

Ray 2.6.0 & Ray nightly Python 3.9.12 Fedora 37

Reproduction script

import gymnasium as gym
import numpy as np

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy
from ray.rllib.examples.env.debug_counter_env import DebugCounterEnv
from ray.rllib.evaluation.worker_set import WorkerSet

from ray.tune import register_env

class DebugCounterEnvSettable(DebugCounterEnv):
    def __init__(self, config=None):
        self.observation_space = gym.spaces.Box(0, np.inf, (1,), np.float32)
        self.count_by = int(config.get("count_by", 1))

    def step(self, action):
        # Here is th change - we step by `count_by`.
        self.i += self.count_by
        truncated = False
        # Here is the change - termination depends now also on the `count_by`.
        terminated = self.i >= (15 * self.count_by) + self.start_at_t
        return self._get_obs(), float(self.i % 3), terminated, truncated, {}

    def set_count_by(self, count_by):
        """Make `count_by` settable."""
        self.count_by = count_by

register_env("counter", lambda ctx: DebugCounterEnvSettable(ctx))

config = (

workers = worker = WorkerSet(
    env_creator=lambda ctx: DebugCounterEnvSettable(ctx),

# Get benchmark values.
env = DebugCounterEnvSettable(config={})
env_100 = DebugCounterEnvSettable(config={"count_by": 100})

env_obs = []
env_100_obs = []
obs, info = env.reset()
obs, info = env_100.reset()

while True:
    obs, reward, terminated, truncated, info = env.step(None)
    obs, reward, terminated, truncated, info = env_100.step(None)
    if terminated:

# Consider only "obs" not "new_obs", i.e. until the second last element.
print(f"Benchmark value for `count_by=1: {sum(env_obs[:-1])}")
print(f"Benchmark value for `count_by=100: {sum(env_100_obs[:-1])}")

# Sample from the environment with default setting (`count_by`=1).
batches = []
for i in range(4):
    batches.append(workers.foreach_worker(lambda w: w.sample()))

# This gives problems with sampling as directly after `reset()` the next `step()` will be taken before
# `RolloutWorker.sample()` returns. Therefore the first `num_envs_per_worker` episodes will
# have values from the `count_by`=1 setting and therefore result in faulty metrics.
workers.foreach_worker(lambda w: w.foreach_env(lambda env: env.set_count_by(100)))

# Reset the environments and run the example again. The result will show that
# now there is one more observation in the `"obs"` , but this time due to
# the reset and not due to the `terminated` condition. The counting also
# starts at 0 again and therefore each single observation is an even number.
# workers.foreach_worker(lambda w: w.foreach_env(lambda env: env.reset()))

for i in range(4):
    batches.append(workers.foreach_worker(lambda w: w.sample()))

for i, batch in enumerate(batches):
    if i == 4:
        print(f"Batches with `count_by`=100")

# In the first batch (of each sub-nvironment) we see that the first observations
# are from the environment setting `count_by`=1 (0 and 1), the next ones are
# `count_by`=100 (it also takes now a step longer to reach self.i >= 15*100).
print(f"First batch after setting environment variable to `count_by`=100.")

Issue Severity

Medium: It is a significant difficulty but I can work around it.

ArturNiederfahrenhorst commented 1 year ago

