Eventual-Inc / Daft

Distributed DataFrame for Python designed for the cloud, powered by Rust
https://getdaft.io
Apache License 2.0
1.82k stars 113 forks source link

Create a map_get expression to enable users to efficiently work with Map DataTypes #2240

Closed daveqs closed 1 month ago

daveqs commented 1 month ago

Is your feature request related to a problem? Please describe. There is no Daft-native way to obtain values from map keys, and therefore it is not possible to work with Map DataType columns except with UDFs.

Describe the solution you'd like A Daft-native expression that is functionally equivalent (or a similar as is practical) to PyArrow's pyarrow.compute.map_lookup() or to Spark's pyspark.sql.functions.map_values()

Describe alternatives you've considered The following Daft UDF provides me with the functionality I need to work with Map Datatype columns, although it is limited in that it requires hardcoding the target DataType (in this example, DataType.int64() ). It would be preferable to allow the target DataType to be inferred from the Map DataType column. The UDF presumes that in the Map DataType column no key appears more than once, which I believe is a reasonable assumption for the use cases that I can envision.

@daft.udf(return_dtype=daft.DataType.int64())
def map_get(map_col, key):
    # Convert Daft Series to Arrow MapArray
    map_array = map_col.to_arrow()

    # Extract the scalar key from from Daft expression which will be used to look up the value in the MapArray
    key_scalar = key.to_pylist()[0]

    # Use the Arrow compute function to look up the values associated with the keys in the MapArray
    values_array =  pa.compute.map_lookup(map_array, pa.scalar(key_scalar, type=pa.large_string()), 'first')

    # Convert the resulting Arrow Array to a Daft Series
    return daft.Series.from_arrow(values_array)

Additional context The expression should return None values in rows where the user-supplied key is not contained in the map, which is the behavior exhibited by the UDF. The following example can be used to demonstrate the UDF:

import pyarrow as pa
import daft

# Example data
keys = pa.array(['x', 'y', 'x', 'z'])
values = pa.array([1, 2, 4, 5])
offsets = pa.array([0, 2, 2, 4])

# Create the MapArray
map_array = pa.MapArray.from_arrays(offsets, keys, values)
map_col = pa.chunked_array([map_array])

# Make a pyarrow table from map_col and a column of timestamps, then convert to a daft dataframe
arrow_tbl = pa.table({'timestamp_col': pa.chunked_array([pa.array([1, 2, 3], type=pa.timestamp('s'))]), 'map_col': map_col})
daft_df = daft.from_arrow(arrow_tbl)

# make a daft udf to apply to the map column which retrives the value associated with a key
@daft.udf(return_dtype=daft.DataType.int64())
def map_get(map_col, key):
    # Convert Daft Series to Arrow MapArray
    map_array = map_col.to_arrow()

    # Extract the scalar key from from Daft expression which will be used to look up the value in the MapArray
    key_scalar = key.to_pylist()[0]

    # Use the Arrow compute function to look up the values associated with the keys in the MapArray
    values_array =  pa.compute.map_lookup(map_array, pa.scalar(key_scalar, type=pa.large_string()), 'first')

    # Convert the resulting Arrow Array to a Daft Series
    return daft.Series.from_arrow(values_array)

# Apply the UDF to the map column to extract the value associated with the key 'x'
daft_df = daft_df.with_column("x", map_get(daft_df["map_col"], key=daft.lit("x")))
daft_df.show()
jaychia commented 1 month ago

This is a great feature request, adding it to our list of to-dos

daveqs commented 1 month ago

@jaychia great, thank you!