pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
29.92k stars 1.93k forks source link

Use sortedness/statistics in comparisons #9796

Open jonashaag opened 1 year ago

jonashaag commented 1 year ago

Problem description

Inspired by some experimentation in #9794 I was thinking why we don't take into account sortedness and statistics in some of the comparison/is_in/... operations. As an example we can make this optimisation:

# Unoptimized
val == series

# Optimized
if val < series.min(): False
if val > series.max(): False
else: val == series

Example:

val = 1_000_000
s = pl.Series(range(val))

%timeit val == s
7.4 ms ± 94.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit pl.select(pl.repeat(False, val)) if (val > s.max() or val < s.min()) else (val == s)
1.04 ms ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
mcrumiller commented 1 year ago

Is min/max saved/pre-computed, or does this require a computation of min & max every time? If it's recomputed, I don't get how the optimized version is faster, as computing series.min() should take about the same time as val==series, no?

magarick commented 1 year ago

It looks like the point here is that if the series is sorted you can take the first and last values. If it's outside of the range, you know it's not in. Otherwise you can do a binary search. Other than having to be moderately careful about nulls, this seems straightforward enough.

jonashaag commented 1 year ago

It is still faster if the series is shuffled.

magarick commented 1 year ago

How? Does every series cache its min and max values? If not, you have to do a full pass over the series to get the min and max but if you find an element in the series you can stop early. What am I missing?

jonashaag commented 1 year ago

Does not seem to be cached

In [2]: s = pl.Series(range(1_000)).shuffle()

In [3]: %timeit s.min()
314 ns ± 2.68 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [4]: s = pl.Series(range(10_000)).shuffle()

In [5]: %timeit s.min()
1.68 µs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: s = pl.Series(range(100_000)).shuffle()

In [7]: %timeit s.min()
15.5 µs ± 12.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
jonashaag commented 1 year ago

Maybe it's because pl.repeat is very fast so maybe this example is a special case:

In [15]: s = pl.Series(range(1_000_000)).shuffle()

In [16]: m = s.min()

In [18]: %timeit pl.select(pl.repeat(False, len(s)))
35.6 µs ± 266 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [19]: %timeit s < m
232 µs ± 601 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
magarick commented 1 year ago

Does seem to be cached

In [2]: s = pl.Series(range(1_000)).shuffle()

In [3]: %timeit s.min()
314 ns ± 2.68 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [4]: s = pl.Series(range(10_000)).shuffle()

In [5]: %timeit s.min()
1.68 µs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: s = pl.Series(range(100_000)).shuffle()

In [7]: %timeit s.min()
15.5 µs ± 12.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Did you mean "doesn't" because those numbers look like they're increasing approximately linearly with series length (modulo overhead of course).

And in your second example, I'm not sure what you're trying to show. Repeating false will be faster than doing all the comparisons, but it looks like you can't avoid the comparisons unless you've precomputed the min/max and the series is sorted.

jonashaag commented 1 year ago

Oops, I corrected the comment and I'm posting new benchmarks here that are less flawed:

In [31]: s = pl.Series(range(1_000_000))

In [33]: %timeit s.min(), s.max()
374 ns ± 1.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [32]: %timeit val == s
259 µs ± 1.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [34]: %timeit val < s
6.39 µs ± 31.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [35]: %timeit val > s
6.88 µs ± 20.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [36]: %timeit pl.select(pl.repeat(False, val)) if (val > s.max() or val < s.min()) else (val == s)
36.5 µs ± 481 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

What I meant to say is that .min() + .max() are almost free and pl.select is much faster than ==.

The numbers change substantially if the data is shuffled:

In [37]: s = pl.Series(range(1_000_000)).shuffle()

In [38]: %timeit s.min(), s.max()  # 1000x slower
310 µs ± 5.47 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [40]: %timeit val == s
232 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [39]: %timeit val < s
232 µs ± 311 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [41]: %timeit pl.select(pl.repeat(False, val)) if (val > s.max() or val < s.min()) else (val == s)  # slower but still faster than < or ==
188 µs ± 467 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [42]: %timeit pl.select(pl.repeat(False, val)) if (val > s.min() or val < s.max()) else (val == s)  # changed order of > and <, still faster
189 µs ± 901 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
magarick commented 1 year ago

It's because it knows a range is sorted when you construct the Series. If you replace a value it will unset the flag even if it's still sorted

>>> s = pl.Series(range(1_000_000))
>>> r = s.shuffle()
>>> timeit('s.min()', globals = globals(), number = 10000)
0.007208314724266529
>>> timeit('s.min()', globals = globals(), number = 10000)
0.006978688761591911
>>> timeit('r.min()', globals = globals(), number = 10000)
2.141932046972215
>>> s[0] = -1
>>> timeit('s.min()', globals = globals(), number = 10000)
1.9460285976529121
jonashaag commented 1 year ago

That's obvious, but my min/max + op dance is still faster than just the op