hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.14k stars 723 forks source link

Images as observation_space in Pendulum-v0 (or in Classic Control) #915

Closed meric-sakarya closed 4 years ago

meric-sakarya commented 4 years ago

Hello, I am working on the Pendulum-v0 environment and using SAC and TD3 implementations from Stable Baselines. I want to alter the observation_space so I can train the model using image inputs, instead of the float32 array it currently uses. Any help regarding how to solve this task would be greatly appreciated. Also, if there isn't already one, a framework or a wrapper to easily use image inputs as observations would be a nice feature to have.

Miffyli commented 4 years ago

stable-baselines agents support taking images as inputs if you set the policy to be "CnnPolicy", in which case the input is processed with a small CNN network. You need to modify those Gym envs itself to get the image output. Seems like the solution is to simply call image_obs = env.render("rgb_array") and provide that in step function, which should be easy with a wrapper along these lines (not tested / complete, also you might want to resize the image to something smaller):

from gym import Wrapper, spaces

class RGBArrayAsObservationWrapper(Wrapper):
    """
    Use env.render(rgb_array) as observation
    rather than the observation environment provides
    """
    def __init__(self, env):
        # TODO this might not work before environment has been reset
        dummy_obs = env.render("rgb_array")
        # Update observation space
        # TODO assign correct low and high
        self.observation_space  = spaces.Box(low=0, high=255, shape=dummy_obs.shape, dtype=dummy_obs.dtype)

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        obs = env.render("rgb_array")
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        obs = env.render("rgb_array")
        return obs, reward, done, info
meric-sakarya commented 4 years ago

First, thank you very much for your quick and detailed response! I added self.reset() to the first line with TODO and changed the env.render() calls to self.env.render() calls. Here is the complete code that I am using.

import gym
import numpy as np
from gym.envs.classic_control import PendulumEnv
from stable_baselines.common.env_checker import check_env

from stable_baselines.sac.policies import CnnPolicy
from stable_baselines import TD3

from gym import Wrapper, spaces

class RGBArrayAsObservationWrapper(Wrapper):

    def __init__(self, env):
        super(RGBArrayAsObservationWrapper, self).__init__(env)
        self.reset()
        dummy_obs = env.render('rgb_array')
        self.observation_space = spaces.Box(low=0, high=255, shape=dummy_obs.shape, dtype=dummy_obs.dtype)

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        obs = self.env.render("rgb_array")
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        obs = self.env.render("rgb_array")
        return obs, reward, done, info

TEST_COUNT = 100

pendulum_env = PendulumEnv()
pendulum_env = RGBArrayAsObservationWrapper(pendulum_env)
check_env(pendulum_env, warn=True)

model = TD3(CnnPolicy, pendulum_env, verbose=1)
model.learn(total_timesteps=10_000, log_interval=10)
model.save("td3_pendulum")

sum_rewards = 0
for i in range(TEST_COUNT):
    while not done:
        action, _states = model.predict(obs)
        obs, rewards, done, info = pendulum_env.step(action)
        pendulum_env.render()
        sum_rewards += rewards

    pendulum_env.reset()
    done = False

print(sum_rewards / TEST_COUNT)

My goal is simple, trying to figure out the performance of the two algorithms TD3 and SAC when working with image inputs. My issues are listed below.

Process finished with exit code 1


Again any help would be greatly appreciated.
Miffyli commented 4 years ago

I am glad to hear you got the code to work 👍

SAC/TD3: Did you also change the policy to one from TD3, not SAC? Also regarding this: SAC/TD3 are known to be inefficient in learning from images. I suggest you try A2C/PPO, unless SAC/TD3 are the very things you want to study.

I can't figure out the performance because the logging does not seem to work, even though verbose is set to be 1.

The logs are displayed after fixed number of updates, and if the training is slow this can take a long while.

I tried debugging to figure out how fast training is done, and it is incredibly slow. Does anyone have any idea what might be causing the training to be this slow?

I imagine the rendering code is not very fast as it is meant for debugging/enjoyment use (observed by humans). Also remember to resize the image to something smaller (e.g. 40x40), because the original image is probably too large.

I don't want the video to pop-up after executing self.env.render() in the init function since there is no need.

Likely unavoidable as the rendering is done with OpenGL based code and OpenGL requires valid screen surface to draw on.

If there are no more issues related to stable-baselines, you can close the issue.

meric-sakarya commented 4 years ago

Thanks for the fast answer yet again, I am amazed by how fast you are replying! :)

Did you also change the policy to one from TD3, not SAC?

Yes, it started working but the learning phase is even slower than TD3. I have been waiting for just one time_step for about 10 minutes now.

Also remember to resize the image to something smaller (e.g. 40x40), because the original image is probably too large.

I am an absolute beginner when it comes to stable_baselines, can you maybe explain how I might do that?

Miffyli commented 4 years ago

Again, this is not a place for tech support, and I am closing the issue as there does not seem to be issues related to stable-baselines.

Yes, it started working but the learning phase is even slower than TD3. I have been waiting for just one time_step for about 10 minutes now.

You could double-check how fast the environment is with random agent (action = env.action_space.sample()).

I am an absolute beginner when it comes to stable_baselines, can you maybe explain how I might do that?

Using a simple image resize function, like one from scikit-images.

Edit:

Thanks for the fast answer yet again, I am amazed by how fast you are replying! :)

Cheers :). I try my best.