araffin / rl-baselines-zoo

A collection of 100+ pre-trained RL agents using Stable Baselines, training and hyperparameter optimization included.
https://stable-baselines.readthedocs.io/
MIT License
1.12k stars 208 forks source link

A custom policy for Dict observation spaces. #23

Closed tatsubori closed 3 years ago

tatsubori commented 5 years ago

I think this is a bit large task so let me raise as an issue here. Currently flattening Dict observation space (combinations like image and text) for gym envs like MiniGrid is on going. MlpPolicy can handle it but might look awkward without using appropriating feature extractions such as CNN, etc.

CustomPolicies are the ways to go.

araffin commented 5 years ago

I think you will need to define a policy that inherits from FeedFowardPolicy (or even ActorCriticPolicy), not Mlp or Cnn, because it will be a mix.

araffin commented 5 years ago

Looking more closely at the gym miniworld envs, you don't need a custom policy in fact. You directly use a provided wrapper: https://github.com/maximecb/gym-minigrid/blob/999599a412db112bc7efa9a0f72f8c315074f8bb/gym_minigrid/wrappers.py#L144

tatsubori commented 5 years ago

In order to implement something similar to what is available as: https://github.com/lcswillems/rl-starter-files we need a custom policy anyway.

For examples with MiniGrid, given a Dict of image, mission text, and others -> we can either extract only image (ImgObsWrapper obs['image']), or flatten everything into a 1D vector (FlatObsWrapper, a MiniGrid-specific version of gym.wrappers.FlattenDictWrapper). Pretty much depending on gym and policies.

I locally made it able to run with ImgObsWrapper - CustomCnnPolicy:

class CustomGridCnnPolicy(BasePolicy): # as common.FeedForwardPolicy
    """
    Assuming gym_minigrid.wrappers.ImgObsWrapper
    """

    def __init__(self, *args, **kwargs):
        print("CustomGridCnnPolicy(): {} {}".format(args, kwargs))
        super(CustomGridCnnPolicy, self).__init__(*args, **kwargs,
                    cnn_extractor=rl_starter_cnn)

which doesn't converge faster than MlpPolicy. ;->

I am trying Cnn-Lstm according to one of the rl-starter-files models, but still ignoring mission text. Then we can try FlatObsWrapper - another policy recovering image and text from a given vector.

araffin commented 4 years ago

now in the roadmap of V3: https://github.com/DLR-RM/stable-baselines3/issues/1

araffin commented 3 years ago

closing this in favor of https://github.com/DLR-RM/stable-baselines3/pull/243