hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.14k stars 723 forks source link

Questions about CNN policy input channel [question] #1195

Open DavidLudl opened 2 months ago

DavidLudl commented 2 months ago

Hello,

I am learning how to implement the costum CNN policy and environment with the stablebaseline 3. I am following the example "Custom Feature Extractor" in this link: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html

I am confusing about the channel, which is defined as observation_space.shape[0]. When I am examing the oberservation space of gymnasium:

import gymnasium as gym

env = gym.make("BreakoutNoFrameskip-v4")
print("Observation Space Shape: ", env.observation_space.shape)
print("Image Channel: ", env.observation_space.shape[0])

The output is

Observation Space Shape:  (210, 160, 3)
Image Channel:  210

But when I excute the code in the link. There is no error. But if I pass the last item of observation space shape, n_input_channels = observation_space.shape[2], which I suppose the correct channel size. The error raised. So I want to ask, whether the SB3 reshuffle the observation space shape? And when I define my own ENV, should I set the space shape C H W or H W C (where should I put the channel)?

Thank you for your time and help.