pola-rs / pyo3-polars

Plugins/extension for Polars
MIT License
232 stars 38 forks source link

Maintain input series names, in rust, when a plugin is called within .over() context #79

Closed azmyrajab closed 3 months ago

azmyrajab commented 4 months ago

Hi - thanks a lot for making it easy to write nice polars plugins!

In my plugin extension I produce a polars series struct output with field names based on the names of the passed &[inputs] (context is naming least squares coefficients and returning a struct series after doing some manipulation to inputs).

This seemingly works well when called in a normal context, but when the plugin extension expression is chained with .over() the input series appear to have empty names ("").

Here is a simplified dummy example:

fn output_struct_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
    for field in input_fields {
        println!("field_name={:?}", field.name());
    }
    Ok(Field::new(
        "coefficients",
        DataType::Struct(input_fields.to_vec()),
    ))
}

#[polars_expr(output_type_func=output_struct_dtype)]
fn inputs_to_struct(inputs: &[Series]) -> PolarsResult<Series> {
    for input in inputs {
        println!("series_name={:?}", input.name());
    }
    // Create DataFrame from the vector of Series
    let df = DataFrame::new(inputs.to_vec()).unwrap();
    // Convert DataFrame to a Series of struct dtype
    Ok(df.into_struct("coefficients").into_series().with_name("coefficients"))
}

Now on python side let's say we have:

def convert_series_to_struct(*inputs: pl.Expr) -> pl.Expr:
    return register_plugin_function(
        plugin_path=Path(__file__).parent,
        function_name="inputs_to_struct",
        args=inputs,
    )
  1. First we try select on the expression without an .over() context:
    
    from test_project import convert_series_to_struct

df = pl.DataFrame({"y": [1.16, -2.16, -1.57, 0.21, 0.22, 1.6, -2.11, -2.92, -0.86, 0.47], "x1": [0.72, -2.43, -0.63, 0.05, -0.07, 0.65, -0.02, -1.64, -0.92, -0.27], "x2": [0.24, 0.18, -0.95, 0.23, 0.44, 1.01, -2.08, -1.36, 0.01, 0.75], "group": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], "weights": [0.34, 0.97, 0.39, 0.8, 0.57, 0.41, 0.19, 0.87, 0.06, 0.34], })

df.select(convert_series_to_struct(pl.col("x1"), pl.col("x2"), pl.col("x3"))


<img width="159" alt="image" src="https://github.com/pola-rs/pyo3-polars/assets/9899436/ebdf05f7-d82c-477f-b394-c0285c564297">
  1. Next, let's try select on expression chained with .over() with "POLARS_VERBOSE" set:
    df.select(convert_series_to_struct(pl.col("x1"), pl.col("x2")).over("group"))
    image

Notice that the input series names are lost (but the input fields which is used for the output type annotation don't) -- which causes a duplicate error.

So far I've side-stepped this by naming the interrim dataframe, in rust, with some arbitrary column names ("1", "2", ..., "n") and then calling something like .struct.rename_fields([f.meta.output_name() for f in features]) but this is blocking using input_wildcard_expansion=True and is probably not clean.

Any idea if it is easy to propagate series names like you do for fields? Or any settings etc. that I may be missing?

Thanks a lot!

azmyrajab commented 4 months ago

Hello! would it be possible to ask for a rough ETA for something like this to be fixed? Unfortunately I'm not familiar enough with the codebase to know where to look - but happy to try a PR if someone can provide guidance

Unf this is causing inconsistent behaviour in my pl.Struct outputs in the meantime

cmdlineluser commented 3 months ago

There was a recent issue where this was fixed in Polars for qcut: https://github.com/pola-rs/polars/pull/15715 (input name "empty" in a groupby context)

It seems this is controlled via the FunctionOptions pass_name_to_apply

            if self.pass_name_to_apply {
                s.rename(&name);
            }

It seems that register_plugin_function also has this option which should fix this.

register_plugin_function(
    ...,
    pass_name_to_apply=True
)

(I'm not sure why it defaults to False - perhaps someone with more knowledge can answer that.)

azmyrajab commented 3 months ago

Hi Karl,

Thank you! Sorry, I totally missed that parameter to register plugin function. It seems it defaults to False for performance reasons (I guess implicitly assuming that most won't use pl.struct in their rust implementations).

In any case, I set it to "True" and my issue seems to have been resolved.

    pass_name_to_apply
        If set to `True`, the `Series` passed to the function in a group-by operation
        will ensure the name is set. This is an extra heap allocation per group.

I will close this issue as setting to True works for my usecase

cmdlineluser commented 3 months ago

for performance reasons

Ah! I didn't realize the function had its own docs page.

It seems like something that could be added to the User Guide or to @MarcoGorelli's Plugin Tutorial - as it could be considered a bit of a "gotcha".