RPegoud / jym

JAX implementation of RL algorithms and vectorized environments
MIT License
32 stars 1 forks source link

Refactor tabular vmapped functions #11

Closed RPegoud closed 10 months ago

RPegoud commented 10 months ago

Currently, vmapped tabular functions are defined in the global scope

v_reset = jit(vmap(
                env.reset,
                out_axes=((0, 0), 0),  # ((env_state), obs)
                axis_name="batch_axis",
            ))

v_step = jit(vmap(
                env.step,
                in_axes=((0, 0), 0),  # ((env_state), action)
                out_axes=((0, 0), 0, 0, 0),  # ((env_state), obs, reward, done)
                axis_name="batch_axis",
            ))
....

And passed to rollout functions:

def parallel_rollout(
    keys, TIME_STEPS, N_ACTIONS, GRID_SIZE, N_ENV, v_policy, v_update, v_step, v_reset
):

Instead, they should be defined within env, agent and policy classes, following the pattern:

class SimpleBandit:
    @staticmethod
    @jit
    def __call__(action, reward, pulls, q_values):
        q_value = q_values[action]
        pull_update = pulls[action] + 1
        q_update = q_value + (1 / pull_update) * (reward - q_value)
        return q_values.at[action].set(q_update), pulls.at[action].set(pull_update)

    @staticmethod
    @jit
    def batch_update(actions, rewards, pulls, q_values):
        return vmap(
            SimpleBandit.__call__,
            in_axes=(0, 0, -1, -1),
            out_axes=-1,
        )(actions, rewards, pulls, q_values)