instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
607 stars 77 forks source link

feat(wrappers): Inherit type hints from environment #221

Closed aar65537 closed 7 months ago

aar65537 commented 9 months ago

Is your feature request related to a problem? Please describe

Currently, wrappers don't inherit type hinting from the environments they wrap. For example, consider the return types of the reset method. As you can see when using the environment directly, type hinting works as expected.

import jax

from jumanji.environments.logic.game_2048.env import Game2048, State
from jumanji.wrappers import Wrapper

key = jax.random.PRNGKey(0)

game_2048 = Game2048()
state, timestep = game_2048.reset(key)
reveal_type(state)  # jumanji.environments.logic.game_2048.types.State
reveal_type(
    timestep.observation
)  # Tuple[Any, Any, fallback=jumanji.environments.logic.game_2048.types.Observation]

However when using a wrapper around the environment, a couple problems show up. First, we are told that we need to annotate the wrapped environment. Second, type information about state and timestep.observation are lost.

game_2048 = Wrapper(Game2048())  # error: Need type annotation for "game_2048"
state, timestep = game_2048.reset(key)
reveal_type(state)  # Any
reveal_type(timestep.observation)  # Any

Explicitly annotating the environment fixes the error and restores the type hint on the state variable. However, there is still no type hint for timestep.observation.

game_2048: Wrapper[State] = Wrapper(Game2048())
state, timestep = game_2048.reset(key)
reveal_type(state)  # jumanji.environments.logic.game_2048.types.State
reveal_type(timestep.observation)  # Any

I believe a similar problem exists for most of the methods on the Wrapper class.

Describe the solution you'd like

Ideally, the following code should just give proper type hints.

game_2048 = Wrapper(Game2048())
state, timestep = game_2048.reset(key)
reveal_type(state)  # jumanji.environments.logic.game_2048.types.State
reveal_type(
    timestep.observation
)  # Tuple[Any, Any, fallback=jumanji.environments.logic.game_2048.types.Observation]

The first issue of needing to explicitly annotate the wrapped environment can be solved by modifying the __init__ method as follows

def __init__(self, env: Environment[State]):
    super().__init__()
    self._env = env

This will also provide type hints for the state variable. Providing type hints for timestep.observation will probably require adding a generic Observation typevar to the Environment class.