rlcode / per

Prioritized Experience Replay (PER) implementation in PyTorch
MIT License
302 stars 76 forks source link

divide by zero error #4

Open richielo opened 6 years ago

richielo commented 6 years ago

Hello, thank you for the work. I am facing the issue of dividing by zero error in the line below when calling the sample function to sample memory. Any idea why?

is_weight /= is_weight.max()

stormont commented 5 years ago

It's caused by this and it's actually raising from the np.power in the line above.

I've forked and partially fixed the issues and made a couple other changes (plus made the PER memory more configurable).

/Users/{user}/repos/per/prioritized_memory.py:48: UserWarning: Pulled 1 uninitialized samples
  warnings.warn('Pulled {} uninitialized samples'.format(uninitialized_samples))

I'm happy to PR my changes into this repo if @rlcode wants them.

emunaran commented 5 years ago

Hi guys, @rlcode - many thanks for your work. I also observed uninitialized samples pulling and got the mentioned unwanted "0". I didn't figure out yet the problem and I wonder if the sampling process works as it supposed to work. @stormont did you figure out the root caused?

josiahls commented 5 years ago

As referenced in (Schaul et al., 2015), as TD error approaches 0 we will have divide by zero errors. They fix this via:

image

Where epsilon is a small value to prevent this. I think you are missing this from your algorithm? I am pretty confident that if you have been testing on cartpole you with never run into this issue, however in discrete state spaces (like mazes) this becomes a real problem.

yougeyxt commented 5 years ago

Hello, I also find that the uninitialized samples will be sampled and got the unwanted data "0". I tried to find out the root caused but failed. Did you guys figure out the reason? @stormont @emunaran Many thanks!

yougeyxt commented 5 years ago

Also, according to the paper, when store a new transition (s, a, r, s) to the memory the priority should be the maximum priority among the leaf node right? But in the code it used the TD error of the s and s which is different from the paper. I am wondering whether this is a bug or not.

Jspujol commented 5 years ago

Hello, I also find that the uninitialized samples will be sampled and got the unwanted data "0". I tried to find out the root caused but failed. Did you guys figure out the reason? @stormont @emunaran Many thanks!

Hi there! I faced the same issue and what I did is to sample another value of that same interval, until it is not an integer (given that the capacity is initialized to np.zeros ). In the prioritized memory I added the following:

for i in range(n):
            a = segment * i
            b = segment * (i + 1)
            while True:
                s = random.uniform(a, b)
                (idx, p, data) = self.tree.get(s)
                if not isinstance(data, int):
                    break
            priorities.append(p)
            batch.append(data)
            idxs.append(idx)

This did the trick for me. Hope it does the same to you.

being-aerys commented 3 years ago

If anyone is still wondering why it pulls 0 from the replay memory, it is because the location in the replay memory that was sampled was not filled out yet and thus contained the initial values with which we initialized the replay buffer. i.e., 0's. If you set a condition that the training does not start until the buffer is completely filled, then you never encounter this issue.

ZINZINBIN commented 1 year ago

Hello, I also find that the uninitialized samples will be sampled and got the unwanted data "0". I tried to find out the root caused but failed. Did you guys figure out the reason? @stormont @emunaran Many thanks!

Hi there! I faced the same issue and what I did is to sample another value of that same interval, until it is not an integer (given that the capacity is initialized to np.zeros ). In the prioritized memory I added the following:

for i in range(n):
            a = segment * i
            b = segment * (i + 1)
            while True:
                s = random.uniform(a, b)
                (idx, p, data) = self.tree.get(s)
                if not isinstance(data, int):
                    break
            priorities.append(p)
            batch.append(data)
            idxs.append(idx)

This did the trick for me. Hope it does the same to you.

Brilliant!