Closed RPegoud closed 1 year ago
Resolved the original issue by adding movements as a class attribute to the environment
class GridWorld(BaseEnv): def __init__(self, initial_state, goal_state, grid_size) -> None: super(GridWorld, self).__init__() self.initial_state = initial_state self.goal_state = goal_state self.grid_size = grid_size self.movements = jnp.array([[0, 1], [1, 0], [0, -1], [-1, 0]])
The step function was then updated to convert the action to a movement:
@partial(jit, static_argnums=(0)) def step(self, env_state, action): state, key = env_state action = self.movements[action] new_state = jnp.clip(jnp.add(state, action), jnp.array([0, 0]), self.grid_size) reward, done = self._get_reward_done(new_state) env_state = new_state, key env_state = self._reset_if_done(env_state, done) new_state = env_state[0] return env_state, self._get_obs(new_state), reward, done
Resolved the original issue by adding movements as a class attribute to the environment
The step function was then updated to convert the action to a movement: