hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

MlpPolicy with MultiDiscrete environment #1103

Closed mg64ve closed 3 years ago

mg64ve commented 3 years ago

Hi, I am running stable-baselines in docker with cuda 10. Of course, the following example from documentation works:

import gym

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common import make_vec_env
from stable_baselines import PPO2

# multiprocess environment
env = make_vec_env('CartPole-v1', n_envs=4)

model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo2_cartpole")

del model # remove to demonstrate saving and loading

model = PPO2.load("ppo2_cartpole")

# Enjoy trained agent
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

I want to test the same with MultiDiscrete environment. For this reason I wrote CustomCartpole environment that uses 2D action space (it is only a simple example):

import gym

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common import make_vec_env
from stable_baselines import PPO2
from stable_baselines.common.env_checker import check_env

import gym.envs.classic_control
import math
import numpy as np

class CustomCartpole(gym.envs.classic_control.CartPoleEnv):
    """Add a dimension to the cartpole action space that is used as 'speed' button."""

    def __init__(self):
        super().__init__()
        self.force_mag = 5.0
        self.action_space = gym.spaces.MultiDiscrete([2, 4])

    def step(self, action):
        err_msg = "%r (%s) invalid" % (action, type(action))
        print('{}'.format(err_msg))
        assert self.action_space.contains(action), err_msg

        x, x_dot, theta, theta_dot = self.state
        force = self.force_mag if action[0] == 1 else -self.force_mag
        force *= (action[1] + 1)
        costheta = math.cos(theta)
        sintheta = math.sin(theta)

        temp = (force + self.polemass_length * theta_dot ** 2 * sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        if self.kinematics_integrator == 'euler':
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:  # semi-implicit euler
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        self.state = (x, x_dot, theta, theta_dot)

        done = bool(
            x < -self.x_threshold
            or x > self.x_threshold
            or theta < -self.theta_threshold_radians
            or theta > self.theta_threshold_radians
        )

        if not done:
            reward = 1.0
        elif self.steps_beyond_done is None:
            # Pole just fell!
            self.steps_beyond_done = 0
            reward = 1.0
        else:
            if self.steps_beyond_done == 0:
                logger.warn(
                    "You are calling 'step()' even though this "
                    "environment has already returned done = True. You "
                    "should always call 'reset()' once you receive 'done = "
                    "True' -- any further steps are undefined behavior."
                )
            self.steps_beyond_done += 1
            reward = 0.0

        return np.array(self.state), reward, done, {}

env = CustomCartpole()
check_env(env)
model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo2_customcartpole")

del model # remove to demonstrate saving and loading

model = PPO2.load("ppo2_cutomcartpole")

# Enjoy trained agent
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

I am getting the following error:

array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid

I have checked in the documentation and I did not find anything regarding action space for MlpPolicy. I assume it should check the dimensions and adjust the policy accordingly. What is wrong?

araffin commented 3 years ago

Hello, I think you are just printing the error message even though there is no error... If there was one, this would have printed only once and stopped after raising the assert error.

mg64ve commented 3 years ago

Thanks @araffin , you are right, it keeps going. However it does not print any ep_reward_mean data. Is it learning ? The following is part of the log:

array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
--------------------------------------
| approxkl           | 0.00026964353 |
| clipfrac           | 0.0           |
| explained_variance | 0.019         |
| fps                | 215           |
| n_updates          | 14            |
| policy_entropy     | 2.0163999     |
| policy_loss        | -0.0038015784 |
| serial_timesteps   | 1792          |
| time_elapsed       | 8.49          |
| total_timesteps    | 1792          |
| value_loss         | 39.603714     |
--------------------------------------
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
-------------------------------------
| approxkl           | 0.0018063349 |
| clipfrac           | 0.015625     |
| explained_variance | 0.000335     |
| fps                | 211          |
| n_updates          | 15           |
| policy_entropy     | 2.0012653    |
| policy_loss        | -0.008501572 |
| serial_timesteps   | 1920         |
| time_elapsed       | 9.08         |
| total_timesteps    | 1920         |
| value_loss         | 42.672264    |
-------------------------------------
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 2]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 2]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([1, 0]) (<class 'numpy.ndarray'>) invalid
array([0, 3]) (<class 'numpy.ndarray'>) invalid
array([0, 1]) (<class 'numpy.ndarray'>) invalid
array([1, 1]) (<class 'numpy.ndarray'>) invalid
--------------------------------------
| approxkl           | 0.00018022978 |
| clipfrac           | 0.0           |
| explained_variance | 0.0575        |
| fps                | 218           |
| n_updates          | 16            |
| policy_entropy     | 1.9997047     |
| policy_loss        | -0.0007104698 |
| serial_timesteps   | 2048          |
| time_elapsed       | 9.69          |
| total_timesteps    | 2048          |
| value_loss         | 64.8821       |
--------------------------------------

and this is what the original script prints:

--------------------------------------
| approxkl           | 0.00018330614 |
| clipfrac           | 0.0           |
| ep_len_mean        | 21.4          |
| ep_reward_mean     | 21.4          |
| explained_variance | 0.00317       |
| fps                | 439           |
| n_updates          | 1             |
| policy_entropy     | 0.69296503    |
| policy_loss        | -0.0029364298 |
| serial_timesteps   | 128           |
| time_elapsed       | 2e-05         |
| total_timesteps    | 512           |
| value_loss         | 40.458572     |
--------------------------------------
araffin commented 3 years ago

However it does not print any ep_reward_mean data. Is it learning ?

Duplicate of https://github.com/hill-a/stable-baselines/issues/24 and it is also in the documentation, you need a Monitor wrapper for that. But yes, it is (the loss is reported).

Anyway, I would recommend you to use Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3 (the env is wrapped in a monitor wrapper automatically when possible)

closing as the original issue was solved

mg64ve commented 3 years ago

Thanks @araffin . With SB3 almost the same error but it prints ep_reward_mean:

image

Miffyli commented 3 years ago

Again, as arrafin pointed , that "error message" comes from your environment, not from stable-baselines (the line where you print). We do not offer custom tech support for fixing custom environments.

araffin commented 3 years ago

There is no error, just a print.

mg64ve commented 3 years ago

ok thanks @Miffyli and @araffin . I did not understand that print is because my environment.