DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.38k stars 1.61k forks source link

[Question] CheckpointCallback keep last K #1893

Closed NickLucche closed 2 months ago

NickLucche commented 2 months ago

❓ Question

Hey, thanks a lot for your work maintaining this framework throughout the years!

In order to keep the last K checkpoints only (instead of the whole history from step 0), I have written a custom callback that deletes the old ones. I was wondering, is there a more "compliant" way to do this already built-in in the framework? If not, is this kind of feature of any interest to be added to the CheckpointCallback capabilities directly (eg save_top_k argument)?

Checklist

qgallouedec commented 2 months ago

Thank you for your feedback! No, this feature doesn't exist. Can you share your callback here? I don't think many people need it, but if we get a lot of feedback about it we might integrate this feature.

NickLucche commented 2 months ago

Sure here's my callback

from heapq import heappop, heappush

class KeepLastKCheckpoints(BaseCallback):
    def __init__(self, k: int, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
        super().__init__(verbose)
        self.k = k
        self.save_freq = save_freq
        self.save_path = save_path
        self.name_prefix = name_prefix
        # keep checkpoints in a minheap, indexed by timestep so we can easily remove oldest 
        self._ckpt_heap = []

    def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> str:
        # from `CheckpointCallback`
        return os.path.join(self.save_path, f"{self.name_prefix}_{checkpoint_type}{self.num_timesteps}_steps.{extension}")

    def _on_step(self) -> bool:
        # track CheckpointCallback trigger
        if self.n_calls % self.save_freq == 0:
            heappush(self._ckpt_heap, (self.num_timesteps, self._checkpoint_path(extension="zip")))
            if len(self._ckpt_heap) > self.k:
                _, ckpt_name = heappop(self._ckpt_heap)
                os.remove(ckpt_name)

        return super()._on_step()

as you can see it's not very clean because I had to copy paste most of the logic from CheckpointCallback, so I could see this being integrated fairly easily. We dont even need to store the second tuple value self._checkpoint_path(..) in the heap, we can just have "timestep" be an argument to it an handle the replay buffer the same way (which I dont need in my case). Also, currently I have to put the callback after the CheckpointCallback one to make sure it's synched..