Closed Leong1230 closed 2 years ago
Hello,
Please fill up the issue template completely and provide a minimal working example to reproduce your issue.
I tried to reproduce it with a gym env but could not yet:
import gym
import torch as th
import torch.nn as nn
from stable_baselines3 import A2C
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.evaluation import evaluate_policy
class FlattenBatchNormExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input and applies batch normalization.
:param observation_space:
"""
def __init__(self, observation_space: gym.Space):
super(FlattenBatchNormExtractor, self).__init__(
observation_space,
get_flattened_obs_dim(observation_space),
)
self.flatten = nn.Flatten()
self.batch_norm = nn.BatchNorm1d(self._features_dim)
def forward(self, observations: th.Tensor) -> th.Tensor:
result = self.flatten(observations)
result = self.batch_norm(result)
return result
policy_kwargs = dict(
features_extractor_class=FlattenBatchNormExtractor
)
model = A2C("MlpPolicy", "CartPole-v1", seed=1, verbose=1, policy_kwargs=policy_kwargs)
model.learn(10_000)
model.save("test_a2c")
env = gym.make("CartPole-v1")
env.seed(1)
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
print(f"mean_reward before loading: {mean_reward}")
del model
model = A2C.load("test_a2c")
env = gym.make("CartPole-v1")
env.seed(1)
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=20, warn=False)
print(f"mean_reward after loading: {mean_reward}")
The latter (Method 2) seems to fail to load the trained model. Is it supposed to yeild the same results or not?
Duplicate of https://github.com/DLR-RM/stable-baselines3/issues/533 and https://github.com/hill-a/stable-baselines/issues/30#issuecomment-423694592
I can reproduce this. @araffin
import gym
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import VecFrameStack, VecMonitor
# params
N_sim_eval = 1000
gym_name = 'Pendulum-v0'
N_steps_train = 100000
n_stack = 2
log_dir = './Out/Logs'
print('')
print('---- Training ----')
print('')
env = DummyVecEnv([lambda: gym.make(gym_name)])
env = VecFrameStack(env, n_stack=n_stack)
env = VecMonitor(env, log_dir)
model = SAC(
'MlpPolicy',
env, verbose=1,
learning_rate=1e-4,
)
model.learn(total_timesteps=N_steps_train)
model.save('Out/save_test')
print('')
print('---- Reward After model.load() ----')
print('')
model = SAC(
'MlpPolicy',
env, verbose=1,
learning_rate=1e-4,
)
model.load('Out/save_test')
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=N_sim_eval, deterministic=True)
print(f'mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}')
print('')
print('---- Reward After SAC.load() ----')
print('')
model = SAC.load('Out/save_test')
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=N_sim_eval, deterministic=True)
print(f'mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}')
---- Reward After model.load() ----
mean_reward:-1394.89 +/- 214.19
---- Reward After SAC.load() ----
mean_reward:-144.10 +/- 80.72
I can reproduce this. @araffin
See my answer in https://github.com/DLR-RM/stable-baselines3/issues/683#issuecomment-996213446 .
model.load()
is not in-place, so you must use model = SAC.load()
(we only show that in the doc, and if you use the RL Zoo, it is done automatically for you).
Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.
Question
Sorry, I think it is not about the BN. I got different results when loading the trained model using:
and
The latter (Method 2) seems to fail to load the trained model. Is it supposed to yeild the same results or not?
Add
Checklist
Originally posted by @Leong1230 in https://github.com/DLR-RM/stable-baselines3/issues/537#issuecomment-945091297