Closed RPegoud closed 1 year ago
Refactored the vmap versions of env.step, env.reset, agent.update and policy.call to be defined as class methods, e.g.:
v_update = jit(vmap( agent.update, # state, action, reward, done, next_state, q_values in_axes=(0, 0, 0, 0, 0, -1), out_axes=-1, axis_name="batch_axis" ))
Becomes:
@partial(jit, static_argnums=(0)) def update(self, state, action, reward, done, next_state, q_values): next_q_values = q_values[tuple(next_state)] target = q_values[tuple(jnp.append(state, action))] target += self.learning_rate * ( reward + self.discount * jnp.sum(next_q_values * self.softmax_prob_distr(next_q_values)) - target ) return q_values.at[tuple(jnp.append(state, action))].set(target) @partial(jit, static_argnums=(0)) def batch_update(self, state, action, reward, done, next_state, q_values): return vmap( Expected_Sarsa.update, # self, state, action, reward, done, next_state, q_values in_axes=(None, 0, 0, 0, 0, 0, -1), # return the batch dimensio as last dimension of the output out_axes=-1, axis_name="batch_axis", )(self, state, action, reward, done, next_state, q_values)
Refactored the vmap versions of env.step, env.reset, agent.update and policy.call to be defined as class methods, e.g.:
Becomes: