DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.97k stars 1.68k forks source link

[Question] What format does obs need to be in when passed to model.policy.get_distributon(obs)? #1497

Closed lquantrill closed 1 year ago

lquantrill commented 1 year ago

❓ Question

I am trying to get the action distribution out from an SB3 PPO agent and am trying to use 'model.policy.get_distribution(obs)', however, I am getting some difficult to understand errors to do with the shape of the observation passed in. The code I am using is as follows:

env = make_atari_env('Seaquest-v4', n_envs=8, seed=1)
env = VecFrameStack(env, n_stack=4)

model = PPO.load(model_path)

obs = env.reset()
done = np.full(8, False)

print(obs.shape)  # =(8, 84, 84, 4)

model.policy.get_distribution(torch.from_numpy(obs).cuda())

However, I am getting the following runtime error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 model.policy.get_distribution(torch.from_numpy(obs).cuda())

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/stable_baselines3/common/policies.py:712, in ActorCriticPolicy.get_distribution(self, obs)
    705 def get_distribution(self, obs: th.Tensor) -> Distribution:
    706     """
    707     Get the current policy distribution given the observations.
    708 
    709     :param obs:
    710     :return: the action distribution.
    711     """
--> 712     features = super().extract_features(obs, self.pi_features_extractor)
    713     latent_pi = self.mlp_extractor.forward_actor(features)
    714     return self._get_action_dist_from_latent(latent_pi)

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/stable_baselines3/common/policies.py:131, in BaseModel.extract_features(self, obs, features_extractor)
    123 """
    124 Preprocess the observation if needed and extract features.
    125 
   (...)
    128  :return: The extracted features
    129 """
    130 preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
--> 131 return features_extractor(preprocessed_obs)

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/stable_baselines3/common/torch_layers.py:106, in NatureCNN.forward(self, observations)
    105 def forward(self, observations: th.Tensor) -> th.Tensor:
--> 106     return self.linear(self.cnn(observations))

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/torch/nn/modules/container.py:204, in Sequential.forward(self, input)
    202 def forward(self, input):
    203     for module in self:
--> 204         input = module(input)
    205     return input

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
   1186 # If we don't have any hooks, we want to skip the rest of the logic in
   1187 # this function, and just call forward.
   1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190     return forward_call(*input, **kwargs)
   1191 # Do not call functions when jit is used
   1192 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/torch/nn/modules/conv.py:463, in Conv2d.forward(self, input)
    462 def forward(self, input: Tensor) -> Tensor:
--> 463     return self._conv_forward(input, self.weight, self.bias)

File ~/Documents/sprint3/sb3_venv/lib/python3.10/site-packages/torch/nn/modules/conv.py:459, in Conv2d._conv_forward(self, input, weight, bias)
    455 if self.padding_mode != 'zeros':
    456     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    457                     weight, bias, self.stride,
    458                     _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
    460                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [32, 4, 8, 8], expected input[8, 84, 84, 4] to have 4 channels, but got 84 channels instead

Please could you help me understand what format the obs I am passing in needs to be in? As you can see, the shape that I pass in is (8, 84, 84, 4), which matches the one that the error message says it wants. I am unsure how the shape ends up as (32, 3, 8, 8), and I am unsure how to fix this.

Many thanks for your help!

Checklist

araffin commented 1 year ago

weight of size [32, 4, 8, 8], expected input[8, 84, 84, 4] to have 4 channels, but got 84 channels instead

Almost there, PyTorch expects channel first images, so you need to use VecTransposeImage in your case: https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_transpose.py

During training and at test time, SB3 does it for you automatically: https://github.com/DLR-RM/stable-baselines3/blob/fd0cd82339511b54cd3907df228a656f2a32f0b8/stable_baselines3/common/policies.py#L254-L257

https://github.com/DLR-RM/stable-baselines3/blob/fd0cd82339511b54cd3907df228a656f2a32f0b8/stable_baselines3/common/preprocessing.py#LL72C1-L89C1

lquantrill commented 1 year ago

That's worked perfectly. Thanks for your help and fast response time! Really appreciate it