diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index d7c536ed2771..8c3df46a22be 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1845,9 +1845,9 @@ mod tests { #[tokio::test] async fn aggregate_decimal_min() -> Result<()> { let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); - let result = plan_and_collect(&mut ctx, "select min(c1) from d_table") .await .unwrap(); @@ -1858,6 +1858,10 @@ mod tests { "| -100.009 |", "+-----------------+", ]; + assert_eq!( + &DataType::Decimal(10, 3), + result[0].schema().field(0).data_type() + ); assert_batches_sorted_eq!(expected, &result); Ok(()) } @@ -1865,6 +1869,7 @@ mod tests { #[tokio::test] async fn aggregate_decimal_max() -> Result<()> { let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); @@ -1878,6 +1883,58 @@ mod tests { "| 110.009 |", "+-----------------+", ]; + assert_eq!( + &DataType::Decimal(10, 3), + result[0].schema().field(0).data_type() + ); + assert_batches_sorted_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn aggregate_decimal_sum() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| SUM(d_table.c1) |", + "+-----------------+", + "| 100.000 |", + "+-----------------+", + ]; + assert_eq!( + &DataType::Decimal(20, 3), + result[0].schema().field(0).data_type() + ); + assert_batches_sorted_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn aggregate_decimal_avg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,3) + ctx.register_table("d_table", test::table_with_decimal()) + .unwrap(); + let result = plan_and_collect(&mut ctx, "select avg(c1) from d_table") + .await + .unwrap(); + let expected = vec![ + "+-----------------+", + "| AVG(d_table.c1) |", + "+-----------------+", + "| 5.0000000 |", + "+-----------------+", + ]; + assert_eq!( + &DataType::Decimal(14, 7), + result[0].schema().field(0).data_type() + ); assert_batches_sorted_eq!(expected, &result); Ok(()) } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 50e1a82c74c2..e9f9696a56e8 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -426,7 +426,7 @@ mod tests { | DataType::Int16 | DataType::Int32 | DataType::Int64 => DataType::Int64, - DataType::Float32 | DataType::Float64 => data_type.clone(), + DataType::Float32 | DataType::Float64 => DataType::Float64, _ => data_type.clone(), }; @@ -470,6 +470,29 @@ mod tests { Ok(()) } + #[test] + fn test_sum_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32])?; + assert_eq!(DataType::Int64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8])?; + assert_eq!(DataType::UInt64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(10, 5)])?; + assert_eq!(DataType::Decimal(20, 5), observed); + + let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(35, 5)])?; + assert_eq!(DataType::Decimal(38, 5), observed); + + Ok(()) + } + #[test] fn test_sum_no_utf8() { let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]); @@ -504,6 +527,15 @@ mod tests { let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?; assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(10, 6)])?; + assert_eq!(DataType::Decimal(14, 10), observed); + + let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(36, 6)])?; + assert_eq!(DataType::Decimal(38, 10), observed); Ok(()) } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index d7b437528d5c..e76e4a6b023e 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -193,8 +193,7 @@ mod tests { let input_types = vec![ vec![DataType::Int32], vec![DataType::Float32], - // support the decimal data type - // vec![DataType::Decimal(20, 3)], + vec![DataType::Decimal(20, 3)], ]; for fun in funs { for input_type in &input_types { diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index feb568c8dd72..f09298998a2a 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -23,7 +23,9 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::scalar::ScalarValue; +use crate::scalar::{ + ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, +}; use arrow::compute; use arrow::datatypes::DataType; use arrow::{ @@ -38,11 +40,19 @@ use super::{format_state_name, sum}; pub struct Avg { name: String, expr: Arc, + data_type: DataType, } /// function return type of an average pub fn avg_return_type(arg_type: &DataType) -> Result { match arg_type { + DataType::Decimal(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); + let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); + Ok(DataType::Decimal(new_precision, new_scale)) + } DataType::Int8 | DataType::Int16 | DataType::Int32 @@ -73,6 +83,7 @@ pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 + | DataType::Decimal(_, _) ) } @@ -83,14 +94,15 @@ impl Avg { name: impl Into, data_type: DataType, ) -> Self { - // Average is always Float64, but Avg::new() has a data_type - // parameter to keep a consistent signature with the other - // Aggregate expressions. - assert_eq!(data_type, DataType::Float64); - + // the result of avg just support FLOAT64 and Decimal data type. + assert!(matches!( + data_type, + DataType::Float64 | DataType::Decimal(_, _) + )); Self { name: name.into(), expr, + data_type, } } } @@ -102,7 +114,14 @@ impl AggregateExpr for Avg { } fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(AvgAccumulator::try_new( + // avg is f64 or decimal + &self.data_type, + )?)) } fn state_fields(&self) -> Result> { @@ -114,19 +133,12 @@ impl AggregateExpr for Avg { ), Field::new( &format_state_name(&self.name, "sum"), - DataType::Float64, + self.data_type.clone(), true, ), ]) } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - // avg is f64 - &DataType::Float64, - )?)) - } - fn expressions(&self) -> Vec> { vec![self.expr.clone()] } @@ -205,6 +217,17 @@ impl Accumulator for AvgAccumulator { ScalarValue::Float64(e) => { Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) } + ScalarValue::Decimal128(value, precision, scale) => { + Ok(match value { + None => ScalarValue::Decimal128(None, precision, scale), + // TODO add the checker for overflow the precision + Some(v) => ScalarValue::Decimal128( + Some(v / self.count as i128), + precision, + scale, + ), + }) + } _ => Err(DataFusionError::Internal( "Sum should be f64 on average".to_string(), )), @@ -220,6 +243,73 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; + #[test] + fn test_avg_return_data_type() -> Result<()> { + let data_type = DataType::Decimal(10, 5); + let result_type = avg_return_type(&data_type)?; + assert_eq!(DataType::Decimal(14, 9), result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = avg_return_type(&data_type)?; + assert_eq!(DataType::Decimal(38, 14), result_type); + Ok(()) + } + + #[test] + fn avg_decimal() -> Result<()> { + // test agg + let mut decimal_builder = DecimalBuilder::new(6, 10, 0); + for i in 1..7 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + + generic_test_op!( + array, + DataType::Decimal(10, 0), + Avg, + ScalarValue::Decimal128(Some(35000), 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn avg_decimal_with_nulls() -> Result<()> { + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Avg, + ScalarValue::Decimal128(Some(32500), 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn avg_decimal_all_nulls() -> Result<()> { + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Avg, + ScalarValue::Decimal128(None, 14, 4), + DataType::Decimal(14, 4) + ) + } + #[test] fn avg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index c570aef72b52..027736dbc478 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::scalar::ScalarValue; +use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use arrow::compute; use arrow::datatypes::DataType; use arrow::{ @@ -35,6 +35,8 @@ use arrow::{ }; use super::format_state_name; +use crate::arrow::array::Array; +use arrow::array::DecimalArray; /// SUM aggregate expression #[derive(Debug)] @@ -54,8 +56,15 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { Ok(DataType::UInt64) } - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), + // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc, + // the result type of floating-point is FLOAT64 with the double precision. + DataType::Float64 | DataType::Float32 => Ok(DataType::Float64), + DataType::Decimal(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 10); + Ok(DataType::Decimal(new_precision, *scale)) + } other => Err(DataFusionError::Plan(format!( "SUM does not support type \"{:?}\"", other @@ -76,6 +85,7 @@ pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 + | DataType::Decimal(_, _) ) } @@ -109,6 +119,10 @@ impl AggregateExpr for Sum { )) } + fn create_accumulator(&self) -> Result> { + Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) + } + fn state_fields(&self) -> Result> { Ok(vec![Field::new( &format_state_name(&self.name, "sum"), @@ -121,10 +135,6 @@ impl AggregateExpr for Sum { vec![self.expr.clone()] } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) - } - fn name(&self) -> &str { &self.name } @@ -153,9 +163,34 @@ macro_rules! typed_sum_delta_batch { }}; } +// TODO implement this in arrow-rs with simd +// https://github.com/apache/arrow-rs/issues/1010 +fn sum_decimal_batch( + values: &ArrayRef, + precision: &usize, + scale: &usize, +) -> Result { + let array = values.as_any().downcast_ref::().unwrap(); + + if array.null_count() == array.len() { + return Ok(ScalarValue::Decimal128(None, *precision, *scale)); + } + + let mut result = 0_i128; + for i in 0..array.len() { + if array.is_valid(i) { + result += array.value(i); + } + } + Ok(ScalarValue::Decimal128(Some(result), *precision, *scale)) +} + // sums the array and returns a ScalarValue of its corresponding type. pub(super) fn sum_batch(values: &ArrayRef) -> Result { Ok(match values.data_type() { + DataType::Decimal(precision, scale) => { + sum_decimal_batch(values, precision, scale)? + } DataType::Float64 => typed_sum_delta_batch!(values, Float64Array, Float64), DataType::Float32 => typed_sum_delta_batch!(values, Float32Array, Float32), DataType::Int64 => typed_sum_delta_batch!(values, Int64Array, Int64), @@ -170,7 +205,7 @@ pub(super) fn sum_batch(values: &ArrayRef) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive the type {:?}", e - ))) + ))); } }) } @@ -187,8 +222,62 @@ macro_rules! typed_sum { }}; } +// TODO implement this in arrow-rs with simd +// https://github.com/apache/arrow-rs/issues/1010 +fn sum_decimal( + lhs: &Option, + rhs: &Option, + precision: &usize, + scale: &usize, +) -> ScalarValue { + match (lhs, rhs) { + (None, None) => ScalarValue::Decimal128(None, *precision, *scale), + (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale), + (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale), + (Some(lhs_value), Some(rhs_value)) => { + ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale) + } + } +} + +fn sum_decimal_with_diff_scale( + lhs: &Option, + rhs: &Option, + precision: &usize, + lhs_scale: &usize, + rhs_scale: &usize, +) -> ScalarValue { + // the lhs_scale must be greater or equal rhs_scale. + match (lhs, rhs) { + (None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale), + (None, Some(rhs_value)) => { + let new_value = rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32); + ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) + } + (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale), + (Some(lhs_value), Some(rhs_value)) => { + let new_value = + rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs_value; + ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) + } + } +} + pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { Ok(match (lhs, rhs) { + (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { + let max_precision = p1.max(p2); + if s1.eq(s2) { + // s1 = s2 + sum_decimal(v1, v2, max_precision, s1) + } else if s1.gt(s2) { + // s1 > s2 + sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2) + } else { + // s1 < s2 + sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1) + } + } // float64 coerces everything to f64 (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { typed_sum!(lhs, rhs, Float64, f64) @@ -254,16 +343,14 @@ pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive a scalar {:?}", e - ))) + ))); } }) } impl Accumulator for SumAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.sum = sum(&self.sum, &sum_batch(values)?)?; - Ok(()) + fn state(&self) -> Result> { + Ok(vec![self.sum.clone()]) } fn update(&mut self, values: &[ScalarValue]) -> Result<()> { @@ -272,6 +359,12 @@ impl Accumulator for SumAccumulator { Ok(()) } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + self.sum = sum(&self.sum, &sum_batch(values)?)?; + Ok(()) + } + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { // sum(sum1, sum2) = sum1 + sum2 self.update(states) @@ -282,11 +375,9 @@ impl Accumulator for SumAccumulator { self.update_batch(states) } - fn state(&self) -> Result> { - Ok(vec![self.sum.clone()]) - } - fn evaluate(&self) -> Result { + // TODO: add the checker for overflow + // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. Ok(self.sum.clone()) } } @@ -294,11 +385,145 @@ impl Accumulator for SumAccumulator { #[cfg(test)] mod tests { use super::*; + use crate::arrow::array::DecimalBuilder; use crate::physical_plan::expressions::col; use crate::{error::Result, generic_test_op}; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + #[test] + fn test_sum_return_data_type() -> Result<()> { + let data_type = DataType::Decimal(10, 5); + let result_type = sum_return_type(&data_type)?; + assert_eq!(DataType::Decimal(20, 5), result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = sum_return_type(&data_type)?; + assert_eq!(DataType::Decimal(38, 10), result_type); + Ok(()) + } + + #[test] + fn sum_decimal() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(Some(123 + 124), 10, 2), result); + // test sum decimal with diff scale + let left = ScalarValue::Decimal128(Some(123), 10, 3); + let right = ScalarValue::Decimal128(Some(124), 10, 2); + let result = sum(&left, &right)?; + assert_eq!( + ScalarValue::Decimal128(Some(123 + 124 * 10_i128.pow(1)), 10, 3), + result + ); + // diff precision and scale for decimal data type + let left = ScalarValue::Decimal128(Some(123), 10, 2); + let right = ScalarValue::Decimal128(Some(124), 11, 3); + let result = sum(&left, &right); + assert_eq!( + ScalarValue::Decimal128(Some(123 * 10_i128.pow(3 - 2) + 124), 11, 3), + result.unwrap() + ); + + // test sum batch + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(Some(15), 10, 0), result); + + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + + generic_test_op!( + array, + DataType::Decimal(10, 0), + Sum, + ScalarValue::Decimal128(Some(15), 20, 0), + DataType::Decimal(20, 0) + ) + } + + #[test] + fn sum_decimal_with_nulls() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(None, 10, 2); + let right = ScalarValue::Decimal128(Some(123), 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(Some(123), 10, 2), result); + + // test with batch + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(Some(13), 10, 0), result); + + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 35, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(35, 0), + Sum, + ScalarValue::Decimal128(Some(13), 38, 0), + DataType::Decimal(38, 0) + ) + } + + #[test] + fn sum_decimal_all_nulls() -> Result<()> { + // test sum + let left = ScalarValue::Decimal128(None, 10, 2); + let right = ScalarValue::Decimal128(None, 10, 2); + let result = sum(&left, &right)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 2), result); + + // test with batch + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + let result = sum_batch(&array)?; + assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); + + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Sum, + ScalarValue::Decimal128(None, 20, 0), + DataType::Decimal(20, 0) + ) + } + #[test] fn sum_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index e9eafe1c109c..35ebb2aa8193 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -33,6 +33,11 @@ use std::convert::{Infallible, TryInto}; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +// TODO may need to be moved to arrow-rs +/// The max precision and scale for decimal128 +pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38; +pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38; + /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part of arrow’s `Array`. #[derive(Clone)] @@ -480,8 +485,7 @@ impl ScalarValue { scale: usize, ) -> Result { // make sure the precision and scale is valid - // TODO const the max precision and min scale - if precision <= 38 && scale <= precision { + if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision { return Ok(ScalarValue::Decimal128(Some(value), precision, scale)); } return Err(DataFusionError::Internal(format!( diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index bce50e5610d3..0ede5ad8559e 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -20,7 +20,7 @@ use arrow::datatypes::DataType; use crate::logical_plan::{Expr, LogicalPlan}; -use crate::scalar::ScalarValue; +use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128}; use crate::{ error::{DataFusionError, Result}, logical_plan::{Column, ExpressionVisitor, Recursion}, @@ -520,7 +520,7 @@ pub(crate) fn make_decimal_type( } (Some(p), Some(s)) => { // Arrow decimal is i128 meaning 38 maximum decimal digits - if p > 38 || s > p { + if (p as usize) > MAX_PRECISION_FOR_DECIMAL128 || s > p { return Err(DataFusionError::Internal(format!( "For decimal(precision, scale) precision must be less than or equal to 38 and scale can't be greater than precision. Got ({}, {})", p, s