RPegoud / jym

JAX implementation of RL algorithms and vectorized environments
MIT License
34 stars 2 forks source link

Add movement selection to env.step #1

Closed RPegoud closed 1 year ago

RPegoud commented 1 year ago

Currently, the training loop is implemented as follows:

while not done:
            state, _ = env_state
            action, key = policy(key, N_ACTIONS, state, q_values)
            movement = movements[action]
            env_state, obs, reward, done = env.step(env_state, movement)

            q_values = agent.update(state, action, reward, done, obs, q_values)
            n_steps+=1
        steps.append(n_steps)

The line movement = movements[action], should be included as part of the step method, to allow efficient vectorization through vmap. Therefore, we would have:

while not done:
            state, _ = env_state
            action, key = policy(key, N_ACTIONS, state, q_values)
            env_state, obs, reward, done = env.step(env_state, movement) # removed the movement selection line

            q_values = agent.update(state, action, reward, done, obs, q_values)
            n_steps+=1
        steps.append(n_steps)