From 2095685b15cd75c48b7193b92080c97c24d7aa0e Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 7 Dec 2021 12:02:52 +0800 Subject: [PATCH 1/4] support sum/avg agg for decimal --- .../src/physical_plan/expressions/average.rs | 33 ++- .../src/physical_plan/expressions/sum.rs | 200 ++++++++++++++++-- 2 files changed, 212 insertions(+), 21 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 17d3041453d0..afec4ea0fbab 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -43,6 +43,7 @@ pub struct Avg { /// function return type of an average pub fn avg_return_type(arg_type: &DataType) -> Result { match arg_type { + // TODO decimal type DataType::Int8 | DataType::Int16 | DataType::Int32 @@ -89,6 +90,13 @@ impl AggregateExpr for Avg { Ok(Field::new(&self.name, DataType::Float64, true)) } + fn create_accumulator(&self) -> Result> { + Ok(Box::new(AvgAccumulator::try_new( + // avg is f64 + &DataType::Float64, + )?)) + } + fn state_fields(&self) -> Result> { Ok(vec![ Field::new( @@ -104,13 +112,6 @@ impl AggregateExpr for Avg { ]) } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - // avg is f64 - &DataType::Float64, - )?)) - } - fn expressions(&self) -> Vec> { vec![self.expr.clone()] } @@ -204,6 +205,24 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; + #[test] + fn avg_decimal() -> Result<()> { + // TODO + Ok(()) + } + + #[test] + fn avg_decimal_with_nulls() -> Result<()> { + // TODO + Ok(()) + } + + #[test] + fn avg_decimal_all_nulls() -> Result<()> { + // TODO + Ok(()) + } + #[test] fn avg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index c3f57e31e0d5..96b556fb4d61 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -35,6 +35,8 @@ use arrow::{ }; use super::format_state_name; +use crate::arrow::array::Array; +use arrow::array::DecimalArray; /// SUM aggregate expression #[derive(Debug)] @@ -56,6 +58,7 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } DataType::Float32 => Ok(DataType::Float32), DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal(precision, scale) => Ok(DataType::Decimal(*precision, *scale)), other => Err(DataFusionError::Plan(format!( "SUM does not support type \"{:?}\"", other @@ -93,6 +96,10 @@ impl AggregateExpr for Sum { )) } + fn create_accumulator(&self) -> Result> { + Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) + } + fn state_fields(&self) -> Result> { Ok(vec![Field::new( &format_state_name(&self.name, "sum"), @@ -105,10 +112,6 @@ impl AggregateExpr for Sum { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) - } - fn name(&self) -> &str { &self.name } @@ -137,9 +140,33 @@ macro_rules! typed_sum_delta_batch { }}; } +fn sum_decimal_batch( + values: &ArrayRef, + precision: &usize, + scale: &usize, +) -> Result { + // TODO, if the values is empty, what should we return? + // None or 0 + let array = values.as_any().downcast_ref::().unwrap(); + let mut result = 0_i128; + for i in 0..array.len() { + if array.is_valid(i) { + result += array.value(i); + } + } + if array.null_count() == array.len() { + return Ok(ScalarValue::Decimal128(None, *precision, *scale)); + } + Ok(ScalarValue::Decimal128(Some(result), *precision, *scale)) +} + // sums the array and returns a ScalarValue of its corresponding type. pub(super) fn sum_batch(values: &ArrayRef) -> Result { Ok(match values.data_type() { + DataType::Decimal(precision, scale) => { + // TODO the result data type should use the new precision and scale + sum_decimal_batch(values, precision, scale)? + } DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64), DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, Float32), DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64), @@ -154,7 +181,7 @@ pub(super) fn sum_batch(values: &ArrayRef) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive the type {:?}", e - ))) + ))); } }) } @@ -171,8 +198,35 @@ macro_rules! typed_sum { }}; } +fn sum_decimal( + lhs: &Option, + rhs: &Option, + precision: &usize, + scale: &usize, +) -> ScalarValue { + match (lhs, rhs) { + (None, None) => ScalarValue::Decimal128(None, *precision, *scale), + (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale), + (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale), + (lhs, rhs) => { + ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale) + } + } +} + pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { Ok(match (lhs, rhs) { + (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + // TODO the result data type should use the new precision and scale + sum_decimal(v1, v2, p1, s1) + } else { + return Err(DataFusionError::Internal(format!( + "Sum is not expected to receive lhs {:?}, rhs {:?}", + lhs, rhs + ))); + } + } // float64 coerces everything to f64 (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { typed_sum!(lhs, rhs, Float64, f64) @@ -238,16 +292,14 @@ pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive a scalar {:?}", e - ))) + ))); } }) } impl Accumulator for SumAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.sum = sum(&self.sum, &sum_batch(values)?)?; - Ok(()) + fn state(&self) -> Result> { + Ok(vec![self.sum.clone()]) } fn update(&mut self, values: &[ScalarValue]) -> Result<()> { @@ -256,6 +308,12 @@ impl Accumulator for SumAccumulator { Ok(()) } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + self.sum = sum(&self.sum, &sum_batch(values)?)?; + Ok(()) + } + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { // sum(sum1, sum2) = sum1 + sum2 self.update(states) @@ -266,10 +324,6 @@ impl Accumulator for SumAccumulator { self.update_batch(states) } - fn state(&self) -> Result> { - Ok(vec![self.sum.clone()]) - } - fn evaluate(&self) -> Result { Ok(self.sum.clone()) } @@ -278,11 +332,129 @@ impl Accumulator for SumAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::arrow::array::DecimalBuilder; use crate::physical_plan::expressions::col; use crate::{error::Result, generic_test_op}; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + #[test] + fn sum_decimal() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(Some(123 + 124), 10, 2), result); + // negative test + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 11, 2); + let result = sum(&left, &right); + assert_eq!( + DataFusionError::Internal(format!( + "Sum is not expected to receive lhs {:?}, rhs {:?}", + left, right + )) + .to_string(), + result.unwrap_err().to_string() + ); + + // test sum batch + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result); + + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + + generic_test_op!( + array, + DataType::Decimal(10, 0), + Sum, + ScalarValue::Decimal128(Some(15), 10, 0), + DataType::Decimal(10, 0) + ) + } + + #[test] + fn sum_decimal_with_nulls() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(None, 10, 2); + let right = ScalarValue::Decimal128(Some(123), 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(Some(123), 10, 2), result); + + // test with batch + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result); + + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Sum, + ScalarValue::Decimal128(Some(13), 10, 0), + DataType::Decimal(10, 0) + ) + } + + #[test] + fn sum_decimal_all_nulls() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(None, 10, 2); + let right = ScalarValue::Decimal128(None, 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 2), result); + + // test with batch + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); + + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Sum, + ScalarValue::Decimal128(None, 10, 0), + DataType::Decimal(10, 0) + ) + } + #[test] fn sum_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); From 3c19c41eb0eb20577fab8abe4375a0f7f192d303 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Wed, 8 Dec 2021 20:07:36 +0800 Subject: [PATCH 2/4] support sum/avg agg for decimal --- datafusion/src/execution/context.rs | 52 +++++++++++ datafusion/src/physical_plan/aggregates.rs | 1 + .../coercion_rule/aggregate_rule.rs | 3 +- .../src/physical_plan/expressions/average.rs | 88 +++++++++++++++---- .../src/physical_plan/expressions/sum.rs | 43 ++++++--- 5 files changed, 157 insertions(+), 30 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 59d6f44f59b1..b369c3fece2c 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1842,6 +1842,58 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_decimal_min() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "select min(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| MIN(d_table.c1) |", + "+-----------------+", + "| -100.009 |", + "+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn aggregate_decimal_max() -> Result<()> { + let mut ctx = ExecutionContext::new(); + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + + let result = plan_and_collect(&mut ctx, "select max(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| MAX(d_table.c1) |", + "+-----------------+", + "| 100.009 |", + "+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn aggregate_decimal_sum() -> Result<()> { + Ok(()) + } + + #[tokio::test] + async fn aggregate_decimal_avg() -> Result<()> { + Ok(()) + } + + + #[tokio::test] async fn aggregate() -> Result<()> { let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 50e1a82c74c2..47fc0ac31af4 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -115,6 +115,7 @@ pub fn return_type( // The coerced_data_types is same with input_types. Ok(coerced_data_types[0].clone()) } + // TODO get the new decimal data type AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index d7b437528d5c..e76e4a6b023e 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -193,8 +193,7 @@ mod tests { let input_types = vec![ vec![DataType::Int32], vec![DataType::Float32], - // support the decimal data type - // vec![DataType::Decimal(20, 3)], + vec![DataType::Decimal(20, 3)], ]; for fun in funs { for input_type in &input_types { diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 20d400e5a581..0411584e2e46 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -38,12 +38,14 @@ use super::{format_state_name, sum}; pub struct Avg { name: String, expr: Arc, + data_type: DataType, } /// function return type of an average pub fn avg_return_type(arg_type: &DataType) -> Result { match arg_type { - // TODO decimal type + // TODO how to handler decimal data type + DataType::Decimal(precision,scale) => Ok(DataType::Decimal(*precision, *scale)), DataType::Int8 | DataType::Int16 | DataType::Int32 @@ -74,6 +76,7 @@ pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 + | DataType::Decimal(_,_) ) } @@ -87,11 +90,17 @@ impl Avg { // Average is always Float64, but Avg::new() has a data_type // parameter to keep a consistent signature with the other // Aggregate expressions. - assert_eq!(data_type, DataType::Float64); - - Self { - name: name.into(), - expr, + match data_type { + DataType::Float64 | DataType::Decimal(_,_) => { + Self { + name: name.into(), + expr, + data_type, + } + } + _ => { + unreachable!(); + } } } } @@ -103,13 +112,13 @@ impl AggregateExpr for Avg { } fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + Ok(Field::new(&self.name, self.data_type.clone(), true)) } fn create_accumulator(&self) -> Result> { Ok(Box::new(AvgAccumulator::try_new( - // avg is f64 - &DataType::Float64, + // avg is f64 or decimal + &self.data_type )?)) } @@ -122,7 +131,7 @@ impl AggregateExpr for Avg { ), Field::new( &format_state_name(&self.name, "sum"), - DataType::Float64, + self.data_type.clone(), true, ), ]) @@ -206,6 +215,15 @@ impl Accumulator for AvgAccumulator { ScalarValue::Float64(e) => { Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) } + ScalarValue::Decimal128(value, precision, scale) => { + // TODO support the decimal data type + Ok(match value { + None => ScalarValue::Decimal128(None, precision, scale), + Some(v) => { + ScalarValue::Decimal128(Some(v / self.count as i128), precision, scale) + } + }) + } _ => Err(DataFusionError::Internal( "Sum should be f64 on average".to_string(), )), @@ -223,20 +241,58 @@ mod tests { #[test] fn avg_decimal() -> Result<()> { - // TODO - Ok(()) + // test agg + let mut decimal_builder = DecimalBuilder::new(6, 10, 0); + for i in 1..7 { + // the avg is 3.5, but we get the result of 3 + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + + generic_test_op!( + array, + DataType::Decimal(10, 0), + Avg, + ScalarValue::Decimal128(Some(3), 10, 0), + DataType::Decimal(10, 0) + ) } #[test] fn avg_decimal_with_nulls() -> Result<()> { - // TODO - Ok(()) + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Avg, + ScalarValue::Decimal128(Some(3), 10, 0), + DataType::Decimal(10, 0) + ) } #[test] fn avg_decimal_all_nulls() -> Result<()> { - // TODO - Ok(()) + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Avg, + ScalarValue::Decimal128(None, 10, 0), + DataType::Decimal(10, 0) + ) } #[test] diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index bd1af9719a7d..5c2ec2e70a47 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -56,9 +56,19 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { Ok(DataType::UInt64) } + // In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc, + // the result type of floating-point is FLOAT64 with the double precision. + // TODO, we should change this rule DataType::Float32 => Ok(DataType::Float32), DataType::Float64 => Ok(DataType::Float64), - DataType::Decimal(precision, scale) => Ok(DataType::Decimal(*precision, *scale)), + // TODO get the wider precision + // Max precision is 38 + DataType::Decimal(precision, scale) => { + // // in the spark, the result type is DECIMAL(p + min(10, 31-p), s) + // let new_precision = *precision + 10.min(31-*precision); + // Ok(DataType::Decimal(new_precision, *scale)) + Ok(DataType::Decimal(*precision, *scale)) + } other => Err(DataFusionError::Plan(format!( "SUM does not support type \"{:?}\"", other @@ -79,6 +89,7 @@ pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 + | DataType::Decimal(_,_) ) } @@ -156,23 +167,25 @@ macro_rules! typed_sum_delta_batch { }}; } +// TODO implement this in arrow-rs with simd +// https://github.com/apache/arrow-rs/issues/1010 fn sum_decimal_batch( values: &ArrayRef, precision: &usize, scale: &usize, ) -> Result { - // TODO, if the values is empty, what should we return? - // None or 0 let array = values.as_any().downcast_ref::().unwrap(); + + if array.null_count() == array.len() { + return Ok(ScalarValue::Decimal128(None, *precision, *scale)); + } + let mut result = 0_i128; for i in 0..array.len() { if array.is_valid(i) { result += array.value(i); } } - if array.null_count() == array.len() { - return Ok(ScalarValue::Decimal128(None, *precision, *scale)); - } Ok(ScalarValue::Decimal128(Some(result), *precision, *scale)) } @@ -180,7 +193,6 @@ fn sum_decimal_batch( pub(super) fn sum_batch(values: &ArrayRef) -> Result { Ok(match values.data_type() { DataType::Decimal(precision, scale) => { - // TODO the result data type should use the new precision and scale sum_decimal_batch(values, precision, scale)? } DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64), @@ -214,6 +226,8 @@ macro_rules! typed_sum { }}; } +// TODO implement this in arrow-rs with simd +// https://github.com/apache/arrow-rs/issues/1010 fn sum_decimal( lhs: &Option, rhs: &Option, @@ -232,9 +246,8 @@ fn sum_decimal( pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { Ok(match (lhs, rhs) { - (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { - if p1.eq(p2) && s1.eq(s2) { - // TODO the result data type should use the new precision and scale + (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, _p2, s2)) => { + if s1.eq(s2) { sum_decimal(v1, v2, p1, s1) } else { return Err(DataFusionError::Internal(format!( @@ -320,12 +333,16 @@ impl Accumulator for SumAccumulator { fn update(&mut self, values: &[ScalarValue]) -> Result<()> { // sum(v1, v2, v3) = v1 + v2 + v3 + // For the decimal data type, the precision of `sum` may be different from that of value, + // but the scale must be same. self.sum = sum(&self.sum, &values[0])?; Ok(()) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; + // For the decimal data type, the precision of `sum` may be different from that of value, + // but the scale must be same. self.sum = sum(&self.sum, &sum_batch(values)?)?; Ok(()) } @@ -341,6 +358,8 @@ impl Accumulator for SumAccumulator { } fn evaluate(&self) -> Result { + // TODO: For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. + // We should add the checker Ok(self.sum.clone()) } } @@ -361,9 +380,9 @@ mod tests { let right = ScalarValue::Decimal128(Some(124), 10, 2); let result = sum(&left, &right)?; assert_eq!(ScalarValue::Decimal128(Some(123 + 124), 10, 2), result); - // negative test + // negative test with diff scale let left = ScalarValue::Decimal128(Some(123), 10, 2); - let right = ScalarValue::Decimal128(Some(124), 11, 2); + let right = ScalarValue::Decimal128(Some(124), 11, 3); let result = sum(&left, &right); assert_eq!( DataFusionError::Internal(format!( From 9675e3020326d78c377d94e11404d54b7e5cebc5 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 14 Dec 2021 13:29:16 +0800 Subject: [PATCH 3/4] suppor the avg and add test --- datafusion/src/execution/context.rs | 49 ++++++++++- datafusion/src/physical_plan/aggregates.rs | 35 +++++++- .../src/physical_plan/expressions/average.rs | 34 ++++++-- .../src/physical_plan/expressions/sum.rs | 85 ++++++++++++++----- 4 files changed, 168 insertions(+), 35 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 6efcd3f96c64..8c3df46a22be 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1845,9 +1845,9 @@ mod tests { #[tokio::test] async fn aggregate_decimal_min() -> Result<()> { let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select min(c1) from d_table") .await .unwrap(); @@ -1858,6 +1858,10 @@ mod tests { "| -100.009 |", "+-----------------+", ]; + assert_eq!( + &DataType::Decimal(10, 3), + result[0].schema().field(0).data_type() + ); assert_batches_sorted_eq!(expected, &result); Ok(()) } @@ -1865,6 +1869,7 @@ mod tests { #[tokio::test] async fn aggregate_decimal_max() -> Result<()> { let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); @@ -1878,17 +1883,59 @@ mod tests { "| 110.009 |", "+-----------------+", ]; + assert_eq!( + &DataType::Decimal(10, 3), + result[0].schema().field(0).data_type() + ); assert_batches_sorted_eq!(expected, &result); Ok(()) } #[tokio::test] async fn aggregate_decimal_sum() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| SUM(d_table.c1) |", + "+-----------------+", + "| 100.000 |", + "+-----------------+", + ]; + assert_eq!( + &DataType::Decimal(20, 3), + result[0].schema().field(0).data_type() + ); + assert_batches_sorted_eq!(expected, &result); Ok(()) } #[tokio::test] async fn aggregate_decimal_avg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + let result = plan_and_collect(&mut ctx, "select avg(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| AVG(d_table.c1) |", + "+-----------------+", + "| 5.0000000 |", + "+-----------------+", + ]; + assert_eq!( + &DataType::Decimal(14, 7), + result[0].schema().field(0).data_type() + ); + assert_batches_sorted_eq!(expected, &result); Ok(()) } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 47fc0ac31af4..e9f9696a56e8 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -115,7 +115,6 @@ pub fn return_type( // The coerced_data_types is same with input_types. Ok(coerced_data_types[0].clone()) } - // TODO get the new decimal data type AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( @@ -427,7 +426,7 @@ mod tests { | DataType::Int16 | DataType::Int32 | DataType::Int64 => DataType::Int64, - DataType::Float32 | DataType::Float64 => data_type.clone(), + DataType::Float32 | DataType::Float64 => DataType::Float64, _ => data_type.clone(), }; @@ -471,6 +470,29 @@ mod tests { Ok(()) } + #[test] + fn test_sum_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32])?; + assert_eq!(DataType::Int64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8])?; + assert_eq!(DataType::UInt64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(10, 5)])?; + assert_eq!(DataType::Decimal(20, 5), observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(35, 5)])?; + assert_eq!(DataType::Decimal(38, 5), observed); + + Ok(()) + } + #[test] fn test_sum_no_utf8() { let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]); @@ -505,6 +527,15 @@ mod tests { let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?; assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(10, 6)])?; + assert_eq!(DataType::Decimal(14, 10), observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(36, 6)])?; + assert_eq!(DataType::Decimal(38, 10), observed); Ok(()) } diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 4a8373659b7e..3d781c5c8234 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -44,8 +44,12 @@ pub struct Avg { /// function return type of an average pub fn avg_return_type(arg_type: &DataType) -> Result { match arg_type { - // TODO how to handler decimal data type - DataType::Decimal(precision, scale) => Ok(DataType::Decimal(*precision, *scale)), + DataType::Decimal(precision, scale) => { + // the new precision and scale for return type of avg function + let new_precision = 38.min(*precision + 4); + let new_scale = 38.min(*scale + 4); + Ok(DataType::Decimal(new_precision, new_scale)) + } DataType::Int8 | DataType::Int16 | DataType::Int32 @@ -214,9 +218,9 @@ impl Accumulator for AvgAccumulator { Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) } ScalarValue::Decimal128(value, precision, scale) => { - // TODO support the decimal data type Ok(match value { None => ScalarValue::Decimal128(None, precision, scale), + // TODO add the checker for overflow the precision Some(v) => ScalarValue::Decimal128( Some(v / self.count as i128), precision, @@ -239,6 +243,18 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; + #[test] + fn test_avg_return_data_type() -> Result<()> { + let data_type = DataType::Decimal(10, 5); + let result_type = avg_return_type(&data_type)?; + assert_eq!(DataType::Decimal(14, 9), result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = avg_return_type(&data_type)?; + assert_eq!(DataType::Decimal(38, 14), result_type); + Ok(()) + } + #[test] fn avg_decimal() -> Result<()> { // test agg @@ -253,8 +269,8 @@ mod tests { array, DataType::Decimal(10, 0), Avg, - ScalarValue::Decimal128(Some(3), 10, 0), - DataType::Decimal(10, 0) + ScalarValue::Decimal128(Some(35000), 14, 4), + DataType::Decimal(14, 4) ) } @@ -273,8 +289,8 @@ mod tests { array, DataType::Decimal(10, 0), Avg, - ScalarValue::Decimal128(Some(3), 10, 0), - DataType::Decimal(10, 0) + ScalarValue::Decimal128(Some(32500), 14, 4), + DataType::Decimal(14, 4) ) } @@ -290,8 +306,8 @@ mod tests { array, DataType::Decimal(10, 0), Avg, - ScalarValue::Decimal128(None, 10, 0), - DataType::Decimal(10, 0) + ScalarValue::Decimal128(None, 14, 4), + DataType::Decimal(14, 4) ) } diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index c9ebafb0ad06..f308415ee3ff 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -58,16 +58,12 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } // In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc, // the result type of floating-point is FLOAT64 with the double precision. - // TODO, we should change this rule - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), - // TODO get the wider precision + DataType::Float64 | DataType::Float32 => Ok(DataType::Float64), // Max precision is 38 DataType::Decimal(precision, scale) => { - // // in the spark, the result type is DECIMAL(p + min(10, 31-p), s) - // let new_precision = *precision + 10.min(31-*precision); - // Ok(DataType::Decimal(new_precision, *scale)) - Ok(DataType::Decimal(*precision, *scale)) + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + let new_precision = 38.min(*precision + 10); + Ok(DataType::Decimal(new_precision, *scale)) } other => Err(DataFusionError::Plan(format!( "SUM does not support type \"{:?}\"", @@ -244,11 +240,38 @@ fn sum_decimal( } } +fn sum_decimal_with_diff_scale( + lhs: &Option, + rhs: &Option, + precision: &usize, + lhs_scale: &usize, + rhs_scale: &usize, +) -> ScalarValue { + // the lhs_scale must be greater or equal rhs_scale. + match (lhs, rhs) { + (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale), + (None, rhs) => { + let new_value = rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32); + ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) + } + (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale), + (lhs, rhs) => { + let new_value = + rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs.unwrap(); + ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) + } + } +} + pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { Ok(match (lhs, rhs) { - (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, _p2, s2)) => { + (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { if s1.eq(s2) { sum_decimal(v1, v2, p1, s1) + } else if s1.gt(s2) && p1.ge(p2) { + // For avg aggravate function. + // In the avg function, the scale of result data type is different with the scale of the input data type. + sum_decimal_with_diff_scale(v1, v2, p1, s1, s2) } else { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive lhs {:?}, rhs {:?}", @@ -333,16 +356,12 @@ impl Accumulator for SumAccumulator { fn update(&mut self, values: &[ScalarValue]) -> Result<()> { // sum(v1, v2, v3) = v1 + v2 + v3 - // For the decimal data type, the precision of `sum` may be different from that of value, - // but the scale must be same. self.sum = sum(&self.sum, &values[0])?; Ok(()) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; - // For the decimal data type, the precision of `sum` may be different from that of value, - // but the scale must be same. self.sum = sum(&self.sum, &sum_batch(values)?)?; Ok(()) } @@ -358,8 +377,8 @@ impl Accumulator for SumAccumulator { } fn evaluate(&self) -> Result { - // TODO: For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. - // We should add the checker + // TODO: add the checker for overflow + // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. Ok(self.sum.clone()) } } @@ -373,6 +392,18 @@ mod tests { use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + #[test] + fn test_sum_return_data_type() -> Result<()> { + let data_type = DataType::Decimal(10, 5); + let result_type = sum_return_type(&data_type)?; + assert_eq!(DataType::Decimal(20, 5), result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = sum_return_type(&data_type)?; + assert_eq!(DataType::Decimal(38, 10), result_type); + Ok(()) + } + #[test] fn sum_decimal() -> Result<()> { // test sum @@ -380,6 +411,14 @@ mod tests { let right = ScalarValue::Decimal128(Some(124), 10, 2); let result = sum(&left, &right)?; assert_eq!(ScalarValue::Decimal128(Some(123 + 124), 10, 2), result); + // test sum decimal with diff scale + let left = ScalarValue::Decimal128(Some(123), 10, 3); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + let result = sum(&left, &right)?; + assert_eq!( + ScalarValue::Decimal128(Some(123 + 124 * 10_i128.pow(1)), 10, 3), + result + ); // negative test with diff scale let left = ScalarValue::Decimal128(Some(123), 10, 2); let right = ScalarValue::Decimal128(Some(124), 11, 3); @@ -413,8 +452,8 @@ mod tests { array, DataType::Decimal(10, 0), Sum, - ScalarValue::Decimal128(Some(15), 10, 0), - DataType::Decimal(10, 0) + ScalarValue::Decimal128(Some(15), 20, 0), + DataType::Decimal(20, 0) ) } @@ -440,7 +479,7 @@ mod tests { assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result); // test agg - let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + let mut decimal_builder = DecimalBuilder::new(5, 35, 0); for i in 1..6 { if i == 2 { decimal_builder.append_null()?; @@ -451,10 +490,10 @@ mod tests { let array: ArrayRef = Arc::new(decimal_builder.finish()); generic_test_op!( array, - DataType::Decimal(10, 0), + DataType::Decimal(35, 0), Sum, - ScalarValue::Decimal128(Some(13), 10, 0), - DataType::Decimal(10, 0) + ScalarValue::Decimal128(Some(13), 38, 0), + DataType::Decimal(38, 0) ) } @@ -485,8 +524,8 @@ mod tests { array, DataType::Decimal(10, 0), Sum, - ScalarValue::Decimal128(None, 10, 0), - DataType::Decimal(10, 0) + ScalarValue::Decimal128(None, 20, 0), + DataType::Decimal(20, 0) ) } From f4b56559b4ebb2f3e2b644a7f6b91ac165d88c2f Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Fri, 17 Dec 2021 17:01:21 +0800 Subject: [PATCH 4/4] add comments and const --- .../src/physical_plan/expressions/average.rs | 33 +++++------ .../src/physical_plan/expressions/sum.rs | 47 +++++++-------- datafusion/src/scalar.rs | 8 ++- datafusion/src/sql/planner.rs | 58 +++++++++---------- 4 files changed, 69 insertions(+), 77 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 3d781c5c8234..f09298998a2a 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -23,7 +23,9 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::scalar::ScalarValue; +use crate::scalar::{ + ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, +}; use arrow::compute; use arrow::datatypes::DataType; use arrow::{ @@ -45,9 +47,10 @@ pub struct Avg { pub fn avg_return_type(arg_type: &DataType) -> Result { match arg_type { DataType::Decimal(precision, scale) => { - // the new precision and scale for return type of avg function - let new_precision = 38.min(*precision + 4); - let new_scale = 38.min(*scale + 4); + // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); + let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); Ok(DataType::Decimal(new_precision, new_scale)) } DataType::Int8 @@ -91,18 +94,15 @@ impl Avg { name: impl Into, data_type: DataType, ) -> Self { - // Average is always Float64, but Avg::new() has a data_type - // parameter to keep a consistent signature with the other - // Aggregate expressions. - match data_type { - DataType::Float64 | DataType::Decimal(_, _) => Self { - name: name.into(), - expr, - data_type, - }, - _ => { - unreachable!(); - } + // the result of avg just support FLOAT64 and Decimal data type. + assert!(matches!( + data_type, + DataType::Float64 | DataType::Decimal(_, _) + )); + Self { + name: name.into(), + expr, + data_type, } } } @@ -260,7 +260,6 @@ mod tests { // test agg let mut decimal_builder = DecimalBuilder::new(6, 10, 0); for i in 1..7 { - // the avg is 3.5, but we get the result of 3 decimal_builder.append_value(i as i128)?; } let array: ArrayRef = Arc::new(decimal_builder.finish()); diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index f308415ee3ff..027736dbc478 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::scalar::ScalarValue; +use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use arrow::compute; use arrow::datatypes::DataType; use arrow::{ @@ -56,13 +56,13 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { Ok(DataType::UInt64) } - // In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc, + // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc, // the result type of floating-point is FLOAT64 with the double precision. DataType::Float64 | DataType::Float32 => Ok(DataType::Float64), - // Max precision is 38 DataType::Decimal(precision, scale) => { // in the spark, the result type is DECIMAL(min(38,precision+10), s) - let new_precision = 38.min(*precision + 10); + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 10); Ok(DataType::Decimal(new_precision, *scale)) } other => Err(DataFusionError::Plan(format!( @@ -234,8 +234,8 @@ fn sum_decimal( (None, None) => ScalarValue::Decimal128(None, *precision, *scale), (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale), (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale), - (lhs, rhs) => { - ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale) + (Some(lhs_value), Some(rhs_value)) => { + ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale) } } } @@ -250,14 +250,14 @@ fn sum_decimal_with_diff_scale( // the lhs_scale must be greater or equal rhs_scale. match (lhs, rhs) { (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale), - (None, rhs) => { - let new_value = rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32); + (None, Some(rhs_value)) => { + let new_value = rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32); ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) } (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale), - (lhs, rhs) => { + (Some(lhs_value), Some(rhs_value)) => { let new_value = - rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs.unwrap(); + rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs_value; ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) } } @@ -266,17 +266,16 @@ fn sum_decimal_with_diff_scale( pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { Ok(match (lhs, rhs) { (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { + let max_precision = p1.max(p2); if s1.eq(s2) { - sum_decimal(v1, v2, p1, s1) - } else if s1.gt(s2) && p1.ge(p2) { - // For avg aggravate function. - // In the avg function, the scale of result data type is different with the scale of the input data type. - sum_decimal_with_diff_scale(v1, v2, p1, s1, s2) + // s1 = s2 + sum_decimal(v1, v2, max_precision, s1) + } else if s1.gt(s2) { + // s1 > s2 + sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2) } else { - return Err(DataFusionError::Internal(format!( - "Sum is not expected to receive lhs {:?}, rhs {:?}", - lhs, rhs - ))); + // s1 < s2 + sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1) } } // float64 coerces everything to f64 @@ -419,17 +418,13 @@ mod tests { ScalarValue::Decimal128(Some(123 + 124 * 10_i128.pow(1)), 10, 3), result ); - // negative test with diff scale + // diff precision and scale for decimal data type let left = ScalarValue::Decimal128(Some(123), 10, 2); let right = ScalarValue::Decimal128(Some(124), 11, 3); let result = sum(&left, &right); assert_eq!( - DataFusionError::Internal(format!( - "Sum is not expected to receive lhs {:?}, rhs {:?}", - left, right - )) - .to_string(), - result.unwrap_err().to_string() + ScalarValue::Decimal128(Some(123 * 10_i128.pow(3 - 2) + 124), 11, 3), + result.unwrap() ); // test sum batch diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index e9eafe1c109c..35ebb2aa8193 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -33,6 +33,11 @@ use std::convert::{Infallible, TryInto}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +// TODO may need to be moved to arrow-rs +/// The max precision and scale for decimal128 +pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38; +pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38; + /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part of arrow’s `Array`. #[derive(Clone)] @@ -480,8 +485,7 @@ impl ScalarValue { scale: usize, ) -> Result { // make sure the precision and scale is valid - // TODO const the max precision and min scale - if precision <= 38 && scale <= precision { + if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); } return Err(DataFusionError::Internal(format!( diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 72c1962a3e16..02020fb54645 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -35,7 +35,7 @@ use crate::logical_plan::{ }; use crate::optimizer::utils::exprlist_to_columns; use crate::prelude::JoinType; -use crate::scalar::ScalarValue; +use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use crate::{ error::{DataFusionError, Result}, physical_plan::udaf::AggregateUDF, @@ -371,27 +371,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => { Ok(DataType::Utf8) } - SQLDataType::Decimal(precision, scale) => { - match (precision, scale) { - (None, _) | (_, None) => { + SQLDataType::Decimal(precision, scale) => match (precision, scale) { + (None, _) | (_, None) => { + return Err(DataFusionError::Internal(format!( + "Invalid Decimal type ({:?}), precision or scale can't be empty.", + sql_type + ))); + } + (Some(p), Some(s)) => { + if (*p as usize) > MAX_PRECISION_FOR_DECIMAL128 || *s > *p { return Err(DataFusionError::Internal(format!( - "Invalid Decimal type ({:?}), precision or scale can't be empty.", - sql_type - ))); - } - (Some(p), Some(s)) => { - // TODO add bound checker in some utils file or function - if *p > 38 || *s > *p { - return Err(DataFusionError::Internal(format!( "Error Decimal Type ({:?}), precision must be less than or equal to 38 and scale can't be greater than precision", sql_type ))); - } else { - Ok(DataType::Decimal(*p as usize, *s as usize)) - } + } else { + Ok(DataType::Decimal(*p as usize, *s as usize)) } } - } + }, SQLDataType::Float(_) => Ok(DataType::Float32), SQLDataType::Real => Ok(DataType::Float32), SQLDataType::Double => Ok(DataType::Float64), @@ -2054,27 +2051,24 @@ pub fn convert_data_type(sql_type: &SQLDataType) -> Result { SQLDataType::Char(_) | SQLDataType::Varchar(_) => Ok(DataType::Utf8), SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), SQLDataType::Date => Ok(DataType::Date32), - SQLDataType::Decimal(precision, scale) => { - match (precision, scale) { - (None, _) | (_, None) => { + SQLDataType::Decimal(precision, scale) => match (precision, scale) { + (None, _) | (_, None) => { + return Err(DataFusionError::Internal(format!( + "Invalid Decimal type ({:?}), precision or scale can't be empty.", + sql_type + ))); + } + (Some(p), Some(s)) => { + if (*p as usize) > MAX_PRECISION_FOR_DECIMAL128 || *s > *p { return Err(DataFusionError::Internal(format!( - "Invalid Decimal type ({:?}), precision or scale can't be empty.", + "Error Decimal Type ({:?})", sql_type ))); - } - (Some(p), Some(s)) => { - // TODO add bound checker in some utils file or function - if *p > 38 || *s > *p { - return Err(DataFusionError::Internal(format!( - "Error Decimal Type ({:?})", - sql_type - ))); - } else { - Ok(DataType::Decimal(*p as usize, *s as usize)) - } + } else { + Ok(DataType::Decimal(*p as usize, *s as usize)) } } - } + }, other => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL type {:?}", other