DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.68k stars 1.65k forks source link

[Feature Request] Additional Callback before stepping the optimizer #1980

Closed kcorder closed 1 month ago

kcorder commented 1 month ago

🚀 Feature

Inside the train() method for all algorithms, add an additional callback method before stepping the optimizer. This method could be named on_optimizer_step(), added as a method to the BaseCallback class.

Motivation

It would be useful if there was a way to modify the loss value for algorithms before stepping the optimizer. Currently the only way to do this is to subclass the algorithm and override the train() method, making needed changes there. However this would require subclassing every algorithm that one might want to compare with such changes.

I use Stable-Baselines3 for research and often that is adding features to existing algorithms and I have run into this issue several times now. For example, the Random Convolution method uses a feature matching loss that can be added to the policy loss. If I start with A2C, implement this, and then want to try it with PPO or some other algorithm, they all require custom subclasses. This is tedious and not really needed for algorithm-agnostic features like the aforementioned or Random Network Distillation, etc.

Pitch

  1. Add new method on_optimizer_step() to BaseCallback class
  2. Call this method in every algorithm train() method before the optimization step. For instance, in PPO this would be inserted at line 273: https://github.com/DLR-RM/stable-baselines3/blob/bd3c0c653068a6af1993df7be1a12acfb4be0127/stable_baselines3/ppo/ppo.py#L256-L278
  3. (2) requires that a reference to the callback object is saved as a class member during learn() (preferred in my opinion), or that it is passed through as an argument like train(callback).

Alternatives

Another way to support this is to refactor the train() method so that it does not step the optimizer. This would be similar to how RLlib does training with a 2-step process. So instead we would first get the loss value, then learn() would call a new optimization_step() method.

This would require even more code change and is less flexible. E.g., PPO's training loop would be harder to fit into this scheme.

Additional context

While this does require touching every algorithm train() method, implementing the default behavior as pass (as most callback methods are) is safe.

Checklist

araffin commented 1 month ago

Hello,

If I start with A2C, implement this, and then want to try it with PPO or some other algorithm, they all require custom subclasses.

This is indeed the recommended way. Especially since you usually want to add some monitoring or do more custom things that are not practical with callbacks. You would only need to copy and paste the train() and update it to your convenience.

It seems from the way you describe things that you are mostly interested in A2C/PPO where there is only one optimizer. The approach you suggest would not be trivial for SAC, TD3, or CrossQ, for example.

(btw, A2C is a special case of PPO: https://arxiv.org/abs/2205.09123)

kcorder commented 1 month ago

I see what what you mean regarding SAC/TD3 – not just because of multiple optimizers but because actor loss depends on newly updated critic.

For reference I have implemented ~5 additional features, some of which must go inside the train() method. Adding new features to multiple classes is still a lot of code duplication. So I think I'll subclass the algs with a custom mixin so the additional train() stuff only needs to be written once.