Closed MWeltevrede closed 2 years ago
Nice spot, and thank you for all the details! Sorry for the delay in response ^^.
Even if minor, this should be brought up by the code (or ideally, fixed). Sadly the optimize_memory_usage
is somewhat headachy code that can mess with many things. The most obvious solution is to raise an exception if both flags are used; I lean towards an exception as this could potentially destroy someone's runs, and they can still run them without memory optimization. Of course, ideal solution would be to fix the issue, and you seem to have a code to test out the implementation ready :)
We would be happy to review a PR that addresses this issue, if you have the time to offer.
No problem! I agree that a quick hotfix for now would be to raise an exception if both flags are true.
I might have some time to implement and test a fix for this issue. However, I'm not 100% sure what the best way to fix it would be.
The core of the problem is that in case of timeouts the last observation of an episode should be stored somewhere. So a straightforward approach would be to simply store those last observations in a separate dictionary (keyed by the position of that transition in the buffer: self.pos
)
When sampling we then check if we sampled transitions with timeouts, and if so we retrieve the appropriate last observations from the dictionary.
Not sure if that is the most optimal way to do it though. So, if you (or anyone else) has a better idea I would love to hear it!
Ah gotcha, that would indeed be a very hairy thing to implement correctly at this stage... I say we go with the hotfix, which instructs users to disable optimized memory usage if they want to use handle_timeout_termination
:)
Another approach to fix it could be to just increment the self.pos
by 2 (instead of 1) after we added a timeout. This would avoid the scenario in which we overwrite the last observation (at self.pos + 1
) with the first observation of the new episode (now at self.pos + 2
).
However, this does introduce an invalid transition into the replay buffer (at self.pos + 1
) every time we encounter a timeout (which we have to keep track off to avoid sampling them). Therefore, the effective replay buffer size will be smaller than self.buffer_size
. Specifically how much smaller is dependent on the environment (how often timeouts occur and after how many steps). So, the actual size of the replay buffer is kind of hidden from the user, which is perhaps not a very nice property to have.
I will likely implement a fix for this for my own research, so I could submit a PR if you want (either for the hotfix or a true fix as described above)?
Hmm but to do sampling right, one has to keep a track of the timeout-ed episodes (to avoid sampling the invalid self.pos + 1
samples), right? Sorry if this was obvious in your message -- I just want to make sure we are on the same page :)
Tbh that sounds like good amount of additional complexity and if-elsing, and can get really tricky at the edge-scenarios (e.g. ends/beginnings of the replay buffers). For the sanity and cleanliness of the code-base (a departure from which caused this headache in the first place), I think it might be worth it to just disallow using these two parameters together.
Yes indeed, you would have to keep track of the positions in the replay buffer that are invalid and probably do some if-elsing to avoid sampling those. I don't think it would need special care at the beginning or end of the buffer, but it would add a decent amount of complexity (for example to properly handle vectorized envs).
I have submitted a PR for the hotfix. For my own work I have already implemented a true fix using the first approach:
The core of the problem is that in case of timeouts the last observation of an episode should be stored somewhere. So a straightforward approach would be to simply store those last observations in a separate dictionary (keyed by the position of that transition in the buffer: self.pos) When sampling we then check if we sampled transitions with timeouts, and if so we retrieve the appropriate last observations from the dictionary.
I could submit a separate PR containing that fix? That way you can judge more specifically whether the ability to use both optimised memory and timeouts simultaneously is worth the added complexity to the code.
Cheers! Will take a look over it.
Hmm you could submit a PR for that, but I personally would not have time to 100% verify it works correctly (well, that is what tests are for :)). There are also no guarantees if it would be merged, but if you wish to do so, I won't say no!
Edit: If you do not plan to make the PR, please close this issue :)
🐛 Bug
When using the ReplayBuffer class, setting both
optimize_memory_usage = True
andhandle_timeout_termination = True
will lead to incorrect behaviour.This is because when
handle_timeout_termination = True
, the replay buffer will set thedone
value at the end of an episode toFalse
if the end of an episode was due to a timeout: https://github.com/DLR-RM/stable-baselines3/blob/d68f0a2411766beb6da58ee0e989d1a6a72869bc/stable_baselines3/common/buffers.py#L300-L302In an algorithm like DQN, this means the Q value target will bootstrap using the
next_observation
variable https://github.com/DLR-RM/stable-baselines3/blob/d68f0a2411766beb6da58ee0e989d1a6a72869bc/stable_baselines3/dqn/dqn.py#L196However, this leads to incorrect behaviour if
optimize_memory_usage = True
, because thenext_observation
variable in that case is defined as: https://github.com/DLR-RM/stable-baselines3/blob/d68f0a2411766beb6da58ee0e989d1a6a72869bc/stable_baselines3/common/buffers.py#L291-L292 Which will get overwritten by the first state of the new episode. So, whenoptimize_memory_usage = True
, the replay buffer will not return the last state of the episode (which would be correct behaviour).As a result, if both
optimize_memory_usage = True
andhandle_timeout_termination = True
, a reinforcement learning algorithm will sometimes have target values that bootstrap from the wrong state (the first state of the next episode, rather than the last state of the current episode)To Reproduce
This bug only affects the Q value updates if there are end of episodes due to timeouts in the environment. In practice, the difference between the correct target and the incorrect target can be small. For this reason, I wasn't able to find a simple scenario in which this will lead to noticeable incorrect training results However, I added a simple code segment that will illustrate the problem and why if both
optimize_memory_usage = True
andhandle_timeout_termination = True
the behaviour will be incorrect.Expected behavior
If
handle_timeout_termination = True
, the replay buffer should return the last (terminal) state as thenext_observation
variable, regardless of whetheroptimize_memory_usage
is True or False.System Info
OS: Ubuntu 20.04.4 LTS Python: 3.9.7 Stable-Baselines3: 1.5.1a6 PyTorch: 1.11.0 GPU Enabled: True Numpy: 1.21.2 Gym: 0.21.0
Checklist