Stable-Baselines-Team / rl-colab-notebooks

Colab notebooks part of the documentation of Stable Baselines reinforcement learning library
https://stable-baselines.readthedocs.io/
MIT License
205 stars 38 forks source link

RNN examples #15

Closed anirjoshi closed 7 months ago

anirjoshi commented 7 months ago

Are there any examples here that use RNN in their policy? I would like to construct an example that uses RNN in its policy! If there is an example like this it would be easier for me to see the implementation!

araffin commented 7 months ago

https://sb3-contrib.readthedocs.io/en/master/modules/ppo_recurrent.html#example ?

anirjoshi commented 7 months ago

@araffin Thank you for your response. However is it possible to implement for a custom gymnasium environment? For example I have the following environment, is it possible to implement it on this environment?

class ModuloComputationEnv(gym.Env):
    """Environment in which an agent must learn to output mod 2,3,4 of the sum of
       seen observations.

    Observations are squences of integer numbers ,
    e.g. (1,3,4,5)

    The action space is just 3 values first for the sum of inputs till now %2, second %3 
    and third %4.

    Rewards are r=-abs(self.ac1-action[0]) - abs(self.ac2-action[1]) - abs(self.ac3-action[2]), 
    for all steps.
    """

    def __init__(self, config):

        #the input sequence can have any number from 0,99
        self.observation_space = Sequence(Discrete(100), seed=2)

        #the action is a vector of 3, [%2, %3, %4], of the sum of the input sequence
        self.action_space = MultiDiscrete([2,3,4])

        self.cur_obs = None

        #this variable maintains the episode_length
        self.episode_len = 0

        #this variable maintains %2
        self.ac1 = 0

        #this variable maintains %3
        self.ac2 = 0

        #this variable maintains %4
        self.ac3 = 0

    def reset(self, *, seed=None, options=None):
        """Resets the episode and returns the initial observation of the new one.
        """

        # Reset the episode len.
        self.episode_len = 0

        # Sample a random sequence from our observation space.
        self.cur_obs = self.observation_space.sample()

        #take the sum of the initial observation
        sum_obs = sum(self.cur_obs)

        #consider the %2, %3, and %4 of the initial observation
        self.ac1 = sum_obs%2
        self.ac2 = sum_obs%3
        self.ac3 = sum_obs%4

        # Return initial observation.
        return self.cur_obs, {}

    def step(self, action):
        """Takes a single step in the episode given `action`

        Returns:
            New observation, reward, done-flag, info-dict (empty).
        """
        # Set `truncated` flag after 10 steps.
        self.episode_len += 1
        truncated = False
        terminated = self.episode_len >= 10

        #the reward is the negative of further away from computing the individual values
        reward = abs(self.ac1-action[0]) + abs(self.ac2-action[1]) + abs(self.ac3-action[2])
        reward = -reward

        # Set a new observation (random sample).
        self.cur_obs = self.observation_space.sample()

        #recompute the %2, %3 and %4 values
        self.ac1 = (self.cur_obs+self.ac1)%2
        self.ac2 = (self.cur_obs+self.ac2)%3
        self.ac3 = (self.cur_obs+self.ac3)%4

        return self.cur_obs, reward, terminated, truncated, {}
araffin commented 7 months ago

You don't need to change anything, see https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/230#issuecomment-1908020667 and https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/183#issuecomment-1640028534 and https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/209

anirjoshi commented 7 months ago

@araffin Thank you for your response, but my observation is a Sequence which is not of a fixed length. For for each step, the observation is a variable length sequence. Is it possible to integrate this kind of a scenario? I believe from the comment issue, and link2 that the RecurrentPPO is recurrent because of it is maintaining the history. Whereas in my environment, I do not need the history, my environment's observation is itself a variable length. Is this possible. Sorry if this is getting annoying!

araffin commented 7 months ago

but my observation is a Sequence which is not of a fixed length.

Sequence space is not supported by SB3 (https://github.com/DLR-RM/stable-baselines3/issues/1688)

anirjoshi commented 7 months ago

Ok, thank you for the response!