RPegoud / jym

JAX implementation of RL algorithms and vectorized environments
MIT License
32 stars 2 forks source link

11-refactor-tabular-vmapped-functions #12

Closed RPegoud closed 1 year ago

RPegoud commented 1 year ago

Refactored the vmap versions of env.step, env.reset, agent.update and policy.call to be defined as class methods, e.g.:

v_update = jit(vmap(
                   agent.update,
                   #  state, action, reward, done, next_state, q_values
                   in_axes=(0, 0, 0, 0, 0, -1),
                   out_axes=-1,
                   axis_name="batch_axis"
                   ))

Becomes:

@partial(jit, static_argnums=(0))
    def update(self, state, action, reward, done, next_state, q_values):
        next_q_values = q_values[tuple(next_state)]
        target = q_values[tuple(jnp.append(state, action))]
        target += self.learning_rate * (
            reward
            + self.discount
            * jnp.sum(next_q_values * self.softmax_prob_distr(next_q_values))
            - target
        )
        return q_values.at[tuple(jnp.append(state, action))].set(target)

    @partial(jit, static_argnums=(0))
    def batch_update(self, state, action, reward, done, next_state, q_values):
        return vmap(
            Expected_Sarsa.update,
            #  self, state, action, reward, done, next_state, q_values
            in_axes=(None, 0, 0, 0, 0, 0, -1),
            # return the batch dimensio as last dimension of the output
            out_axes=-1,
            axis_name="batch_axis",
        )(self, state, action, reward, done, next_state, q_values)