pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
28.11k stars 1.73k forks source link

NumPy ufunc implementation on Expr causes potential performance pitfall #15873

Open stinodego opened 2 months ago

stinodego commented 2 months ago

The problem

Consider the following code example:

import numpy as np
import polars as pl

# Good - native Polars expression
e1 = pl.col("a") + np.float64(1.0)
print(e1)  # [(col("a")) + (dyn float: 1.0)]

# Bad - Python UDF
e2 = np.float64(1.0) + pl.col("a")
print(e2)  # col("a").python_udf()

The first expression works as expected: it enters into Expr.__add__, converts the numpy float to a literal expression, and combines the two into a native Polars expression.

The second expression is problematic: it enters into the __add__ method of the numpy object, then calls the __array_ufunc__ implementation of the Expr object. This results in an elementwise map_batches call, which is much slower!

Potential solutions

If we remove the ufunc implementation on Expr completely, the problem is solved as the code will now go through Expr.__radd__ rather than the ufunc code.

If we want to keep the ufunc code, we must detect inputs that can be converted to native Polars expressions. There are quite a few valid ufuncs though: https://numpy.org/doc/stable/reference/ufuncs.html#available-ufuncs

stinodego commented 2 months ago

@alexander-beedie I'm not sure what the best way to go is here - do you have an idea on how to solve this issue?

alexander-beedie commented 2 months ago

@alexander-beedie I'm not sure what the best way to go is here - do you have an idea on how to solve this issue?

Hmm... Best way to start is probably to do a survey over the ufuncs and see how many are missing native polars equivalents? I suspect the number isn't too unmanageable, in which case we might want to add some of the more reasonable missing ones as native expressions. That would then give us a decent basis for disabling expression-based ufuncs (I don't really have a sense of how widely used they are) 🤔 We could add identification of the low-hanging fruit --such as add-- in the meantime, if we're seeing this "in the wild".

deanm0000 commented 2 months ago

I don't think it'd be an elementwise map_batches. It's the equivalent of doing

df.select(pl.col('a').map_batches(lambda x: np.add(np.float64(1.0), x)))

It dispatches to numpy's addition instead of polars but it's still vectorized. I don't think there's a big performance hit there

Consider

df=pl.DataFrame({'a':np.random.normal(0,1,100_000_000)})
%%timeit
df.select(pl.col("a") + np.float64(1.0))
### 151 ms ± 27.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
df.select(np.float64(1.0) + pl.col("a"))
### 152 ms ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

and just for good measure

assert (
    df
    .with_columns(
        b=pl.col("a") + np.float64(1.0),
        c=np.float64(1.0) + pl.col("a")
    )
    .filter(pl.col('b')!=pl.col('c'))
    .shape[0]
)==0
deanm0000 commented 2 months ago

For this use case I think you only want to take back the symbolic operators rather than worrying about all available np ufuncs. I think that can be done with the following additions to __array_ufunc__

func_mapping={
    np.add: "add",
    np.subtract: "sub",
    np.multiply: "mul",
    np.divide: "truediv",
    np.bitwise_or: "or_",
    np.less: "lt",
    ##etc
}
if ufunc in func_mapping:
    return getattr(pl.lit(inputs[0]), func_mapping[func])(inputs[1])

There ought to be more checks for robustness but the idea is that if __array_ufunc__ sees a particular ufunc then instead of sending it through to map_batches it can instead return the polars operator and bypass numpy. I don't think there's a way to know that someone typed np.float64(1.1) + pl.col('a') as opposed to np.add(np.float64(1.1), pl.col('a')) but if there is, I don't think we want to preempt the latter but it seems unlikely enough as to not worry about it.

@stinodego I think the reason you said it was elementwise was because of this flag

https://github.com/pola-rs/polars/blob/c95e41f7b8987daf3df67f08329be41e0dea2f79/py-polars/polars/expr/expr.py#L333

but that flag is just so the optimizer knows it's safe to run in streaming. It doesn't signal to run the operation one element at a time. Am I guessing right?

Anyway, let me know what you guys think.