google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Increase modularization of training wrappers #422

Open vyeevani opened 7 months ago

vyeevani commented 7 months ago

The training wrappers for auto reset and episode wrapper are leaking info to each other.

This is a bigger problem if people want to stack their own wrappers. For example, I'd like to write a meta-episode wrapper that takes multiple episodes and aggregates them into a single meta episode for use in a meta-RL setting. The wrapper that I'm writing needs to track the number of episodes so that it can do a meta episode reset when it reaches some watermark. However, the auto reset would break this since it wouldn't reset the meta wrapper under it.

At a high level, I'd like to separate the state of the auto reset wrapper from the environments that it's wrapping. I propose to do this by caching the initial state of the environment and the current state of the environment in the info, and only working on that.

  1. In reset, get the base state and store it in info with two separate keys: initial_base_state, current_base_state
  2. In reset, take the initial_base_state's: observation, reward, and done and package it along with the info from step (1)
  3. In step, if you are done, then return the same aggregated state as step (2)
  4. In step, if not done, then return the evolved state by updating the current base state.

Note, through this process, you never need to return the pipeline state through the state itself.

vyeevani commented 7 months ago

Haven't tested this, but I'm thinking something like this:

class AutoResetWrapper2(Wrapper):
  """Automatically resets Brax envs that are done."""
  def reset(self, rng: jax.Array) -> State:
    base_state = self.env.reset(rng)
    info = base_state.info.copy()
    info.update({
        'initial_base_state': base_state,
        'current_base_state': base_state
    })

    return State(
        pipeline_state=base_state.pipeline_state,
        obs=base_state.obs,
        reward=base_state.reward,
        done=base_state.done,
        metrics=base_state.metrics,
        info=info
    )

  def step(self, state: State, action: jax.Array) -> State:
    initial_base_state = state.info['initial_base_state']
    current_base_state = state.info['current_base_state']
    next_base_state = self.env.step(current_base_state, action)

    done = next_base_state.done
    def where_done(x, y):
      return jp.where(done, x, y)

    info = jax.tree_map(where_done, initial_base_state.info, next_base_state.info).copy()
    info.update ({
        'initial_base_state': initial_base_state,
        'current_base_state': jax.tree_map(where_done, initial_base_state, next_base_state),
    })

    return State(
        pipeline_state=jax.tree_map(where_done, initial_base_state.pipeline_state, next_base_state.pipeline_state),
        obs=jax.tree_map(where_done, initial_base_state.obs, next_base_state.obs),
        reward=jax.tree_map(where_done, initial_base_state.reward, next_base_state.reward),
        done=next_base_state.done,
        metrics=jax.tree_map(where_done, initial_base_state.metrics, next_base_state.metrics),
        info=info
    )
vyeevani commented 7 months ago

This allows us to pass back information from the wrapped things up to the clients that are expecting them (episode wrappers truncated for instance) while at the same time caching this stuff so that we can reuse it during the reset.

vyeevani commented 7 months ago

Tested the above (not very rigorously), seems to work

btaba commented 7 months ago

Hi @vyeevani

Thanks for the proposal, indeed there is logic leaking into reset from the episode wrapper.

If you changed first_pipeline_state to first_state (which would contain the first State object) in the impl at HEAD, would that suffice for your use-case? Why do you need to store current_base_state in the info as well?

My only concern beyond cleaner semantics is how this affects performance