narwhals-dev / narwhals

Lightweight and extensible compatibility layer between dataframe libraries!
https://narwhals-dev.github.io/narwhals/
MIT License
571 stars 89 forks source link

[Enh]: Add `Series|Expr.replace` #1223

Open FBruzzesi opened 3 weeks ago

FBruzzesi commented 3 weeks ago

We would like to learn about your use case. For example, if this feature is needed to adopt Narwhals in an open source project, could you please enter the link to it below?

This would enable plotly to do custom sorting without filtering + concatenating:

- nw.concat(
-     [df.filter(nw.col(names) == value) for value in order],
-     how="vertical"
- )
+ (df.with_columns(
+     __custom_sort_col = nw.col(names).replace({v:i for i,v in enumerate(order)}
+     )
+ .sort("__custom_sort_col")
+ .drop("__custom_sort_col")
+ )

(code snippet)

Please describe the purpose of the new feature or describe the problem to solve.

Replicate polars Expr|Series.replace

Suggest a solution if possible.

No response

If you have tried alternatives, please describe them below.

No response

Additional information that may help us understand your needs.

No response

MarcoGorelli commented 3 weeks ago

thanks @FBruzzesi !

I think the Polars-native solution would be:

df.sort(pl.col(names).replace({x: i for i, x in enumerate(order)}))

, without setting any temporary columns

that would require both DataFrame.sort taking expressions, and Expr.replace - i'll take a look

FBruzzesi commented 3 weeks ago

Thanks for the feedback! Yes I would imagine that polars can get around that using expressions as sort key - yet we currently don't support expressions in these contexts, and I have the impression that it may not be trivial to allow that in the current framework we have?

MarcoGorelli commented 3 weeks ago

yup, definitely not trivial...but I think you're right, replace (or rather, replace_all?) is what we need here - and definitely an improvement over concat

MarcoGorelli commented 3 weeks ago

what would you think about doing this using a join? Polars does a join under the hood to do this anyway

https://github.com/pola-rs/polars/blob/dbbd93fae922b94cb1b5e700472a75b5a975fc44/crates/polars-ops/src/series/ops/replace.rs#L203-L233

Example:

@nw.narwhalify(eager_only=True)
def sort_by_custom_order(df, key, order):
    order_key = generate_unique_token(8, df.columns)
    order_df = nw.from_dict(
        {key: order, order_key: range(len(order))},
        native_namespace=nw.get_native_namespace(df),
    )
    return df.join(order_df, on=key, how="left").sort(order_key).drop(order_key)

which, in the Plotly context, you could call as

args["data_frame"] = sort_by_custom_order(df, names, order)

Demo:

import polars as pl
import pandas as pd
import narwhals.stable.v1 as nw
from narwhals.utils import generate_unique_token
import pyarrow as pa

data = {'a': ['foo', 'bar', 'foo', 'foo', 'bar', 'quox', 'foo'], 'b': [1, 3,2,6,3,3,4]}
order = ['foo', 'quox', 'bar']

@nw.narwhalify(eager_only=True)
def sort_by_custom_order(df, key, order):
    order_key = generate_unique_token(8, df.columns)
    order_df = nw.from_dict({key: order, order_key: range(len(order))}, native_namespace=nw.get_native_namespace(df))
    return df.join(order_df, on=key, how='left').sort(order_key).drop(order_key)

print(sort_by_custom_order(pd.DataFrame(data), 'a', order))
print(sort_by_custom_order(pl.DataFrame(data), 'a', order))
print(sort_by_custom_order(pa.table(data), 'a', order))

outputs

      a  b
0   foo  1
2   foo  2
3   foo  6
6   foo  4
5  quox  3
1   bar  3
4   bar  3
shape: (7, 2)
┌──────┬─────┐
│ a    ┆ b   │
│ ---  ┆ --- │
│ str  ┆ i64 │
╞══════╪═════╡
│ foo  ┆ 1   │
│ foo  ┆ 2   │
│ foo  ┆ 6   │
│ foo  ┆ 4   │
│ quox ┆ 3   │
│ bar  ┆ 3   │
│ bar  ┆ 3   │
└──────┴─────┘
pyarrow.Table
a: string
b: int64
----
a: [["foo","foo","foo","foo","quox","bar","bar"]]
b: [[1,2,6,4,3,3,3]]
MarcoGorelli commented 3 weeks ago

🤔 nevermind, the join strategy seems to be slower than the concat strategy from the plotly pr 😳

MarcoGorelli commented 3 weeks ago

I make a branch in which I roughly implemented replace and replace_strict, and it looks like for both pandas and Polars, your concat-solution is actually the fastest 🙌

import polars as pl
import pandas as pd
import narwhals.stable.v1 as nw
from narwhals.utils import generate_unique_token
import pyarrow as pa
import numpy as np
rng = np.random.default_rng(1)

pd.set_option('future.no_silent_downcasting', True)

data = {'a': ['foo', 'bar', 'foo', 'foo', 'bar', 'quox', 'foo'], 'b': [1, 3,2,6,3,3,4]}
order = ['foo', 'quox', 'bar']

@nw.narwhalify(eager_only=True)
def func(df, key, order):
    order_key = generate_unique_token(8, df.columns)
    order_df = nw.from_dict({key: order, order_key: range(len(order))}, native_namespace=nw.get_native_namespace(df))
    return df.join(order_df, on=key, how='left').sort(order_key).drop(order_key)

@nw.narwhalify
def func2(df, key, order):
    return nw.concat(
            [df.filter(nw.col(key) == value) for value in order], how="vertical"
        )

@nw.narwhalify
def func3(df, key, order):
    token = generate_unique_token(8, df.columns)
    return df.with_columns(nw.col(key).replace_strict({x: i for i, x in enumerate(order)}, return_dtype=nw.UInt8).alias(token)).sort(token).drop(token)

print(func(pd.DataFrame(data), 'a', order))
print(func(pl.DataFrame(data), 'a', order))
print(func2(pd.DataFrame(data), 'a', order))
print(func2(pl.DataFrame(data), 'a', order))
print(func3(pd.DataFrame(data), 'a', order))
print(func3(pl.DataFrame(data), 'a', order))

bigdata = {'a': rng.integers(0, 3, size=100_000), 'b': rng.integers(0, 3, size=100_000), 'c': rng.integers(0, 3, size=100_000)}
order = [1, 0, 2]
In [26]: %timeit _ =  func(pd.DataFrame(bigdata), 'a', order)
8.33 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [27]: %timeit _ =  func2(pd.DataFrame(bigdata), 'a', order)
3.13 ms ± 334 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [28]: %timeit _ =  func3(pd.DataFrame(bigdata), 'a', order)
7.19 ms ± 1.07 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [29]: %timeit _ =  func(pl.DataFrame(bigdata), 'a', order)
12.8 ms ± 200 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [30]: %timeit _ =  func2(pl.DataFrame(bigdata), 'a', order)
1.79 ms ± 74.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [31]: %timeit _ =  func3(pl.DataFrame(bigdata), 'a', order)
12.7 ms ± 232 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

However, Polars doesn't rechunk when it concat, so it may end up being slower later on. We can do more comprehensive timing tests of a full plotting function with all 3 methodologies to see which one is better in the full context


interestingly enough, any of these approaches is faster than the original index-based solution in plotly:

In [18]: %timeit _ = pd.DataFrame(bigdata).set_index('a').loc[order].reset_index()
19.8 ms ± 2.14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

This is pleasantly surprising to me, I was expecting that we would be degrading performance here - nice!

FBruzzesi commented 3 weeks ago

Thanks Marco, that's definitly unexpected. Should we also consider to use group_by/partition_by instead of consecutive filtering? (I won't be really able to take a look before Wed/Thu)

FBruzzesi commented 3 days ago

replace_strict has been merged. Are we planning to add replace as well, or should we close this issue?