Closed RPegoud closed 1 year ago
Currently, vmapped tabular functions are defined in the global scope
v_reset = jit(vmap( env.reset, out_axes=((0, 0), 0), # ((env_state), obs) axis_name="batch_axis", )) v_step = jit(vmap( env.step, in_axes=((0, 0), 0), # ((env_state), action) out_axes=((0, 0), 0, 0, 0), # ((env_state), obs, reward, done) axis_name="batch_axis", )) ....
And passed to rollout functions:
def parallel_rollout( keys, TIME_STEPS, N_ACTIONS, GRID_SIZE, N_ENV, v_policy, v_update, v_step, v_reset ):
Instead, they should be defined within env, agent and policy classes, following the pattern:
class SimpleBandit: @staticmethod @jit def __call__(action, reward, pulls, q_values): q_value = q_values[action] pull_update = pulls[action] + 1 q_update = q_value + (1 / pull_update) * (reward - q_value) return q_values.at[action].set(q_update), pulls.at[action].set(pull_update) @staticmethod @jit def batch_update(actions, rewards, pulls, q_values): return vmap( SimpleBandit.__call__, in_axes=(0, 0, -1, -1), out_axes=-1, )(actions, rewards, pulls, q_values)
Currently, vmapped tabular functions are defined in the global scope
And passed to rollout functions:
Instead, they should be defined within env, agent and policy classes, following the pattern: