Closed lquantrill closed 1 year ago
weight of size [32, 4, 8, 8], expected input[8, 84, 84, 4] to have 4 channels, but got 84 channels instead
Almost there, PyTorch expects channel first images, so you need to use VecTransposeImage
in your case: https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_transpose.py
During training and at test time, SB3 does it for you automatically: https://github.com/DLR-RM/stable-baselines3/blob/fd0cd82339511b54cd3907df228a656f2a32f0b8/stable_baselines3/common/policies.py#L254-L257
That's worked perfectly. Thanks for your help and fast response time! Really appreciate it
❓ Question
I am trying to get the action distribution out from an SB3 PPO agent and am trying to use 'model.policy.get_distribution(obs)', however, I am getting some difficult to understand errors to do with the shape of the observation passed in. The code I am using is as follows:
However, I am getting the following runtime error:
Please could you help me understand what format the obs I am passing in needs to be in? As you can see, the shape that I pass in is (8, 84, 84, 4), which matches the one that the error message says it wants. I am unsure how the shape ends up as (32, 3, 8, 8), and I am unsure how to fix this.
Many thanks for your help!
Checklist