Shmuma / ptan

PyTorch Agent Net: reinforcement learning toolkit for pytorch
MIT License
531 stars 165 forks source link

Tracking pos in ExperienceReplayBuffer #21

Closed cpnota closed 5 years ago

cpnota commented 5 years ago

I was having trouble getting PrioritizedReplayBuffer work with some of my code. In particular, the _sample_proportional method was always returning an array of zeros. I noticed that the _add function depends on the super class's self.pos:

    def _add(self, *args, **kwargs):
        idx = self.pos
        super()._add(*args, **kwargs)
        self._it_sum[idx] = self._max_priority ** self._alpha
        self._it_min[idx] = self._max_priority ** self._alpha

However, the super class only tracks pos once the buffer is full:

    def _add(self, sample):
        if len(self.buffer) < self.capacity:
            self.buffer.append(sample)
        else:
            self.buffer[self.pos] = sample
            self.pos = (self.pos + 1) % self.capacity

Therefore, idx was always zero when adding to the segment tree, and the sample method always returned an array of zeros. Moving the last line out of the else block fixed the problem.

Am I missing any reason that this shouldn't be done?

Thank you for doing so much for the RL community, by the way! I am enjoying your book.

Shmuma commented 5 years ago

Thanks a lot for noticing this issue!