JuliaStats / StatsBase.jl

Basic statistics for Julia
Other
584 stars 194 forks source link

A faster algorithm for weighted sampling with replacement when k < n by reservoir sampling? #928

Closed Tortar closed 6 months ago

Tortar commented 7 months ago

I designed some time ago an algorithm (described in https://arxiv.org/abs/2403.20256) which I thought to be useful in sampling from data streams, it turns out that this is faster than the current algorithm in StatsBase by quite a bit, this surely needs some careful inspection but it is a lot faster in some cases when the number of items in the sample is less than the number of items in the population. This is it:

using Random, StatsBase, Distributions

function weighted_reservoir_sample(rng, a, ws, n)
    m = min(length(a), n)
    view_w_f_n = @view ws[1:m]
    w_sum = sum(view_w_f_n)
    reservoir = sample(rng, (@view a[1:m]), Weights(view_w_f_n, w_sum), n)
    length(a) <= n && return reservoir
    w_skip = skip(rng, w_sum, n)
    @inbounds for i in n+1:length(a)
        w_el = ws[i]
        w_sum += w_el
        if w_sum > w_skip
            p = w_el/w_sum
            z = (1-p)^(n-3)
            q = rand(rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
            k = choose(n, p, q, z)
            for j in 1:k
                r = rand(rng, j:n)
                reservoir[r] = a[i]
                reservoir[r], reservoir[j] = reservoir[j], reservoir[r]
            end 
            w_skip = skip(rng, w_sum, n)
        end
    end
    return shuffle!(rng, reservoir)
end

function skip(rng, w_sum::AbstractFloat, m)
    q = rand(rng)^(1/m)
    return w_sum/q
end

function choose(n, p, q, z)
    m = 1-p
    s = z
    z = s*m*m*(m + n*p)
    z > q && return 1
    z += n*p*(n-1)*p*s*m/2
    z > q && return 2
    z += n*p*(n-1)*p*(n-2)*p*s/6
    z > q && return 3
    return quantile(Binomial(n, p), q)
end

benchmarking

rng = Xoshiro(42);
a = collect(1:10^7);
wv(el) = rand() < 0.1 ? 10 * rand() : rand()
ws = Weights(wv.(a));

weighted_reservoir_sample(rng, a, ws, 1);
weighted_reservoir_sample(rng, a, ws, 10^4);
sample(rng, a, ws, 1);
sample(rng, a, ws, 10^4);

for i in 0:7
    t1 = @elapsed weighted_reservoir_sample(rng, a, ws, 10^i);
    t2 = @elapsed sample(rng, a, ws, 10^i);
    println("sample with 10^$i items with population of 10^7 items: $(t2/t1)")
end

shows this relative perf improvement in respect to the current one:

sample with 10^0 items with population of 10^7 items: 0.8358660101066833
sample with 10^1 items with population of 10^7 items: 5.248531783569411
sample with 10^2 items with population of 10^7 items: 19.3146914281279
sample with 10^3 items with population of 10^7 items: 17.139903421544233
sample with 10^4 items with population of 10^7 items: 10.72330908054339
sample with 10^5 items with population of 10^7 items: 3.2609968862521956
sample with 10^6 items with population of 10^7 items: 0.8949282382149918
sample with 10^7 items with population of 10^7 items: 0.9909354681929494

actually on the dev version after #927 it is even more pronounced:

sample with 10^0 items with population of 10^7 items: 0.5206986178221531
sample with 10^1 items with population of 10^7 items: 4.426968640113709
sample with 10^2 items with population of 10^7 items: 33.29267938056488
sample with 10^3 items with population of 10^7 items: 31.532953274019665
sample with 10^4 items with population of 10^7 items: 23.297611639617777
sample with 10^5 items with population of 10^7 items: 6.736645208235632
sample with 10^6 items with population of 10^7 items: 1.2902524581091546
sample with 10^7 items with population of 10^7 items: 0.9597508690177733

FWIW, this passes all my tests in https://github.com/JuliaDynamics/StreamSampling.jl which try to also assess if the sample is really random. What do you think of using this method in the cases it is faster?

Tortar commented 6 months ago

It's probably necessary to publish it in a peer-review journal before anything else