Open vyeevani opened 1 year ago
Haven't tested this, but I'm thinking something like this:
class AutoResetWrapper2(Wrapper):
"""Automatically resets Brax envs that are done."""
def reset(self, rng: jax.Array) -> State:
base_state = self.env.reset(rng)
info = base_state.info.copy()
info.update({
'initial_base_state': base_state,
'current_base_state': base_state
})
return State(
pipeline_state=base_state.pipeline_state,
obs=base_state.obs,
reward=base_state.reward,
done=base_state.done,
metrics=base_state.metrics,
info=info
)
def step(self, state: State, action: jax.Array) -> State:
initial_base_state = state.info['initial_base_state']
current_base_state = state.info['current_base_state']
next_base_state = self.env.step(current_base_state, action)
done = next_base_state.done
def where_done(x, y):
return jp.where(done, x, y)
info = jax.tree_map(where_done, initial_base_state.info, next_base_state.info).copy()
info.update ({
'initial_base_state': initial_base_state,
'current_base_state': jax.tree_map(where_done, initial_base_state, next_base_state),
})
return State(
pipeline_state=jax.tree_map(where_done, initial_base_state.pipeline_state, next_base_state.pipeline_state),
obs=jax.tree_map(where_done, initial_base_state.obs, next_base_state.obs),
reward=jax.tree_map(where_done, initial_base_state.reward, next_base_state.reward),
done=next_base_state.done,
metrics=jax.tree_map(where_done, initial_base_state.metrics, next_base_state.metrics),
info=info
)
This allows us to pass back information from the wrapped things up to the clients that are expecting them (episode wrappers truncated for instance) while at the same time caching this stuff so that we can reuse it during the reset.
Tested the above (not very rigorously), seems to work
Hi @vyeevani
Thanks for the proposal, indeed there is logic leaking into reset from the episode wrapper.
If you changed first_pipeline_state
to first_state
(which would contain the first State
object) in the impl at HEAD, would that suffice for your use-case? Why do you need to store current_base_state
in the info as well?
My only concern beyond cleaner semantics is how this affects performance
The training wrappers for auto reset and episode wrapper are leaking info to each other.
This is a bigger problem if people want to stack their own wrappers. For example, I'd like to write a meta-episode wrapper that takes multiple episodes and aggregates them into a single meta episode for use in a meta-RL setting. The wrapper that I'm writing needs to track the number of episodes so that it can do a meta episode reset when it reaches some watermark. However, the auto reset would break this since it wouldn't reset the meta wrapper under it.
At a high level, I'd like to separate the state of the auto reset wrapper from the environments that it's wrapping. I propose to do this by caching the initial state of the environment and the current state of the environment in the info, and only working on that.
Note, through this process, you never need to return the pipeline state through the state itself.