DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.96k stars 1.68k forks source link

Model Regularization? #610

Closed DavidChiumera-404 closed 3 years ago

DavidChiumera-404 commented 3 years ago

Question

I am wondering how one would approach regularizing a DRL model using this library.

Additional context

I have found resources online, claiming it can be done through using the entropy coefficient or discount factor, but I was wondering if something like L1/L2 or dropout is part of this library? I know it could involve potentially modifying the loss function if not, but it doesn't seem clear to me which loss function would make the most sense to use (policy?). I also found this post #240, but it doesn't seem like wight_decay is a parameter used by all model types, such as with PPO for example.

Checklist

Miffyli commented 3 years ago

This, much like other things, are somewhat unexplored topics. L2 regularization (for whole network) has been tried in some contexts, like generalization (see here), but generally it is not seen as something you need (or should) do in RL, and as such is somewhat unexplored ("what is the right way to do things").

The easiest to start with is L2 where you use the weight_decay parameter. Fro dropout, you need to do a custom policy with custom network, but lately there were updates to help with that (#537 and #553).

DavidChiumera-404 commented 3 years ago

Thank you kindly for the response. This is exactly the sentiment of what I have seen online -> no one really has a playbook and it is still under investigation. You mention using the weight_decay parameter, but this doesn't seem to be available for all models, is it only implemented in some, or does it go by a different name depending on the model? Apologies if I am overlooking something.

Miffyli commented 3 years ago

Ah, should have more clear: weight_decay parameter is not exposed by SB3 API, so you have to modify the code/algorithms to achieve this :)

DavidChiumera-404 commented 3 years ago

Thanks for the info! I better get cracking then ;)