pola-rs / pyo3-polars

Plugins/extension for Polars
MIT License
232 stars 38 forks source link

plugins' names not respected? #59

Closed MarcoGorelli closed 7 months ago

MarcoGorelli commented 7 months ago

Here's a reproducible example:

Cargo.toml

[package]
name = "minimal_plugin"
version = "0.1.0"
edition = "2021"

[lib]
name = "minimal_plugin"
crate-type= ["cdylib"]

[dependencies]
pyo3 = { version = "0.20.0", features = ["extension-module"] }
pyo3-polars = { version = "0.10.0", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
polars = { version = "0.36.2", default-features = false }

[target.'cfg(target_os = "linux")'.dependencies]
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }

pyproject.toml

[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"

[project]
name = "minimal_plugin"
requires-python = ">=3.8"
classifiers = [
  "Programming Language :: Rust",
  "Programming Language :: Python :: Implementation :: CPython",
  "Programming Language :: Python :: Implementation :: PyPy",
]

minimal_plugin/__init__.py:

import polars as pl
from polars.utils.udfs import _get_shared_lib_location
from polars.type_aliases import IntoExpr

lib = _get_shared_lib_location(__file__)

@pl.api.register_expr_namespace("mp")
class MinimalExamples:
    def __init__(self, expr: pl.Expr):
        self._expr = expr

    def rename(self) -> pl.Expr:
        return self._expr.register_plugin(
            lib=lib,
            symbol="rename",
            is_elementwise=True,
        )

src/lib.rs:

mod expressions;

#[cfg(target_os = "linux")]
use jemallocator::Jemalloc;

#[global_allocator]
#[cfg(target_os = "linux")]
static ALLOC: Jemalloc = Jemalloc;

src/expressions.rs:

#![allow(clippy::unused_unit)]
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

fn same_output_type(input_fields: &[Field]) -> PolarsResult<Field> {
    let field = &input_fields[0];
    Ok(field.clone())
}

#[polars_expr(output_type_func=same_output_type)]
fn rename(inputs: &[Series]) -> PolarsResult<Series> {
    let mut s = inputs[0].clone();
    s.rename("foo");
    Ok(s)
}

run.py

import polars as pl
import minimal_plugin  # noqa: F401

df = pl.DataFrame({'a': [1,2,3], 'b': [4,5,6]})
print(df.with_columns(pl.col('a').mp.rename()))

This outputs:

shape: (3, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1   ┆ 4   │
│ 2   ┆ 5   │
│ 3   ┆ 6   │
└─────┴─────┘

Expected:

shape: (3, 3)
┌─────┬─────┬─────┐
│ a   ┆ b   ┆ foo │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ 1   ┆ 4   ┆ 1   │
│ 2   ┆ 5   ┆ 2   │
│ 3   ┆ 6   ┆ 3   │
└─────┴─────┴─────┘
ritchie46 commented 7 months ago

Functions are not really allowed to change names. We follow the left hand rule and of course alias.

MarcoGorelli commented 7 months ago

thanks - @ion-elgreco you may want to not name the output "neighbours" here

https://github.com/ion-elgreco/polars-hash/blob/618d53ee7905b080e2f57de0fe27a684e9bc386f/polars_hash/polars_hash/src/geohashers.rs#L102-L114

then, else the lazy schema will be incorrect

check this report from the discord:

When I run the following lazily:

import polars_hash as plh

(pl.from_dicts({'h1':'sp1xk2m6194y'})
 # .lazy()
 .with_columns(plh.col('h1').geohash.neighbors())
 .unnest('h1')
 .select(pl.concat_list('n', 'ne'))
 # .fetch()
)

I get: SchemaError: expected struct dtype, got: 'str'

ion-elgreco commented 7 months ago

@MarcoGorelli ah oops, I missed that report, only just saw the issue on my repo now.

I'll pass the original name of the series back then instead