hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

Environment checker returns assertion error contradicting debug statements #1168

Closed techboy-coder closed 2 years ago

techboy-coder commented 2 years ago

While trying to create my own custom environment, I run into a strange problem. According to stablebaseline3's custom environment checker (check_env), my observation (returned by step()) does not match the specified observation space. At the same time, I print out the shapes and data types of all the observations returned, as well as the observation space example, which all match and give me the same shape and data type. The error message seems to be contradicting my debug statements (print line).

# terminal output. See full output below.
Sample Observation (462,) int32
Reset Observation (462,) int32
Step Observation (462,) int32

Error: I am getting the following error: AssertionError: The observation returned by the `step()` method does not match the given observation space

Full error message: ```python ╰─> python .\app.py py version: 3.9.12 (main, Apr 4 2022, 05:22:27) [MSC v.1916 64 bit (AMD64)] Dimensions: 462 (flattened). Sample Observation (462,) int32 Reset Observation (462,) int32 Step Observation (462,) int32 ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ C:\Users\Shivr\Documents\Code\Projects\MA9\home\app.py:157 in │ │ │ │ 154 │ │ return self.obs │ │ 155 │ │ 156 env = SnakeEnv() │ │ ❱ 157 check_env(env) │ │ │ │ C:\Users\Shivr\miniconda3\lib\site-packages\stable_baselines3\common\env_checker.py:288 in │ │ check_env │ │ │ │ 285 │ │ │ ) │ │ 286 │ │ │ 287 │ # ============ Check the returned values =============== │ │ ❱ 288 │ _check_returned_values(env, observation_space, action_space) │ │ 289 │ │ │ 290 │ # ==== Check the render method and the declared render modes ==== │ │ 291 │ if not skip_render_check: │ │ │ │ C:\Users\Shivr\miniconda3\lib\site-packages\stable_baselines3\common\env_checker.py:172 in │ │ _check_returned_values │ │ │ │ 169 │ │ │ │ raise AssertionError(f"Error while checking key={key}: " + str(e)) from │ │ 170 │ │ │ 171 │ else: │ │ ❱ 172 │ │ _check_obs(obs, observation_space, "step") │ │ 173 │ │ │ 174 │ # We also allow int because the reward will be cast to float │ │ 175 │ assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a │ │ │ │ C:\Users\Shivr\miniconda3\lib\site-packages\stable_baselines3\common\env_checker.py:112 in │ │ _check_obs │ │ │ │ 109 │ elif _is_numpy_array_space(observation_space): │ │ 110 │ │ assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name} │ │ 111 │ │ │ ❱ 112 │ assert observation_space.contains( │ │ 113 │ │ obs │ │ 114 │ ), f"The observation returned by the `{method_name}()` method does not match the giv │ │ 115 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ AssertionError: The observation returned by the `step()` method does not match the given observation space ```
Complete code. ```python from rich import pretty pretty.install() from rich.traceback import install install(show_locals=False) import sys print("py version:", sys.version) from stable_baselines3.common.env_checker import check_env import tqdm import os import threading import websocket import numpy as np import numpy import time import json import nanoid import gym from gym import spaces def dictToJson(d): return json.dumps(d) def jsonToDict(j): return json.loads(j) # websocket.enableTrace(True) class Client: def __init__(self, ViewDistance=10): self.url = "ws://localhost:8080/ws" self.wsapp = websocket.WebSocketApp(self.url, on_message=self.on_message, on_open=self.on_open, on_close=self.on_close) self.wst = threading.Thread(target=self.wsapp.run_forever) self.wst.daemon = True self.wst.start() conn_timeout = 5 while not self.wsapp.sock.connected and conn_timeout: time.sleep(1) conn_timeout -= 1 self.id = nanoid.generate(size = 4) self.job = "" self.viewdistance = ViewDistance self.field = [] self.direction = -1 self.score = 0 self.dead = False self.data = np.empty(shape=(2*self.viewdistance+1, 2*self.viewdistance+1)) self.join() def send(self, msg): self.wsapp.send(msg) def join(self): self.send(self.joinMsg()) def joinMsg(self): msg = { "Task":{ "Job":"join" }, "Player":{ "Name":"bot #"+ self.id, "ViewDistance":self.viewdistance } } return dictToJson(msg) def step(self, direction=-1): if direction >= 0 and direction <= 3: msg = { "Task":{ "Job":"turn", "Amount":direction } } msg = dictToJson(msg) self.send(msg) time.sleep(0.05) for row in self.field: ix = row["X"]+self.viewdistance iy = row["Y"]+self.viewdistance # Empty = 0, food = 1, snakebody = 2, snakehead = 3 rowtype = row["Type"] rowSnakeHead = row["SnakeHead"] fieldtype = 0 if rowtype == 1: fieldtype = 1 if rowtype == 2: if rowSnakeHead: fieldtype = 3 else: fieldtype = 2 self.data[ix][iy] = fieldtype return self.data, self.direction, self.score, self.dead def on_open(self, ws): # print("Open Connection, sending join request...") ws.send(self.joinMsg()) def on_message(self, wsappm, msg): msg = jsonToDict(msg) self.job = msg["Task"]["Job"] if self.job == "update": self.field = msg["Field"] self.direction = msg["Task"]["Amount"] self.score = msg["Player"]["Score"] if self.job == "dead": self.dead = True self.wsapp.close() # print(f"\r {self.id} moving {self.direction} has score of {self.score}.", end = " ") def on_close(self, ws, close_status_code, close_msg): self.dead = True print("Connection closed") def close(self): self.wsapp.close() print(f"Connection for {self.id} closed") class SnakeEnv(gym.Env): def __init__(self, viewdistance=10): super(SnakeEnv, self).__init__() self.viewdistance = viewdistance self.action_space = spaces.Discrete(4) # space = (2*viewdistance+1, 2*viewdistance+1) + 1 row for direction (which means cols stays same, extra row) # Flatten obs space dims = (2*viewdistance+1+1)*(2*viewdistance+1) print(f"Dimensions: {dims} (flattened).") self.observation_space = spaces.Box(-1.0, 4.0, shape=(dims,), dtype=int) sample = self.observation_space.sample() print("Sample Observation", np.shape(sample), sample.dtype) # init blank game self.obs = np.zeros(shape=(dims,)).astype(dtype=int) # Client connects to realtime game self.client = Client(ViewDistance=viewdistance) def step(self, action): data, direction, score, dead = self.client.step(direction = action) dirobs = np.empty(2*self.viewdistance+1) dirobs.fill(direction) obs = np.vstack([data, dirobs]) obs = obs.flatten() obs = obs.astype(int) print("Step Observation",np.shape(obs), obs.dtype) self.obs = obs return obs, score, dead, {} def close(self): self.client.close() def reset(self): self.client = Client(ViewDistance=self.viewdistance) print("Reset Observation",np.shape(self.obs), self.obs.dtype) return self.obs env = SnakeEnv() check_env(env) ``` I am not sure if this is an error on the Stable Baslines side or if I am doing something wrong here. I would appreciate it if you could help me or point me in the right direction.
araffin commented 2 years ago

Hello, you posted your issue in the old SB2 repo, but you are using sb3... Did you check the values too? they must be in the defined limits.

techboy-coder commented 2 years ago

Hey 👋, Thanks for letting me know about the wrong repo. I just found the new one... I was also able to solve the issue.

I found that np.empty() (in Client class during the initialization of self.data) fills up the array with the maximum and minimum signed 32-bit int values. So yes, I had to define limits (I just arr.fill(0) my array). Not doing so triggered the assertion error.


│                                                                                                  │
│ ╭─────────────────────────────────────────── locals ───────────────────────────────────────────╮ │
│ │       method_name = 'step'                                                                   │ │
│ │               obs = array([          0,           0,           0,           0, -2147483648,  │ │
│ │                     │   │   │   │    0, -2147483648, -2147483648,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │      -2147483648, -2147483648, -2147483648, -2147483648, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │      -2147483648, -2147483648, -2147483648, -2147483648, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0, -2147483648,           0, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0, -2147483648, -2147483648, -2147483648,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │      -2147483648,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0, -2147483648,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0, -2147483648,  │ │
│ │                     │      -2147483648, -2147483648, -2147483648,           0, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0, -2147483648, -2147483648, -2147483648,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │      -2147483648, -2147483648, -2147483648,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │      -2147483648,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │      -2147483648, -2147483648, -2147483648,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648,           0,  │ │
│ │                     │      -2147483648, -2147483648,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0, -2147483648,           0,           0,     4999656,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │      -2147483648, -2147483648, -2147483648, -2147483648,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │      -2147483648, -2147483648, -2147483648, -2147483648,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │      -2147483648, -2147483648, -2147483648, -2147483648, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
│ │                     │   │   │   │    0,           0,           0, -2147483648,           0,  │ │
│ │                     │   │   │   │    0, -2147483648,           0,           0, -2147483648,  │ │
│ │                     │   │   │   │    0,           0,           0,           0,           0,  │ │
# ...

Also, thanks for taking the time and replying :).