Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
465 stars 173 forks source link

[Question] Setting net_arch of recurrent policies #172

Closed joelmichelson closed 1 year ago

joelmichelson commented 1 year ago

❓ Question

I am attempting to use recurrent policies as follows:


net_arch = [
        {'activation_fn': th.nn.ReLU, 'pi': [32, 32, 32, 32], 'vf': [33, 32, 32, 32]}, #dummy values just to show up clearly in print statement
        {'lstm': 55},
        {'activation_fn': th.nn.ReLU, 'pi': [25], 'vf': [26]}
    ]
policy_kwargs = {
    'net_arch': net_arch,
    'features_extractor_class': CustomCNN,
    'lstm_hidden_size': 64,
    'features_extractor_kwargs': {
        # this seems to work correctly
    },
}
policy = RecurrentActorCriticCnnPolicy(obs_space, act_space, lr_schedule, **policy_kwargs)
print('policy', policy)

Later passing this policy to the RecurrentPPO class initialization. However, on printing this policy, I get a summary which does use my CustomCNN implementation correctly, but does not appear to be using all of the net_arch:

policy RecurrentActorCriticCnnPolicy(
  (features_extractor): CustomCNN(
    (cnn): Sequential( # this looks correct
      (0): Conv2d(5, 4, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(4, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=72, out_features=32, bias=True)
      (1): ReLU()
    )
  )
  (mlp_extractor): MlpExtractor(
    (shared_net): Sequential()
    (policy_net): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=32, bias=True)
      (5): Tanh()
      (6): Linear(in_features=32, out_features=32, bias=True)
      (7): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=64, out_features=33, bias=True)
      (1): Tanh()
      (2): Linear(in_features=33, out_features=32, bias=True)
      (3): Tanh()
      (4): Linear(in_features=32, out_features=32, bias=True)
      (5): Tanh()
      (6): Linear(in_features=32, out_features=32, bias=True)
      (7): Tanh()
    )
  )
  (action_net): Linear(in_features=32, out_features=4, bias=True)
  (value_net): Linear(in_features=32, out_features=1, bias=True)
  (lstm_actor): LSTM(32, 64)
  (lstm_critic): LSTM(32, 64)
)

So the first item in net_arch is correct, aside from activation_fn not being set. But the subsequent lines don't appear do do anything.

I'm not sure if I'm misunderstanding how the recurrent policy is set up and/or if I'm initializing it totally incorrectly. Are the LSTM layers in this policy summary being unused, as it appears, or is the real architecture different from this summary (do lstm layers just go between features and mlp always)? If not, how can I write a working net_arch?

Checklist

araffin commented 1 year ago

Hello,

net_arch = [
        {'activation_fn': th.nn.ReLU, 'pi': [32, 32, 32, 32], 'vf': [33, 32, 32, 32]}, #dummy values just to show up clearly in print statement
        {'lstm': 55},
        {'activation_fn': th.nn.ReLU, 'pi': [25], 'vf': [26]}
    ]

Where did you see that syntax? The net_arch argument must be a dictionary in SB3 v1.8.0 (we did that change to simplify code and remove inconsistent behavior).

To understand what you can change, best is to take a look at: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/aacded79c5e8357545fd94999c2a18cb8f285cb6/sb3_contrib/common/recurrent/policies.py#L83-L87

joelmichelson commented 1 year ago

I apologize for the confusion. I'm not sure why I was using a list for net_arch.