The line movement = movements[action], should be included as part of the step method, to allow efficient vectorization through vmap.
Therefore, we would have:
while not done:
state, _ = env_state
action, key = policy(key, N_ACTIONS, state, q_values)
env_state, obs, reward, done = env.step(env_state, movement) # removed the movement selection line
q_values = agent.update(state, action, reward, done, obs, q_values)
n_steps+=1
steps.append(n_steps)
Currently, the training loop is implemented as follows:
The line movement = movements[action], should be included as part of the step method, to allow efficient vectorization through vmap. Therefore, we would have: