I designed some time ago an algorithm (described in https://arxiv.org/abs/2403.20256) which I thought to be useful in sampling from data streams, it turns out that this is faster than the current algorithm in StatsBase by quite a bit, this surely needs some careful inspection but it is a lot faster in some cases when the number of items in the sample is less than the number of items in the population. This is it:
using Random, StatsBase, Distributions
function weighted_reservoir_sample(rng, a, ws, n)
m = min(length(a), n)
view_w_f_n = @view ws[1:m]
w_sum = sum(view_w_f_n)
reservoir = sample(rng, (@view a[1:m]), Weights(view_w_f_n, w_sum), n)
length(a) <= n && return reservoir
w_skip = skip(rng, w_sum, n)
@inbounds for i in n+1:length(a)
w_el = ws[i]
w_sum += w_el
if w_sum > w_skip
p = w_el/w_sum
z = (1-p)^(n-3)
q = rand(rng, Uniform(z*(1-p)*(1-p)*(1-p),1.0))
k = choose(n, p, q, z)
for j in 1:k
r = rand(rng, j:n)
reservoir[r] = a[i]
reservoir[r], reservoir[j] = reservoir[j], reservoir[r]
end
w_skip = skip(rng, w_sum, n)
end
end
return shuffle!(rng, reservoir)
end
function skip(rng, w_sum::AbstractFloat, m)
q = rand(rng)^(1/m)
return w_sum/q
end
function choose(n, p, q, z)
m = 1-p
s = z
z = s*m*m*(m + n*p)
z > q && return 1
z += n*p*(n-1)*p*s*m/2
z > q && return 2
z += n*p*(n-1)*p*(n-2)*p*s/6
z > q && return 3
return quantile(Binomial(n, p), q)
end
benchmarking
rng = Xoshiro(42);
a = collect(1:10^7);
wv(el) = rand() < 0.1 ? 10 * rand() : rand()
ws = Weights(wv.(a));
weighted_reservoir_sample(rng, a, ws, 1);
weighted_reservoir_sample(rng, a, ws, 10^4);
sample(rng, a, ws, 1);
sample(rng, a, ws, 10^4);
for i in 0:7
t1 = @elapsed weighted_reservoir_sample(rng, a, ws, 10^i);
t2 = @elapsed sample(rng, a, ws, 10^i);
println("sample with 10^$i items with population of 10^7 items: $(t2/t1)")
end
shows this relative perf improvement in respect to the current one:
sample with 10^0 items with population of 10^7 items: 0.8358660101066833
sample with 10^1 items with population of 10^7 items: 5.248531783569411
sample with 10^2 items with population of 10^7 items: 19.3146914281279
sample with 10^3 items with population of 10^7 items: 17.139903421544233
sample with 10^4 items with population of 10^7 items: 10.72330908054339
sample with 10^5 items with population of 10^7 items: 3.2609968862521956
sample with 10^6 items with population of 10^7 items: 0.8949282382149918
sample with 10^7 items with population of 10^7 items: 0.9909354681929494
actually on the dev version after #927 it is even more pronounced:
sample with 10^0 items with population of 10^7 items: 0.5206986178221531
sample with 10^1 items with population of 10^7 items: 4.426968640113709
sample with 10^2 items with population of 10^7 items: 33.29267938056488
sample with 10^3 items with population of 10^7 items: 31.532953274019665
sample with 10^4 items with population of 10^7 items: 23.297611639617777
sample with 10^5 items with population of 10^7 items: 6.736645208235632
sample with 10^6 items with population of 10^7 items: 1.2902524581091546
sample with 10^7 items with population of 10^7 items: 0.9597508690177733
FWIW, this passes all my tests in https://github.com/JuliaDynamics/StreamSampling.jl which try to also assess if the sample is really random. What do you think of using this method in the cases it is faster?
I designed some time ago an algorithm (described in https://arxiv.org/abs/2403.20256) which I thought to be useful in sampling from data streams, it turns out that this is faster than the current algorithm in
StatsBase
by quite a bit, this surely needs some careful inspection but it is a lot faster in some cases when the number of items in the sample is less than the number of items in the population. This is it:benchmarking
shows this relative perf improvement in respect to the current one:
actually on the dev version after #927 it is even more pronounced:
FWIW, this passes all my tests in https://github.com/JuliaDynamics/StreamSampling.jl which try to also assess if the sample is really random. What do you think of using this method in the cases it is faster?