DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.84k stars 1.68k forks source link

Setting up seed in Custom Gym environment #1932

Closed Chainesh closed 4 months ago

Chainesh commented 4 months ago

🐛 Bug

check_env(env) doesn't show anything and while running the below code I don't see any output, after using render mode I can it's stuck at first frame. Maybe the issue is with the Custom Env. Thanks :)

Code example

import gymnasium as gym from feauture_extractor import MinigridFeaturesExtractor from minigrid.wrappers import ImgObsWrapper from stable_baselines3 import PPO from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.callbacks import StopTrainingOnRewardThreshold from stable_baselines3.common.vec_env import VecTransposeImage, DummyVecEnv from minigrid.wrappers import RGBImgPartialObsWrapper import tensorboard

missions = [ "go to the red key", ]

stop_callback = StopTrainingOnRewardThreshold(reward_threshold=0.925, verbose=1)

class CustomEnv(gym.Env): def init(self, env, mission): self.env = env self.observation_space = env.observation_space self.action_space = env.action_space self.mission = mission

def reset(self, **kwargs):
    obs = self.env.reset(**kwargs)
    while self.env.mission != self.mission:
        obs = self.env.reset(**kwargs)
    return obs

def step(self, action):
    return self.env.step(action)

def render(self, *args, **kwargs):
    return self.env.render(*args, **kwargs)

for mission in missions: policy_kwargs = dict( net_arch = dict(pi=[64,128],vf=[64,128]) )

# Create the training environment
train_env = gym.make("BabyAI-GoToLocal-v0",max_episode_steps= 512)#, render_mode = "human")
train_env = RGBImgPartialObsWrapper(train_env)
train_env = CustomEnv(train_env, mission)
train_env = ImgObsWrapper(train_env)
train_env = Monitor(train_env)
train_env = DummyVecEnv([lambda: train_env])
train_env = VecTransposeImage(train_env)

# Create the evaluation environment
eval_env = gym.make("BabyAI-GoToLocal-v0", max_episode_steps= 512)#, render_mode = "human")
eval_env = RGBImgPartialObsWrapper(eval_env)
eval_env = CustomEnv(eval_env,mission)
eval_env = ImgObsWrapper(eval_env)
eval_env = Monitor(eval_env)
eval_env = DummyVecEnv([lambda: eval_env])
eval_env = VecTransposeImage(eval_env)

save_path = f"main_code/New_model/PPO/{mission.replace(' ', '_')}_model"
eval_callback = EvalCallback(eval_env, callback_on_new_best=stop_callback, eval_freq=8192,
                            best_model_save_path=save_path, verbose=1, n_eval_episodes= 30)

model = PPO("CnnPolicy", train_env, policy_kwargs=policy_kwargs, verbose=1,
            learning_rate=0.0005, tensorboard_log="./logs/PPO2/",
            batch_size= 2048,
            n_epochs= 100, seed = 42)

model.learn(2.5e6, callback=eval_callback)

# Close the environments
train_env.close()
eval_env.close()

Relevant log output / Error message

When I use seed = 42 or various possible seed numbers it doesn't show anything, When I used render_mode = "human" I can see it's stuck on first frame.

System Info

Checklist

qgallouedec commented 4 months ago

Probably comes from this

while self.env.mission != self.mission:
    obs = self.env.reset(**kwargs)

You indefinitely sample an observation with the same seed. So you always get the same result.

qgallouedec commented 4 months ago

However, note that this isn't really a question related to sb3, I'd advise you to ask it on the gymnasium repo instead.

Chainesh commented 4 months ago

while self.env.mission != self.mission: obs = self.env.reset(**kwargs)

What do I need to change here? I'm trying to train Babyai Levels on specific instructions, so I need to do that. Yeah sure I'll do that, thanks for the reply :)

qgallouedec commented 4 months ago

Please link the gymnasium issue here before closing :)

Chainesh commented 4 months ago

https://github.com/Farama-Foundation/Gymnasium/issues/1061#issue-2310882993 I've opened the issue here :)