Skip to content

Commit dfd4442

Browse files
authored
Make FirstValue an UDAF, Change AggregateUDFImpl::accumulator signature, support ORDER BY for UDAFs (#9874)
* first draft Signed-off-by: jayzhan211 <[email protected]> * clippy fix Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * use one vector for ordering req Signed-off-by: jayzhan211 <[email protected]> * add sort exprs to accumulator Signed-off-by: jayzhan211 <[email protected]> * clippy Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * fix doc test Signed-off-by: jayzhan211 <[email protected]> * change to ref Signed-off-by: jayzhan211 <[email protected]> * fix typo Signed-off-by: jayzhan211 <[email protected]> * fix doc Signed-off-by: jayzhan211 <[email protected]> * fmt Signed-off-by: jayzhan211 <[email protected]> * move schema and logical ordering exprs Signed-off-by: jayzhan211 <[email protected]> * remove redudant info Signed-off-by: jayzhan211 <[email protected]> * rename Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * add ignore nulls Signed-off-by: jayzhan211 <[email protected]> * fix conflict Signed-off-by: jayzhan211 <[email protected]> * backup Signed-off-by: jayzhan211 <[email protected]> * complete return_type Signed-off-by: jayzhan211 <[email protected]> * complete replace Signed-off-by: jayzhan211 <[email protected]> * split to first value udf Signed-off-by: jayzhan211 <[email protected]> * replace accumulator Signed-off-by: jayzhan211 <[email protected]> * fmt Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * small fix Signed-off-by: jayzhan211 <[email protected]> * remove ordering types Signed-off-by: jayzhan211 <[email protected]> * make state fields more flexible Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * replace done Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * rm comments Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * rm test1 Signed-off-by: jayzhan211 <[email protected]> * fix state fields Signed-off-by: jayzhan211 <[email protected]> * fmt Signed-off-by: jayzhan211 <[email protected]> * args struct for accumulator Signed-off-by: jayzhan211 <[email protected]> * simplify Signed-off-by: jayzhan211 <[email protected]> * add sig Signed-off-by: jayzhan211 <[email protected]> * add comments Signed-off-by: jayzhan211 <[email protected]> * fmt Signed-off-by: jayzhan211 <[email protected]> * fix docs Signed-off-by: jayzhan211 <[email protected]> * use exprs utils Signed-off-by: jayzhan211 <[email protected]> * rm state type Signed-off-by: jayzhan211 <[email protected]> * add comment Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent d2ba901 commit dfd4442

File tree

24 files changed

+450
-134
lines changed

24 files changed

+450
-134
lines changed

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow_schema::{Field, Schema};
1819
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
1920
use datafusion_physical_expr::NullState;
2021
use std::{any::Any, sync::Arc};
@@ -30,7 +31,8 @@ use datafusion::error::Result;
3031
use datafusion::prelude::*;
3132
use datafusion_common::{cast::as_float64_array, ScalarValue};
3233
use datafusion_expr::{
33-
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
34+
function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
35+
GroupsAccumulator, Signature,
3436
};
3537

3638
/// This example shows how to use the full AggregateUDFImpl API to implement a user
@@ -85,13 +87,21 @@ impl AggregateUDFImpl for GeoMeanUdaf {
8587
/// is supported, DataFusion will use this row oriented
8688
/// accumulator when the aggregate function is used as a window function
8789
/// or when there are only aggregates (no GROUP BY columns) in the plan.
88-
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
90+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
8991
Ok(Box::new(GeometricMean::new()))
9092
}
9193

9294
/// This is the description of the state. accumulator's state() must match the types here.
93-
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
94-
Ok(vec![DataType::Float64, DataType::UInt32])
95+
fn state_fields(
96+
&self,
97+
_name: &str,
98+
value_type: DataType,
99+
_ordering_fields: Vec<arrow_schema::Field>,
100+
) -> Result<Vec<arrow_schema::Field>> {
101+
Ok(vec![
102+
Field::new("prod", value_type, true),
103+
Field::new("n", DataType::UInt32, true),
104+
])
95105
}
96106

97107
/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
@@ -191,7 +201,6 @@ impl Accumulator for GeometricMean {
191201

192202
// create local session context with an in-memory table
193203
fn create_context() -> Result<SessionContext> {
194-
use datafusion::arrow::datatypes::{Field, Schema};
195204
use datafusion::datasource::MemTable;
196205
// define a schema.
197206
let schema = Arc::new(Schema::new(vec![

datafusion/core/src/execution/context/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@ use datafusion_common::{
6969
OwnedTableReference, SchemaReference,
7070
};
7171
use datafusion_execution::registry::SerializerRegistry;
72+
use datafusion_expr::type_coercion::aggregates::NUMERICS;
73+
use datafusion_expr::{create_first_value, Signature, Volatility};
7274
use datafusion_expr::{
7375
logical_plan::{DdlStatement, Statement},
7476
var_provider::is_system_variables,
7577
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
7678
};
79+
use datafusion_physical_expr::create_first_value_accumulator;
7780
use datafusion_sql::{
7881
parser::{CopyToSource, CopyToStatement, DFParser},
7982
planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel},
@@ -82,6 +85,7 @@ use datafusion_sql::{
8285

8386
use async_trait::async_trait;
8487
use chrono::{DateTime, Utc};
88+
use log::debug;
8589
use parking_lot::RwLock;
8690
use sqlparser::dialect::dialect_from_str;
8791
use url::Url;
@@ -1451,6 +1455,22 @@ impl SessionState {
14511455
datafusion_functions_array::register_all(&mut new_self)
14521456
.expect("can not register array expressions");
14531457

1458+
let first_value = create_first_value(
1459+
"FIRST_VALUE",
1460+
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
1461+
Arc::new(create_first_value_accumulator),
1462+
);
1463+
1464+
match new_self.register_udaf(Arc::new(first_value)) {
1465+
Ok(Some(existing_udaf)) => {
1466+
debug!("Overwrite existing UDAF: {}", existing_udaf.name());
1467+
}
1468+
Ok(None) => {}
1469+
Err(err) => {
1470+
panic!("Failed to register UDAF: {}", err);
1471+
}
1472+
}
1473+
14541474
new_self
14551475
}
14561476
/// Returns new [`SessionState`] using the provided

datafusion/core/src/physical_planner.rs

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -247,24 +247,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
247247
distinct,
248248
args,
249249
filter,
250-
order_by,
250+
order_by: _,
251251
null_treatment: _,
252252
}) => match func_def {
253253
AggregateFunctionDefinition::BuiltIn(..) => {
254254
create_function_physical_name(func_def.name(), *distinct, args)
255255
}
256256
AggregateFunctionDefinition::UDF(fun) => {
257-
// TODO: Add support for filter and order by in AggregateUDF
257+
// TODO: Add support for filter by in AggregateUDF
258258
if filter.is_some() {
259259
return exec_err!(
260260
"aggregate expression with filter is not supported"
261261
);
262262
}
263-
if order_by.is_some() {
264-
return exec_err!(
265-
"aggregate expression with order_by is not supported"
266-
);
267-
}
263+
268264
let names = args
269265
.iter()
270266
.map(|e| create_physical_name(e, false))
@@ -1667,20 +1663,22 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
16671663
)?),
16681664
None => None,
16691665
};
1670-
let order_by = match order_by {
1671-
Some(e) => Some(create_physical_sort_exprs(
1672-
e,
1673-
logical_input_schema,
1674-
execution_props,
1675-
)?),
1676-
None => None,
1677-
};
1666+
16781667
let ignore_nulls = null_treatment
16791668
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
16801669
== NullTreatment::IgnoreNulls;
16811670
let (agg_expr, filter, order_by) = match func_def {
16821671
AggregateFunctionDefinition::BuiltIn(fun) => {
1683-
let ordering_reqs = order_by.clone().unwrap_or(vec![]);
1672+
let physical_sort_exprs = match order_by {
1673+
Some(exprs) => Some(create_physical_sort_exprs(
1674+
exprs,
1675+
logical_input_schema,
1676+
execution_props,
1677+
)?),
1678+
None => None,
1679+
};
1680+
let ordering_reqs: Vec<PhysicalSortExpr> =
1681+
physical_sort_exprs.clone().unwrap_or(vec![]);
16841682
let agg_expr = aggregates::create_aggregate_expr(
16851683
fun,
16861684
*distinct,
@@ -1690,16 +1688,30 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
16901688
name,
16911689
ignore_nulls,
16921690
)?;
1693-
(agg_expr, filter, order_by)
1691+
(agg_expr, filter, physical_sort_exprs)
16941692
}
16951693
AggregateFunctionDefinition::UDF(fun) => {
1694+
let sort_exprs = order_by.clone().unwrap_or(vec![]);
1695+
let physical_sort_exprs = match order_by {
1696+
Some(exprs) => Some(create_physical_sort_exprs(
1697+
exprs,
1698+
logical_input_schema,
1699+
execution_props,
1700+
)?),
1701+
None => None,
1702+
};
1703+
let ordering_reqs: Vec<PhysicalSortExpr> =
1704+
physical_sort_exprs.clone().unwrap_or(vec![]);
16961705
let agg_expr = udaf::create_aggregate_expr(
16971706
fun,
16981707
&args,
1708+
&sort_exprs,
1709+
&ordering_reqs,
16991710
physical_input_schema,
17001711
name,
1701-
);
1702-
(agg_expr?, filter, order_by)
1712+
ignore_nulls,
1713+
)?;
1714+
(agg_expr, filter, physical_sort_exprs)
17031715
}
17041716
AggregateFunctionDefinition::Name(_) => {
17051717
return internal_err!(

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ use datafusion::{
4545
};
4646
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
4747
use datafusion_expr::{
48-
create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
48+
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
49+
SimpleAggregateUDF,
4950
};
5051
use datafusion_physical_expr::expressions::AvgAccumulator;
5152

@@ -491,7 +492,7 @@ impl TimeSum {
491492
// Returns the same type as its input
492493
let return_type = timestamp_type.clone();
493494

494-
let state_type = vec![timestamp_type.clone()];
495+
let state_fields = vec![Field::new("sum", timestamp_type, true)];
495496

496497
let volatility = Volatility::Immutable;
497498

@@ -505,7 +506,7 @@ impl TimeSum {
505506
return_type,
506507
volatility,
507508
accumulator,
508-
state_type,
509+
state_fields,
509510
));
510511

511512
// register the selector as "time_sum"
@@ -591,6 +592,11 @@ impl FirstSelector {
591592
fn register(ctx: &mut SessionContext) {
592593
let return_type = Self::output_datatype();
593594
let state_type = Self::state_datatypes();
595+
let state_fields = state_type
596+
.into_iter()
597+
.enumerate()
598+
.map(|(i, t)| Field::new(format!("{i}"), t, true))
599+
.collect::<Vec<_>>();
594600

595601
// Possible input signatures
596602
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
@@ -607,7 +613,7 @@ impl FirstSelector {
607613
Signature::one_of(signatures, volatility),
608614
return_type,
609615
accumulator,
610-
state_type,
616+
state_fields,
611617
));
612618

613619
// register the selector as "first"
@@ -717,15 +723,11 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
717723
Ok(DataType::UInt64)
718724
}
719725

720-
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
726+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
721727
// should use groups accumulator
722728
panic!("accumulator shouldn't invoke");
723729
}
724730

725-
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
726-
Ok(vec![DataType::UInt64])
727-
}
728-
729731
fn groups_accumulator_supported(&self) -> bool {
730732
true
731733
}

datafusion/expr/src/expr.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,14 +577,15 @@ impl AggregateFunction {
577577
distinct: bool,
578578
filter: Option<Box<Expr>>,
579579
order_by: Option<Vec<Expr>>,
580+
null_treatment: Option<NullTreatment>,
580581
) -> Self {
581582
Self {
582583
func_def: AggregateFunctionDefinition::UDF(udf),
583584
args,
584585
distinct,
585586
filter,
586587
order_by,
587-
null_treatment: None,
588+
null_treatment,
588589
}
589590
}
590591
}

0 commit comments

Comments
 (0)