DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.98k stars 1.68k forks source link

[Bug]: Iteration not updated in locals while learning #1827

Closed ericrwp closed 8 months ago

ericrwp commented 8 months ago

🐛 Bug

In the method stable_baselines3.common.on_policy_algorithm.OnPolicyAlgorithm.learn the iteration value is not updated in the locals dictionary while using callbacks.

To Reproduce

from stable_baselines3 import PPO

def callback_function(v_locals, v_globals):
    iteration_index = v_locals['iteration']
    print(f'iteration_index={iteration_index}')

    return True

checkpoint_callback = ConvertCallback(lambda x, y: callback_function(x, y))

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, callback=[checkpoint_callback])

Relevant log output / Error message

iteration_index=0
iteration_index=0
iteration_index=0
----------------------------------------
| time/                   |            |
|    fps                  | 69         |
|    iterations           | 3          |
|    time_elapsed         | 87         |
|    total_timesteps      | 6144       |
| train/                  |            |
|    approx_kl            | 0.00776526 |
|    clip_fraction        | 0.0327     |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.37      |
|    explained_variance   | -0.000442  |
|    learning_rate        | 0.0003     |
|    loss                 | 3.45e+04   |
|    n_updates            | 20         |
|    policy_gradient_loss | -0.0103    |
|    value_loss           | 6.32e+04   |
----------------------------------------
iteration_index=0
iteration_index=0
iteration_index=0

System Info

Checklist

araffin commented 8 months ago

Hello, if you want to have the number of iterations, the best is to have a counter that you increment as every call of on_rollout_end().