vwxyzjn / ppo-implementation-details

The source code for the blog post The 37 Implementation Details of Proximal Policy Optimization
https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/
Other
637 stars 99 forks source link

Using PPO with RNN #8

Open anirjoshi opened 8 months ago

anirjoshi commented 8 months ago

I want to use RNN for RL on a custom environment. My custom environment has a variable size for the observations. For example, consider the following environment:

class ModuloComputationEnv(gym.Env):
    """Environment in which an agent must learn to output mod 2,3,4 of the sum of
       seen observations.

    Observations are squences of integer numbers ,
    e.g. (1,3,4,5)

    The action space is just 3 values first for the sum of inputs till now %2, second %3 
    and third %4.

    Rewards are r=-abs(self.ac1-action[0]) - abs(self.ac2-action[1]) - abs(self.ac3-action[2]), 
    for all steps.
    """

    def __init__(self, config):

        #the input sequence can have any number from 0,99
        self.observation_space = Sequence(Discrete(100), seed=2)

        #the action is a vector of 3, [%2, %3, %4], of the sum of the input sequence
        self.action_space = MultiDiscrete([2,3,4])

        self.cur_obs = None

        #this variable maintains the episode_length
        self.episode_len = 0

        #this variable maintains %2
        self.ac1 = 0

        #this variable maintains %3
        self.ac2 = 0

        #this variable maintains %4
        self.ac3 = 0

    def reset(self, *, seed=None, options=None):
        """Resets the episode and returns the initial observation of the new one.
        """

        # Reset the episode len.
        self.episode_len = 0

        # Sample a random sequence from our observation space.
        self.cur_obs = self.observation_space.sample()

        #take the sum of the initial observation
        sum_obs = sum(self.cur_obs)

        #consider the %2, %3, and %4 of the initial observation
        self.ac1 = sum_obs%2
        self.ac2 = sum_obs%3
        self.ac3 = sum_obs%4

        # Return initial observation.
        return self.cur_obs, {}

    def step(self, action):
        """Takes a single step in the episode given `action`

        Returns:
            New observation, reward, done-flag, info-dict (empty).
        """
        # Set `truncated` flag after 10 steps.
        self.episode_len += 1
        truncated = False
        terminated = self.episode_len >= 10

        #the reward is the negative of further away from computing the individual values
        reward = abs(self.ac1-action[0]) + abs(self.ac2-action[1]) + abs(self.ac3-action[2])
        reward = -reward

        # Set a new observation (random sample).
        self.cur_obs = self.observation_space.sample()

        #recompute the %2, %3 and %4 values
        self.ac1 = (self.cur_obs+self.ac1)%2
        self.ac2 = (self.cur_obs+self.ac2)%3
        self.ac3 = (self.cur_obs+self.ac3)%4

        return self.cur_obs, reward, terminated, truncated, {}