Kaixhin / Atari

Persistent advantage learning dueling double DQN for the Arcade Learning Environment
MIT License
264 stars 73 forks source link

Implement rank-based prioritised experience replay #1

Closed Kaixhin closed 8 years ago

Kaixhin commented 8 years ago

Requires a "sum tree" binary heap for efficient execution.

Edit 2016-06-02: Please keep to the contributing guidelines.

Kaixhin commented 8 years ago

Update: Test on Frostbite has only achieved DQN/DDQN scores, despite rank-based prioritised experience replay achieving much higher scores. Therefore the current implementation is either wrong or suboptimal. Proportional prioritised replay still needs to be implemented.

SimsGautam commented 8 years ago

Hi, What's the current progress with prioritized experience replay implementation? Has it been successfully tested yet? I noticed it was striked through in the readme. Thanks!

Kaixhin commented 8 years ago

As above, I attempted an implementation of rank-based prioritised experience replay, but it got nowhere near the scores it should have when I ran a test on Frostbite. The paper helpfully provides complete training graphs for comparison, so clearly there's a bug somewhere. Proportional prioritised experience replay has not been attempted yet since it is trickier to get right.

SimsGautam commented 8 years ago

Ah, I see. Any ideas on where the bug could be? I could help take a look as well. I've been trying to find prioritized experience replay, so if you know of any other implementations available publicly, that'd be great! thanks again.

Kaixhin commented 8 years ago

For rank-based, first you need a priority queue, which I've implemented as a max binary heap. I did quick tests halfway through implementation, but checking that would be first.

Apart from the priority (absolute TD-error) being the key, the value is actually an index which indicates the location of the sample in the experience replay buffer. So first comes checking the new sample update and the sample update after learning. The actual algorithm/maths to compare against the paper is here. The logic for picking the partitions for sampling is here.

There are a few more bits, but those are the important ones to try first. I've looked a lot, but this does seem to be the only public attempt - so no chance of porting code from others!

mryellow commented 8 years ago

Just a quick glance to see if anything jumps out at my uneducated fresh eyes. Haven't used BinaryHeap structures enough to know if this is something or not...

https://github.com/Kaixhin/Atari/blob/master/structures/BinaryHeap.lua#L116

Would having the recursion inside those conditions mean it only continues rebalancing the heap if it is currently working on an item which is in the wrong position? It stops soon as it hits an item which doesn't need moving, intended? They're already in order up to this point?

Kaixhin commented 8 years ago

@mryellow Yes that is the case - one of the ideas behind this data structure is that if an element is inserted (or replaced - the less straightforward case), it will either travel up or down (or stay) along its path in the tree (a heap is a specialised tree). So going from an empty or ordered tree, each operation only needs to deal with the element in question.

Tests would be best, but bar that referring to pseudocode/code in other languages would be one way of checking this. One error I might have made is in converting from the usual 0-based indexing to 1-based indexing.

mryellow commented 8 years ago

each operation only needs to deal with the element in question

Seemed that was the idea, my formal CS is weak.

One error I might have made is in converting from the usual 0-based indexing to 1-based indexing.

That was one area which stood out, the bumping up/down of size. Could easily be a missing one somewhere, either in an assignment or a condition. I'll do some logging sometime through the week and see if anything jumps out.

mryellow commented 8 years ago

In BinaryHeap should the parent priority always be less than both children?

https://github.com/mryellow/Atari/blob/test-prior/spec/structures.lua#L140-L151

I'm failing those after insert calls, however after calls to updateByVal everything is sorted.

Kaixhin commented 8 years ago

@mryellow In a max heap (pictured in Wikipedia) the parent should always be greater than both children. Inserting or replacing will almost certainly mess up the order, which is why calling upHeap on an insert and both downHeap and upHeap on a replace should fix the heap ordering.

mryellow commented 8 years ago

Yeah that's checking out, with "greater than or equal to".

With indices, ISWeights = experience:sample(1), the last indices corresponds to the highest ISWeights, while this index should also point to the highest priority in the sample?

Kaixhin commented 8 years ago

Experience:sample() doesn't take any arguments (number of samples returned is hard-coded as minibatch size). If I remember correctly, rankIndices returns indices within the heap, descending in priority. The hash table is then used to retrieve the actual index in the replay memory. The ISWeights should be inversely proportional to the priority (to counteract for high priority samples being reused a lot).

In the paper they mention that they use the order in the underlying array as a proxy for the priority (noted in line 218, so the final elements might not necessarily be in priority order.

mryellow commented 8 years ago

The ISWeights should be inversely proportional to the priority

Will have to check further but that part might be failing in the end.

If I fill the heap via Experience:store, then update all the priorities using updateByVal I end up with this:

local indices, ISWeights = experience:sample()
  for i = 1, indices:size(1) do
    print("i: " .. i)
    print("indices[i]: " .. indices[i])
    print("ISWeight: " .. ISWeights[i])
    print("priority: " .. experience.priorityQueue.array[indices[i]][1])
    print("val: " .. experience.priorityQueue.array[indices[i]][2])
  end
i: 1
indices[i]: 5136
ISWeight: 0.10489251370625
priority: 1.09
val: 4
i: 2
indices[i]: 650
ISWeight: 0.25466125260447
priority: 1.467
val: 83
i: 3
indices[i]: 5395
ISWeight: 0.3869674701564
priority: 1.272
val: 347
i: 4
indices[i]: 1517
ISWeight: 0.48500964026323
priority: 1.294
val: 751
i: 5
indices[i]: 5400
ISWeight: 0.5745158898887
priority: 1.311
val: 1340
i: 6
indices[i]: 3590
ISWeight: 0.65013024797926
priority: 1.204
val: 2045
i: 7
indices[i]: 176
ISWeight: 0.71405114882925
priority: 1.487
val: 2818
i: 8
indices[i]: 4633
ISWeight: 0.82566234942261
priority: 1.226
val: 4630
i: 9
indices[i]: 2934
ISWeight: 0.88447589334189
priority: 1.141
val: 5858
i: 10
indices[i]: 2231
ISWeight: 1
priority: 1.437
val: 8913
Kaixhin commented 8 years ago

It looks fine to me - I just printed rankIndices, self.indices and w on line 244. rankIndices is what indexes into experience.priorityQueue.array - indices is the set of indices for the experience replay buffer itself (e.g. for self.states).

mryellow commented 8 years ago

Ahh I was using the wrong index. Priorities mostly checking out when using rankIndex inside sample.

Although items are out-of-order until updateByVal is called, while then still a few items out towards the end:

After store, which sets to findMax:

priority: 1.0076759999994
priority: 1.0079299999993
priority: 1.0062629999995
priority: 1.0053179999996
priority: 1.0081579999993
priority: 1.0070929999994
priority: 1.0050279999996
priority: 1.0080359999993
priority: 1.0034719999997
priority: 1.000243

After updateByVal, set to 1 + math.random(500)/1000:

priority: 1.5
priority: 1.498
priority: 1.494
priority: 1.452
priority: 1.469
priority: 1.396
priority: 1.321
priority: 1.172
priority: 1.24
priority: 1.151

Guess there is no guarantee that the items towards end will be the lowest magnitude.

Kaixhin commented 8 years ago

So a check would be to make sure updateByVal is called when needed, but it seems that is the case. And yes, items towards the end are generally lower priority but won't be ordered properly. Thanks a lot for running these tests - it looks like the low-level mechanics seem to be working fine so far.

Although it's unlikely to show anything, it might be worth halving smallConst and testing to see if that has an impact on Catch? I'll see if I can also try that on Frostbite soon.

mryellow commented 8 years ago

Thanks a lot for running these tests

Idiots make the best testers, hopefully my rubber-duck ignorance hits some gold sooner or later ;-)

I'm maxing out my RAM and running on CPU rather than GPU otherwise I'd kick off some agents, maybe it's all actually working.

called when needed, but it seems that is the case

Could there be potential for it getting stuck visiting the same experiences, thus only updating those parts of the heap when updatePriorities is called in Agent and never calling updateByVal for some (or many) of them? Leaving parts of the heap broken, enough that they don't fit into stratum in a way which sees them sampled in future (never having both upHeap and downHeap called).

Kaixhin commented 8 years ago

If only idiots made the best developers too... But really this really could benefit from a fresh pair of eyes.

That problem should only happen if the heap is broken in the first place - only inserted or updated elements need to be fixed (the calls to upHeap and downHeap should fix anything else that breaks as a result of them being called). So does using priorityQueue:insert to build the heap fail to preserve the heap property in your tests?

mryellow commented 8 years ago

So does using priorityQueue:insert to build the heap fail to preserve the heap property in your tests?

Well it only gets to looking at the parent for an upHeap, which will push the greatest all the way to the top, but may leave other children of that now swapped parent waiting for a downHeap. Guess long as the sampling hits indexes from everywhere in the heap then they'll all eventually get an updateByVal and hence downHeap.

which will push the greatest all the way to the top

Testing the index after each insert shows new max going to first position and findMax works after a bunch of random inserts or updates.

Kaixhin commented 8 years ago

Ah that's not an issue - check out how insert works in a binary heap. I'm fairly confident that the priority queue is fine now, so I've started going over the paper and my code more thoroughly now. Made some changes, have a few more to go, and will hopefully get to start a test on Frostbite by the weekend.

Kaixhin commented 8 years ago

Missing implementation detail to do - the heap-based array "is infrequently sorted once every 10^6 steps to prevent the heap becoming too unbalanced".

Also, in line 10 of algorithm 1 it is unclear whether the max IS weight is calculated over all samples or just the ones that were chosen.

mryellow commented 8 years ago

Once again, my math is weak but noob eyes might spot something and can read it well enough to fall over trying.

Where does subscript i come from? Could that be indicating "all samples"? However then w is defined right there, only for experiences in the batch, so surely not already full of all possible weights.

Finding the lowest ISWeights would be easy enough via findMax(), but you'd want the opposite. They don't mention pre-computing the weights and storing them, if it were the overall max I'd imagine they'd be stored much like the priorities with the ability to recall the max without having to calculate it by looking at the whole list.

Kaixhin commented 8 years ago

It's a bit ambiguous, but yes it could well be over just the batch. I printed out the difference between the full sample max and the batch max for a while, with the max difference being ~0.05. I don't think that should make much of a difference, so I'll switch back to batch max for now. Worked out a simple trick for getting the full sample max IS weight here - it comes from the smallest probability, which is the last element in the probability distribution.

I'm hoping that rebalancing the heap is the last part of the puzzle. A quick guess is that sorting the heap's underlying array and reconstructing the heap from that might just do the trick.

schaul commented 8 years ago

I was just made aware of your effort to reproduce our work and combine all the best bits, very cool, I'm very happy to see that!

We tried to put as much information into the paper as possible, but if we missed something, feel free to ask.

Concretely:

avasilkov commented 8 years ago

@schaul Hi, thanks for your great work!

I have a few questions about priority experience replay as well. If you have time, could you answer to all/some of them? Thank you!

1)Does the constant assigned to new transitions matter? I have a 1000 for example and all my errors are clipped to -1, 1 before square inside the error.

2)I see jumps in my error values after each max-heap sorting, is this normal?//Please, see edit.

3)How often do I need to check if heap is unbalanced or not?

4)How do you remove elements when the heap is full? Currently I just remove the last element in the heap array.

5)to update the heap in batch, do you need to keep a dictionary from element id to its current position? Or you used some other method?

6)How bad is for beta to reach 1.0 too soon? Before the end of learning. alpha stays the same.

7)I train my network every 4 added transitions and my minibatch size is 32. So there is 4 new and unvisited transitions in the heap for every batch update. Should I increase minibatches or decrease 4 to 2 or 1 or leave it be?

I don't participate in this repository, but I hope my questions will be helpful to anyone who was trying to find a priority experience replay implementation and ended up here :)

Edit: jumps occur not after sorting, but after recomputing boundaries while heap is not full.

Kaixhin commented 8 years ago

@avasilkov I'll attempt to answer a few questions based on what I've read in the paper/my implementation in this repo.

  1. According to the algorithm, p1 = 1, and thereafter transitions are added with the max priority from the heap. Which will be max 1 with TD-error clipping and the absolute function.
  2. Not sure.
  3. According to the appendix the heap is rebalanced every 1m steps for the Atari domain.
  4. The heap is tied to the experience replay (circular buffer), so when a new transition replaces an old one, the corresponding element should be updated in the heap, rebalancing as needed.
  5. I am using a dictionary to do this.
  6. According to table 3, beta is annealed until the end of training for the Atari domain.
  7. Minibatch size and k should remain at 32 and 4 respectively.
avasilkov commented 8 years ago

@Kaixhin Thank you for your answers! I really like your idea of uniting all these amazing papers into one crazy network!

What do you think about persistent advantage learning combined with double dqn? I'm not sure how to combine them, do they play along well?

2)About my second question, I asked it wrong. jumps happen while heap is not full and I call sort+recompute segment boundaries. So it doesn't happen after memory is full, but after each boundary recompute there are jumps: error picture, imgur Maybe it's because my replay memory doesn't serve transitions that were added between boundaries computations and the network sees a lot of new errors? But then it should get back to normal pretty quickly..

3)

According to the appendix the heap is rebalanced every 1m steps for the Atari domain.

Yes, I know, but @schaul wrote

the heap can become rather imbalanced, check this and rebalance if necessary

Currently I'm rebalancing it every heap_size(800k) transitions.

4) So in your implementation it's FIFO order, and even if the oldest transition still has a large error, you replace it with new?

6) But what if I don't know when the end will be? I don't train it on atari. How do you anneal it without knowing for how many steps you are going to anneal it? 7) I have one question that is not about experience replay.. should I change target network update from 10000 steps to 40000 steps because I train my network only every 4 steps?

Thanks again!

Edit: the error pic is averaged batch error taken every 500 steps.

Kaixhin commented 8 years ago

@avasilkov According to one of the authors, PAL does not work well with double Q-learning (see notes in code).

  1. -
  2. I wouldn't be surprised to see the jump in errors for the reason you mentioned, and might expect it to decrease, but RL is harder to debug. I'd suggest more finely recomputing boundaries and seeing the results. A better test would be to try this on a different RL domain.
  3. If there is a heuristic for automatically doing this I'd also be interested to know.
  4. Yes.
  5. -
  6. It won't be as optimal, but if your problem is unbounded then it would make sense to leave it fixed and try different values to see what works best.
  7. I don't think there is a neat connection between the two hyperparameters - it's affected by several aspects of learning and the domain. The DDQN paper switches to 30000 whilst still updating every 4 steps. On the Catch domain in this repo, 4 works (so effectively no target network needed).
avasilkov commented 8 years ago

Thank you, @Kaixhin ! I guess I will go and try FIFO order for removing. Maybe because of my way of removing elements I have stale transitions that only do harm. And thanks for the suggestion about not annealing beta parameter.

mw66 commented 8 years ago

how about https://docs.python.org/2/library/queue.html#Queue.PriorityQueue ? :-)

and how about Python? at least the array index there is 0-based :-)

I read this thread, and found it seems you are struggling to get the low level code work correct.

From a language perspective, Python is a better than Lua; on top of that Python has much much bigger communities, and (standard) libraries. So you do not need to worry about such kind low level code, or bug that hinder your main work on machine learning.

As you mentioned, probably this is the only public attempt to prioritised experience replay. It's a pity to see that your progress are blocked by such low level issues.

I have a suggestion, how about joining force with others? e.g.

https://github.com/tambetm/simple_dqn

to "stand on the shoulders of giants" :-)

That Python code is well written, and easy to understand / extend.

I'd really like to see the relatively small ML community can work together to make joint effort, and make one good & complete implementation, instead of many fragmented ones.

Just my 2 cents.

Kaixhin commented 8 years ago

@mingwugmail I don't want to start a debate about the pros and cons of Torch vs Theano, but some points relevant to this repository:

  1. The priority queue that I wrote has extra methods that require access to the internals.
  2. @mryellow wrote some unit tests to check that this was working correctly.
  3. I am trying to replicate DeepMind's results, but each experiment takes me about 2 weeks on one GPU. One of the authors said that some games may not always give the same results as reported in the paper, which makes this even harder to confirm.
mw66 commented 8 years ago

I just want to provide more info:

simple_dqn is based on Neon not Theano, and from it's doc

http://www.nervanasys.com/deep-reinforcement-learning-with-neon/

""" I chose to base it on Neon, because it has

-- the fastest convolutional kernels, -- all the required algorithms implemented (e.g. RMSProp), -- a reasonable Python API. """

and convnets benchmarks

https://github.com/soumith/convnet-benchmarks

Mature library + fastest implementation all will speed up your experiment.

mryellow commented 8 years ago

I have a funny non-stationary domain which gives prioritised replay a bit of a workout. No real results to compare against but have noticed the behaviour change.

There are 3 actions, left, right, noop. During training without prioritised replay the noop will sit very close to random (33%), decreasing as epsilon progresses. However with prioritised replay it will jump all over the place, for example per 10k actions: 17.1%, 33.2%, 8.8%, 17.8%, 65.2% etc. During evaluation it settles down a bit and ends up just above 0% (dropout doubling seems to be pushing other actions up above noop reward).

Kaixhin commented 8 years ago

@mryellow Actually @lake4790k spotted an issue to do with indexing from the priority queue to the experience replay buffer - I've only had time to throw together a quick patch on a new per branch and set off one test, but best to look at that branch for now.

I've also now spotted another issue, which is that terminal states shouldn't be stored in the queue at all. In per I've given them a priority of 0, which I don't think should make too much of a difference, but I'll need to correct this eventually too.

Edit 2016-06-04: First test on the per branch with Frostbite looks to be successful (see below). A working version of prioritised experience replay will be merged in after refactor is finished.

scores