sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.07k stars 99 forks source link

[Feature Request] A simple (and effective?) way to support cherry-picked env reset in `xla` mode #293

Open bkkgbkjb opened 6 months ago

bkkgbkjb commented 6 months ago

Motivation

AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:

  1. envs.reset() method doesn't accept env_ids arguments, unlike sync mode. And even if it does:
  2. (step_env() or send() seem to accept env_ids arguments,) but it's hard to generate dynamic-shaped env_ids array in xla mode and jax. (I've made some preliminary exps. to verify this)

The problem resulted from this, would be incorrect transitions to be appeared: (term_state, any_action, rew -> init_state) as also pointed out in #194

Solution

I think we could try adding masked_env_ids in reset() and send() methods

SO we could do something like this


obs, _ = envs.reset()
handle, recv, send, step = envs.xla()

while True:

    handle, (obs, rew, terms, truncs, info) = step(handle, some_acts)

    # proposed masked auto-resetting
    auto_reset_masks = jnp.logical_or(terms, truncs)
    _obs = env.reset(masked_env_ids = auto_reset_masks)

    obs = jnp.where(auto_reset_masks, _obs, obs) 

in xla mode

masked_env_ids has static shape of env_nums and would only reset envs of True in masks and return dummy obs for False-ed envs.

Alternative Methods

Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.

But if the proposed solution is correct, I think it's better to have it for elegance.

Additional context

Unfortunately, I'm not an expert in C++ and I'm not sure if the proposed solution, despite simple, would work as expected. But based on my understanding, this shall be implementable so long as we perform it in C++ processes.

Checklist

JesseFarebro commented 1 month ago

This would be great to have, as performing exact evaluation like this: https://github.com/sail-sg/envpool/issues/113#issuecomment-1126725772 isn't possible to jit.