pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
29.13k stars 1.83k forks source link

Unexpected error with Expr.list.sample with n=0 in some rows but not all #16232

Closed bobot closed 3 months ago

bobot commented 3 months ago

Checks

Reproducible example

import polars as pl

k=2
p=0

df = pl.DataFrame({ "a" : [p]*k+[1+p], "b": [[1]*p]*k+[range(1,p+2)]})
print(df)
print(df.select(pl.col("b").list.sample(n=pl.col("a"),seed=0)))

Log output

shape: (3, 2)
┌─────┬───────────┐
│ a   ┆ b         │
│ --- ┆ ---       │
│ i64 ┆ list[i64] │
╞═════╪═══════════╡
│ 0   ┆ []        │
│ 0   ┆ []        │
│ 1   ┆ [1]       │
└─────┴───────────┘
Traceback (most recent call last):
  File "test_expr_sample.py", line 14, in <module>
    print(df.select(pl.col("b").list.sample(n=pl.col("a"),seed=0)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/polars/dataframe/frame.py", line 8137, in select
    return self.lazy().select(*exprs, **named_exprs).collect(_eager=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/polars/lazyframe/frame.py", line 1816, in collect
    return wrap_df(ldf.collect(callback))
                   ^^^^^^^^^^^^^^^^^^^^^
polars.exceptions.ShapeError: cannot take a larger sample than the total population when `with_replacement=false`

Issue description

Using Expr.list.sample with n=0 in some rows but not all can raise an unexpected error.

Expected behavior

shape: (3, 1)
┌───────────┐
│ b         │
│ ---       │
│ list[i64] │
╞═══════════╡
│ []        │
│ []        │
│ [1]       │
└───────────┘

It works for k=1 p=0, and seems to work with p>0. So the sampling with some rows equal to 0 seems the heart of the problem.

Installed versions

``` --------Version info--------- Polars: 0.20.26 Index type: UInt32 Platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.38 Python: 3.11.6 (main, Oct 8 2023, 05:06:43) [GCC 13.2.0] ----Optional dependencies---- adbc_driver_manager: cloudpickle: connectorx: deltalake: fastexcel: fsspec: gevent: hvplot: matplotlib: nest_asyncio: 1.6.0 numpy: 1.26.4 openpyxl: pandas: 2.2.2 pyarrow: 16.0.0 pydantic: 2.7.1 pyiceberg: pyxlsb: sqlalchemy: torch: xlsx2csv: xlsxwriter: ```
bobot commented 3 months ago

Also it works if the row with a=1 (so n=1) is first.

KernelA commented 3 months ago

I faced with the same problem. Polars version is 0.20.25

itamarst commented 3 months ago

I'm taking a look at this.

itamarst commented 3 months ago

The implementation devolves to lst_sample_n (in polars_core::chunked_array::list::namespace). Here's the relevant snippet:

           ca.try_zip_and_apply_amortized(n, |opt_s, opt_n| {
                match (opt_s, opt_n) {
                (Some(s), Some(n)) => s
                    .as_ref()
                    .sample_n(n as usize, with_replacement, shuffle, seed)
                    .map(Some),
                _ => Ok(None),
            }})

Here's the sample_n() function, from polars_core::chunked_array::random:

impl Series {
    pub fn sample_n(
        &self,
        n: usize,
        with_replacement: bool,
        shuffle: bool,
        seed: Option<u64>,
    ) -> PolarsResult<Self> {
        ensure_shape(n, self.len(), with_replacement)?;
        if n == 0 {
            return Ok(self.clear());
        }
        let len = self.len();

        match with_replacement {
            true => {
                let idx = create_rand_index_with_replacement(n, len, seed);
                debug_assert_eq!(len, self.len());
                // SAFETY: we know that we never go out of bounds.
                unsafe { Ok(self.take_unchecked(&idx)) }
            },
            false => {
                let idx = create_rand_index_no_replacement(n, len, seed, shuffle);
                debug_assert_eq!(len, self.len());
                // SAFETY: we know that we never go out of bounds.
                unsafe { Ok(self.take_unchecked(&idx)) }
            },
        }
    }

Changing self.clear() to Series::new_empty(self.name(), self.dtype()) seems to solve the problem!

Given clear() takes &self, you wouldn't expect it to cause side-effects. Except that try_zip_and_apply_amortized() uses UnstableSeries, so my current suspicion is that somewhere in there the invariants are wrong or not being applied. Which means this is plausibly a bigger problem, certainly this change doesn't seem like it should help. And yet...

itamarst commented 3 months ago

@reswqa I think maybe you wrote the original code? Do you have any insight?

itamarst commented 3 months ago

Documentation for UnstableSeries says:

/// A wrapper type that should make it a bit more clear that we should not clone Series

But the API makes it trivial to clone() (my_unstable_series.as_ref().clone()), and in fact that is was Series::clear() was doing in this particular edge case.

itamarst commented 3 months ago

Here is a demonstration of what I believe is unsound behavior of the UnstableSeries API:

    fn undefined_behavior() {
        let mut series = Series::new("a", [1, 2]);
        let mut_ref = &mut series;
        let addr1 = std::ptr::addr_of!(*mut_ref);
        let mut unstable = UnstableSeries::new(mut_ref);

        // This won't compile:
        // let non_mut_ref = series.clone();

        // But we can now have both a mutable reference and a non-mutable
        // reference at the same time a different way, which is completely
        // unsound:
        let non_mut_ref = unstable.as_ref();
        let addr2 = std::ptr::addr_of!(*non_mut_ref);
        assert_eq!(addr1 as usize, addr2 as usize);

        let cloned = non_mut_ref.clone();

        // And now we can mutate cloned, at a distance:
        let series2 = Series::new("b", [3, 4]);
        unstable.swap(&mut series2.array_ref(0).clone());
        assert_eq!(cloned.sum::<usize>().unwrap(), 7);
    }