AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:
envs.reset() method doesn't accept env_ids arguments, unlike sync mode. And even if it does:
(step_env() or send() seem to accept env_ids arguments,) but it's hard to generate dynamic-shaped env_ids array in xla mode and jax. (I've made some preliminary exps. to verify this)
The problem resulted from this, would be incorrect transitions to be appeared: (term_state, any_action, rew -> init_state)
as also pointed out in #194
Solution
I think we could try adding masked_env_ids in reset() and send() methods
masked_env_ids has static shape of env_nums and would only reset envs of True in masks and return dummy obs for False-ed envs.
Alternative Methods
Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.
But if the proposed solution is correct, I think it's better to have it for elegance.
Additional context
Unfortunately, I'm not an expert in C++ and I'm not sure if the proposed solution, despite simple, would work as expected.
But based on my understanding, this shall be implementable so long as we perform it in C++ processes.
Checklist
194 is somehow related to this, but I think the original issue is more about the correct behaviors of auto-reset, while this focuses a simple (but effective) way to resolve a minor issue in current impl.
Motivation
AFAIK (and also in #194), currently it's unable to cherry-pick terminated envs for reset in xla mode as:
envs.reset()
method doesn't acceptenv_ids
arguments, unlike sync mode. And even if it does:step_env()
orsend()
seem to acceptenv_ids
arguments,) but it's hard to generate dynamic-shapedenv_ids
array inxla
mode andjax
. (I've made some preliminary exps. to verify this)The problem resulted from this, would be incorrect transitions to be appeared:
(term_state, any_action, rew -> init_state)
as also pointed out in #194Solution
I think we could try adding
masked_env_ids
inreset()
andsend()
methodsSO we could do something like this
in
xla
modemasked_env_ids
has static shape ofenv_nums
and would only reset envs ofTrue
in masks and return dummy obs forFalse
-ed envs.Alternative Methods
Currently, I'm working-around this inconvenience by overwriting the wrong transitions by previous correct ones. This shall not make a significant difference to general algorithms.
But if the proposed solution is correct, I think it's better to have it for elegance.
Additional context
Unfortunately, I'm not an expert in
C++
and I'm not sure if the proposed solution, despite simple, would work as expected. But based on my understanding, this shall be implementable so long as we perform it inC++
processes.Checklist
194 is somehow related to this, but I think the original issue is more about the correct behaviors of auto-reset, while this focuses a simple (but effective) way to resolve a minor issue in current impl.