Open LilithHafner opened 10 months ago
Interesting. Why not use this at least for large vectors.
Regarding the performance of quantile
, see also https://github.com/JuliaStats/Statistics.jl/pull/91.
Note that the better answer here would be a QuickSelect based approach. partial sorting does more work than necessary here.
Doesn't partialsort!
use quickselect?
@nalimilan, yes. partialsort!(v::AbstractVector, k::Integer)
uses QuickSelect in most cases on Julia 1.10.x.
On 1.11.0-rc1 it uses BracketedSort (a generalization of the alg I proposed in this PR).
However, I haven't closed this issue because another key optimization proposed here is that median
need not make a copy. This is possible because partialsort
with alg=BracketedSort
can be (but has not yet been) optimized to not copy the entire array.
The concept is to take a random sample to quickly find values that almost certainly (99% chance) bracket target value(s), then efficiently pass over the whole input, counting values that fall above/below the bracketed range and explicitly storing only those that fall within the target range. If the median does not fall within the target range, try again with a new random seed up to three times (99.9999% success rate if the randomness is good). If the median does fall within the selected subset, find the exact target values within the selected subset.
Here's a naive implementation that is 4x faster for large inputs and allocates O(n ^ 2/3) memory instead of O(n) memory.
I think this is reasonably close to optimal for large inputs, but I payed no heed to optimizing the O(n^(2/3)) factors, so it is likely possible to optimize this to lower the crossover point where this becomes more efficient than the current median code.
This generalizes quite well to
quantiles(n, k)
for shortk
. It has a runtime ofO(n * k)
with a low constant factor. The calls topartialsort!
can also be replaced with more efficient recursive calls toquantile
Benchmarks
Runtimes measured in clock cycles per element (@ 3.49 GHz)
10^9 OOMs.
Benchmark code
```julia println("length | median | my_median") println("-------|--------|----------") for i in 1:8 n = 10^i print("10^", rpad(i, 2), " | ") x = rand(n) t0 = @belapsed median($x) t0 *= 3.49e9/n print(rpad(round(t0, digits=2), 4, '0'), " | ") t1 = @belapsed my_median($x) t1 *= 3.49e9/n println(rpad(round(t1, digits=2), 4, '0')) end ``` And I removed the `length(x) < 2^12` fastpath to get accurate results for smaller inputs. I replaced the `@assert` with `1 <= lo_i || return median(v)`