-
Notifications
You must be signed in to change notification settings - Fork 1.8k
support sum/avg agg for decimal, change sum(float32) --> float64 #1408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
2095685
e3b005f
3c19c41
ff61c04
9675e30
f4b5655
2d247ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,9 @@ use std::sync::Arc; | |
|
|
||
| use crate::error::{DataFusionError, Result}; | ||
| use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; | ||
| use crate::scalar::ScalarValue; | ||
| use crate::scalar::{ | ||
| ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, | ||
| }; | ||
| use arrow::compute; | ||
| use arrow::datatypes::DataType; | ||
| use arrow::{ | ||
|
|
@@ -45,9 +47,10 @@ pub struct Avg { | |
| pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> { | ||
| match arg_type { | ||
| DataType::Decimal(precision, scale) => { | ||
| // the new precision and scale for return type of avg function | ||
| let new_precision = 38.min(*precision + 4); | ||
| let new_scale = 38.min(*scale + 4); | ||
| // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). | ||
| // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 | ||
| let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); | ||
| let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); | ||
| Ok(DataType::Decimal(new_precision, new_scale)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could potentially use |
||
| } | ||
| DataType::Int8 | ||
|
|
@@ -91,18 +94,15 @@ impl Avg { | |
| name: impl Into<String>, | ||
| data_type: DataType, | ||
| ) -> Self { | ||
| // Average is always Float64, but Avg::new() has a data_type | ||
| // parameter to keep a consistent signature with the other | ||
| // Aggregate expressions. | ||
| match data_type { | ||
| DataType::Float64 | DataType::Decimal(_, _) => Self { | ||
| name: name.into(), | ||
| expr, | ||
| data_type, | ||
| }, | ||
| _ => { | ||
| unreachable!(); | ||
| } | ||
| // the result of avg just support FLOAT64 and Decimal data type. | ||
| assert!(matches!( | ||
| data_type, | ||
| DataType::Float64 | DataType::Decimal(_, _) | ||
| )); | ||
| Self { | ||
| name: name.into(), | ||
| expr, | ||
| data_type, | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -260,7 +260,6 @@ mod tests { | |
| // test agg | ||
| let mut decimal_builder = DecimalBuilder::new(6, 10, 0); | ||
| for i in 1..7 { | ||
| // the avg is 3.5, but we get the result of 3 | ||
| decimal_builder.append_value(i as i128)?; | ||
| } | ||
| let array: ArrayRef = Arc::new(decimal_builder.finish()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ use std::sync::Arc; | |
|
|
||
| use crate::error::{DataFusionError, Result}; | ||
| use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; | ||
| use crate::scalar::ScalarValue; | ||
| use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; | ||
| use arrow::compute; | ||
| use arrow::datatypes::DataType; | ||
| use arrow::{ | ||
|
|
@@ -56,13 +56,13 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> { | |
| DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { | ||
| Ok(DataType::UInt64) | ||
| } | ||
| // In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc, | ||
| // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc, | ||
| // the result type of floating-point is FLOAT64 with the double precision. | ||
| DataType::Float64 | DataType::Float32 => Ok(DataType::Float64), | ||
| // Max precision is 38 | ||
| DataType::Decimal(precision, scale) => { | ||
| // in the spark, the result type is DECIMAL(min(38,precision+10), s) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thank you for the context
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alamb |
||
| let new_precision = 38.min(*precision + 10); | ||
| // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 | ||
| let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 10); | ||
| Ok(DataType::Decimal(new_precision, *scale)) | ||
| } | ||
| other => Err(DataFusionError::Plan(format!( | ||
|
|
@@ -234,8 +234,8 @@ fn sum_decimal( | |
| (None, None) => ScalarValue::Decimal128(None, *precision, *scale), | ||
| (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale), | ||
| (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale), | ||
| (lhs, rhs) => { | ||
| ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale) | ||
| (Some(lhs_value), Some(rhs_value)) => { | ||
| ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -250,14 +250,14 @@ fn sum_decimal_with_diff_scale( | |
| // the lhs_scale must be greater or equal rhs_scale. | ||
| match (lhs, rhs) { | ||
| (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale), | ||
| (None, rhs) => { | ||
| let new_value = rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32); | ||
| (None, Some(rhs_value)) => { | ||
| let new_value = rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32); | ||
| ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) | ||
| } | ||
| (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale), | ||
| (lhs, rhs) => { | ||
| (Some(lhs_value), Some(rhs_value)) => { | ||
| let new_value = | ||
| rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs.unwrap(); | ||
| rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs_value; | ||
| ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) | ||
| } | ||
| } | ||
|
|
@@ -266,17 +266,16 @@ fn sum_decimal_with_diff_scale( | |
| pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> { | ||
| Ok(match (lhs, rhs) { | ||
| (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { | ||
| let max_precision = p1.max(p2); | ||
| if s1.eq(s2) { | ||
| sum_decimal(v1, v2, p1, s1) | ||
| } else if s1.gt(s2) && p1.ge(p2) { | ||
| // For avg aggravate function. | ||
| // In the avg function, the scale of result data type is different with the scale of the input data type. | ||
| sum_decimal_with_diff_scale(v1, v2, p1, s1, s2) | ||
| // s1 = s2 | ||
| sum_decimal(v1, v2, max_precision, s1) | ||
| } else if s1.gt(s2) { | ||
| // s1 > s2 | ||
| sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2) | ||
| } else { | ||
| return Err(DataFusionError::Internal(format!( | ||
| "Sum is not expected to receive lhs {:?}, rhs {:?}", | ||
| lhs, rhs | ||
| ))); | ||
| // s1 < s2 | ||
| sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1) | ||
| } | ||
| } | ||
| // float64 coerces everything to f64 | ||
|
|
@@ -419,17 +418,13 @@ mod tests { | |
| ScalarValue::Decimal128(Some(123 + 124 * 10_i128.pow(1)), 10, 3), | ||
| result | ||
| ); | ||
| // negative test with diff scale | ||
| // diff precision and scale for decimal data type | ||
| let left = ScalarValue::Decimal128(Some(123), 10, 2); | ||
| let right = ScalarValue::Decimal128(Some(124), 11, 3); | ||
| let result = sum(&left, &right); | ||
| assert_eq!( | ||
| DataFusionError::Internal(format!( | ||
| "Sum is not expected to receive lhs {:?}, rhs {:?}", | ||
| left, right | ||
| )) | ||
| .to_string(), | ||
| result.unwrap_err().to_string() | ||
| ScalarValue::Decimal128(Some(123 * 10_i128.pow(3 - 2) + 124), 11, 3), | ||
| result.unwrap() | ||
| ); | ||
|
|
||
| // test sum batch | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,11 @@ use std::convert::{Infallible, TryInto}; | |
| use std::str::FromStr; | ||
| use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; | ||
|
|
||
| // TODO may need to be moved to arrow-rs | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving to arrow-rs would be a good idea I think |
||
| /// The max precision and scale for decimal128 | ||
| pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38; | ||
| pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38; | ||
|
|
||
| /// Represents a dynamically typed, nullable single value. | ||
| /// This is the single-valued counter-part of arrow’s `Array`. | ||
| #[derive(Clone)] | ||
|
|
@@ -480,8 +485,7 @@ impl ScalarValue { | |
| scale: usize, | ||
| ) -> Result<Self> { | ||
| // make sure the precision and scale is valid | ||
| // TODO const the max precision and min scale | ||
| if precision <= 38 && scale <= precision { | ||
| if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { | ||
| return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); | ||
| } | ||
| return Err(DataFusionError::Internal(format!( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍