Open UPUPGOO opened 4 years ago
I do not quite understand what is the issue. Could you reformat your example code with triple-ticks (``` like this ```)? Also the suggested change is the same.
I do not quite understand what is the issue. Could you reformat your example code with triple-ticks (
like this
)? Also the suggested change is the same.
Sorry for that. I have modified my question.
Thank you! Now I see the issue. Yes, I think there is one extra end -= 1
, and because of it the last added item is not included in the sampling. This can be tested with
from stable_baselines.common.buffers import PrioritizedReplayBuffer
buffer = PrioritizedReplayBuffer(100, 0.6)
x = np.array([1.])
for i in range(10):
x = np.array([i])
buffer.add(x, x, x, x, x)
print(buffer.sample(10, beta=0.5))
print(buffer.sample(10, beta=0.5))
... # You will never see state "9" in sampled experiences (state nor next state)
A PR to fix this would be welcomed :). I am not sure where the extra "-1" should be fixed exactly given the documentation of the functions. Also an update to tests for this change would be nice.
By the way, the codes for calculating weights in the function sample
of the class PrioritizedReplayBuffer
can be simplified.
change
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage)) ** (-beta)
p_sample = self._it_sum[idxes] / self._it_sum.sum()
weights = (p_sample * len(self._storage)) ** (-beta) / max_weight
to
weights2 = (self._it_sum[idxes] / self._it_min.min()) ** (-beta)
This can be derived by simple mathematical derivation.
I also did some experiments to verify this with the following code.
from stable_baselines.common.buffers import PrioritizedReplayBuffer
buffer = PrioritizedReplayBuffer(100, 0.6)
x = np.array([1.])
for i in range(10):
x = np.array([i])
buffer.add(x, x, x, x, x)
#update priorities [0.05 0.1 0.15 0.2 0.25 0.3 0.35 0.4 0.45 0.5 ]
buffer.update_priorities(np.arange(10), np.linspace(0.05, 0.5, 10))
data = buffer.sample(10, beta=0.5)
weights1 = data[-2]
idxes = data[-1]
weights2 = (buffer._it_sum[idxes] / buffer._it_min.min()) ** (-0.5)
print(weights1 - weights2)
The result is all 0. The original codes are more like the formula of the paper, but the simplified codes I think are much faster.
Nice catch! It indeed checks out. I do not know how much faster it would be (the bottleneck is in the segment tree summing), but it is more compact and also easier to understand by just looking at the computation.
Feel free to make a PR that includes these two changes :)
Thank you! Now I see the issue. Yes, I think there is one extra
end -= 1
, and because of it the last added item is not included in the sampling. This can be tested withfrom stable_baselines.common.buffers import PrioritizedReplayBuffer buffer = PrioritizedReplayBuffer(100, 0.6) x = np.array([1.]) for i in range(10): x = np.array([i]) buffer.add(x, x, x, x, x) print(buffer.sample(10, beta=0.5)) print(buffer.sample(10, beta=0.5)) ... # You will never see state "9" in sampled experiences (state nor next state)
A PR to fix this would be welcomed :). I am not sure where the extra "-1" should be fixed exactly given the documentation of the functions. Also an update to tests for this change would be nice.
I agree with you. This change is just a little trick. and maybe causes little impact. I would make a PR later. Thank you.
Hello, does this mean that Prioritized Experience Replay in DQN isn't working in Stable Baselines ?
@Jogima-cyber
Dare I say most of it is working, except the last added sample is not included in the random sampling process. Given the number of samples in buffer this is seems like a minuscule error (which still should be fixed!), but I can not say for sure if the effect on learning is small.
@UPUPGOO
Any update on PR for this? I am asking to check if somebody is working on this and wants to make a PR out of it. If not, I can add it.
@Jogima-cyber
Dare I say most of it is working, except the last added sample is not included in the random sampling process. Given the number of samples in buffer this is seems like a minuscule error (which still should be fixed!), but I can not say for sure if the effect on learning is small.
@UPUPGOO
Any update on PR for this? I am asking to check if somebody is working on this and wants to make a PR out of it. If not, I can add it.
Sorry for late replay. I did a PR but it seemed not pass. You can update this. Thank you.
In the file https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/buffers.py, (line 206)
total = self._it_sum.sum(0, len(self._storage) - 1)
Use the above code to compute the total priorities and set paramend
of functionself._it_sum.sum
tolen(self._storage) - 1
.But in the file https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/common/segment_tree.py, (line 75) the code
end -= 1
in the functionreduce
which is called by the above functionself._it_sum.sum
also subtract by 1.Has it been repeatedly subtracted by 1?
I simply verified my idea with the following code.
If changing
len(buffer._storage-1)
tolen(buffer._storage)
, I can get the correct result. Because I add 10 new data into the buffer, the total priorities I think should be 10. If I misunderstood the code, please let me know.