instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
607 stars 77 forks source link

fix: timestep extras default value #207

Closed sash-a closed 9 months ago

sash-a commented 1 year ago

timestep.extras defaults to None and there are certain cases where this doesn't work. For example when making a wrapper and you want to add items to extras. The issue is that you don't know if the dictionary is initialized by the env or another wrapper, so you need to do something like timestep.extras = timestep.extras or None or jax.lax.cond(timestep.extras is None, {}, timestep.extras). However both these approaches don't work in jax - first one is a concretization error and in the second the true_fun and false_fun don't have the same type structure.

If we default the extras field to {} then one can simply do timestep.extras[key] = value which solves the initialization issue.