pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
27.97k stars 1.72k forks source link

Add stratisfied sampling #2661

Open johanbog opened 2 years ago

johanbog commented 2 years ago

Describe your feature request

In pandas, you can do stratisfied sampling of a Series by using `.sample(weights=...), where the weights can be ndarray-like or a column in the dataframe. I would like to do this in polars as well. For now, I'll just go via a pandas Dataframe for this step.

ritchie46 commented 2 years ago

If I understand stratified sampling correctly (I assume you sample by group size).

You can do this with the landing of #2668 something like this.

df = pl.DataFrame({
    "groups": [1] * 5 + [2] * 10,
    "values": range(15)
}).with_row_count("row_idx")

# sample row_idx and groups column
sampled = (df.groupby("groups")
    .agg([
        pl.col("row_idx").sample(0.5, with_replacement=False, seed=1)
    ]).explode("row_idx")
)

# with the sample row_idx we sample the other columns
(pl.concat([sampled, df.select(pl.all().exclude(["groups", "row_idx"])
                              .take(sampled["row_idx"]))], how="horizontal"))
shape: (7, 3)
┌────────┬─────────┬────────┐
│ groups ┆ row_idx ┆ values │
│ ---    ┆ ---     ┆ ---    │
│ i64    ┆ u32     ┆ i64    │
╞════════╪═════════╪════════╡
│ 2      ┆ 5       ┆ 5      │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 9       ┆ 9      │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 14      ┆ 14     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 7       ┆ 7      │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 2      ┆ 10      ┆ 10     │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1      ┆ 1       ┆ 1      │
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
│ 1      ┆ 0       ┆ 0      │
└────────┴─────────┴────────┘

I don't think its worth an API method, but we should deliver the components to be able to do this easily.

slonik-az commented 2 years ago

For me polars' small API (as compared to pandas') is one of the polars major features. There are many different ways of sampling and stratified sampling is just one of them. API compositionality is main strength of polars which allows to build on top of limited API core.

detrin commented 9 months ago

mentioned in https://www.youtube.com/watch?v=YUMhGp1ryUY image

ryu1kn commented 5 months ago

I wanted to use different fractions/weights on different groups so that I can up-sample some groups and down-sample others. Previously I was using Pandas to do this, but maybe thanks to https://github.com/pola-rs/polars/pull/11943 which allows us to specify an Expr in sample's n / fraction, I could get a similar result with Polars (and a lot faster). I'm going to share it here in case someone finds it useful.

Given weights are assigned to each data point,

df = pd.DataFrame({
    "groups": [1] * 5 + [2] * 10,
    "values": range(15),
    "weights": [2] * 5 + [1] * 10
})

#      groups  values  weights
#  0        1       0        2
#  1        1       1        2
#  2        1       2        2
#  3        1       3        2
#  ...
#  11       2      11        1
#  12       2      12        1
#  13       2      13        1
#  14       2      14        1

With Pandas, as we know, we can tell Pandas to look at weights column to know weights.

df.sample(len(df), replace=True, weights="weights")

To do a weighted sampling with Polars, we can give fraction (or n) an expression:

df_pl = pl.from_pandas(df).with_row_index("row_idx")

df_sampled = (df_pl
    # To calculate fraction later, convert weights to probabilities
    .with_columns(p=pl.col("weights") / pl.col("weights").sum())
    .group_by("groups")
    .agg([
        pl.col("row_idx").sample(
            # Here we give an expression to `fraction`.
            # In each group, all `p` are same; so I just pick the first value.
            fraction=pl.col("p").first() * len(df_pl),
            with_replacement=True, seed=1
        )
    ])
    .explode("row_idx")
    .drop_nulls("row_idx")
)

(pl.concat([df_sampled, df_pl.select(pl.all().exclude(["groups", "row_idx"])
                                     .take(df_sampled["row_idx"]))], how="horizontal"))