RLE-Foundation / RLeXplore

RLeXplore provides stable baselines of exploration methods in reinforcement learning, such as intrinsic curiosity module (ICM), random network distillation (RND) and rewarding impact-driven exploration (RIDE).
https://docs.rllte.dev/
MIT License
367 stars 15 forks source link

Correct usage with SB3 / Callbacks? #3

Open emrul opened 1 year ago

emrul commented 1 year ago

Hi, this looks like a really interesting set of algorithms. I wanted to try some out using the SB3-zoo and was hoping for a plug-and-play approach. I wondered if I could integrate rlexplore using callbacks so I came up with the following:

from stable_baselines3.common.callbacks import BaseCallback
from rlexplore import REVD
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm

class RLeXploreCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.explorer = None
        self.buffer = None
        pass

    def init_callback(self, model: "base_class.BaseAlgorithm") -> None:
        super().init_callback(model)
        env = self.training_env
        self.explorer = REVD(obs_shape=env.observation_space.shape, action_shape=env.action_space.shape, device=model.device, latent_dim=128, beta=1e-2, kappa=1e-5)

        if isinstance(self.model, OnPolicyAlgorithm):
            self.buffer = self.model.rollout_buffer
        elif isinstance(self.model, OffPolicyAlgorithm):
            self.buffer = self.model.replay_buffer
        pass

    def _on_rollout_end(self) -> None:
        intrinsic_rewards = self.explorer.compute_irs(
            rollouts={'observations': self.buffer.observations},
            time_steps=self.num_timesteps,
            k=3)
        self.buffer.rewards += intrinsic_rewards[:, :, 0]
        pass

    def _on_step(self) -> bool:
        # TODO maybe log to TensorBoard?
        return True

Then I include it in my list of callbacks and it seems to run. However, I'm still poking around without fully understanding what I'm doing (dangerous!) so does the above look correct? If it is correct, maybe it can be added as an example for others.

Second question is did I do this bit right: time_steps=self.num_timesteps?

Third question I have is that in the examples directory the sample uses rollout_buffer but is it valid to use this for Off Policy algorithms like DQN (switching for the replay_buffer instead?)

yuanmingqi commented 1 year ago

Hello, the repository is still under development, and any attempt is welcome. You can make a PR to add more examples.

  1. REVD only supports on-policy algorithms;
  2. time_steps=self.num_timesteps is correct.