DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.77k stars 1.67k forks source link

[Bug] optimize_memory_usage not compatible with handle_timeout_termination #934

Closed MWeltevrede closed 2 years ago

MWeltevrede commented 2 years ago

🐛 Bug

When using the ReplayBuffer class, setting both optimize_memory_usage = True and handle_timeout_termination = True will lead to incorrect behaviour.

This is because when handle_timeout_termination = True, the replay buffer will set the done value at the end of an episode to False if the end of an episode was due to a timeout: https://github.com/DLR-RM/stable-baselines3/blob/d68f0a2411766beb6da58ee0e989d1a6a72869bc/stable_baselines3/common/buffers.py#L300-L302

In an algorithm like DQN, this means the Q value target will bootstrap using the next_observation variable https://github.com/DLR-RM/stable-baselines3/blob/d68f0a2411766beb6da58ee0e989d1a6a72869bc/stable_baselines3/dqn/dqn.py#L196

However, this leads to incorrect behaviour if optimize_memory_usage = True, because the next_observation variable in that case is defined as: https://github.com/DLR-RM/stable-baselines3/blob/d68f0a2411766beb6da58ee0e989d1a6a72869bc/stable_baselines3/common/buffers.py#L291-L292 Which will get overwritten by the first state of the new episode. So, when optimize_memory_usage = True, the replay buffer will not return the last state of the episode (which would be correct behaviour).

As a result, if both optimize_memory_usage = True and handle_timeout_termination = True, a reinforcement learning algorithm will sometimes have target values that bootstrap from the wrong state (the first state of the next episode, rather than the last state of the current episode)

Note that this behaviour does not happen if optimize_memory_usage = False, because in that case the next_observation variable is stored in a separate buffer and won't be overwritten by the first state of the new episode.

To Reproduce

This bug only affects the Q value updates if there are end of episodes due to timeouts in the environment. In practice, the difference between the correct target and the incorrect target can be small. For this reason, I wasn't able to find a simple scenario in which this will lead to noticeable incorrect training results However, I added a simple code segment that will illustrate the problem and why if both optimize_memory_usage = True and handle_timeout_termination = True the behaviour will be incorrect.

import gym 
import numpy as np

from stable_baselines3 import DQN

env = gym.make('MountainCar-v0')
model = DQN("MlpPolicy", env, optimize_memory_usage=False, replay_buffer_kwargs={'handle_timeout_termination': True})
# running for 250 steps so that a single timeout (after 200 steps) will be in the replay buffer
model.learn(total_timesteps=250)

index_of_done = np.where(model.replay_buffer.dones == 1)[0]
print(f"current observation: {model.replay_buffer.observations[index_of_done]}")

# if optimize_memory_usage is False, the next observation is stored in a seperate buffer called next_observations
print(f"next observation if optimize_memory_usage is False: {model.replay_buffer.next_observations[index_of_done]}")

# if optimize_memory_usage is True, the next observation is stored in the same buffer as the current observations, but at the index + 1
# which is the state after an environment reset
print(f"next observation if optimize_memory_usage were True: {model.replay_buffer.observations[index_of_done + 1]}")

print(f"done: {model.replay_buffer.dones[index_of_done]}")
print(f"timeout: {model.replay_buffer.timeouts[index_of_done]}")

# the replay buffer will return done == False
# which means that DQN will bootstrap the return from the next observation, which will be the state after a reset if optimize_memory_usage is True
print(f"value of done returned by ReplayBuffer.sample(): {model.replay_buffer.dones[index_of_done] * (1 - model.replay_buffer.timeouts[index_of_done])}")
current observation: [[[-0.49728924  0.00461391]]]
next observation if optimize_memory_usage is False: [[[-0.49387246  0.00341679]]]
next observation if optimize_memory_usage were True: [[[-0.51673186  0.        ]]]
done: [[1.]]
timeout: [[1.]]
value of done returned by ReplayBuffer.sample(): [[0.]]

Expected behavior

If handle_timeout_termination = True, the replay buffer should return the last (terminal) state as the next_observation variable, regardless of whether optimize_memory_usage is True or False.

System Info

OS: Ubuntu 20.04.4 LTS Python: 3.9.7 Stable-Baselines3: 1.5.1a6 PyTorch: 1.11.0 GPU Enabled: True Numpy: 1.21.2 Gym: 0.21.0

Checklist

Miffyli commented 2 years ago

Nice spot, and thank you for all the details! Sorry for the delay in response ^^.

Even if minor, this should be brought up by the code (or ideally, fixed). Sadly the optimize_memory_usage is somewhat headachy code that can mess with many things. The most obvious solution is to raise an exception if both flags are used; I lean towards an exception as this could potentially destroy someone's runs, and they can still run them without memory optimization. Of course, ideal solution would be to fix the issue, and you seem to have a code to test out the implementation ready :)

We would be happy to review a PR that addresses this issue, if you have the time to offer.

MWeltevrede commented 2 years ago

No problem! I agree that a quick hotfix for now would be to raise an exception if both flags are true.

I might have some time to implement and test a fix for this issue. However, I'm not 100% sure what the best way to fix it would be.

The core of the problem is that in case of timeouts the last observation of an episode should be stored somewhere. So a straightforward approach would be to simply store those last observations in a separate dictionary (keyed by the position of that transition in the buffer: self.pos) When sampling we then check if we sampled transitions with timeouts, and if so we retrieve the appropriate last observations from the dictionary.

Not sure if that is the most optimal way to do it though. So, if you (or anyone else) has a better idea I would love to hear it!

Miffyli commented 2 years ago

Ah gotcha, that would indeed be a very hairy thing to implement correctly at this stage... I say we go with the hotfix, which instructs users to disable optimized memory usage if they want to use handle_timeout_termination :)

MWeltevrede commented 2 years ago

Another approach to fix it could be to just increment the self.pos by 2 (instead of 1) after we added a timeout. This would avoid the scenario in which we overwrite the last observation (at self.pos + 1) with the first observation of the new episode (now at self.pos + 2).

However, this does introduce an invalid transition into the replay buffer (at self.pos + 1) every time we encounter a timeout (which we have to keep track off to avoid sampling them). Therefore, the effective replay buffer size will be smaller than self.buffer_size. Specifically how much smaller is dependent on the environment (how often timeouts occur and after how many steps). So, the actual size of the replay buffer is kind of hidden from the user, which is perhaps not a very nice property to have.

I will likely implement a fix for this for my own research, so I could submit a PR if you want (either for the hotfix or a true fix as described above)?

Miffyli commented 2 years ago

Hmm but to do sampling right, one has to keep a track of the timeout-ed episodes (to avoid sampling the invalid self.pos + 1 samples), right? Sorry if this was obvious in your message -- I just want to make sure we are on the same page :)

Tbh that sounds like good amount of additional complexity and if-elsing, and can get really tricky at the edge-scenarios (e.g. ends/beginnings of the replay buffers). For the sanity and cleanliness of the code-base (a departure from which caused this headache in the first place), I think it might be worth it to just disallow using these two parameters together.

MWeltevrede commented 2 years ago

Yes indeed, you would have to keep track of the positions in the replay buffer that are invalid and probably do some if-elsing to avoid sampling those. I don't think it would need special care at the beginning or end of the buffer, but it would add a decent amount of complexity (for example to properly handle vectorized envs).

MWeltevrede commented 2 years ago

I have submitted a PR for the hotfix. For my own work I have already implemented a true fix using the first approach:

The core of the problem is that in case of timeouts the last observation of an episode should be stored somewhere. So a straightforward approach would be to simply store those last observations in a separate dictionary (keyed by the position of that transition in the buffer: self.pos) When sampling we then check if we sampled transitions with timeouts, and if so we retrieve the appropriate last observations from the dictionary.

I could submit a separate PR containing that fix? That way you can judge more specifically whether the ability to use both optimised memory and timeouts simultaneously is worth the added complexity to the code.

Miffyli commented 2 years ago

Cheers! Will take a look over it.

Hmm you could submit a PR for that, but I personally would not have time to 100% verify it works correctly (well, that is what tests are for :)). There are also no guarantees if it would be merged, but if you wish to do so, I won't say no!

Edit: If you do not plan to make the PR, please close this issue :)