instadeepai / jumanji

🕹️ A diverse suite of scalable reinforcement learning environments in JAX
https://instadeepai.github.io/jumanji
Apache License 2.0
645 stars 80 forks source link

Add dtype choice in step type/functions #256

Closed thomashirtz closed 1 week ago

thomashirtz commented 2 weeks ago

Is your feature request related to a problem? Please describe

In a personal project, I had to be very efficient in the memory management, my reward were taking a lot of space. In my case I needed to change the reward and the discount to float16 instead of float32. I had to copy over the types file to do the modification locally. However I feel like some use case may need this flexibility.

Describe the solution you'd like

Give an extra parameter to the step functions to give dtype. (Example with one of them)

def truncation(
    reward: Array,
    observation: Observation,
    discount: Optional[Array] = None,
    extras: Optional[Dict] = None,
    shape: Union[int, Sequence[int]] = (),
    dtype: jnp.dtype = jnp.float32,
) -> TimeStep:
    """Returns a `TimeStep` with `step_type` set to `StepType.LAST`.

    Args:
        reward: array.
        observation: array or tree of arrays.
        discount: array.
        extras: environment metric(s) or information returned by the environment but
            not observed by the agent (hence not in the observation). For example, it
            could be whether an invalid action was taken. In most environments, extras
            is None.
        shape: optional parameter to specify the shape of the rewards and discounts.
            Allows multi-agent environment compatibility. Defaults to () for
            scalar reward and discount.
    Returns:
        TimeStep identified as the truncation of an episode.
    """
    discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
    extras = extras or {}
    return TimeStep(
        step_type=StepType.LAST,
        reward=reward,
        discount=discount,
        observation=observation,
        extras=extras,
    )

I would be happy to do the PR.


Misc

sash-a commented 2 weeks ago

I think that would be a nice addition, happy to review it :smile:

sash-a commented 2 weeks ago

To be honest I think the discounts should probably be booleans while they are stored in the timestep because for me they just indicated end of episode, but I think this would add nice flexibility

thomashirtz commented 2 weeks ago

To be honest I think the discounts should probably be booleans while they are stored in the timestep because for me they just indicated end of episode, but I think this would add nice flexibility

I'm fine with both, as long as it doesn't take too much space. I go with argument set by default to boolean ? or just boolean ?

sash-a commented 2 weeks ago

My only issue is that this strays from the original dm_env api where it is a float so it can represent both RL discount (gamma) and done.

Let's definitely add it as an argument, but for the default I'm not sure if boolean or float32 is best @clement-bonnet any thoughts on this?

clement-bonnet commented 2 weeks ago

To my knowledge, having the discount as a float is more common than as a boolean for the reasons you mentioned @sash-a. I would keep it a float unless there are strong reasons to do otherwise :)

sash-a commented 2 weeks ago

Great then if you could add the argument with a default of float32, I'm happy to accept the PR

thomashirtz commented 2 weeks ago

Great then if you could add the argument with a default of float32, I'm happy to accept the PR

The PR is available for review :)