-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add coercion rules for AggregateFunctions #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c01db7e
fc48c58
7bc7eec
42b2192
14d51da
17064f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,15 +28,16 @@ | |
|
|
||
| use super::{ | ||
| functions::{Signature, Volatility}, | ||
| type_coercion::{coerce, data_types}, | ||
| Accumulator, AggregateExpr, PhysicalExpr, | ||
| }; | ||
| use crate::error::{DataFusionError, Result}; | ||
| use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; | ||
| use crate::physical_plan::distinct_expressions; | ||
| use crate::physical_plan::expressions; | ||
| use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; | ||
| use expressions::{avg_return_type, sum_return_type}; | ||
| use std::{fmt, str::FromStr, sync::Arc}; | ||
|
|
||
| /// the implementation of an aggregate function | ||
| pub type AccumulatorFunctionImplementation = | ||
| Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>; | ||
|
|
@@ -87,35 +88,38 @@ impl FromStr for AggregateFunction { | |
| return Err(DataFusionError::Plan(format!( | ||
| "There is no built-in function named {}", | ||
| name | ||
| ))) | ||
| ))); | ||
| } | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| /// Returns the datatype of the aggregation function | ||
| /// Returns the datatype of the aggregate function. | ||
| /// This is used to get the returned data type for aggregate expr. | ||
| pub fn return_type( | ||
| fun: &AggregateFunction, | ||
| input_expr_types: &[DataType], | ||
| ) -> Result<DataType> { | ||
| // Note that this function *must* return the same type that the respective physical expression returns | ||
| // or the execution panics. | ||
|
|
||
| // verify that this is a valid set of data types for this function | ||
| data_types(input_expr_types, &signature(fun))?; | ||
| let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?; | ||
|
|
||
| match fun { | ||
| // TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64. | ||
| AggregateFunction::Count | AggregateFunction::ApproxDistinct => { | ||
| Ok(DataType::UInt64) | ||
| } | ||
| AggregateFunction::Max | AggregateFunction::Min => { | ||
| Ok(input_expr_types[0].clone()) | ||
| // For min and max agg function, the returned type is same as input type. | ||
| // The coerced_data_types is same with input_types. | ||
| Ok(coerced_data_types[0].clone()) | ||
| } | ||
| AggregateFunction::Sum => sum_return_type(&input_expr_types[0]), | ||
| AggregateFunction::Avg => avg_return_type(&input_expr_types[0]), | ||
| AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), | ||
| AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), | ||
| AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( | ||
| "item", | ||
| input_expr_types[0].clone(), | ||
| coerced_data_types[0].clone(), | ||
| true, | ||
| )))), | ||
| } | ||
|
|
@@ -131,26 +135,26 @@ pub fn create_aggregate_expr( | |
| name: impl Into<String>, | ||
| ) -> Result<Arc<dyn AggregateExpr>> { | ||
| let name = name.into(); | ||
| let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?; | ||
| // get the coerced phy exprs if some expr need to be wrapped with the try cast. | ||
| let coerced_phy_exprs = | ||
| coerce_exprs(fun, input_phy_exprs, input_schema, &signature(fun))?; | ||
| if coerced_phy_exprs.is_empty() { | ||
| return Err(DataFusionError::Plan(format!( | ||
| "Invalid or wrong number of arguments passed to aggregate: '{}'", | ||
| name, | ||
| ))); | ||
| } | ||
|
|
||
| let coerced_exprs_types = coerced_phy_exprs | ||
| .iter() | ||
| .map(|e| e.data_type(input_schema)) | ||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
| let input_exprs_types = input_phy_exprs | ||
| // get the result data type for this aggregate function | ||
| let input_phy_types = input_phy_exprs | ||
| .iter() | ||
| .map(|e| e.data_type(input_schema)) | ||
| .collect::<Result<Vec<_>>>()?; | ||
|
|
||
| // In order to get the result data type, we must use the original input data type to calculate the result type. | ||
| let return_type = return_type(fun, &input_exprs_types)?; | ||
| let return_type = return_type(fun, &input_phy_types)?; | ||
|
|
||
| Ok(match (fun, distinct) { | ||
| (AggregateFunction::Count, false) => Arc::new(expressions::Count::new( | ||
|
|
@@ -161,7 +165,7 @@ pub fn create_aggregate_expr( | |
| (AggregateFunction::Count, true) => { | ||
| Arc::new(distinct_expressions::DistinctCount::new( | ||
| coerced_exprs_types, | ||
| coerced_phy_exprs.to_vec(), | ||
| coerced_phy_exprs, | ||
| name, | ||
| return_type, | ||
| )) | ||
|
|
@@ -262,6 +266,199 @@ pub fn signature(fun: &AggregateFunction) -> Signature { | |
| mod tests { | ||
| use super::*; | ||
| use crate::error::Result; | ||
| use crate::physical_plan::expressions::{ | ||
| ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, | ||
| }; | ||
|
|
||
| #[test] | ||
| fn test_count_arragg_approx_expr() -> Result<()> { | ||
| let funcs = vec![ | ||
| AggregateFunction::Count, | ||
| AggregateFunction::ArrayAgg, | ||
| AggregateFunction::ApproxDistinct, | ||
| ]; | ||
| let data_types = vec![ | ||
| DataType::UInt32, | ||
| DataType::Int32, | ||
| DataType::Float32, | ||
| DataType::Float64, | ||
| DataType::Decimal(10, 2), | ||
| DataType::Utf8, | ||
| ]; | ||
| for fun in funcs { | ||
| for data_type in &data_types { | ||
| let input_schema = | ||
| Schema::new(vec![Field::new("c1", data_type.clone(), true)]); | ||
| let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new( | ||
| expressions::Column::new_with_schema("c1", &input_schema).unwrap(), | ||
| )]; | ||
| let result_agg_phy_exprs = create_aggregate_expr( | ||
| &fun, | ||
| false, | ||
| &input_phy_exprs[0..1], | ||
| &input_schema, | ||
| "c1", | ||
| )?; | ||
| match fun { | ||
| AggregateFunction::Count => { | ||
| assert!(result_agg_phy_exprs.as_any().is::<Count>()); | ||
| assert_eq!("c1", result_agg_phy_exprs.name()); | ||
| assert_eq!( | ||
| Field::new("c1", DataType::UInt64, true), | ||
| result_agg_phy_exprs.field().unwrap() | ||
| ); | ||
| } | ||
| AggregateFunction::ApproxDistinct => { | ||
| assert!(result_agg_phy_exprs.as_any().is::<ApproxDistinct>()); | ||
| assert_eq!("c1", result_agg_phy_exprs.name()); | ||
| assert_eq!( | ||
| Field::new("c1", DataType::UInt64, false), | ||
| result_agg_phy_exprs.field().unwrap() | ||
| ); | ||
| } | ||
| AggregateFunction::ArrayAgg => { | ||
| assert!(result_agg_phy_exprs.as_any().is::<ArrayAgg>()); | ||
| assert_eq!("c1", result_agg_phy_exprs.name()); | ||
| assert_eq!( | ||
| Field::new( | ||
| "c1", | ||
| DataType::List(Box::new(Field::new( | ||
| "item", | ||
| data_type.clone(), | ||
| true | ||
| ))), | ||
| false | ||
| ), | ||
| result_agg_phy_exprs.field().unwrap() | ||
| ); | ||
| } | ||
| _ => {} | ||
| }; | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_min_max_expr() -> Result<()> { | ||
| let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; | ||
| let data_types = vec![ | ||
| DataType::UInt32, | ||
| DataType::Int32, | ||
| DataType::Float32, | ||
| DataType::Float64, | ||
| DataType::Decimal(10, 2), | ||
| DataType::Utf8, | ||
| ]; | ||
| for fun in funcs { | ||
| for data_type in &data_types { | ||
| let input_schema = | ||
| Schema::new(vec![Field::new("c1", data_type.clone(), true)]); | ||
| let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new( | ||
| expressions::Column::new_with_schema("c1", &input_schema).unwrap(), | ||
| )]; | ||
| let result_agg_phy_exprs = create_aggregate_expr( | ||
| &fun, | ||
| false, | ||
| &input_phy_exprs[0..1], | ||
| &input_schema, | ||
| "c1", | ||
| )?; | ||
| match fun { | ||
| AggregateFunction::Min => { | ||
| assert!(result_agg_phy_exprs.as_any().is::<Min>()); | ||
| assert_eq!("c1", result_agg_phy_exprs.name()); | ||
| assert_eq!( | ||
| Field::new("c1", data_type.clone(), true), | ||
| result_agg_phy_exprs.field().unwrap() | ||
| ); | ||
| } | ||
| AggregateFunction::Max => { | ||
| assert!(result_agg_phy_exprs.as_any().is::<Max>()); | ||
| assert_eq!("c1", result_agg_phy_exprs.name()); | ||
| assert_eq!( | ||
| Field::new("c1", data_type.clone(), true), | ||
| result_agg_phy_exprs.field().unwrap() | ||
| ); | ||
| } | ||
| _ => {} | ||
| }; | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_sum_avg_expr() -> Result<()> { | ||
| let funcs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; | ||
| let data_types = vec![ | ||
| DataType::UInt32, | ||
| DataType::UInt64, | ||
| DataType::Int32, | ||
| DataType::Int64, | ||
| DataType::Float32, | ||
| DataType::Float64, | ||
| ]; | ||
| for fun in funcs { | ||
| for data_type in &data_types { | ||
| let input_schema = | ||
| Schema::new(vec![Field::new("c1", data_type.clone(), true)]); | ||
| let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new( | ||
| expressions::Column::new_with_schema("c1", &input_schema).unwrap(), | ||
| )]; | ||
| let result_agg_phy_exprs = create_aggregate_expr( | ||
| &fun, | ||
| false, | ||
| &input_phy_exprs[0..1], | ||
| &input_schema, | ||
| "c1", | ||
| )?; | ||
| match fun { | ||
| AggregateFunction::Sum => { | ||
| assert!(result_agg_phy_exprs.as_any().is::<Sum>()); | ||
| assert_eq!("c1", result_agg_phy_exprs.name()); | ||
| let mut expect_type = data_type.clone(); | ||
| if matches!( | ||
| data_type, | ||
| DataType::UInt8 | ||
| | DataType::UInt16 | ||
| | DataType::UInt32 | ||
| | DataType::UInt64 | ||
| ) { | ||
| expect_type = DataType::UInt64; | ||
| } else if matches!( | ||
| data_type, | ||
| DataType::Int8 | ||
| | DataType::Int16 | ||
| | DataType::Int32 | ||
| | DataType::Int64 | ||
| ) { | ||
| expect_type = DataType::Int64; | ||
| } else if matches!( | ||
| data_type, | ||
| DataType::Float32 | DataType::Float64 | ||
| ) { | ||
| expect_type = data_type.clone(); | ||
| } | ||
|
Comment on lines
+420
to
+442
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just FYI you can write this kind of logic in a more concise way with something like (untested and abbreviated) let expect_type = match (data_type) {
DataType::UInt8 | .... => DataType::UInt64,
DataType::Int8 | .... => DataType::Int64,
_ => data_type.clone()
}
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #1416 <-- PR wth proposed cleanup
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good style and suggestion! |
||
| assert_eq!( | ||
| Field::new("c1", expect_type.clone(), true), | ||
| result_agg_phy_exprs.field().unwrap() | ||
| ); | ||
| } | ||
| AggregateFunction::Avg => { | ||
| assert!(result_agg_phy_exprs.as_any().is::<Avg>()); | ||
| assert_eq!("c1", result_agg_phy_exprs.name()); | ||
| assert_eq!( | ||
| Field::new("c1", DataType::Float64, true), | ||
| result_agg_phy_exprs.field().unwrap() | ||
| ); | ||
| } | ||
| _ => {} | ||
| }; | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_min_max() -> Result<()> { | ||
|
|
@@ -270,6 +467,16 @@ mod tests { | |
|
|
||
| let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?; | ||
| assert_eq!(DataType::Int32, observed); | ||
|
|
||
| // test decimal for min | ||
| let observed = return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)])?; | ||
| assert_eq!(DataType::Decimal(10, 6), observed); | ||
|
|
||
| // test decimal for max | ||
| let observed = | ||
| return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)])?; | ||
| assert_eq!(DataType::Decimal(28, 13), observed); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
|
|
@@ -293,6 +500,10 @@ mod tests { | |
|
|
||
| let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?; | ||
| assert_eq!(DataType::UInt64, observed); | ||
|
|
||
| let observed = | ||
| return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?; | ||
| assert_eq!(DataType::UInt64, observed); | ||
| Ok(()) | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.