DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.24k stars 1.71k forks source link

[Bug]: VecExtractDictObs does not handle terminal observation #1441

Closed WeberSamuel closed 1 year ago

WeberSamuel commented 1 year ago

🐛 Bug

VecExtractDictObs Wrapper currently does not update the terminal_observation entry in the info returned by VecEnv. This causes an exception when storing into a buffer, since the observation shape mismatches.

To Reproduce

from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import VecEnvWrapper, DummyVecEnv, VecExtractDictObs, VecEnv
from gym import spaces
from gym import make

class VecObsToDict(VecEnvWrapper):
    def __init__(self, venv: VecEnv):
        super().__init__(venv, spaces.Dict({"observation": venv.observation_space}))

    def step_wait(self):
        obs, rewards, dones, infos = self.venv.step_wait()
        obs = {"observation": obs}
        for info in infos:
            if "terminal_observation" in info:
                info["terminal_observation"] = {"observation": info["terminal_observation"]}
        return obs, rewards, dones, infos

    def reset(self):
        obs = self.venv.reset()
        obs = {"observation": obs}
        return obs

env = make("CartPole-v1")
venv = DummyVecEnv([lambda: env])
venv = VecObsToDict(venv)
venv = VecExtractDictObs(venv, "observation")

dqn = DQN("MlpPolicy", venv)
dqn.learn(100000, progress_bar=True)

Relevant log output / Error message

has occurred: TypeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
float() argument must be a string or a real number, not 'dict'
  File "C:\Users\samue\miniconda3\envs\thesis\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 466, in _store_transition
    next_obs[i] = infos[i]["terminal_observation"]
  File "C:\Users\samue\miniconda3\envs\thesis\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 558, in collect_rollouts
    self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)
  File "C:\Users\samue\miniconda3\envs\thesis\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 311, in learn
    rollout = self.collect_rollouts(
  File "C:\Users\samue\miniconda3\envs\thesis\lib\site-packages\stable_baselines3\dqn\dqn.py", line 269, in learn
    return super().learn(
  File "C:\Users\samue\Documents\GitHub\thesis\bug.py", line 29, in <module>
    dqn.learn(100000, progress_bar=True)
  File "C:\Users\samue\miniconda3\envs\thesis\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "C:\Users\samue\miniconda3\envs\thesis\lib\runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
TypeError: float() argument must be a string or a real number, not 'dict'

System Info

Checklist

araffin commented 1 year ago

Hello, thanks for spotting the bug and the PR =). Please do not use master branch for your fork next time, it prevents maintainers from pushing edits.