Closed tatsubori closed 3 years ago
I think you will need to define a policy that inherits from FeedFowardPolicy (or even ActorCriticPolicy), not Mlp or Cnn, because it will be a mix.
Looking more closely at the gym miniworld envs, you don't need a custom policy in fact. You directly use a provided wrapper: https://github.com/maximecb/gym-minigrid/blob/999599a412db112bc7efa9a0f72f8c315074f8bb/gym_minigrid/wrappers.py#L144
In order to implement something similar to what is available as: https://github.com/lcswillems/rl-starter-files we need a custom policy anyway.
For examples with MiniGrid, given a Dict of image, mission text, and others -> we can either extract only image (ImgObsWrapper obs['image']), or flatten everything into a 1D vector (FlatObsWrapper, a MiniGrid-specific version of gym.wrappers.FlattenDictWrapper). Pretty much depending on gym and policies.
I locally made it able to run with ImgObsWrapper - CustomCnnPolicy:
class CustomGridCnnPolicy(BasePolicy): # as common.FeedForwardPolicy
"""
Assuming gym_minigrid.wrappers.ImgObsWrapper
"""
def __init__(self, *args, **kwargs):
print("CustomGridCnnPolicy(): {} {}".format(args, kwargs))
super(CustomGridCnnPolicy, self).__init__(*args, **kwargs,
cnn_extractor=rl_starter_cnn)
which doesn't converge faster than MlpPolicy. ;->
I am trying Cnn-Lstm according to one of the rl-starter-files models, but still ignoring mission text. Then we can try FlatObsWrapper - another policy recovering image and text from a given vector.
now in the roadmap of V3: https://github.com/DLR-RM/stable-baselines3/issues/1
closing this in favor of https://github.com/DLR-RM/stable-baselines3/pull/243
I think this is a bit large task so let me raise as an issue here. Currently flattening Dict observation space (combinations like image and text) for gym envs like MiniGrid is on going. MlpPolicy can handle it but might look awkward without using appropriating feature extractions such as CNN, etc.
CustomPolicies are the ways to go.