Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions datafusion/src/physical_plan/expressions/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use std::sync::Arc;

use crate::error::{DataFusionError, Result};
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use crate::scalar::{
ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128,
};
use arrow::compute;
use arrow::datatypes::DataType;
use arrow::{
Expand All @@ -45,9 +47,10 @@ pub struct Avg {
pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal(precision, scale) => {
// the new precision and scale for return type of avg function
let new_precision = 38.min(*precision + 4);
let new_scale = 38.min(*scale + 4);
// in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4);
let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4);
Ok(DataType::Decimal(new_precision, new_scale))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
DataType::Int8
Expand Down Expand Up @@ -91,18 +94,15 @@ impl Avg {
name: impl Into<String>,
data_type: DataType,
) -> Self {
// Average is always Float64, but Avg::new() has a data_type
// parameter to keep a consistent signature with the other
// Aggregate expressions.
match data_type {
DataType::Float64 | DataType::Decimal(_, _) => Self {
name: name.into(),
expr,
data_type,
},
_ => {
unreachable!();
}
// the result of avg just support FLOAT64 and Decimal data type.
assert!(matches!(
data_type,
DataType::Float64 | DataType::Decimal(_, _)
));
Self {
name: name.into(),
expr,
data_type,
}
}
}
Expand Down Expand Up @@ -260,7 +260,6 @@ mod tests {
// test agg
let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
for i in 1..7 {
// the avg is 3.5, but we get the result of 3
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
Expand Down
47 changes: 21 additions & 26 deletions datafusion/src/physical_plan/expressions/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::sync::Arc;

use crate::error::{DataFusionError, Result};
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128};
use arrow::compute;
use arrow::datatypes::DataType;
use arrow::{
Expand Down Expand Up @@ -56,13 +56,13 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
Ok(DataType::UInt64)
}
// In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc,
// In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,
// the result type of floating-point is FLOAT64 with the double precision.
DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
// Max precision is 38
DataType::Decimal(precision, scale) => {
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for the context

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb
The same issue is about precision and scale promotion.
In the PG

postgres=# \d t1
                   Table "public.t1"
 Column |     Type     | Collation | Nullable | Default
--------+--------------+-----------+----------+---------
 c1     | numeric(4,2) |           |          |

postgres=# create table test as select sum(c1) from t1;
SELECT 1
postgres=# \d test;
                Table "public.test"
 Column |  Type   | Collation | Nullable | Default
--------+---------+-----------+----------+---------
 sum    | numeric |           |          |

let new_precision = 38.min(*precision + 10);
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 10);
Ok(DataType::Decimal(new_precision, *scale))
}
other => Err(DataFusionError::Plan(format!(
Expand Down Expand Up @@ -234,8 +234,8 @@ fn sum_decimal(
(None, None) => ScalarValue::Decimal128(None, *precision, *scale),
(None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
(lhs, rhs) => {
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
(Some(lhs_value), Some(rhs_value)) => {
ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale)
}
}
}
Expand All @@ -250,14 +250,14 @@ fn sum_decimal_with_diff_scale(
// the lhs_scale must be greater or equal rhs_scale.
match (lhs, rhs) {
(None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
(None, rhs) => {
let new_value = rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32);
(None, Some(rhs_value)) => {
let new_value = rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32);
ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
}
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale),
(lhs, rhs) => {
(Some(lhs_value), Some(rhs_value)) => {
let new_value =
rhs.unwrap() * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs.unwrap();
rhs_value * 10_i128.pow((lhs_scale - rhs_scale) as u32) + lhs_value;
ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
}
}
Expand All @@ -266,17 +266,16 @@ fn sum_decimal_with_diff_scale(
pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
Ok(match (lhs, rhs) {
(ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
let max_precision = p1.max(p2);
if s1.eq(s2) {
sum_decimal(v1, v2, p1, s1)
} else if s1.gt(s2) && p1.ge(p2) {
// For avg aggravate function.
// In the avg function, the scale of result data type is different with the scale of the input data type.
sum_decimal_with_diff_scale(v1, v2, p1, s1, s2)
// s1 = s2
sum_decimal(v1, v2, max_precision, s1)
} else if s1.gt(s2) {
// s1 > s2
sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
} else {
return Err(DataFusionError::Internal(format!(
"Sum is not expected to receive lhs {:?}, rhs {:?}",
lhs, rhs
)));
// s1 < s2
sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
}
}
// float64 coerces everything to f64
Expand Down Expand Up @@ -419,17 +418,13 @@ mod tests {
ScalarValue::Decimal128(Some(123 + 124 * 10_i128.pow(1)), 10, 3),
result
);
// negative test with diff scale
// diff precision and scale for decimal data type
let left = ScalarValue::Decimal128(Some(123), 10, 2);
let right = ScalarValue::Decimal128(Some(124), 11, 3);
let result = sum(&left, &right);
assert_eq!(
DataFusionError::Internal(format!(
"Sum is not expected to receive lhs {:?}, rhs {:?}",
left, right
))
.to_string(),
result.unwrap_err().to_string()
ScalarValue::Decimal128(Some(123 * 10_i128.pow(3 - 2) + 124), 11, 3),
result.unwrap()
);

// test sum batch
Expand Down
8 changes: 6 additions & 2 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

// TODO may need to be moved to arrow-rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving to arrow-rs would be a good idea I think

/// The max precision and scale for decimal128
pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38;
pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38;

/// Represents a dynamically typed, nullable single value.
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone)]
Expand Down Expand Up @@ -480,8 +485,7 @@ impl ScalarValue {
scale: usize,
) -> Result<Self> {
// make sure the precision and scale is valid
// TODO const the max precision and min scale
if precision <= 38 && scale <= precision {
if precision <= MAX_PRECISION_FOR_DECIMAL128 && scale <= precision {
return Ok(ScalarValue::Decimal128(Some(value), precision, scale));
}
return Err(DataFusionError::Internal(format!(
Expand Down
58 changes: 26 additions & 32 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::logical_plan::{
};
use crate::optimizer::utils::exprlist_to_columns;
use crate::prelude::JoinType;
use crate::scalar::ScalarValue;
use crate::scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128};
use crate::{
error::{DataFusionError, Result},
physical_plan::udaf::AggregateUDF,
Expand Down Expand Up @@ -371,27 +371,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => {
Ok(DataType::Utf8)
}
SQLDataType::Decimal(precision, scale) => {
match (precision, scale) {
(None, _) | (_, None) => {
SQLDataType::Decimal(precision, scale) => match (precision, scale) {
(None, _) | (_, None) => {
return Err(DataFusionError::Internal(format!(
"Invalid Decimal type ({:?}), precision or scale can't be empty.",
sql_type
)));
}
(Some(p), Some(s)) => {
if (*p as usize) > MAX_PRECISION_FOR_DECIMAL128 || *s > *p {
return Err(DataFusionError::Internal(format!(
"Invalid Decimal type ({:?}), precision or scale can't be empty.",
sql_type
)));
}
(Some(p), Some(s)) => {
// TODO add bound checker in some utils file or function
if *p > 38 || *s > *p {
return Err(DataFusionError::Internal(format!(
"Error Decimal Type ({:?}), precision must be less than or equal to 38 and scale can't be greater than precision",
sql_type
)));
} else {
Ok(DataType::Decimal(*p as usize, *s as usize))
}
} else {
Ok(DataType::Decimal(*p as usize, *s as usize))
}
}
}
},
SQLDataType::Float(_) => Ok(DataType::Float32),
SQLDataType::Real => Ok(DataType::Float32),
SQLDataType::Double => Ok(DataType::Float64),
Expand Down Expand Up @@ -2054,27 +2051,24 @@ pub fn convert_data_type(sql_type: &SQLDataType) -> Result<DataType> {
SQLDataType::Char(_) | SQLDataType::Varchar(_) => Ok(DataType::Utf8),
SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)),
SQLDataType::Date => Ok(DataType::Date32),
SQLDataType::Decimal(precision, scale) => {
match (precision, scale) {
(None, _) | (_, None) => {
SQLDataType::Decimal(precision, scale) => match (precision, scale) {
(None, _) | (_, None) => {
return Err(DataFusionError::Internal(format!(
"Invalid Decimal type ({:?}), precision or scale can't be empty.",
sql_type
)));
}
(Some(p), Some(s)) => {
if (*p as usize) > MAX_PRECISION_FOR_DECIMAL128 || *s > *p {
return Err(DataFusionError::Internal(format!(
"Invalid Decimal type ({:?}), precision or scale can't be empty.",
"Error Decimal Type ({:?})",
sql_type
)));
}
(Some(p), Some(s)) => {
// TODO add bound checker in some utils file or function
if *p > 38 || *s > *p {
return Err(DataFusionError::Internal(format!(
"Error Decimal Type ({:?})",
sql_type
)));
} else {
Ok(DataType::Decimal(*p as usize, *s as usize))
}
} else {
Ok(DataType::Decimal(*p as usize, *s as usize))
}
}
}
},
other => Err(DataFusionError::NotImplemented(format!(
"Unsupported SQL type {:?}",
other
Expand Down