rust-itertools / itertools

Extra iterator adaptors, iterator methods, free functions, and macros.
https://docs.rs/itertools/
Apache License 2.0
2.64k stars 299 forks source link

Implement `Combinations::nth` #914

Closed kinto-b closed 3 months ago

kinto-b commented 3 months ago

Hi there, this PR addresses #301 in the same way as #329. @Philippe-Cholet I was indeed wrong that this wouldn't give a substantial performance gain

cargo bench --bench specializations "combinations[1234]/nth"

combinations1/nth       time:   [4.8406 ms 4.8562 ms 4.8734 ms]
                        change: [-53.963% -53.706% -53.442%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 15 outliers among 100 measurements (15.00%)
  6 (6.00%) high mild
  9 (9.00%) high severe

combinations2/nth       time:   [4.6338 ms 4.6532 ms 4.6762 ms]
                        change: [-55.198% -54.769% -54.327%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 11 outliers among 100 measurements (11.00%)
  1 (1.00%) high mild
  10 (10.00%) high severe

combinations3/nth       time:   [4.7743 ms 4.8138 ms 4.8600 ms]
                        change: [-53.934% -53.322% -52.707%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 14 outliers among 100 measurements (14.00%)
  8 (8.00%) high mild
  6 (6.00%) high severe

combinations4/nth       time:   [5.1255 ms 5.1524 ms 5.1841 ms]
                        change: [-53.459% -52.613% -51.852%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 18 outliers among 100 measurements (18.00%)
  6 (6.00%) high mild
  12 (12.00%) high severe

As discussed in the comments of #301, an even more efficient solution which would be performant even in the large n/k regime, would be to use the combinatorial number system. For record keeping sake, here's the function I sketched out before I realised that the itertools uses a 'reversed' lexicographic ordering.

    fn nth(&mut self, n: usize) -> Option<Self::Item> {
        self.state += n;

        // https://en.wikipedia.org/wiki/Combinatorial_number_system#Finding_the_k-combination_for_a_given_number
        let mut remainder = self.state;
        let mut delta = 0;
        for i in (0..self.k()).rev() {
            let mut m = i;
            while let Some(d) = checked_binomial(m, i + 1) {
                if d > remainder {
                    self.indices[i] = m - 1;
                    remainder -= delta;
                    break;
                }

                m += 1;
                delta = d;
            }
        }

        // We may need to pad out the pool
        if let Some(x) = self.indices.last() {
            self.pool.prefill(x + 1);
            if *x >= self.n() {
                return None;
            }
        }

        // We should only return an empty list once
        if self.indices.is_empty() && self.state > 0 {
            return None;
        }

        Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect())
    }

It's possible to use this same idea with the lexicographic ordering used by itertools, but we'd need to know number of elements in the underlying iterator. Suppose we do, and we label this N. Then the function above would just need to be modified so that

  1. We initialise remainder = checked_binomial(N, self.k()) - n - 1
  2. We 'invert' the indices before returning the combination so that i = N-1-i and then reverse the order to ensure that the indices remain strictly increasing.
codecov[bot] commented 3 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 94.45%. Comparing base (6814180) to head (2ffbd31). Report is 48 commits behind head on master.

:exclamation: Current head 2ffbd31 differs from pull request most recent head 2c5a2ba. Consider uploading reports for the commit 2c5a2ba to get more accurate results

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #914 +/- ## ========================================== + Coverage 94.38% 94.45% +0.06% ========================================== Files 48 48 Lines 6665 6870 +205 ========================================== + Hits 6291 6489 +198 - Misses 374 381 +7 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

kinto-b commented 3 months ago

Have made those two patches, thanks for reviewing :)

cargo bench --bench specializations "combinations[1234]/nth"

combinations1/nth       time:   [3.9183 ms 3.9953 ms 4.0795 ms]
                        change: [-19.349% -17.726% -16.196%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 9 outliers among 100 measurements (9.00%)
  8 (8.00%) high mild
  1 (1.00%) high severe

combinations2/nth       time:   [3.5200 ms 3.5950 ms 3.6761 ms]
                        change: [-24.364% -22.741% -21.022%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 17 outliers among 100 measurements (17.00%)
  10 (10.00%) high mild
  7 (7.00%) high severe

combinations3/nth       time:   [3.5204 ms 3.5581 ms 3.5994 ms]
                        change: [-27.161% -26.085% -24.983%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 4 outliers among 100 measurements (4.00%)
  2 (2.00%) high mild
  2 (2.00%) high severe

combinations4/nth       time:   [3.7733 ms 3.8175 ms 3.8691 ms]
                        change: [-26.955% -25.909% -24.814%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 6 outliers among 100 measurements (6.00%)
  2 (2.00%) high mild
  4 (4.00%) high severe
kinto-b commented 3 months ago

You can squash them all together at merge time:

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-commits

Here are the benchmarks (comparing to master)

combinations1/nth       time:   [3.9468 ms 3.9861 ms 4.0303 ms]
                        change: [-65.786% -65.264% -64.765%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 8 outliers among 100 measurements (8.00%)
  3 (3.00%) high mild
  5 (5.00%) high severe

combinations2/nth       time:   [3.8809 ms 3.9623 ms 4.0542 ms]
                        change: [-65.562% -64.704% -63.750%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 14 outliers among 100 measurements (14.00%)
  3 (3.00%) high mild
  11 (11.00%) high severe

combinations3/nth       time:   [4.0474 ms 4.1469 ms 4.2530 ms]
                        change: [-65.039% -64.163% -63.323%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 3 outliers among 100 measurements (3.00%)
  3 (3.00%) high mild

combinations4/nth       time:   [4.0968 ms 4.1343 ms 4.1770 ms]
                        change: [-65.459% -65.056% -64.615%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 8 outliers among 100 measurements (8.00%)
  2 (2.00%) high mild
  6 (6.00%) high severe
kinto-b commented 3 months ago

also thanks for reviewing these PRs. If there's anything else you think I might be able to work on, let me know :)

Philippe-Cholet commented 3 months ago

I think the "Squash and merge" option is not selected (and I'm not the owner). I'll merge as is.

Thanks for the benchmarks. -65% is nice right? The benchmark does repeatedly .nth(0)s ... .nth(9)s so it does not give much info on each nth(n) though. nth(n) was n+1 heap-allocations before and is now only one, thanks! I think it's roughly n+1 faster.

About making nth even more performant with the "combinatorial number system", it would require to store state field (and therefore update elsewhere too) or use my remaining_for function each time nth is called. I guess it would be fantastic for .nth(/*big*/n). I could imagine

#[inline(never)]
fn nth_perf(&mut self, n: usize) ... { ... }

fn nth(&mut self, n: usize) ... {
    if n == 0 { ... }
    if n >= ARBITRARY_VALUE { return self.nth_perf(n); }
    ...
}

but I wonder if it's worth the trouble. With Vec items, our combinations are more convenient than truly performant. It's nice to have a specialized nth but for more critical performance, another entire implementation could be better. Without people asking for it, I tend to think we should not.

also thanks for reviewing these PRs. If there's anything else you think I might be able to work on, let me know :)

Well, I'll surely mention you soon.

EDIT: @kinto-b Do you want to make an helper function that does indices.iter().map(index then clone).collect() and reduce code duplication? There are also other fold specializations and some other nth, count, last specializations too.

kinto-b commented 3 months ago

@Philippe-Cholet Yep I'll add that helper and have a browse of the other outstanding fold specializations :) Feel free to @ me in anything