-
Notifications
You must be signed in to change notification settings - Fork 1.8k
support sum/avg agg for decimal, change sum(float32) --> float64 #1408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
2095685
e3b005f
3c19c41
ff61c04
9675e30
f4b5655
2d247ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,9 @@ use std::sync::Arc; | |
|
|
||
| use crate::error::{DataFusionError, Result}; | ||
| use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; | ||
| use crate::scalar::ScalarValue; | ||
| use crate::scalar::{ | ||
| ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, | ||
| }; | ||
| use arrow::compute; | ||
| use arrow::datatypes::DataType; | ||
| use arrow::{ | ||
|
|
@@ -38,11 +40,19 @@ use super::{format_state_name, sum}; | |
| pub struct Avg { | ||
| name: String, | ||
| expr: Arc<dyn PhysicalExpr>, | ||
| data_type: DataType, | ||
| } | ||
|
|
||
| /// function return type of an average | ||
| pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> { | ||
| match arg_type { | ||
| DataType::Decimal(precision, scale) => { | ||
| // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 | ||
| let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); | ||
| let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); | ||
| Ok(DataType::Decimal(new_precision, new_scale)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could potentially use |
||
| } | ||
| DataType::Int8 | ||
| | DataType::Int16 | ||
| | DataType::Int32 | ||
|
|
@@ -73,6 +83,7 @@ pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { | |
| | DataType::Int64 | ||
| | DataType::Float32 | ||
| | DataType::Float64 | ||
| | DataType::Decimal(_, _) | ||
| ) | ||
| } | ||
|
|
||
|
|
@@ -83,14 +94,15 @@ impl Avg { | |
| name: impl Into<String>, | ||
| data_type: DataType, | ||
| ) -> Self { | ||
| // Average is always Float64, but Avg::new() has a data_type | ||
| // parameter to keep a consistent signature with the other | ||
| // Aggregate expressions. | ||
| assert_eq!(data_type, DataType::Float64); | ||
|
|
||
| // the result of avg just support FLOAT64 and Decimal data type. | ||
| assert!(matches!( | ||
| data_type, | ||
| DataType::Float64 | DataType::Decimal(_, _) | ||
| )); | ||
| Self { | ||
| name: name.into(), | ||
| expr, | ||
| data_type, | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -102,7 +114,14 @@ impl AggregateExpr for Avg { | |
| } | ||
|
|
||
| fn field(&self) -> Result<Field> { | ||
| Ok(Field::new(&self.name, DataType::Float64, true)) | ||
| Ok(Field::new(&self.name, self.data_type.clone(), true)) | ||
| } | ||
|
|
||
| fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { | ||
| Ok(Box::new(AvgAccumulator::try_new( | ||
| // avg is f64 or decimal | ||
| &self.data_type, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if a sum of I think handling overflow is probably fine to for a later date / PR, but it is strange to me that there is a discrepancy between the type for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The result type of phy expr(sum/avg) is same with each Accumulator, and it was decided by If the column is decimal(8,2), the avg of this column must be less than For the sum agg, we just should increase the precision part, and the rule of adding
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add issue to track the overflow. |
||
| )?)) | ||
| } | ||
|
|
||
| fn state_fields(&self) -> Result<Vec<Field>> { | ||
|
|
@@ -114,19 +133,12 @@ impl AggregateExpr for Avg { | |
| ), | ||
| Field::new( | ||
| &format_state_name(&self.name, "sum"), | ||
| DataType::Float64, | ||
| self.data_type.clone(), | ||
| true, | ||
| ), | ||
| ]) | ||
| } | ||
|
|
||
| fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { | ||
| Ok(Box::new(AvgAccumulator::try_new( | ||
| // avg is f64 | ||
| &DataType::Float64, | ||
| )?)) | ||
| } | ||
|
|
||
| fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { | ||
| vec![self.expr.clone()] | ||
| } | ||
|
|
@@ -205,6 +217,17 @@ impl Accumulator for AvgAccumulator { | |
| ScalarValue::Float64(e) => { | ||
| Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) | ||
| } | ||
| ScalarValue::Decimal128(value, precision, scale) => { | ||
| Ok(match value { | ||
| None => ScalarValue::Decimal128(None, precision, scale), | ||
| // TODO add the checker for overflow the precision | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| Some(v) => ScalarValue::Decimal128( | ||
| Some(v / self.count as i128), | ||
| precision, | ||
| scale, | ||
| ), | ||
| }) | ||
| } | ||
| _ => Err(DataFusionError::Internal( | ||
| "Sum should be f64 on average".to_string(), | ||
| )), | ||
|
|
@@ -220,6 +243,73 @@ mod tests { | |
| use arrow::record_batch::RecordBatch; | ||
| use arrow::{array::*, datatypes::*}; | ||
|
|
||
| #[test] | ||
| fn test_avg_return_data_type() -> Result<()> { | ||
| let data_type = DataType::Decimal(10, 5); | ||
| let result_type = avg_return_type(&data_type)?; | ||
| assert_eq!(DataType::Decimal(14, 9), result_type); | ||
|
|
||
| let data_type = DataType::Decimal(36, 10); | ||
| let result_type = avg_return_type(&data_type)?; | ||
| assert_eq!(DataType::Decimal(38, 14), result_type); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn avg_decimal() -> Result<()> { | ||
| // test agg | ||
| let mut decimal_builder = DecimalBuilder::new(6, 10, 0); | ||
| for i in 1..7 { | ||
| decimal_builder.append_value(i as i128)?; | ||
| } | ||
| let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
|
|
||
| generic_test_op!( | ||
| array, | ||
| DataType::Decimal(10, 0), | ||
| Avg, | ||
| ScalarValue::Decimal128(Some(35000), 14, 4), | ||
| DataType::Decimal(14, 4) | ||
| ) | ||
| } | ||
|
|
||
| #[test] | ||
| fn avg_decimal_with_nulls() -> Result<()> { | ||
| let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
| for i in 1..6 { | ||
| if i == 2 { | ||
| decimal_builder.append_null()?; | ||
| } else { | ||
| decimal_builder.append_value(i)?; | ||
| } | ||
| } | ||
| let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
| generic_test_op!( | ||
| array, | ||
| DataType::Decimal(10, 0), | ||
| Avg, | ||
| ScalarValue::Decimal128(Some(32500), 14, 4), | ||
| DataType::Decimal(14, 4) | ||
| ) | ||
| } | ||
|
|
||
| #[test] | ||
| fn avg_decimal_all_nulls() -> Result<()> { | ||
| // test agg | ||
| let mut decimal_builder = DecimalBuilder::new(5, 10, 0); | ||
| for _i in 1..6 { | ||
| decimal_builder.append_null()?; | ||
| } | ||
| let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
| generic_test_op!( | ||
| array, | ||
| DataType::Decimal(10, 0), | ||
| Avg, | ||
| ScalarValue::Decimal128(None, 14, 4), | ||
| DataType::Decimal(14, 4) | ||
| ) | ||
| } | ||
|
|
||
| #[test] | ||
| fn avg_i32() -> Result<()> { | ||
| let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍