When the RolloutWorker samples and reaches an __all_done__ it calls reset() for sub-envs - that's great and it should be like this
It, however, also makes the first step() in the sub-envs before returning the last batch and before next batch gets requested (due to the iterator (yield) design in the EnvRunnerV2)
That is not necessarily game-changing, but becomes tricky when the user needs to change an environment setting (e.g. via foreach_env of the WorkerSet) during sampling (e.g. between rollouts). The first step now was already sampled before, i.e. an Episode contains then a sample from the "old" setting and the rest from the "new" setting. This biases the trajectory and corresponding metrics - specifically in evaluation.
Possible workarounds
We could do a reset() for all sub-envs before sampling with the new setting, but this still results in the first sample being from the first setting - as no terminated=True or truncated=True was provided.
We could clean the trajectory in the on_postprocess_trajectory(), but for this we need a signal for the switch in environment settings or have to specifically know how values must look like - both not necessarily easy to receive (specifically in real-world settings or in general random starts). Maybe via the info, but not the cleanest solution.
We could run different workers or sub-envs with different settings, but this could be a waste of resources and has to be customized - for most users it is non-trivial
Maybe the dev-team has a different approach to it, momentarily I run different env settings either via diffrent workers or via different sub-envs - it is a bit cumbersome and wastes resources, but works.
I assume also that this effects any task-setting of environments via the env_task_fn in Algorithm (e.g. for curriculum learning or policy generalization) as the first num_envs_per_worker*num_rollout_workers trajectories will be biased. This might even happen in "self-play" or "league-training" scenarios where the policies are changed against which an agent has to play.
What you expected to happen
That the sampler returns first the batch and then if another batch is requested makes the first step in environment. With this setup any switching of env settings does not result in faulty trajectories.
Use cases
Comparing to real-world benchmarks
Curriculum learning
Generalizing policies
Versions / Dependencies
Ray 2.6.0 & Ray nightly
Python 3.9.12
Fedora 37
Reproduction script
import gymnasium as gym
import numpy as np
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy
from ray.rllib.examples.env.debug_counter_env import DebugCounterEnv
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.tune import register_env
class DebugCounterEnvSettable(DebugCounterEnv):
def __init__(self, config=None):
super().__init__(config)
self.observation_space = gym.spaces.Box(0, np.inf, (1,), np.float32)
self.count_by = int(config.get("count_by", 1))
def step(self, action):
# Here is th change - we step by `count_by`.
self.i += self.count_by
truncated = False
# Here is the change - termination depends now also on the `count_by`.
terminated = self.i >= (15 * self.count_by) + self.start_at_t
return self._get_obs(), float(self.i % 3), terminated, truncated, {}
def set_count_by(self, count_by):
"""Make `count_by` settable."""
self.count_by = count_by
register_env("counter", lambda ctx: DebugCounterEnvSettable(ctx))
config = (
PPOConfig()
.environment(
env="counter",
disable_env_checking=True,
)
.framework(
framework="tf2",
eager_tracing=True,
)
.rollouts(
rollout_fragment_length=1,
batch_mode="complete_episodes",
num_envs_per_worker=2,
num_rollout_workers=1,
)
.resources(
num_cpus_per_worker=1,
num_cpus_for_local_worker=1,
)
.rl_module(
_enable_rl_module_api=True,
)
.training(
_enable_learner_api=True,
)
.debugging(
log_level="DEBUG",
)
)
workers = worker = WorkerSet(
env_creator=lambda ctx: DebugCounterEnvSettable(ctx),
config=config,
default_policy_class=PPOTF2Policy,
num_workers=1,
local_worker=False,
)
# Get benchmark values.
env = DebugCounterEnvSettable(config={})
env_100 = DebugCounterEnvSettable(config={"count_by": 100})
env_obs = []
env_100_obs = []
obs, info = env.reset()
env_obs.append(obs)
obs, info = env_100.reset()
env_100_obs.append(obs)
while True:
obs, reward, terminated, truncated, info = env.step(None)
env_obs.append(obs)
obs, reward, terminated, truncated, info = env_100.step(None)
env_100_obs.append(obs)
if terminated:
break
# Consider only "obs" not "new_obs", i.e. until the second last element.
print(f"Benchmark value for `count_by=1: {sum(env_obs[:-1])}")
print(f"Benchmark value for `count_by=100: {sum(env_100_obs[:-1])}")
# Sample from the environment with default setting (`count_by`=1).
batches = []
for i in range(4):
batches.append(workers.foreach_worker(lambda w: w.sample()))
# This gives problems with sampling as directly after `reset()` the next `step()` will be taken before
# `RolloutWorker.sample()` returns. Therefore the first `num_envs_per_worker` episodes will
# have values from the `count_by`=1 setting and therefore result in faulty metrics.
workers.foreach_worker(lambda w: w.foreach_env(lambda env: env.set_count_by(100)))
# Reset the environments and run the example again. The result will show that
# now there is one more observation in the `"obs"` , but this time due to
# the reset and not due to the `terminated` condition. The counting also
# starts at 0 again and therefore each single observation is an even number.
# workers.foreach_worker(lambda w: w.foreach_env(lambda env: env.reset()))
for i in range(4):
batches.append(workers.foreach_worker(lambda w: w.sample()))
for i, batch in enumerate(batches):
if i == 4:
print(f"Batches with `count_by`=100")
print(len(batch[0]))
print(sum(batch[0]["default_policy"]["obs"]))
# In the first batch (of each sub-nvironment) we see that the first observations
# are from the environment setting `count_by`=1 (0 and 1), the next ones are
# `count_by`=100 (it also takes now a step longer to reach self.i >= 15*100).
print(f"First batch after setting environment variable to `count_by`=100.")
print(batches[0][0]["default_policy"]["obs"])
Issue Severity
Medium: It is a significant difficulty but I can work around it.
What happened + What you expected to happen
What happened
RolloutWorker
samples and reaches an__all_done__
it callsreset()
for sub-envs - that's great and it should be like thisstep()
in the sub-envs before returning the last batch and before next batch gets requested (due to the iterator (yield
) design in theEnvRunnerV2
)foreach_env
of theWorkerSet
) during sampling (e.g. between rollouts). The first step now was already sampled before, i.e. anEpisode
contains then a sample from the "old" setting and the rest from the "new" setting. This biases the trajectory and corresponding metrics - specifically in evaluation.Possible workarounds
reset()
for all sub-envs before sampling with the new setting, but this still results in the first sample being from the first setting - as noterminated=True
ortruncated=True
was provided.on_postprocess_trajectory()
, but for this we need a signal for the switch in environment settings or have to specifically know how values must look like - both not necessarily easy to receive (specifically in real-world settings or in general random starts). Maybe via theinfo
, but not the cleanest solution.Maybe the dev-team has a different approach to it, momentarily I run different env settings either via diffrent workers or via different sub-envs - it is a bit cumbersome and wastes resources, but works.
I assume also that this effects any task-setting of environments via the
env_task_fn
inAlgorithm
(e.g. for curriculum learning or policy generalization) as the firstnum_envs_per_worker
*num_rollout_workers
trajectories will be biased. This might even happen in "self-play" or "league-training" scenarios where the policies are changed against which an agent has to play.What you expected to happen
That the sampler returns first the batch and then if another batch is requested makes the first step in environment. With this setup any switching of env settings does not result in faulty trajectories.
Use cases
Versions / Dependencies
Ray 2.6.0 & Ray nightly Python 3.9.12 Fedora 37
Reproduction script
Issue Severity
Medium: It is a significant difficulty but I can work around it.