RPegoud / jym

JAX implementation of RL algorithms and vectorized environments
MIT License
32 stars 1 forks source link

fixed parallel training, added plotting #2

Closed RPegoud closed 11 months ago

RPegoud commented 11 months ago

Resolved the original issue by adding movements as a class attribute to the environment

class GridWorld(BaseEnv):
    def __init__(self, initial_state, goal_state, grid_size) -> None:
        super(GridWorld, self).__init__()

        self.initial_state = initial_state
        self.goal_state = goal_state
        self.grid_size = grid_size
        self.movements = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]])

The step function was then updated to convert the action to a movement:

@partial(jit, static_argnums=(0))
    def step(self, env_state, action):
        state, key = env_state
        action = self.movements[action]
        new_state = jnp.clip(jnp.add(state, action), jnp.array([0, 0]), self.grid_size)
        reward, done = self._get_reward_done(new_state)

        env_state = new_state, key
        env_state = self._reset_if_done(env_state, done)
        new_state = env_state[0]

        return env_state, self._get_obs(new_state), reward, done