Closed NickLucche closed 2 months ago
Thank you for your feedback! No, this feature doesn't exist. Can you share your callback here? I don't think many people need it, but if we get a lot of feedback about it we might integrate this feature.
Sure here's my callback
from heapq import heappop, heappush
class KeepLastKCheckpoints(BaseCallback):
def __init__(self, k: int, save_freq: int, save_path: str, name_prefix: str = "rl_model", verbose: int = 0):
super().__init__(verbose)
self.k = k
self.save_freq = save_freq
self.save_path = save_path
self.name_prefix = name_prefix
# keep checkpoints in a minheap, indexed by timestep so we can easily remove oldest
self._ckpt_heap = []
def _checkpoint_path(self, checkpoint_type: str = "", extension: str = "") -> str:
# from `CheckpointCallback`
return os.path.join(self.save_path, f"{self.name_prefix}_{checkpoint_type}{self.num_timesteps}_steps.{extension}")
def _on_step(self) -> bool:
# track CheckpointCallback trigger
if self.n_calls % self.save_freq == 0:
heappush(self._ckpt_heap, (self.num_timesteps, self._checkpoint_path(extension="zip")))
if len(self._ckpt_heap) > self.k:
_, ckpt_name = heappop(self._ckpt_heap)
os.remove(ckpt_name)
return super()._on_step()
as you can see it's not very clean because I had to copy paste most of the logic from CheckpointCallback
, so I could see this being integrated fairly easily.
We dont even need to store the second tuple value self._checkpoint_path(..)
in the heap, we can just have "timestep" be an argument to it an handle the replay buffer the same way (which I dont need in my case).
Also, currently I have to put the callback after the CheckpointCallback
one to make sure it's synched..
❓ Question
Hey, thanks a lot for your work maintaining this framework throughout the years!
In order to keep the last K checkpoints only (instead of the whole history from step 0), I have written a custom callback that deletes the old ones. I was wondering, is there a more "compliant" way to do this already built-in in the framework? If not, is this kind of feature of any interest to be added to the
CheckpointCallback
capabilities directly (egsave_top_k
argument)?Checklist