araffin / rl-baselines-zoo

A collection of 100+ pre-trained RL agents using Stable Baselines, training and hyperparameter optimization included.
https://stable-baselines.readthedocs.io/
MIT License
1.12k stars 208 forks source link

What observation space/environmet were the pre-trained Atart DQN agents trained using? #45

Closed adamtupper closed 4 years ago

adamtupper commented 4 years ago

Describe the bug

Unable to load the pre-trained Atari DQN agents because the observation space doesn't match the Atari environment observation space.

Code example

(Using the Space Invaders model copied from the trained_agents directory.)

env = make_atari('SpaceInvadersNoFrameskip-v4')
model = DQN.load('../../atari-models/SpaceInvadersNoFrameskip-v4.pkl')
obs = env.reset()

action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)

plt.imshow(obs)
plt.show()

Error Message

ValueError                                Traceback (most recent call last)
<ipython-input-21-4057e0f0b223> in <module>
      3 obs = env.reset()
      4 
----> 5 action, _states = model.predict(obs)
      6 obs, rewards, dones, info = env.step(action)
      7 

~/miniconda3/envs/baselines/lib/python3.6/site-packages/stable_baselines/common/base_class.py in predict(self, observation, state, mask, deterministic)
    717             mask = [False for _ in range(self.n_envs)]
    718         observation = np.array(observation)
--> 719         vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
    720 
    721         observation = observation.reshape((-1,) + self.observation_space.shape)

~/miniconda3/envs/baselines/lib/python3.6/site-packages/stable_baselines/common/base_class.py in _is_vectorized_observation(observation, observation_space)
    647                                  "Box environment, please use {} ".format(observation_space.shape) +
    648                                  "or (n_env, {}) for the observation shape."
--> 649                                  .format(", ".join(map(str, observation_space.shape))))
    650         elif isinstance(observation_space, gym.spaces.Discrete):
    651             if observation.shape == ():  # A numpy array of a number, has shape empty tuple '()'

ValueError: Error: Unexpected observation shape (210, 160, 3) for Box environment, please use (84, 84, 4) or (n_env, 84, 84, 4) for the observation shape.

System Info Describe the characteristic of your environment:

araffin commented 4 years ago

Hello, Please look at the code or the documentation, you must use make_atari_env and VecFrameStack.

adamtupper commented 4 years ago

Thanks for the heads up! I hadn't clicked on the difference between make_atari and make_atari_env.