JuliaReinforcementLearning / ReinforcementLearningTrajectories.jl

A generalized experience replay buffer for reinforcement learning
MIT License
8 stars 8 forks source link

SumTree sampling errors #59

Closed CasBex closed 1 year ago

CasBex commented 1 year ago

TL;DR

The sampling function for CircularPrioritizedTraces is broken: it sometimes yields zero-priority samples. This is caused by rounding errors when broadcasting multiplication over a SumTree. It is better to broadcast over the leaves of the SumTree and use the result to construct a new SumTree. Proof is in the final section of this issue.

Background

I opened an issue in ReinforcementLearning.jl earlier today about the Prioritized DQN example not working properly and managed to trace it down to the SumTree implementation in this package.

When random sampling from a SumTree{Float32} sometimes, samples with priority 0f0 can appear, which should be impossible. I have had difficulty coming up with a minimal example to demonstrate this, but this can reliably be done by running the JuliaRL_PrioritizedDQN_CartPole experiment from ReinforcementLearningExperiments.jl which should then error with the message displayed in the issue linked above.

Random samples with zero priority

I managed to extract some data by modifying get(t::SumTree, v) in sum_tree.jl as follows:


function Base.get(t::SumTree, v)
    parent_ind = 1
    leaf_ind = parent_ind
    vs = Vector{typeof(v)}()
    visited = Int[]
    while true
        push!(visited, parent_ind)
        push!(vs, v)
        left_child_ind = parent_ind * 2
        right_child_ind = left_child_ind + 1
        if left_child_ind > length(t.tree)
            leaf_ind = parent_ind
            break
        else
            if v ≤ t.tree[left_child_ind]
                parent_ind = left_child_ind
            else
                v -= t.tree[left_child_ind]
                parent_ind = right_child_ind
            end
        end
    end
    if leaf_ind <= t.nparents
        leaf_ind += t.capacity
    end
    p = t.tree[leaf_ind]
    if p == 0f0
        @show vs, visited
        @show t.tree
        @assert false
    end
    ind = leaf_ind - t.nparents
    real_ind = ind >= t.first ? ind - t.first + 1 : ind + t.capacity - t.first + 1
    real_ind, p
end

This gives the following output:

julia> hcat(vs, t.tree[visited])
11×2 Matrix{Float32}:
 9.75795      200.816
 9.75795       90.1445
 9.75795       36.01
 9.75795       18.6549
 9.75795        9.75824
 6.88227        6.88182 # Notice here that v > t.tree[parent_ind] which should never be possible
 3.89462        3.89439
 3.17859        3.17821
 2.09281        2.09244
 1.52258        1.52222
 0.000346422    0.0 # zero priority node: this should never be sampled

Rounding errors

In a sum tree the essential invariant is: parent == left + right. In case of Floats this is obviously subject to rounding errors, but it should be approximately true.

To multiply a SumTree a with vector b two methods are possible:

  1. multiply directly via broadcast: a.*=b
  2. make a copy of the leaves; multiply the leaves with b; reconstruct a new SumTree from the multiplied leaves

In the code below these methods are copy_multiply and multiply_copy respectively. The code below shows that method 1 introduces larger invariant violations due to numerical rounding than method 2. It is this rounding which causes the zero-priority sampling above.

using StatsBase
using ReinforcementLearningTrajectories
using StableRNGs

function error_tree(tree)
    errors = Vector{eltype(tree)}()
    for i in 1:length(tree)
        l = 2 * i
        r = l + 1
        if r > length(tree)
            break
        end
        push!(errors, tree[i] - tree[l] - tree[r])
    end
    return errors
end

metric(x) = sqrt(sum(abs2, x))

function copy_multiply(stree, m)
    new_tree = deepcopy(stree)
    new_tree .*= m
    return new_tree
end

function multiply_copy(stree, m)
    tree = stree.tree
    leaves = tree[Int(ceil(length(tree)/2)):end]
    new_tree = SumTree(eltype(stree), stree.capacity)
    append!(new_tree, leaves .* m)
    return new_tree
end

rng = StableRNG(123)
n = 1024

a = SumTree(n)

append!(a, rand(rng, eltype(a), n))
a

b = rand(rng, Bool, n)

c = copy_multiply(a, b)
d = multiply_copy(a, b)

@show metric(error_tree(c.tree)), metric(error_tree(d.tree))

Output (note that the second error is in the order of magnitude of eps(Float32))

(metric(error_tree(c.tree)), metric(error_tree(d.tree))) = (0.00038676208f0, 7.38864f-5)
findmyway commented 1 year ago

👍 A very comprehensive analysis!

6.88227 6.88182 # Notice here that v > t.tree[parent_ind] which should never be possible

Can we simply add a bound check here to force the value after subtraction to be within the range of right leaf ?

CasBex commented 1 year ago

I would have thought so as well, but 72e3791442d6aacc02884ea7719e61ba67ae36ce still gave errors for some random seeds. We can test both approaches after I've written a test for #60

CasBex commented 1 year ago

I have made a new branch for the fix in 72e3791, and cannot seem to get the tests failing. However, when running the JuliaRL_PrioritizedDQN_CartPole experiment with seed 93 I do get an error so I guess it is not completely ok

CasBex commented 1 year ago

https://github.com/CasBex/ReinforcementLearningTrajectories.jl/tree/sumtree_alternative

findmyway commented 1 year ago

I think I got one very extreme case:

julia> t = SumTree(4)
0-element SumTree{Float32}

julia> t.first = 2
2

julia> push!(t, 0.1)
0.1

julia> push!(t, 0.1)
0.1

julia> push!(t, 0.1)
0.1

julia> t
3-element SumTree{Float32}:
 0.1
 0.1
 0.1

julia> t.tree
7-element Vector{Float32}:
 0.3
 0.1
 0.2
 0.0
 0.1
 0.1
 0.1

julia> get(t, 0)
┌ Info: debug
└   parent_ind = 1
┌ Info: debug
└   parent_ind = 2
┌ Info: debug
└   parent_ind = 4
(4, 0.0f0)

The root cause I guess is that we only did a very loose check here.

        if left_child_ind > length(t.tree)
            leaf_ind = parent_ind
            break

We should also make sure left_child_ind or right_child_ind is valid at the moment. But I agree your post check in #60 is also fine.

Also, in very extreme case, rand will return 0 (I was bitten by this before) and this is not what we want. Better to add a check to it while sampling. Or simply apply the following change:

-             if v ≤ t.tree[left_child_ind]
+             if v < t.tree[left_child_ind]

And we'd better add a bound check in the setindex! method to ensure values after modification are never negative.

CasBex commented 1 year ago
-             if v ≤ t.tree[left_child_ind]
+             if v < t.tree[left_child_ind]

I get how this could fix your example, but consider that we also have some actual items with weight 0 in practice it may just as well land on the right child and that may be zero. Consider the following:


julia> t = SumTree(4)
0-element SumTree{Float32}

julia> append!(t, [0.1, 0.1, 0.1, 0.1])

julia> t .*= Bool[1, 0, 1, 1]
4-element SumTree{Float32}:
 0.1
 0.0
 0.1
 0.1

julia> t.tree
7-element Vector{Float32}:
 0.3
 0.1
 0.2
 0.1
 0.0
 0.1
 0.1

julia> get(t, 0)
(1, 0.1f0)

julia> get(t, 0.1)
(1, 0.1f0)

julia> get(t, 0.2)
(3, 0.1f0)

Given the proposed change, the get(t, 0.1) call would fall on the zero priority node

CasBex commented 1 year ago

Closed by #60