VecExtractDictObs Wrapper currently does not update the terminal_observation entry in the info returned by VecEnv.
This causes an exception when storing into a buffer, since the observation shape mismatches.
To Reproduce
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import VecEnvWrapper, DummyVecEnv, VecExtractDictObs, VecEnv
from gym import spaces
from gym import make
class VecObsToDict(VecEnvWrapper):
def __init__(self, venv: VecEnv):
super().__init__(venv, spaces.Dict({"observation": venv.observation_space}))
def step_wait(self):
obs, rewards, dones, infos = self.venv.step_wait()
obs = {"observation": obs}
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = {"observation": info["terminal_observation"]}
return obs, rewards, dones, infos
def reset(self):
obs = self.venv.reset()
obs = {"observation": obs}
return obs
env = make("CartPole-v1")
venv = DummyVecEnv([lambda: env])
venv = VecObsToDict(venv)
venv = VecExtractDictObs(venv, "observation")
dqn = DQN("MlpPolicy", venv)
dqn.learn(100000, progress_bar=True)
Relevant log output / Error message
has occurred: TypeError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
float() argument must be a string or a real number, not 'dict'
File "C:\Users\samue\miniconda3\envs\thesis\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 466, in _store_transition
next_obs[i] = infos[i]["terminal_observation"]
File "C:\Users\samue\miniconda3\envs\thesis\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 558, in collect_rollouts
self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)
File "C:\Users\samue\miniconda3\envs\thesis\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 311, in learn
rollout = self.collect_rollouts(
File "C:\Users\samue\miniconda3\envs\thesis\lib\site-packages\stable_baselines3\dqn\dqn.py", line 269, in learn
return super().learn(
File "C:\Users\samue\Documents\GitHub\thesis\bug.py", line 29, in <module>
dqn.learn(100000, progress_bar=True)
File "C:\Users\samue\miniconda3\envs\thesis\lib\runpy.py", line 86, in _run_code
exec(code, run_globals)
File "C:\Users\samue\miniconda3\envs\thesis\lib\runpy.py", line 196, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
TypeError: float() argument must be a string or a real number, not 'dict'
System Info
OS: Windows-10-10.0.22621-SP0 10.0.22621
Python: 3.10.9
Stable-Baselines3: 1.8.0
PyTorch: 2.0.0
GPU Enabled: True
Numpy: 1.23.5
Gym: 0.21.0
Checklist
[X] I have checked that there is no similar issue in the repo
🐛 Bug
VecExtractDictObs Wrapper currently does not update the
terminal_observation
entry in the info returned by VecEnv. This causes an exception when storing into a buffer, since the observation shape mismatches.To Reproduce
Relevant log output / Error message
System Info
Checklist