-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
Describe the bug
When using an async UDF as an input to certain UDAFs we can encounter an issue with schema.
To Reproduce
use arrow::compute::cast_with_options;
use datafusion::arrow::datatypes::DataType;
use datafusion::error::Result;
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion::prelude::*;
use std::any::Any;
use std::sync::Arc;
use tonic::async_trait;
#[derive(Debug, PartialEq, Eq, Hash)]
struct CustomUDF {
signature: Signature,
}
impl CustomUDF {
fn new() -> Self {
CustomUDF {
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for CustomUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"custom_udf"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
fn_impl(args)
}
}
#[async_trait]
impl AsyncScalarUDFImpl for CustomUDF {
async fn invoke_async_with_args(
&self,
args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
fn_impl(args)
}
}
fn fn_impl(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let arg = &args.args[0];
let array = arg.to_array(args.number_rows)?;
let array = cast_with_options(&array, &DataType::Utf8, &Default::default())?;
Ok(ColumnarValue::Array(array))
}
#[tokio::main]
async fn main() -> Result<()> {
let ctx = SessionContext::new_with_config(
SessionConfig::new().set_bool("datafusion.explain.show_schema", true),
);
ctx.sql("create table data(x int) as values (-10), (2)")
.await?
.collect()
.await?;
// sync works fine
ctx.register_udf(ScalarUDF::new_from_impl(CustomUDF::new()));
ctx.sql("select approx_distinct(custom_udf(x)) from data")
.await?
.explain(false, false)?
.show()
.await?;
ctx.sql("select approx_distinct(custom_udf(x)) from data")
.await?
.show()
.await?;
// issue with async
ctx.register_udf(AsyncScalarUDF::new(Arc::new(CustomUDF::new())).into_scalar_udf());
ctx.sql("select approx_distinct(custom_udf(x)) from data")
.await?
.explain(false, false)?
.show()
.await?;
ctx.sql("select approx_distinct(custom_udf(x)) from data")
.await?
.show()
.await?;
Ok(())
}- Here we test a sync version and the async version; we expect the same result for both
Output:
datafusion (main)$ cargo run --example dataframe
Finished `dev` profile [unoptimized + debuginfo] target(s) in 0.43s
Running `/Users/jeffrey/.cargo_target_cache/debug/examples/dataframe`
+---------------+---------------------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------+
| logical_plan | Aggregate: groupBy=[[]], aggr=[[approx_distinct(custom_udf(CAST(data.x AS Int64)))]] |
| | TableScan: data projection=[x] |
| physical_plan | AggregateExec: mode=Single, gby=[], aggr=[approx_distinct(custom_udf(data.x))], schema=[approx_distinct(custom_udf(data.x)):UInt64;N] |
| | DataSourceExec: partitions=1, partition_sizes=[1], schema=[x:Int32;N] |
| | |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------+
+-------------------------------------+
| approx_distinct(custom_udf(data.x)) |
+-------------------------------------+
| 2 |
+-------------------------------------+
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------+
| plan_type | plan |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------+
| logical_plan | Aggregate: groupBy=[[]], aggr=[[approx_distinct(custom_udf(CAST(data.x AS Int64)))]] |
| | TableScan: data projection=[x] |
| physical_plan | AggregateExec: mode=Final, gby=[], aggr=[approx_distinct(custom_udf(data.x))], schema=[approx_distinct(custom_udf(data.x)):UInt64;N] |
| | CoalescePartitionsExec, schema=[approx_distinct(custom_udf(data.x))[hll_registers]:Binary] |
| | AggregateExec: mode=Partial, gby=[], aggr=[approx_distinct(custom_udf(data.x))], schema=[approx_distinct(custom_udf(data.x))[hll_registers]:Binary] |
| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1, schema=[x:Int32;N, __async_fn_0:Utf8;N] |
| | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=custom_udf(CAST(x@0 AS Int64)))], schema=[x:Int32;N, __async_fn_0:Utf8;N] |
| | CoalesceBatchesExec: target_batch_size=8192, schema=[x:Int32;N] |
| | DataSourceExec: partitions=1, partition_sizes=[1], schema=[x:Int32;N] |
| | |
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------+
Error: Internal("PhysicalExpr Column references column '__async_fn_0' at index 1 (zero-based) but input schema only has 1 columns: [\"x\"]")The error comes from here:
datafusion/datafusion/functions-aggregate/src/approx_distinct.rs
Lines 363 to 366 in e323357
| fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { | |
| let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; | |
| let accumulator: Box<dyn Accumulator> = match data_type { |
- line 364
Expected behavior
Should show same result as sync version instead of error.
Additional context
I want to emphasize that this is not an issue with UDAFs (in this case approx_distinct). This is because during physical planning, async UDFs rewrite the aggregate with new expressions:
datafusion/datafusion/core/src/physical_planner.rs
Lines 745 to 785 in e323357
| for agg_func in &mut aggregates { | |
| match self.try_plan_async_exprs( | |
| num_input_columns, | |
| PlannedExprResult::Expr(agg_func.expressions()), | |
| physical_input_schema.as_ref(), | |
| )? { | |
| PlanAsyncExpr::Async( | |
| async_map, | |
| PlannedExprResult::Expr(physical_exprs), | |
| ) => { | |
| async_exprs.extend(async_map.async_exprs); | |
| if let Some(new_agg_func) = agg_func.with_new_expressions( | |
| physical_exprs, | |
| agg_func | |
| .order_bys() | |
| .iter() | |
| .cloned() | |
| .map(|x| x.expr) | |
| .collect(), | |
| ) { | |
| *agg_func = Arc::new(new_agg_func); | |
| } else { | |
| return internal_err!("Failed to plan async expression"); | |
| } | |
| } | |
| PlanAsyncExpr::Sync(PlannedExprResult::Expr(_)) => { | |
| // Do nothing | |
| } | |
| _ => { | |
| return internal_err!( | |
| "Unexpected result from try_plan_async_exprs" | |
| ) | |
| } | |
| } | |
| } | |
| let input_exec = if !async_exprs.is_empty() { | |
| Arc::new(AsyncFuncExec::try_new(async_exprs, input_exec)?) | |
| } else { | |
| input_exec | |
| }; |
- Introduced by Support Async UDFs as input to aggregations #17619
agg_func has new expressions written in, to replace the async function call. However, the schema it has does not change; therefore later in the pipeline (see approx_distinct above) it has a mismatch between schema and input expressions.
Overall I question this API:
datafusion/datafusion/physical-expr/src/aggregate.rs
Lines 613 to 620 in e323357
| /// Rewrites [`AggregateFunctionExpr`], with new expressions given. The argument should be consistent | |
| /// with the return value of the [`AggregateFunctionExpr::all_expressions`] method. | |
| /// Returns `Some(Arc<dyn AggregateExpr>)` if re-write is supported, otherwise returns `None`. | |
| pub fn with_new_expressions( | |
| &self, | |
| args: Vec<Arc<dyn PhysicalExpr>>, | |
| order_by_exprs: Vec<Arc<dyn PhysicalExpr>>, | |
| ) -> Option<AggregateFunctionExpr> { |
It allows rewriting the expressions but it does not enforce or check that the new expressions conform with the schema it already has.