DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.22k stars 1.71k forks source link

[Feature Request] Recompute the advantage of a minibatch in ppo #445

Closed yangysc closed 3 years ago

yangysc commented 3 years ago

🚀 Feature

According to this paper, recomputing the advantage can be helpful for the PPO performance. The function is provided by tianshou library.

https://github.com/thu-ml/tianshou/blob/655d5fb14fe85ea9da86b441456286fa1f078384/tianshou/policy/modelfree/ppo.py#L107

But I don't know how to add this in sb3. Some hints about how to do that would be very helpful.

Thanks!

Motivation

I am comparing stable-baselines3, tianshou and rllib for the best performance of PPO.

Pitch

Recompute the advantage in learning ppo.

 Checklist

Miffyli commented 3 years ago

Hmm indeed it seems to be a beneficial and simple trick to add, but is not part of the original PPO (not a "baseline"). Granted, we have some other small modifications in the PPO compared to the original OAI baselines. @araffin what would be your take? I am up for such a feature (not too intrusive, seems beneficial, has proper experiments behind it).

yangysc commented 3 years ago

Hmm indeed it seems to be a beneficial and simple trick to add, but is not part of the original PPO (not a "baseline"). Granted, we have some other small modifications in the PPO compared to the original OAI baselines. @araffin what would be your take? I am up for such a feature (not too intrusive, seems beneficial, has proper experiments behind it).

Thanks for your support. Yes, it does not belong to the original PPO, but maybe we can leave it as an option and let more users give feedbacks. I would like to benchmark the difference on my personal task if it comes out.

araffin commented 3 years ago

Hello, if you want to try it out, you mostly need to avoid overwriting the variables here: https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/buffers.py#L441 and recompute values (using current vf function) + advantages (using compute_returns_and_advantage()) at every epoch.

On including that feature into SB3, I'm afraid that would require to hack the code in several places...

yangysc commented 3 years ago

compute_returns_and_advantage

Hello, many thanks for your help. The second parameter for compute_returns_and_advantage() is the variable done. How can I get this from the buffer?

Thanks again.

araffin commented 3 years ago

the last episode starts are saved: https://github.com/DLR-RM/stable-baselines3/blob/df6f9de8f46509dad47e6d2e5620aa993b0fc883/stable_baselines3/common/on_policy_algorithm.py#L200 and the last values can be recomputed using last observation.

Please give it a try and then you can put a link to your code if we want to checkout ;)

yangysc commented 3 years ago

solved.......:) Thanks

davidsblom commented 3 years ago

@yangysc does it help to recompute the advantage?

yangysc commented 3 years ago

@yangysc does it help to recompute the advantage?

I did as araffin kindly suggested. Here is the result for Cartpole. It seems recomputing would be slightly better for this case, despite the fps would drop.

recompute_adv_cartpole
araffin commented 3 years ago

thanks for sharing, btw, are you recomputing it after each epoch or each gradient step? and is there a public link (if someone else is interested in the future) ?

yangysc commented 3 years ago

Hello, I recompute it after each epoch, being consistent with the library tianshou https://github.com/thu-ml/tianshou/blob/655d5fb14fe85ea9da86b441456286fa1f078384/tianshou/policy/modelfree/ppo.py#L107

I pasted the main modification below. Hopefully you can help check if there is any potential problems.

# ppo.py
  def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer

            # Recompute adv
            with th.no_grad():
                _, last_values, _ = self.policy(obs_as_tensor(self._last_obs, self.device))
                self.rollout_buffer.compute_returns_and_advantage(last_values, dones=self._last_episode_starts)

            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                # TODO: investigate why there is no issue with the gradient
                # if that line is commented (as in SAC)
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                values = values.flatten()

                # Normalize advantage
                advantages = rollout_data.advantages

To support multi-envs, I did what you suggested before, avoid overwriting the variables https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/buffers.py#L441, and reshape them whenever sampling. We don't have to do this if we only use one env. But reshaping when sampling heavilly slows low the learning process... Do you have a good solution for this? image

    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
        data = (
            self.swap_and_flatten(self.observations)[batch_inds],
            self.swap_and_flatten(self.actions)[batch_inds],
            self.swap_and_flatten(self.values[batch_inds]).flatten(),
            self.swap_and_flatten(self.log_probs[batch_inds]).flatten(),
            self.swap_and_flatten(self.advantages[batch_inds]).flatten(),
            self.swap_and_flatten(self.returns[batch_inds]).flatten(),
            # self.swap_and_flatten(self.to_torch(self.rewards[batch_inds]).flatten()),
            # self.swap_and_flatten(self.to_torch(self.episode_starts[batch_inds]).flatten()),
            # batch_inds

        )
        return RolloutBufferSamples(*tuple(map(self.to_torch, data)))

I also tried recomputing it after sampling bs observations so this is each gradient step yesterday. I returned the sample indices with the sampled data. But anyway unluckily, I didn't sucessfully make it :<

araffin commented 3 years ago

I pasted the main modification below. Hopefully you can help check if there is any potential problems.

I see that you are re-computing the last values but not updating the values in the rollout buffer, is that intended?

yangysc commented 3 years ago

I see that you are re-computing the last values but not updating the values in the rollout buffer, is that intended?

Fixed this.

   def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer

            # recomput adv
            with th.no_grad():
                _, last_values, _ = self.policy(obs_as_tensor(self._last_obs, self.device))
                _, new_values, _ = self.policy(obs_as_tensor(self.rollout_buffer.observations , self.device))
                self.rollout_buffer.values = new_values.detach().numpy()
                self.rollout_buffer.compute_returns_and_advantage(last_values, dones=self._last_episode_starts)

            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                # TODO: investigate why there is no issue with the gradient
                # if that line is commented (as in SAC)
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                values = values.flatten()

                # indices = rollout_data.indices
                # reward = self.rollout_buffer.rewards[indices].flatten()
                #
                # episode_starts = rollout_data.episode_starts
                #
                # # indices = rollout_data[6]  # this is the pos we append indice to DictRolloutBufferSamples, see L735 of buffers.py
                # # values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                # with th.no_grad():
                #     # Recompute value for the last timestep
                #     _, last_values, _ = self.policy(obs_as_tensor(self._last_obs, self.device))
                # returns, advantages = compute_returns_and_advantage(last_values=last_values,
                #                                            rewards=reward,
                #                                            values=values.detach(),
                #                                            episode_starts=episode_starts,
                #                                            gamma=self.gamma,
                #                                            gae_lambda=self.gae_lambda,
                #                                            buffer_size=self.batch_size,
                #                                            dones=self._last_episode_starts)
                # Normalize advantage
                advantages = rollout_data.advantages

image

It seems recomputing is slightly better, with the cost of lower fps. The entropy loss seems to be much lower.

araffin commented 3 years ago

It seems recomputing is slightly better, with the cost of lower fps. The entropy loss seems to be much lower.

Usually CartPole is nice to debug and check you did not break anything but overall I would go for a more challenging environment like HalfCheetahBulletEnv-v0 (or AntBullet, Hopper and the harder Walker2DBulletEnv-v0) to see if any improvement can be achieved. I would recommend to use the rl zoo for such comparison (the learning curve for normal PPO is already there).

yangysc commented 3 years ago

I agree. If the code above is fine, I would like to compare them later

araffin commented 3 years ago

I agree. If the code above is fine, I would like to compare them later

Looks ok but I would double check the shapes

yangysc commented 3 years ago

I agree. If the code above is fine, I would like to compare them later

Looks ok but I would double check the shapes

Cool, thanks again :)

Karlheinzniebuhr commented 1 year ago

I agree. If the code above is fine, I would like to compare them later

Hi! Where you able to test your implementation on more environments? Would love to know how it performed