takoika / PrioritizedExperienceReplay

Yet another prioritized experience replay buffer implementation.
MIT License
46 stars 12 forks source link

Max Weight Normalization #2

Open benbotto opened 6 years ago

benbotto commented 6 years ago

The proportional code looks like it divides the importance sampling weights by the maximum weight in the sample (https://github.com/takoika/PrioritizedExperienceReplay/blob/master/proportional.py#L84). I think that's incorrect. Shouldn't the weights in the sample be divided by the maximum weight in the entire replay buffer? The weights are used to correct for the bias introduced by prioritized experience replay, i.e. the fact that the entire distribution of replay memory is no longer uniform; therefore, the sampled weights need to be divided by the maximum importance sampling weight over all the replay memory.

Reference the OpenAI baselines code. It's a bitch to read, but I think it's PER implementation is correct.

        # Probability of choosing the sample with the lowest priority.
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * len(self._storage)) ** (-beta)
benbotto commented 6 years ago

Hi sresschke,

I agree that both min and max priorities need to be tracked; however, the update_max_min_priorities method proposed is flawed. The max and min priorities need to be tracked for the samples that presently exist in the buffer, not for every sample that's ever been added. Consider, for example, a buffer size of 8 with these priorities.

8 7 6 5 4 3 2 1

So the max priority is 8 and the min is 1. Now let's say a new item is added with a priority of 2, which causes the 8 to fall off.

7 6 5 4 3 2 1 2

By your algorithm the max priority would still be 8, whereas it should be

  1. This becomes a real issue in OpenAI's DQN implementation, which uses something similar to your proposal. Over time the TD errors--which are generally used to prioritize--grow small. But if even one large TD error is ever encountered throughout the entire training, then the samples will be greatly skewed toward the higher-priority items. You'll see that the highest priority items tend to get sampled much more frequently than they should, and the symptom is catastrophic forgetting (OpenAI's DQN implementation, for example, doesn't come anywhere near reproducing the numbers in the PER and DQN papers). E.g., consider at a much later time that the 8 items in the buffer are:

.0001 .0003 .0000007 .01 .003 .0091 .0004 .0021

Consider what will happen if the max priority is still 8 and the min still 1! No bueno.

Here's how I implemented PER: https://github.com/benbotto/bsy-dqn-atari/tree/2.0.1/replay_memory My implementation exceeds the numbers in the original DQN paper, in the Nature paper on DQN, in the DDQN paper, and in the PER paper on Breakout, Space Invaders, and Pong (more training is ongoing).

[b]

On Tue, Aug 14, 2018 at 6:55 PM, sreschke notifications@github.com wrote:

There appears to be a need to keep track of both the max_priority and min_priority of the buffer. The max_priority is tracked because new transitions should be assigned this priority (see line six of Algorithm 1 in the paper: https://arxiv.org/pdf/1511.05952.pdf) to ensure all transitions have a good chance of being trained on at least once. The min_priority is tracked so that the max_weight can be calculated for weight normalization. This tracking can be implemented with the following steps:

  1. Add the max and min priority members to the constructor of the Experience class:

self.max_priority=-float("inf") self.min_priority=float("inf")

  1. Add the following function to the Experience class:

def update_max_min_priorities(self, priority): if priority > self.max_priority: self.max_priority=priority if priority < self.min_priority: self.min_priority=priority return

1.

Call update_max_min_priorities() in both the add() and priority_update() functions. 2.

In get_batch(), calculate the max_weight and use it to normalize the sample weights:

max_weight = (self.min_priority * self.memory_size) ** (-beta) weights = [ i/max_weight for i in weights] # Normalize for stability

Hope that helps

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/takoika/PrioritizedExperienceReplay/issues/2#issuecomment-413070585, or mute the thread https://github.com/notifications/unsubscribe-auth/ANlYjakJCZdI76fRVgPs4Dc3uYD1c4Ofks5uQ3-egaJpZM4Tl0XN .

sreschke commented 6 years ago

Hi benbotto, you're right, I realized my mistake a few minutes after I posted this. Since the sum_tree is implemented using an array, we can actually access the bottom row of the sum_tree with the right indices. Once we have the bottom row, we simply use the regular max and min functions in the standary python library.

I ended up adding these two functions instead:

def get_max_priority(self): """gets the max priority in the buffer""" index=self.tree.tree_size-2**(self.tree.tree_level-1) return max(self.tree.tree[index:index+self.memory_size])

def get_min_priority(self): """gets the max priority in the buffer""" index=self.tree.tree_size-2**(self.tree.tree_level-1) return min(self.tree.tree[index:index+self.memory_size])

However, these will be O(n) which might be too slow. How did you keep track of the max and min priorities in your implementation?

benbotto commented 6 years ago

That seems right to me, but slicing one million items twice then taking the max and min will be prohibitively slow. That's why I used a SegmentTree for keeping track of the min and max. (See my MaxSegmentTree and MinSegmentTree classes in the previously-mentioned coded.)

On Aug 14, 2018 7:56 PM, "sreschke" notifications@github.com wrote:

Hi benbotto, you're right, I realized my mistake a few minutes after I posted this. Since the sum_tree is implemented using an array, we can actually access the bottom row of the sum_tree with the right indices. Once we have the bottom row, we simply use the regular max and min functions in the standary python library.

I ended up adding these two functions instead:

def get_max_priority(self): """gets the max priority in the buffer""" index=self.tree.tree_size-2**(self.tree.tree_level-1) return max(self.tree.tree[index:index+self.memory_size])

def get_min_priority(self): """gets the max priority in the buffer""" index=self.tree.tree_size-2**(self.tree.tree_level-1) return min(self.tree.tree[index:index+self.memory_size])

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/takoika/PrioritizedExperienceReplay/issues/2#issuecomment-413079104, or mute the thread https://github.com/notifications/unsubscribe-auth/ANlYjZu1yQ0YzmdPyoOLOLsd08HsHUUZks5uQ43JgaJpZM4Tl0XN .