DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.85k stars 1.68k forks source link

[Question] Shared feature extractor and gradient #2006

Closed brn-dev closed 2 weeks ago

brn-dev commented 2 weeks ago

❓ Question

Hi, I have two questions:

Firstly, the comment here says the opposite of the code. I am assuming the code is the intended behavior, right? https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/policies.py#L972-L975

Secondly, does this only make sense in off policy settings (when using q-critics)? I checked out the ActorCriticPolicy which is used in on policy algorithms and there, both the actor and the critic loss influence the feature extractors parameter (if shared): https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/policies.py#L645-L651 https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/policies.py#L660-L672 https://github.com/DLR-RM/stable-baselines3/blob/512eea923afad6f6da4bb53d72b6ea4c6d856e59/stable_baselines3/common/torch_layers.py#L252-L257

The features generated through the first feature extractor (self.feature_extractor) is passed into the mlp_extractor where it is used to form the actor and value latent features without any disruption to the gradient. So, both the actor and the critic loss influence this feature extractor, whereas in the ContinuousCritic, we prevent the feature extractor from being updated from the critic loss if it is shared with the actor.

Checklist

araffin commented 2 weeks ago

Firstly, the comment here says the opposite of the code.

I think the comment is correct although a bit confusing. When the feature extractor is shared, you don't won't gradient to be back propagated through the critic, hence set_grad_enabled(not self.share_features_extractor).

Secondly, does this only make sense in off policy settings

This choice is made by experience. The recommendation is actually not to share the feature extractor when using off-policy algorithms. Another choice would be to use the critic loss only to learn the feature, but it is not recommended to use both losses when sharing the feature extractor (they can conflict).

ActorCriticPolicy which is used in on policy algorithms

In on-policy setting, where the actor and critic update is done in one-go, it seems (experimentally) that you can use both losses to update the features (and you can also use separate networks).

brn-dev commented 2 weeks ago

Makes sense thanks!