instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
583 stars 70 forks source link

Sample a successor state by passing a key to Environment.step #211

Open carlosgmartin opened 11 months ago

carlosgmartin commented 11 months ago

Is your feature request related to a problem? Please describe

Currently, one obtains a successor state by calling Environment.step(state, action). The state itself contains a key, which is derived from the key argument of Environment.reset via splitting and propagation throughout the episode. This lets Jumanji simulate stochastic environments.

However, this approach has some disadvantages:

  1. It does not allow one to (re)sample successor states.
  2. If an agent receives a State as input, it can plan (think AlphaZero) with access to future environment randomness, breaking the assumption that the latter is unpredictable and letting the agent "cheat".

Describe the solution you'd like

Allow Environment.step to receive a key argument directly, as in Environment.step(state, action, key). This is the approach taken by gymnax. It is also the approach pgx intends to take: https://github.com/sotetsuk/pgx/issues/1043.

In the medium/long term, I would support deprecating the State.key attribute entirely, which is currently the only constraint enforced by the StateProtocol. Its removal would allow State objects to be completely generic (they could be strings, ints, tuples, dicts, etc.).

Describe alternatives you've considered

A possible alternative is to create a copy of the state, replace its key attribute, and pass it into Environment.step (for the first issue) or the agent (for the second issue). However, this approach seems hacky and error-prone.

Fundamentally, it seems like step should be treated as an intrinsically stochastic function, implying that it should receive its own key at call time. (The key can be None if it's not needed.)