Skip to content

Async UDF rewrite input to Aggregate breaks UDAFs #18149

@Jefffrey

Description

@Jefffrey

Describe the bug

When using an async UDF as an input to certain UDAFs we can encounter an issue with schema.

To Reproduce

use arrow::compute::cast_with_options;
use datafusion::arrow::datatypes::DataType;
use datafusion::error::Result;
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion::logical_expr::{
    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion::prelude::*;
use std::any::Any;
use std::sync::Arc;
use tonic::async_trait;

#[derive(Debug, PartialEq, Eq, Hash)]
struct CustomUDF {
    signature: Signature,
}

impl CustomUDF {
    fn new() -> Self {
        CustomUDF {
            signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
        }
    }
}

impl ScalarUDFImpl for CustomUDF {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn name(&self) -> &str {
        "custom_udf"
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(DataType::Utf8)
    }

    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
        fn_impl(args)
    }
}

#[async_trait]
impl AsyncScalarUDFImpl for CustomUDF {
    async fn invoke_async_with_args(
        &self,
        args: ScalarFunctionArgs,
    ) -> Result<ColumnarValue> {
        fn_impl(args)
    }
}

fn fn_impl(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
    let arg = &args.args[0];
    let array = arg.to_array(args.number_rows)?;
    let array = cast_with_options(&array, &DataType::Utf8, &Default::default())?;
    Ok(ColumnarValue::Array(array))
}

#[tokio::main]
async fn main() -> Result<()> {
    let ctx = SessionContext::new_with_config(
        SessionConfig::new().set_bool("datafusion.explain.show_schema", true),
    );
    ctx.sql("create table data(x int) as values (-10), (2)")
        .await?
        .collect()
        .await?;

    // sync works fine
    ctx.register_udf(ScalarUDF::new_from_impl(CustomUDF::new()));
    ctx.sql("select approx_distinct(custom_udf(x)) from data")
        .await?
        .explain(false, false)?
        .show()
        .await?;
    ctx.sql("select approx_distinct(custom_udf(x)) from data")
        .await?
        .show()
        .await?;

    // issue with async
    ctx.register_udf(AsyncScalarUDF::new(Arc::new(CustomUDF::new())).into_scalar_udf());
    ctx.sql("select approx_distinct(custom_udf(x)) from data")
        .await?
        .explain(false, false)?
        .show()
        .await?;
    ctx.sql("select approx_distinct(custom_udf(x)) from data")
        .await?
        .show()
        .await?;

    Ok(())
}
  • Here we test a sync version and the async version; we expect the same result for both

Output:

datafusion (main)$ cargo run --example dataframe
    Finished `dev` profile [unoptimized + debuginfo] target(s) in 0.43s
     Running `/Users/jeffrey/.cargo_target_cache/debug/examples/dataframe`
+---------------+---------------------------------------------------------------------------------------------------------------------------------------+
| plan_type     | plan                                                                                                                                  |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------+
| logical_plan  | Aggregate: groupBy=[[]], aggr=[[approx_distinct(custom_udf(CAST(data.x AS Int64)))]]                                                  |
|               |   TableScan: data projection=[x]                                                                                                      |
| physical_plan | AggregateExec: mode=Single, gby=[], aggr=[approx_distinct(custom_udf(data.x))], schema=[approx_distinct(custom_udf(data.x)):UInt64;N] |
|               |   DataSourceExec: partitions=1, partition_sizes=[1], schema=[x:Int32;N]                                                               |
|               |                                                                                                                                       |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------+
+-------------------------------------+
| approx_distinct(custom_udf(data.x)) |
+-------------------------------------+
| 2                                   |
+-------------------------------------+
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------+
| plan_type     | plan                                                                                                                                                    |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------+
| logical_plan  | Aggregate: groupBy=[[]], aggr=[[approx_distinct(custom_udf(CAST(data.x AS Int64)))]]                                                                    |
|               |   TableScan: data projection=[x]                                                                                                                        |
| physical_plan | AggregateExec: mode=Final, gby=[], aggr=[approx_distinct(custom_udf(data.x))], schema=[approx_distinct(custom_udf(data.x)):UInt64;N]                    |
|               |   CoalescePartitionsExec, schema=[approx_distinct(custom_udf(data.x))[hll_registers]:Binary]                                                            |
|               |     AggregateExec: mode=Partial, gby=[], aggr=[approx_distinct(custom_udf(data.x))], schema=[approx_distinct(custom_udf(data.x))[hll_registers]:Binary] |
|               |       RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1, schema=[x:Int32;N, __async_fn_0:Utf8;N]                                    |
|               |         AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=custom_udf(CAST(x@0 AS Int64)))], schema=[x:Int32;N, __async_fn_0:Utf8;N]         |
|               |           CoalesceBatchesExec: target_batch_size=8192, schema=[x:Int32;N]                                                                               |
|               |             DataSourceExec: partitions=1, partition_sizes=[1], schema=[x:Int32;N]                                                                       |
|               |                                                                                                                                                         |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------+
Error: Internal("PhysicalExpr Column references column '__async_fn_0' at index 1 (zero-based) but input schema only has 1 columns: [\"x\"]")

The error comes from here:

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let accumulator: Box<dyn Accumulator> = match data_type {

  • line 364

Expected behavior

Should show same result as sync version instead of error.

Additional context

I want to emphasize that this is not an issue with UDAFs (in this case approx_distinct). This is because during physical planning, async UDFs rewrite the aggregate with new expressions:

for agg_func in &mut aggregates {
match self.try_plan_async_exprs(
num_input_columns,
PlannedExprResult::Expr(agg_func.expressions()),
physical_input_schema.as_ref(),
)? {
PlanAsyncExpr::Async(
async_map,
PlannedExprResult::Expr(physical_exprs),
) => {
async_exprs.extend(async_map.async_exprs);
if let Some(new_agg_func) = agg_func.with_new_expressions(
physical_exprs,
agg_func
.order_bys()
.iter()
.cloned()
.map(|x| x.expr)
.collect(),
) {
*agg_func = Arc::new(new_agg_func);
} else {
return internal_err!("Failed to plan async expression");
}
}
PlanAsyncExpr::Sync(PlannedExprResult::Expr(_)) => {
// Do nothing
}
_ => {
return internal_err!(
"Unexpected result from try_plan_async_exprs"
)
}
}
}
let input_exec = if !async_exprs.is_empty() {
Arc::new(AsyncFuncExec::try_new(async_exprs, input_exec)?)
} else {
input_exec
};

agg_func has new expressions written in, to replace the async function call. However, the schema it has does not change; therefore later in the pipeline (see approx_distinct above) it has a mismatch between schema and input expressions.

Overall I question this API:

/// Rewrites [`AggregateFunctionExpr`], with new expressions given. The argument should be consistent
/// with the return value of the [`AggregateFunctionExpr::all_expressions`] method.
/// Returns `Some(Arc<dyn AggregateExpr>)` if re-write is supported, otherwise returns `None`.
pub fn with_new_expressions(
&self,
args: Vec<Arc<dyn PhysicalExpr>>,
order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
) -> Option<AggregateFunctionExpr> {

It allows rewriting the expressions but it does not enforce or check that the new expressions conform with the schema it already has.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions