pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
27.23k stars 1.67k forks source link

scan_pyarrow_dataset not filtering on partitions #16300

Open mtofano opened 1 month ago

mtofano commented 1 month ago

Checks

Reproducible example

# here is my dataset definition
dataset = ds.dataset(
    source=dataset_path,
    filesystem=s3_fs,  # instance of S3FileSystem
    format="arrow",
    partitioning=ds.partitioning(
        schema=pa.schema(
            [
                pa.field("underlier_id", pa.int64()),
                pa.field("trade_date", pa.date32()),
            ]
        ),
    ),
)

# pyarrow works in < 1s
data = dataset.filter(
    (pc.field("underlier_id") == 5135108)
    & (pc.field("trade_date") == trade_date)
).to_table()

# but polars scan_pyarrow_dataset never completes
data = pl.scan_pyarrow_dataset(dataset).filter(
    pl.col("underlier_id") == 5135108,
    pl.col("trade_date") == trade_date
).collect()

Log output

No response

Issue description

I have a large dataset on S3 consisting of a large amount of .arrow files. We are using directory partitioning by an integer id and a date, which looks like this:

/5135108
    /2016-01-01
        /part-0.arrow
    ...
    /2024-05-17
        /part-0.arrow
/5130371
    /2016-01-01
    ...
    /2024-05-17

We are using pyarrow to write the entirety of this dataset. On the read side polars is much preferred because of it's expressiveness. I want to use the scan_pyarrow_dataset function in order to read and perform filtering with predicate pushdown. However, it seems that polars is not filtering out the partitions defined in the polars query. When I run using pyarrow it takes less than a second to read in the data of a single file, but when I use polars scan_pyarrow_dataset, this never completes and hangs forever. I am assuming because this is not actually filtering out the partitions and it is trying to read in everything.

Expected behavior

I would expect this to filter out the irrelevant partitions from the reads, and push any predicates down to the scan level just as pyarrow does, but that does not seem to be the case.

Installed versions

``` --------Version info--------- Polars: 0.20.26 Index type: UInt32 Platform: Linux-4.18.0-513.9.1.el8_9.x86_64-x86_64-with-glibc2.28 Python: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:50:21) [GCC 12.3.0] ----Optional dependencies---- adbc_driver_manager: cloudpickle: 3.0.0 connectorx: deltalake: fastexcel: fsspec: 2024.3.1 gevent: hvplot: matplotlib: nest_asyncio: 1.6.0 numpy: 1.26.4 openpyxl: pandas: 1.5.3 pyarrow: 15.0.2 pydantic: 2.7.1 pyiceberg: pyxlsb: sqlalchemy: 1.4.49 torch: xlsx2csv: xlsxwriter: ```
ion-elgreco commented 1 month ago

You can check the plan with df.explain. You should see the filter being pushed down into the scan as a pyarrow compute expression.

If it's correctly showing pushed down pyarrow compute expressions, then it rather points to an issue in pyarrow, where filters are not converted to partition filters

ritchie46 commented 4 weeks ago

Yes, we just pass the predicates to pyarrow. So I think this should be taken upstream.

mtofano commented 4 weeks ago

image image

I don't think the issue is with pyarrow, as when running to_table and passing in the compute expressions works as expected outside of polars land.

I suspect the issue is the predicates are not being passed in to to_table as we would expect them to when using scan_pyarrow_dataset. See the screenshots above of my debug session. In the _scan_pyarrow_dataset_impl function I can see there are no predicates being passed in as an argument, and thus no filter is being provided to ds.to_table. The predicates seem to be getting lost in translation somewhere.

The query plan looks correct to me however from the output of explain():

data.explain()
'FILTER [([(col("underlier_id")) == (5135108)]) & ([(col("trade_date")) == (2016-01-04)])] FROM\n\n  PYTHON SCAN \n  PROJECT */7 COLUMNS'
ion-elgreco commented 4 weeks ago

So filtering on non-date/datetime columns works, see below: image

Run this code as-is

import polars as pl

df = pl.DataFrame({
    "foo": [1,2,3],
    "bar": [1,2,3],
    "baz": [1,2,3],
}, schema={"foo": pl.Int64, "bar": pl.Date, "baz": pl.Int64,})

df.write_delta('test_table_scan', 
               mode='overwrite', 
               delta_write_options={"partition_by": ["foo", "bar"], "engine":"rust"}, overwrite_schema=True)

print(
    pl.scan_delta('test_table_scan').filter(pl.col('foo')==2).collect()
)

However, a predicate that contains a date or datetime breaks the predicate pushdown into pyarrow, similar issue: https://github.com/pola-rs/polars/issues/16248

image

import polars as pl

df = pl.DataFrame({
    "foo": [1,2,3],
    "bar": [1,2,2],
    "baz": [1,2,3],
}, schema={"foo": pl.Int64, "bar": pl.Date, "baz": pl.Int64,})

df.write_delta('test_table_scan', 
               mode='overwrite', 
               delta_write_options={"partition_by": ["foo", "bar"], "engine":"rust"}, overwrite_schema=True)

print(
    pl.scan_delta('test_table_scan').filter(pl.col('foo')==2, pl.col('bar')== pl.date(1970,1,3)).collect()
)
ion-elgreco commented 4 weeks ago

Seems like the pushdown is not working when it includes date/datetimes @ritchie46

print(pl.scan_delta('test_table_scan').filter(pl.col('foo')==2, pl.col('bar')== pl.date(1970,1,3)).explain(optimized=True))

FILTER [([(col("foo")) == (2)]) & ([(col("bar")) == (dyn int: 1970.dt.datetime([dyn int: 1, dyn int: 3, dyn int: 0, dyn int: 0, dyn int: 0, dyn int: 0, String(raise)]).strict_cast(Date))])] FROM

  PYTHON SCAN 
  PROJECT */3 COLUMNS

This issue is related: https://github.com/pola-rs/polars/issues/11152

mtofano commented 3 weeks ago

Thank you very much for the replies!

Out of curiosity what exactly is it about dates that break the predicate pushdown? This would be a very nice feature to have as it makes scan_pyarrow_dataset unusable on date partitioned datasets, and it is a very powerful feature we'd love to take advantage of :)