pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
30.57k stars 1.98k forks source link

Inconsistent type casting of `datetime` in functions on `List` column #19917

Open nik-sm opened 6 days ago

nik-sm commented 6 days ago

Checks

Reproducible example

from datetime import datetime
import polars as pl

pl.Config(fmt_str_lengths=120)

## Make toy data
datetimes = pl.datetime_range(
    start=datetime(1970, 1, 1),
    end=datetime(1970, 1, 6),
    interval="1d",
    eager=True,
)
df = pl.DataFrame({"experiment_id": [1, 1, 1, 2, 2, 2], "timestamps": datetimes})
print(df)

## Collect to a column with list of timestamps
df = df.group_by("experiment_id").agg(pl.col("timestamps"))
print(df)

# Trying functions from:    `dir(pl.series.list.ListNameSpace)`,

## Functions that do not error, but give wrong results
# (datetime is coerced to f64)
print("mean", df.select(pl.col("timestamps").list.mean()))
# Gives:
# ┌────────────┐
# │ timestamps │
# │ ---        │
# │ f64        │
# ╞════════════╡
# │ 8.6400e10  │    # 1 day of microseconds beyond unix epoch
# │ 3.4560e11  │    # 4 days of microseconds beyond unix epoch
# └────────────┘
print("median", df.select(pl.col("timestamps").list.median()))
# both std and var end up with just `null` of type f64.
# Not sure what they should be. the values vary, so std is non-null; but
# it should be some fraction of a duration that is tricky to specify.
print("std", df.select(pl.col("timestamps").list.std()))
print("var", df.select(pl.col("timestamps").list.var()))

## Functions that seem to work correctly
print("min", df.select(pl.col("timestamps").list.min()))
print("max", df.select(pl.col("timestamps").list.max()))
print("arg_max", df.select(pl.col("timestamps").list.arg_max()))
print("arg_min", df.select(pl.col("timestamps").list.arg_min()))
print("len", df.select(pl.col("timestamps").list.len()))
print("first", df.select(pl.col("timestamps").list.first()))
print("last", df.select(pl.col("timestamps").list.last()))
print("head", df.select(pl.col("timestamps").list.head(1)))
print("tail", df.select(pl.col("timestamps").list.tail(1)))
print("sample", df.select(pl.col("timestamps").list.sample(1)))
print("gather", df.select(pl.col("timestamps").list.gather([1])))
print("gather_every", df.select(pl.col("timestamps").list.gather_every(2)))
print("get", df.select(pl.col("timestamps").list.get(1)))
print("contains true", df.select(pl.col("timestamps").list.contains(datetime(1970, 1, 1))))
print("contains false", df.select(pl.col("timestamps").list.contains(datetime(2000, 1, 1))))
print("count_matches 0", df.select(pl.col("timestamps").list.count_matches("foo")))
print("count_matches 1", df.select(pl.col("timestamps").list.count_matches(datetime(1970, 1, 1))))

## Functions that do not error, but I did not test enough to verify correctness
print("sort", df.select(pl.col("timestamps").list.sort()))
print("n_unique", df.select(pl.col("timestamps").list.n_unique()))
print("reverse", df.select(pl.col("timestamps").list.reverse()))

## Functions that error due to types issues
print("contains str errors on is_in", df.with_columns(pl.col("timestamps").list.contains("foo")))
# `count_matches` was happy to accept a string, but `contains` is not. unsure whether this is intended.
#
# polars.exceptions.InvalidOperationError: is_in operation not supported for dtypes `str` and `list[datetime[μs]]`

print("sum", df.with_columns(pl.col("timestamps").list.sum()))
# It seems reasonable that `sum` is not defined for datetime
#
# polars.exceptions.InvalidOperationError: `sum` operation not supported for dtype `datetime[μs]`

Apologies - I didn't have time to try the following functions that also appear to be in the same namespace:

concat,diff,drop_nulls,eval,explode,join,set_difference,set_intersection,set_symmetric_difference,set_union,shift,slice,to_array,to_struct,unique

Log output

No response

Issue description

Some list functions preserve datetime type, some cast to f64, and some error.

In my use case, I have a dataframe with two list columns: one holds sensor readings, the other holds timestamps. I want the average value of the sensor column, and the average timestamp.

Ideally, I would like to just use pl.col("timestamps").list.mean() and get a datetime object with the same precision, corresponding to the average timestamp.

Currently, the .mean() function in this context runs without error, but gives a f64 result that I think represents number of units since unix epoch. So far, most polars functions I have tried which handle datetime preserve the "datetime"-ness correctly, even when exporting to native python objects/numpy/etc.

.first() works fine, implying that indexing over a list of datetime is fine. .min() works fine, implying that numeric comparison of datetime is supported.

.sum() errors. Not sure whether it should error or not. Since the average of a list of timestamps makes intuitive sense, and sum is simply average divided by length, then perhaps sum should be defined.

I added a bunch of simple examples above in the hope that it might be useful for exploring this behavior or adding unit tests. Please let me know if more info is useful.

Expected behavior

df_mean = df.select(pl.col("timestamps").list.mean())
breakpoint()
df_mean_expected = pl.from_repr(
    """
    ┌─────────────────────┐
    │ timestamps          │
    │ ---                 │
    │ datetime[μs]        │
    ╞═════════════════════╡
    │ 1970-01-02 00:00:00 │
    │ 1970-01-05 00:00:00 │
    └─────────────────────┘
    """
)

df_median = df.with_columns(pl.col("timestamps").list.median())
df_median_expected = pl.from_repr(
    """
    ┌─────────────────────┐
    │ timestamps          │
    │ ---                 │
    │ datetime[μs]        │
    ╞═════════════════════╡
    │ 1970-01-02 00:00:00 │
    │ 1970-01-05 00:00:00 │
    └─────────────────────┘
    """
)

Installed versions

``` --------Version info--------- Polars: 1.14.0 Index type: UInt32 Platform: macOS-15.0.1-arm64-arm-64bit Python: 3.12.6 (main, Sep 27 2024, 17:59:45) [Clang 15.0.0 (clang-1500.3.9.4)] LTS CPU: False ----Optional dependencies---- adbc_driver_manager altair boto3 cloudpickle connectorx deltalake fastexcel fsspec gevent google.auth great_tables matplotlib nest_asyncio numpy openpyxl pandas pyarrow pydantic pyiceberg sqlalchemy torch xlsx2csv xlsxwriter ```