DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.09k stars 1.7k forks source link

predict method fails after learn with Dict observation space #1661

Closed jamesmr13 closed 1 year ago

jamesmr13 commented 1 year ago

🐛 Bug

My specs:

Windows Python 3.11

Stable_baselines3

Name: stable-baselines3
Version: 2.1.0
Summary: Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.
Home-page: https://github.com/DLR-RM/stable-baselines3
Author: Antonin Raffin
Author-email: antonin.raffin@dlr.de
License: MIT
Location: C:\Users\jrober23\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages
Requires: cloudpickle, gymnasium, matplotlib, numpy, pandas, torch
Required-by: sb3-contrib

Gymnasium

Name: gymnasium
Version: 0.28.1
Summary: A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym).
Home-page:
Author:
Author-email: Farama Foundation <contact@farama.org>
License: MIT License
Location: C:\Users\jrober23\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages
Requires: cloudpickle, farama-notifications, jax-jumpy, numpy, typing-extensions
Required-by: Shimmy, stable-baselines3

I am using a custom gym environment with a Dict observation space shown here

self.observation_space = Dict({'bearings': Box(low = -26.0, high = 50.0, shape=(720,),dtype = np.float64),
                                'position': Box(low = np.array([0,0]), high = np.array(screen_size), shape = (2,)),
                                'wp_info': Box(low=np.array([0,0]), high = np.array((self.max_dist,360)), shape=(2,), dtype=np.float64)})

After I have the model learn I want to test it on the environment in a pygame rendering mode code for that is here:

env = ENV.BearingsOnlyNavigationDiscrete(render_mode=None, screen_size=(200,200))

model = PPO("MultiInputPolicy", env, verbose=1, n_steps=512)
model.learn(total_timesteps=int(timesteps))

env = ENV.BearingsOnlyNavigationDiscrete(render_mode="human", screen_size=(200,200))
obs = env.reset()
while True:
    action, _ = model.predict(obs)
    obs, r, end, _, _ = env.step(action)
    if end:
        obs = env.reset()

My learn method seemingly functions normally as this is the output from 1024 timesteps prior to the error being thrown.

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 29.8     |
|    ep_rew_mean     | -35.9    |
| time/              |          |
|    fps             | 1356     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 512      |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 26.5        |
|    ep_rew_mean          | -30.4       |
| time/                   |             |
|    fps                  | 1007        |
|    iterations           | 2           |
|    time_elapsed         | 1           |
|    total_timesteps      | 1024        |
| train/                  |             |
|    approx_kl            | 0.011428952 |
|    clip_fraction        | 0.116       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.48       |
|    explained_variance   | 0.00251     |
|    learning_rate        | 0.0003      |
|    loss                 | 94.8        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0202     |
|    value_loss           | 212         |
-----------------------------------------

I get the same error with trying to call predict on the SimpleMultiObsEnv provided by stable_baselines3 so I have put that code for the minimal example to reproduce the bug.

Code example

from stable_baselines3 import PPO
from stable_baselines3.common.envs import SimpleMultiObsEnv

env = SimpleMultiObsEnv(random_start=False)

model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=1024)

i = 0
obs = env.reset()
while i < 200:
    action, _ = model.predict(obs)
    obs, r, end, _, _ = env.step(action)
    i = i + 1
    if end:
        obs = env.reset()

Relevant log output / Error message

Traceback (most recent call last):
  File "c:\Users\Public\Documents\OpenGym\PPO_agent.py", line 56, in <module>
    action, _ = model.predict(obs)
                ^^^^^^^^^^^^^^^^^^
  File "C:\Users\jrober23\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\stable_baselines3\common\base_class.py", line 555, in predict
    return self.policy.predict(observation, state, episode_start, deterministic)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jrober23\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\stable_baselines3\common\policies.py", line 346, in predict
    observation, vectorized_env = self.obs_to_tensor(observation)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jrober23\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\stable_baselines3\common\policies.py", line 264, in obs_to_tensor
    vectorized_env = is_vectorized_observation(observation, self.observation_space)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jrober23\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\stable_baselines3\common\utils.py", line 399, in is_vectorized_observation 
    return is_vec_obs_func(observation, observation_space)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jrober23\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\stable_baselines3\common\utils.py", line 349, in is_vectorized_dict_observation
    if observation[key].shape != subspace.shape:
       ~~~~~~~~~~~^^^^^
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

System Info

No response

Checklist

araffin commented 1 year ago

Duplicate of https://github.com/DLR-RM/stable-baselines3/issues/1637#issuecomment-1660158096

jammingchiu commented 1 year ago

observation outputis:[{'observation': array([ 3.8439669e-02, -2.1944723e-12, 1.9740014e-01, 0.0000000e+00, -0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 7.5348794e-02, 9.0728514e-02, 2.0000000e-02, 0.0000000e+00, -0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32), 'achieved_goal': array([0.07534879, 0.09072851, 0.02 ], dtype=float32), 'desired_goal': array([-0.00748878, 0.07134774, 0.1988182 ], dtype=float32)} {'is_success': array(False)}]

obeservation.spaces.items() output is: odict_items([('achieved_goal', Box(-10.0, 10.0, (3,), float32)), ('desired_goal', Box(-10.0, 10.0, (3,), float32)), ('observation', Box(-10.0, 10.0, (19,), float32))])

so you can try this: for key, subspace in observation_space.spaces.items(): obs = observation[0][key]
if obs.shape != subspace.shape: all_non_vectorized = False break