hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.14k stars 723 forks source link

Fix re-training with different number of environments #1133

Closed balisujohn closed 3 years ago

balisujohn commented 3 years ago

A patch fixing the ability of a loaded PPO2 model to further train on a vectorized environment of a different length than the one it was initially trained on.

Description

A test was added to test_ppo2.py which tests the ability of a loaded PPO2 model to train on a vectorized environment of a different length than the one it was initially trained on. set_env was overridden in ppo2.py to explicitly update n_batch to equal n_envs * n_steps as well as calling the set_env function of the super class.

Motivation and Context

Prior to this patch, the code in the issue in the linked pull request will fail, the patch aims to fix that. closes #1132

Types of changes

Checklist:

balisujohn commented 3 years ago

Also possibly worth noting, I tried to reproduce this error with stablebaselines3 since development is more active there and it seems like stablebaselines3 can handle at least a simple version of this situation without crashing.

`

derived from https://github.com/DLR-RM/stable-baselines3

import gym

from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.env_util import make_vec_env

env = make_vec_env("CartPole-v1", n_envs=4)

model = PPO("MlpPolicy", env,verbose=1, n_steps = 10) model.learn(total_timesteps=10000)

model.save("test_model") del model

env =make_vec_env("CartPole-v1", n_envs=1)

model=PPO.load("test_model", env)

model.learn(total_timesteps = 10000)`

(ran without error)

Miffyli commented 3 years ago

Sorry for the delay! Indeed SB3 is more refined in this regard.

Seems like this issue could happen with other algos so they should be fixed as well (e.g. A2C).

@araffin thoughts on merging this? Seems like an appropriate maintenance mode fix.

araffin commented 3 years ago

Sorry for the delay! Indeed SB3 is more refined in this regard.

yes, SB3 is the recommend solution.

Seems like this issue could happen with other algos so they should be fixed as well (e.g. A2C).

yes, probably the case for all A2C-like algorithms (A2C, ACER, ACKTR and maybe TRPO).

It seems we also need to fix type annotation (the new version of pytype does not like when things are not properly marked as optional).

@araffin thoughts on merging this? Seems like an appropriate maintenance mode fix.

I would be happy to merge it after the pytype issues are fixed (you can keep this PR to fix them).

Regarding the test, please use tmp_path pytest argument (automatic temp folder) rather than saving/loading in the same folder (I know that some tests do not use it, but should be the case in SB3 tests).