apache / datafusion

Apache DataFusion SQL Query Engine
https://datafusion.apache.org/
Apache License 2.0
6.28k stars 1.19k forks source link

Multi-field structs from builtin scalar fn incompatible with UDAF #7012

Open alexwilcoxson-rel opened 1 year ago

alexwilcoxson-rel commented 1 year ago

Describe the bug

We have a use case to provide multiple column values to a UDAF. UDAFs support one column input (unless I'm mistaken, I'm looking at this supporting one input data type. this has been resolved by #7096


To work around this we tried packing the columns into a struct column and passing that as input into the UDAF but we're seeing an error with both SQL API struct() builtin and the Expr API BuiltInScalarFunction::Struct

To Reproduce

run the tests below and see following output

Failures

failures:

---- tests::test_udaf_pack_many_col_struct_sql stdout ----
Error: type_coercion
caused by
Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.

Caused by:
    Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.

---- tests::test_udaf_pack_many_col_struct_expr stdout ----
Error: type_coercion
caused by
Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.

Caused by:
    Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.

Table

cargo test -- --nocapture shows the table

+----+-------+------------+-------------------+---------+
| a  | b     | c          | d                 | e       |
+----+-------+------------+-------------------+---------+
| 12 | true  | hi         | {i: 12, j: true}  | {i: 12} |
| 11 | false | datafusion | {i: 11, j: false} | {i: 11} |
+----+-------+------------+-------------------+---------+

Tests

use datafusion::{physical_plan::Accumulator, scalar::ScalarValue};

#[tokio::main]
async fn main() {}

#[derive(Default, Debug)]
struct SumUdaf {
    sum: u32,
}

impl Accumulator for SumUdaf {
    fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion::error::Result<()> {
        if values.is_empty() {
            return Ok(());
        }

        let arr = &values[0];
        (0..arr.len()).try_for_each(|index| {
            let sv = ScalarValue::try_from_array(&arr, index)?;
            if let ScalarValue::Struct(Some(values), _) = sv {
                for v in values {
                    if let ScalarValue::Int32(Some(v)) = v {
                        self.sum += v as u32;
                    }
                }
            } else if let ScalarValue::Int32(Some(v)) = sv {
                self.sum += v as u32;
            }
            Ok(())
        })
    }

    fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
        Ok(ScalarValue::from(self.sum))
    }

    fn size(&self) -> usize {
        std::mem::size_of_val(self)
    }

    fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
        Ok(vec![ScalarValue::from(self.sum)])
    }

    fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion::error::Result<()> {
        if states.is_empty() {
            return Ok(());
        }

        let arr = &states[0];

        (0..arr.len()).try_for_each(|index| {
            if let ScalarValue::UInt32(Some(v)) = ScalarValue::try_from_array(arr, index)? {
                self.sum += v;
            } else {
                unreachable!("")
            }
            Ok(())
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Arc;

    use arrow::{
        array::{
            downcast_array, ArrayBuilder, BooleanBuilder, Int32Builder, StringBuilder,
            StructBuilder, UInt32Array,
        },
        datatypes::{
            DataType as ArrowDataType, Field as ArrowField, FieldRef as ArrowFieldRef,
            Fields as ArrowFields, Schema as ArrowSchema,
        },
        record_batch::RecordBatch,
    };
    use datafusion::{
        logical_expr::{expr::ScalarFunction, AggregateUDF, BuiltinScalarFunction},
        prelude::*,
    };

    fn test_data() -> anyhow::Result<RecordBatch> {
        let d_fields: Vec<ArrowFieldRef> = vec![
            Arc::new(ArrowField::new("i", ArrowDataType::Int32, false)),
            Arc::new(ArrowField::new("j", ArrowDataType::Boolean, false)),
        ];
        let e_fields: Vec<ArrowFieldRef> =
            vec![Arc::new(ArrowField::new("i", ArrowDataType::Int32, false))];
        let schema = ArrowSchema::new(vec![
            ArrowField::new("a", ArrowDataType::Int32, false),
            ArrowField::new("b", ArrowDataType::Boolean, false),
            ArrowField::new("c", ArrowDataType::Utf8, false),
            ArrowField::new_struct("d", &*d_fields, false),
            ArrowField::new_struct("e", &*e_fields, false),
        ]);

        let mut a_builder = Int32Builder::new();
        let mut b_builder = BooleanBuilder::new();
        let mut c_builder = StringBuilder::new();

        a_builder.append_values(&[12, 11], &[true, true]);
        b_builder.append_values(&[true, false], &[true, true])?;
        c_builder.append_value("hi");
        c_builder.append_value("datafusion");

        let struct_builders: Vec<Box<dyn ArrayBuilder>> = vec![
            Box::new(Int32Builder::new()),
            Box::new(BooleanBuilder::new()),
        ];
        let mut d_builder = StructBuilder::new(d_fields, struct_builders);

        d_builder.append(true);
        d_builder.append(true);

        let i_builder = d_builder
            .field_builder::<Int32Builder>(0)
            .ok_or_else(|| anyhow::anyhow!("bad builder"))?;
        i_builder.append_value(12);
        i_builder.append_value(11);
        let j_builder = d_builder
            .field_builder::<BooleanBuilder>(1)
            .ok_or_else(|| anyhow::anyhow!("bad builder"))?;
        j_builder.append_value(true);
        j_builder.append_value(false);

        let mut e_builder = StructBuilder::new(e_fields, vec![Box::new(Int32Builder::new())]);
        e_builder.append(true);
        e_builder.append(true);
        let i_builder = e_builder
            .field_builder::<Int32Builder>(0)
            .ok_or_else(|| anyhow::anyhow!("bad builder"))?;
        i_builder.append_value(12);
        i_builder.append_value(11);

        let mut builders: Vec<Box<dyn ArrayBuilder>> = vec![
            Box::new(a_builder),
            Box::new(b_builder),
            Box::new(c_builder),
            Box::new(d_builder),
            Box::new(e_builder),
        ];
        let arrays = builders.iter_mut().map(|b| b.finish()).collect::<Vec<_>>();

        let batch = RecordBatch::try_new(Arc::new(schema), arrays)?;
        Ok(batch)
    }

    async fn sql(
        sql: impl AsRef<str>,
        udaf_input_type: ArrowDataType,
    ) -> anyhow::Result<DataFrame> {
        let ctx = SessionContext::default();
        let batch = test_data()?;
        ctx.register_batch("batch", batch)?;
        ctx.register_udaf(udaf(udaf_input_type));
        let df = ctx.sql(sql.as_ref()).await?;
        Ok(df)
    }

    fn dataframe() -> anyhow::Result<DataFrame> {
        let ctx = SessionContext::default();
        let batch = test_data()?;
        let df = ctx.read_batch(batch)?;
        Ok(df)
    }

    fn udaf(input_type: ArrowDataType) -> AggregateUDF {
        create_udaf(
            "my_sum",
            input_type,
            Arc::new(ArrowDataType::UInt32),
            datafusion::logical_expr::Volatility::Immutable,
            Arc::new(|_| Ok(Box::new(SumUdaf::default()))),
            Arc::new(vec![ArrowDataType::UInt32]),
        )
    }

    fn pack_cols(cols: Vec<impl Into<Column>>) -> Expr {
        Expr::ScalarFunction(ScalarFunction {
            fun: BuiltinScalarFunction::Struct,
            args: cols.into_iter().map(|c| col(c)).collect::<Vec<_>>(),
        })
    }

    async fn assert(df: DataFrame, expected: u32) -> anyhow::Result<()> {
        let result = df.collect().await?;
        let result_arr = result[0].column(0);
        let result_arr = downcast_array::<UInt32Array>(result_arr);
        let actual = result_arr.value(0);
        assert_eq!(expected, actual);
        Ok(())
    }

    #[tokio::test]
    async fn test_show() -> anyhow::Result<()> {
        let df = dataframe()?;
        df.show().await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_existing_struct_many_col_sql() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("i", ArrowDataType::Int32, false),
            ArrowField::new("j", ArrowDataType::Boolean, false),
        ]));
        let df = sql("SELECT my_sum(d) FROM batch", udaf_input_type).await?;
        assert(df, 23).await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_existing_struct_many_col_expr() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("i", ArrowDataType::Int32, false),
            ArrowField::new("j", ArrowDataType::Boolean, false),
        ]));
        let df = dataframe()?;
        let udaf = udaf(udaf_input_type);

        let df = df.aggregate(vec![], vec![(udaf.call(vec![col("d")]))])?;
        assert(df, 23).await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_existing_struct_one_col_sql() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("i", ArrowDataType::Int32, false),
        ]));
        let df = sql("SELECT my_sum(e) FROM batch", udaf_input_type).await?;
        assert(df, 23).await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_existing_struct_one_col_expr() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("i", ArrowDataType::Int32, false),
        ]));
        let df = dataframe()?;
        let udaf = udaf(udaf_input_type);

        let df = df.aggregate(vec![], vec![(udaf.call(vec![col("e")]))])?;
        assert(df, 23).await?;
        Ok(())
    }
    #[tokio::test]
    async fn test_udaf_pack_one_col_struct_sql() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("c0", ArrowDataType::Int32, true),
            // ArrowField::new("c1", ArrowDataType::Boolean, true),
            //ArrowField::new("c2", ArrowDataType::Utf8, true),
        ]));
        let df = sql("SELECT my_sum(struct(a)) FROM batch", udaf_input_type).await?;
        assert(df, 23).await?;
        Ok(())
    }

    // FAILS - Treats all struct fields as Utf8
    #[tokio::test]
    async fn test_udaf_pack_many_col_struct_sql() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("c0", ArrowDataType::Int32, true),
            ArrowField::new("c1", ArrowDataType::Boolean, true),
            ArrowField::new("c2", ArrowDataType::Utf8, true),
        ]));
        let df = sql("SELECT my_sum(struct(a, b, c)) FROM batch", udaf_input_type).await?;
        assert(df, 23).await?;
        Ok(())
    }

    #[tokio::test]
    async fn test_udaf_pack_one_col_struct_expr() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([ArrowField::new(
            "c0",
            ArrowDataType::Int32,
            true,
        )]));
        let df = dataframe()?;
        let udaf = udaf(udaf_input_type);
        let packed_expr = pack_cols(vec!["a"]);

        let df = df.aggregate(vec![], vec![udaf.call(vec![packed_expr])])?;
        assert(df, 23).await?;

        Ok(())
    }

    // FAILS - Treats all struct fields as Utf8
    #[tokio::test]
    async fn test_udaf_pack_many_col_struct_expr() -> anyhow::Result<()> {
        let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
            ArrowField::new("c0", ArrowDataType::Int32, true),
            ArrowField::new("c1", ArrowDataType::Boolean, true),
        ]));
        let df = dataframe()?;
        let udaf = udaf(udaf_input_type);
        let packed_expr = pack_cols(vec!["a", "b"]);
        let df = df.aggregate(vec![], vec![udaf.call(vec![packed_expr])])?;

        assert(df, 23).await?;

        Ok(())
    }
}

Expected behavior

We are able to create a struct with multiple fields using SQL API struct() builtin or Expr API's BuiltInScalarFunction::Struct and provide that as input to UDAF.

Additional context

The UDAF here is very simple just for example.

Is there a limitation with UDAF or could we open an enhancement request to support multiple input columns?

2010YOUY01 commented 1 year ago

I think it makes sense to let UDAF support multiple column inputs, there are already built-in aggregate functions like correlation/covariance that support multi-column input.

alamb commented 1 year ago

I believe this was fixed in https://github.com/apache/arrow-datafusion/pull/7096

Can you confirm @alexwilcoxson-rel ?

alexwilcoxson-rel commented 1 year ago

@alamb this fixes our initial use case of just needing to provide multiple inputs. There still looks to be an issue with the latest code on main where you can't create a struct and pass it as a single argument to a UDAF, e.g. SELECT my_udaf(struct(col_A, col_B)) This was just our workaround though and is more of an edge case IMO.

So perhaps just keep it with the other "improve struct" issues.

alamb commented 1 year ago

This was just our workaround though and is more of an edge case IMO.

So perhaps just keep it with the other "improve struct" issues.

Yes that makes sense to me -- it is probably worth making a new ticket for just that usecase (especially since this ticket has such a nice reproducer) ❤️