apache / datafusion-python

Apache DataFusion Python Bindings
https://datafusion.apache.org/python
Apache License 2.0
321 stars 63 forks source link

feat: Add `flatten` array function #562

Closed mobley-trent closed 4 months ago

mobley-trent commented 5 months ago

Which issue does this PR close?

Refer to issue #463

Rationale for this change

What changes are included in this PR?

Are there any user-facing changes?

mobley-trent commented 5 months ago

Hello @andygrove do you mind giving me a hand with this PR ? I exposed Flatten in functions.rs but the python array function test for flatten is failing like so:

name = 'flatten'

    def __getattr__(name):
>       return getattr(functions, name)
E       AttributeError: module 'functions' has no attribute 'flatten'
ongchi commented 5 months ago

Hello @andygrove do you mind giving me a hand with this PR ? I exposed Flatten in functions.rs but the python array function test for flatten is failing like so:

name = 'flatten'

    def __getattr__(name):
>       return getattr(functions, name)
E       AttributeError: module 'functions' has no attribute 'flatten'

Hi @mobley-trent Did you try to rebuild the package before running pytest? Like this:

# build and install package
maturin develop

Also, don't forget to active the venv before this command.

mobley-trent commented 4 months ago

Hey @ongchi I tested the flatten function and its failing. Here is the code :

from datafusion import SessionContext, column
from datafusion import functions as f
import numpy as np
import pyarrow as pa

def py_flatten(arr):
    # Testing helper function
    result = []
    for elem in arr:
        if isinstance(elem, list):
            result.extend(py_flatten(elem))
        else:
            result.append(elem)
    return result

ctx = SessionContext()
data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]

batch = pa.RecordBatch.from_arrays(
    [np.array(data, dtype=object)], names=["arr"]
)
df = ctx.create_dataframe([[batch]])
col = column("arr")

stmt = f.flatten(col)
py_expr = lambda: [py_flatten(data)]

result = df.select(stmt).collect()[0].column(0).tolist()

print(f"flatten query: {result}")
print(f"py_expr: {py_expr()}")

Results:

>>> flatten query: [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]
>>> py_expr: [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]

I expected the flatten query to be identical to the py_expr. Is there something I overlooked ? Or is this an underlying bug ?

mobley-trent commented 4 months ago

Using a regular flatten query:

ctx = SessionContext()
ctx.sql("select flatten([[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]);")

Result:

DataFrame()
+----------------------------------------------------------------------------------------------------------------------------+
| flatten(make_array(make_array(Float64(1),Float64(2),Float64(3)),make_array(Float64(4),Float64(5)),make_array(Float64(6)))) |
+----------------------------------------------------------------------------------------------------------------------------+
| [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]                                                                                             |
+----------------------------------------------------------------------------------------------------------------------------+
ongchi commented 4 months ago
DataFrame()
+----------------------------------------------------------------------------------------------------------------------------+
| flatten(make_array(make_array(Float64(1),Float64(2),Float64(3)),make_array(Float64(4),Float64(5)),make_array(Float64(6)))) |
+----------------------------------------------------------------------------------------------------------------------------+
| [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]                                                                                             |
+----------------------------------------------------------------------------------------------------------------------------+

Hi @mobley-trent The df created in the test case maybe is a bit misleading, but it would be like this:

❯ SELECT column1 AS arr FROM (VALUES ([1.0, 2.0, 3.0, 3.0]), ([4.0, 5.0, 3.0]), ([6.0]));
+----------------------+
| arr                  |
+----------------------+
| [1.0, 2.0, 3.0, 3.0] |
| [4.0, 5.0, 3.0]      |
| [6.0]                |
+----------------------+

It's contains of multiple rows of one-dimensional array values. For the flatten function, the existing df should be modified or a new dataframe should be created for this test case.

mobley-trent commented 4 months ago

Fixed the merge conflicts