DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.84k stars 1.68k forks source link

[Question] Why resample SDE noise matrices in PPO optimzation? #1929

Closed brn-dev closed 3 months ago

brn-dev commented 4 months ago

❓ Question

I'm currently implementing a personal RL library and use SB3 as inspiration. I have recently implemented SDE and I'm confused about line 214/215 in your implementation of PPO. Here, the SDE noise matrices are resampled before each PPO update but as far as I have seen, this doesn't do anything since the noise matrices are not used during the PPO updates, only during exploration. Let me elaborate:

In ppo train, we reset the noise matrices and then call evaluate_actions afterwards: https://github.com/DLR-RM/stable-baselines3/blob/6c00565778e5815e4589afc7499aafbd020535ae/stable_baselines3/ppo/ppo.py#L214-L217

reset_noise, which just calls sample_weights, does not cause any state changes except for the exploration matrices https://github.com/DLR-RM/stable-baselines3/blob/6c00565778e5815e4589afc7499aafbd020535ae/stable_baselines3/common/distributions.py#L499-L512

The exploration matrices are only used in get_noise which is in turn only used in sample https://github.com/DLR-RM/stable-baselines3/blob/6c00565778e5815e4589afc7499aafbd020535ae/stable_baselines3/common/distributions.py#L593-L603 https://github.com/DLR-RM/stable-baselines3/blob/6c00565778e5815e4589afc7499aafbd020535ae/stable_baselines3/common/distributions.py#L580-L585

However, evaluate_actions does not call sample() https://github.com/DLR-RM/stable-baselines3/blob/6c00565778e5815e4589afc7499aafbd020535ae/stable_baselines3/common/policies.py#L719-L741

So in conclusion, the reset_noise call in the ppo train function is useless isn't it?

Checklist

araffin commented 4 months ago

Hello, thanks for the question. You should not need indeed to re-sample the noise, but if you comment it out, it will throw an error if I remember correctly. Back then (4 years ago), I didn't have too much time to investigate why, and re-sampling was a good enough solution. So, if you comment it out and it works now, I would be happy to receive a PR =) if it doesn't work and you find out why, I'm also happy to hear the answer ;)

brn-dev commented 4 months ago

After removing the reset_noise call, I tested it on the half-cheetah env and it worked without any errors. PR: https://github.com/DLR-RM/stable-baselines3/pull/1933

image image

araffin commented 4 months ago

After removing the reset_noise call, I tested it on the half-cheetah env

looking at the std value, it doesn't seem you were using gSDE, but I could try with python3 -m rl_zoo3.train --algo ppo --env HalfCheetahBulletEnv-v0 -n 20000 --seed 2 -param n_epochs:5 (bullet env) and it does work indeed.

brn-dev commented 4 months ago

image

The currently published version does give me a std of a little less than 1.0 consitently with sde enabled

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Parallel environments
vec_env = make_vec_env("HalfCheetah-v4", n_envs=4)

model = PPO("MlpPolicy", vec_env, use_sde=True, sde_sample_freq=10, verbose=2)
model.learn(total_timesteps=250000)
model.save("ppo_cartpole")

del model # remove to demonstrate saving and loading

model = PPO.load("ppo_cartpole")

obs = vec_env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")