DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.85k stars 1.68k forks source link

Recalculate Returns and Advantages After Callback to Ensure Reward Consistency (common/on_policy_algorithm.py) #2000

Open mhyrzt opened 4 weeks ago

mhyrzt commented 4 weeks ago

Hello There, SB3 Team! đź‘‹

I hope this message finds you well. First and foremost, I want to extend my gratitude for the fantastic work you’ve done on the Stable Baselines3 library. It's truly a remarkable tool and has been invaluable for my projects.

I apologize in advance for any inconsistencies or if I have missed any items in the checklist. Your understanding and guidance are greatly appreciated.

In this pull request, I’ve added a second call to rollout_buffer.compute_returns_and_advantage() after invoking the on_rollout_end() callback. This adjustment accounts for any potential modifications to the rewards that might occur within custom callback logic. By recalculating returns and advantages, we ensure that the training process reflects the most up-to-date reward structure, improving accuracy and stability of the advantage estimates during the policy update step.

This is especially useful when using callbacks that manipulate rewards for custom reward shaping or augmentation.

Description

Recalculated returns and advantages after the on_rollout_end callback to handle any reward transformations or manipulations performed by custom callbacks. This ensures that the rollout buffer reflects the updated rewards when computing advantages for the policy update step.

Previously, rollout_buffer.compute_returns_and_advantage() was only called before the callback. If any rewards were modified by the callback during on_rollout_end(), the advantages would not reflect these changes. By adding a second call to compute_returns_and_advantage(), we ensure that the correct values are used for training.

Code Changes

# stable_baselines3 > common > on_policy_algorithm.py > OnPolicyAlgorithm > collect_rollout > line 239 to 244

rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

callback.update_locals(locals())

callback.on_rollout_end()

# Added this line
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

Motivation and Context

Why is this change required?

This change ensures that if a callback modifies rewards during the on_rollout_end() phase, the computed returns and advantages are updated accordingly. Without recalculating these values, any changes to the rewards made by the callback would not be reflected in the advantage estimates used for policy updates, potentially leading to suboptimal training performance.

By recalculating returns and advantages, we ensure the most accurate representation of the updated rewards, improving stability and effectiveness in training, especially when using reward shaping techniques or custom callbacks.

Types of changes

Checklist

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

mhyrzt commented 3 weeks ago

Hi again,

Before I begin, I just wanted to apologize for any headaches and for being a bit of a noob since I rarely contribute to projects.

I’ve just run git commit-checks and wanted to give an update. I’ve checked most of the items, and here is the output:

$ make commit-checks        
# Sort imports
ruff check --select I stable_baselines3/ tests/ docs/conf.py setup.py --fix
All checks passed!
# Reformat using black
black stable_baselines3/ tests/ docs/conf.py setup.py
All done! ✨ 🍰 ✨
96 files left unchanged.
mypy stable_baselines3/ tests/ docs/conf.py setup.py
Success: no issues found in 94 source files
# Stop the build if there are Python syntax errors or undefined names
# See https://www.flake8rules.com/
ruff check stable_baselines3/ tests/ docs/conf.py setup.py --select=E9,F63,F7,F82 --output-format=full
All checks passed!
# exit-zero treats all errors as warnings.
ruff check stable_baselines3/ tests/ docs/conf.py setup.py --exit-zero --output-format=concise
All checks passed!

Although I’ve added my username to /docs/misc/changelog.rst, I got a bit confused about how or where I need to change the file.

Also, I haven’t written any tests for the suggested bug fix yet. If you think it's necessary, I'll write them ASAP.

araffin commented 2 weeks ago

I have raised an issue to propose this change (required for new features and bug fixes)

this step is important to discuss the issue and see if a fix/feature is needed or not or if there are better alternatives.