DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.92k stars 1.68k forks source link

[Question] Using images to train DDPG+HER agent #287

Closed SilviaZirino closed 3 years ago

SilviaZirino commented 3 years ago

Hello, I’m trying to train a DDPG+HER agent that interacts with a custom environment and that takes in input the RGB image of the environment. From what I understand, in the previous version of stable baselines only 1D observation spaces were supported in HER (as indicated also in HERGoalEnvWrapper), thus excluding image observations

       if len(goal_space_shape) == 2:
            assert goal_space_shape[1] == 1, "Only 1D observation spaces are supported yet"
        else:
            assert len(goal_space_shape) == 1, "Only 1D observation spaces are supported yet"

In this new version of stable baselines, I see no evident assertion against using 2D spaces but in ObsDictWrapper, the observation and goal dimensions are taken from the first dimension shape only

    if isinstance(self.spaces[0], spaces.Discrete):
        self.obs_dim = 1
        self.goal_dim = 1
    else:
        self.obs_dim = venv.observation_space.spaces["observation"].shape[0]
        self.goal_dim = venv.observation_space.spaces["achieved_goal"].shape[0]

Question

Is it possible to train a DDPG+HER agent from images using the implementation of stable baselines 3?

araffin commented 3 years ago

Hello, This feature is currently not supported by SB3, but this should be the case once we refactored HER after #243 is merged ;)

If you want to work with images directly (which I do not recommend anyway, better to extract an intermediate representation first), you will need to update the ObsDictWrapper so that the conversion from dict to vector and vector to dict works properly.

a quick example:

import numpy as np
import gym
from gym import spaces

from stable_baselines3 import DDPG, HER

class CustomEnv(gym.GoalEnv):
    """Custom Environment that follows gym interface"""

    metadata = {"render.modes": ["human"]}

    def __init__(self):
        super(CustomEnv, self).__init__()
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        N_CHANNELS = 1
        HEIGHT = 64
        WIDTH = 64
        obs_shape = (N_CHANNELS, HEIGHT, WIDTH)
        # obs_shape = (HEIGHT, WIDTH, N_CHANNELS)
        # Example for using image as input (can be channel-first or channel-last):
        self.observation_space = spaces.Dict(
            {
                "observation": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8),
                "achieved_goal": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8),
                "desired_goal": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8),
            }
        )

    def step(self, action):
        reward = 0.0
        done = False
        return self.observation_space.sample(), reward, done, {}

    def compute_reward(self, achieved_goal, desired_goal, info):
        return np.zeros((len(achieved_goal),))

    def reset(self):
        return self.observation_space.sample()

    def render(self, mode="human"):
        pass

model = HER("MlpPolicy", CustomEnv(), DDPG, verbose=1, buffer_size=1000, learning_starts=100, max_episode_length=100)

model.learn(50000)
araffin commented 3 years ago

The refactored HER is available here: https://github.com/DLR-RM/stable-baselines3/pull/351 It notably adds support for images.

JimLiu1213 commented 3 years ago

Hi @araffin I am also using DDPG+HER with only the observation as images, on Colab. The images are (4, 64, 64) of shape. The problem I have here is that memory is not enough on Colab even though I reduce the buffer size to 1e5. The error of 'tcmalloc: large alloc 3276800000 bytes == 0x55ab7f82a000' are shown. Therefore, I changed it to 50000, with cnn_output_dim of 128. But the agent seemed does not learn anything. Could please please give me some ideas how to solve this issue. Thank you very much.