instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
583 stars 70 forks source link

fix: autoreset wrappers #223

Closed sash-a closed 6 months ago

sash-a commented 7 months ago

There was an issue with the autoreset wrappers: they never showed you the final timestep. This is an issue if the final timestep is a truncation (discount = 1, timestep.last() = true), then we'd want this observation in order to get the next value for our value target, but currently this would not be possible.

Gym fixes does it in the same way I am proposing as can be seen here they place the final observation in infos. What this PR does is always place either the current observation or the terminal observation in timestep.extras["real_next_obs"]

Here's how I'm currently using for SAC (and it's working well there):

    def step(
        action: Array, obs: Observation, env_state: State, buffer_state: BufferState
    ) -> Tuple[Array, State, BufferState, Dict]:
        """Given an action, step the environment and add to the buffer."""
        env_state, timestep = jax.vmap(env.step)(env_state, action)
        next_obs = timestep.observation
        rewards = timestep.reward
        terms = ~(timestep.discount).astype(bool)
        infos = timestep.extras

        real_next_obs = infos["real_next_obs"]

        transition = Transition(obs, action, rewards, terms, real_next_obs)
        buffer_state = rb.add(buffer_state, transition)

        return next_obs, env_state, buffer_state, infos["episode_metrics"]
clement-bonnet commented 6 months ago

Hi Sasha, This issue seems a bit related to #106. Returning the reset state/observation instead of the terminal state/observation when auto-resetting has always been the desired feature. This is because none of the Jumanji environments uses truncation, so one does not need the terminal state to train an actor-critic agent. Now, if a user implements a new jumanji environments using the Environment abstraction and other tools from Jumanji, including truncation, one may want to use the truncated state/observation in their own training loop, which seems to be your use case, right? Passing it to the extras seems legit to me. 🙌

sash-a commented 6 months ago

Hi Sasha, This issue seems a bit related to #106. Returning the reset state/observation instead of the terminal state/observation when auto-resetting has always been the desired feature. This is because none of the Jumanji environments uses truncation, so one does not need the terminal state to train an actor-critic agent. Now, if a user implements a new jumanji environments using the Environment abstraction and other tools from Jumanji, including truncation, one may want to use the truncated state/observation in their own training loop, which seems to be your use case, right? Passing it to the extras seems legit to me. 🙌

Yup this is exactly the use case!