DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.35k stars 1.6k forks source link

[Bug]: evaluate_policy called multiple times vor vectorized environments #1912

Open LukasFehring opened 2 months ago

LukasFehring commented 2 months ago

🐛 Bug

When calling

from stable_baselines3.common.evaluation import evaluate_policy
def custom_callback(locals, globals):
    pass

evaluate_policy(callback=custom_callback)

with a vecenv, then the callback gets executed for each of the environments separately. However, the locals dict contains the aggregated results. Therefore you have to manually check for which environment the callback was called, or only execute it every n_envs time.

To Reproduce

import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

# Define a simple callback function
def callback(_locals, _globals):
    pass

# Function to create multiple environments
def make_env():
    return gym.make('CartPole-v1')

# Number of environments
num_envs = 4
envs = [make_env for _ in range(num_envs)]

# Create vectorized environment
vec_env = DummyVecEnv(envs)

# Create a model
model = PPO("MlpPolicy", vec_env, verbose=1)

# Train the model
model.learn(total_timesteps=5000)

# Evaluate the policy
mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=10, callback=callback)

print("Mean reward:", mean_reward, "STD reward:", std_reward)

Relevant log output / Error message

No response

System Info

Checklist

araffin commented 2 months ago

Hello, what is your usecase/expected behavior?

the for loop also decompose the info per env:

https://github.com/DLR-RM/stable-baselines3/blob/35eccaf04fa011128f02eaecac6caab535686459/stable_baselines3/common/evaluation.py#L99-L106

LukasFehring commented 2 months ago

How so? Both the globals and locals contain information on every environment in the vectorized environment. How am I supposed to determine for which env the callback is called?

araffin commented 2 months ago

there is the local variable "i"

LukasFehring commented 2 months ago

Ah ok sorry then. A documentation of locals and globals would probably help to find that! :)

araffin commented 1 month ago

A documentation of locals and globals would probably help to find that! :)

feel free to open a PR that updates the doc ;)