instadeepai / Mava

🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX
Apache License 2.0
737 stars 90 forks source link

Decouple `LogEnvState` #945

Closed sash-a closed 5 months ago

sash-a commented 1 year ago

Having the LogWrapper and LogEnvState means that we often have to do state.env_state which is quite confusing. So just tracking the logging metrics separately in it's own state would make this cleaner and more maintainable

Currently it's done like this:

@dataclass
class LogEnvState:
    """State of the `LogWrapper`."""

    env_state: State
    episode_returns: chex.Numeric
    episode_lengths: chex.Numeric
    # Information about the episode return and length for logging purposes.
    episode_return_info: chex.Numeric
    episode_length_info: chex.Numeric

I'm proposing this:

@dataclass
class LoggerState:
    episode_returns: chex.Numeric
    episode_lengths: chex.Numeric
    # Information about the episode return and length for logging purposes.
    episode_return_info: chex.Numeric
    episode_length_info: chex.Numeric

and then we create methods to reset and step this state and the env state is reset and steped separately

Have chatted to @OmaymaMahjoub about this and we agree it should be cleaner

sash-a commented 5 months ago

Upon rethinking this the current method is fine. This is probably over optimizing for now