DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.38k stars 1.61k forks source link

[Bug]: `rollout/success_rate` does not show for Monitor + OnPolicyAlgorithm #1867

Closed N00bcak closed 3 months ago

N00bcak commented 4 months ago

šŸ› Bug

Expected Behavior

Based on SB3 Documentation, success_rate should display for a Monitor-wrapped Environment + OnPolicyAlgorithm if:

"...an extra argument to the Monitor wrapper [is passed] to log that value (info_keywords=("is_success",)) and info["is_success"]=True/False on the final step of the episode [is provided]."

Actual Behavior

According to actual runs of the environment (see log output) and source code, success_rate is not actually recorded.

Extra Info

OffPolicyAlgorithm appears to behave as expected.

Code example

from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.ppo.ppo import PPO

from gymnasium.envs.classic_control.cartpole import CartPoleEnv

class CartPoleWithSuccess(CartPoleEnv):
    def step(self, action):
        state, reward, terminated, truncated, info = super().step(action)
        if terminated or truncated: 
            info["is_success"] = (-self.x_threshold < self.state[0] < self.x_threshold
                                    and -self.theta_threshold_radians < self.state[2] < self.theta_threshold_radians
                                    )
        return state, reward, terminated, truncated, info

if __name__ == "__main__":
    env = Monitor(CartPoleWithSuccess(), filename = "tmp/log.csv", info_keywords = ("is_success", ))
    check_env(env)

    model = PPO("MlpPolicy", env = env, stats_window_size = 10, verbose = 1)
    model.learn(total_timesteps = int(5e5), progress_bar = True)

Relevant log output / Error message

Using cuda device
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 21.1     |
|    ep_rew_mean     | 21.1     |
| time/              |          |
|    fps             | 1170     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 28.5       |
|    ep_rew_mean          | 28.5       |
| time/                   |            |
|    fps                  | 1135       |
|    iterations           | 2          |
|    time_elapsed         | 3          |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.00899566 |
|    clip_fraction        | 0.0829     |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.686     |
|    explained_variance   | 0.00679    |
|    learning_rate        | 0.0003     |
|    loss                 | 6.37       |
|    n_updates            | 10         |
|    policy_gradient_loss | -0.0119    |
|    value_loss           | 44.4       |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | 50          |
| time/                   |             |
|    fps                  | 1116        |
|    iterations           | 3           |
|    time_elapsed         | 5           |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.008980946 |
|    clip_fraction        | 0.0452      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.672      |
|    explained_variance   | 0.111       |
|    learning_rate        | 0.0003      |
|    loss                 | 14.7        |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0137     |
|    value_loss           | 39.4        |
-----------------------------------------

System Info

Describe the characteristic of your environment:

Results from python -c import stable_baselines3 as sb3; sb3.get_system_info():

- OS: Linux-6.5.0-25-generic-x86_64-with-glibc2.35 # 25~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Feb 20 16:09:15 UTC 2
- Python: 3.10.12
- Stable-Baselines3: 2.2.1
- PyTorch: 2.0.1+cu117
- GPU Enabled: True
- Numpy: 1.25.0
- Cloudpickle: 2.2.1
- Gymnasium: 0.29.1

({'OS': 'Linux-6.5.0-25-generic-x86_64-with-glibc2.35 # 25~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Feb 20 16:09:15 UTC 2', 
'Python': '3.10.12', 
'Stable-Baselines3': '2.2.1', 
'PyTorch': '2.0.1+cu117', 
'GPU Enabled': 'True',
'Numpy': '1.25.0',
'Cloudpickle': '2.2.1', 
'Gymnasium': '0.29.1'}, 
'- OS: Linux-6.5.0-25-generic-x86_64-with-glibc2.35 # 25~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Feb 20 16:09:15 UTC 2\n- Python: 3.10.12\n- Stable-Baselines3: 2.2.1\n- PyTorch: 2.0.1+cu117\n- GPU Enabled: True\n- Numpy: 1.25.0\n- Cloudpickle: 2.2.1\n- Gymnasium: 0.29.1\n')

Checklist

araffin commented 4 months ago

Duplicate of https://github.com/DLR-RM/stable-baselines3/issues/954#issuecomment-1469569590 and https://github.com/DLR-RM/stable-baselines3/issues/1470

We would welcome a PR that either update the docs or add the feature.