Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1845,9 +1845,9 @@ mod tests {
#[tokio::test]
async fn aggregate_decimal_min() -> Result<()> {
let mut ctx = ExecutionContext::new();
// the data type of c1 is decimal(10,3)
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();

let result = plan_and_collect(&mut ctx, "select min(c1) from d_table")
.await
.unwrap();
Expand All @@ -1858,13 +1858,18 @@ mod tests {
"| -100.009 |",
"+-----------------+",
];
assert_eq!(
&DataType::Decimal(10, 3),
result[0].schema().field(0).data_type()
);
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn aggregate_decimal_max() -> Result<()> {
let mut ctx = ExecutionContext::new();
// the data type of c1 is decimal(10,3)
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();

Expand All @@ -1878,6 +1883,58 @@ mod tests {
"| 110.009 |",
"+-----------------+",
];
assert_eq!(
&DataType::Decimal(10, 3),
result[0].schema().field(0).data_type()
);
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn aggregate_decimal_sum() -> Result<()> {
let mut ctx = ExecutionContext::new();
// the data type of c1 is decimal(10,3)
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();
let result = plan_and_collect(&mut ctx, "select sum(c1) from d_table")
.await
.unwrap();
let expected = vec![
"+-----------------+",
"| SUM(d_table.c1) |",
"+-----------------+",
"| 100.000 |",
"+-----------------+",
];
assert_eq!(
&DataType::Decimal(20, 3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

result[0].schema().field(0).data_type()
);
assert_batches_sorted_eq!(expected, &result);
Ok(())
}

#[tokio::test]
async fn aggregate_decimal_avg() -> Result<()> {
let mut ctx = ExecutionContext::new();
// the data type of c1 is decimal(10,3)
ctx.register_table("d_table", test::table_with_decimal())
.unwrap();
let result = plan_and_collect(&mut ctx, "select avg(c1) from d_table")
.await
.unwrap();
let expected = vec![
"+-----------------+",
"| AVG(d_table.c1) |",
"+-----------------+",
"| 5.0000000 |",
"+-----------------+",
];
assert_eq!(
&DataType::Decimal(14, 7),
result[0].schema().field(0).data_type()
);
assert_batches_sorted_eq!(expected, &result);
Ok(())
}
Expand Down
34 changes: 33 additions & 1 deletion datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ mod tests {
| DataType::Int16
| DataType::Int32
| DataType::Int64 => DataType::Int64,
DataType::Float32 | DataType::Float64 => data_type.clone(),
DataType::Float32 | DataType::Float64 => DataType::Float64,
_ => data_type.clone(),
};

Expand Down Expand Up @@ -470,6 +470,29 @@ mod tests {
Ok(())
}

#[test]
fn test_sum_return_type() -> Result<()> {
let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32])?;
assert_eq!(DataType::Int64, observed);

let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8])?;
assert_eq!(DataType::UInt64, observed);

let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(10, 5)])?;
assert_eq!(DataType::Decimal(20, 5), observed);

let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(35, 5)])?;
assert_eq!(DataType::Decimal(38, 5), observed);

Ok(())
}

#[test]
fn test_sum_no_utf8() {
let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]);
Expand Down Expand Up @@ -504,6 +527,15 @@ mod tests {

let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32])?;
assert_eq!(DataType::Float64, observed);

let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(10, 6)])?;
assert_eq!(DataType::Decimal(14, 10), observed);

let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(36, 6)])?;
assert_eq!(DataType::Decimal(38, 10), observed);
Ok(())
}

Expand Down
3 changes: 1 addition & 2 deletions datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ mod tests {
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Float32],
// support the decimal data type
// vec![DataType::Decimal(20, 3)],
vec![DataType::Decimal(20, 3)],
];
for fun in funs {
for input_type in &input_types {
Expand Down
120 changes: 105 additions & 15 deletions datafusion/src/physical_plan/expressions/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use std::sync::Arc;

use crate::error::{DataFusionError, Result};
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use crate::scalar::{
ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128,
};
use arrow::compute;
use arrow::datatypes::DataType;
use arrow::{
Expand All @@ -38,11 +40,19 @@ use super::{format_state_name, sum};
pub struct Avg {
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
}

/// function return type of an average
pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal(precision, scale) => {
// in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4);
let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4);
Ok(DataType::Decimal(new_precision, new_scale))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
DataType::Int8
| DataType::Int16
| DataType::Int32
Expand Down Expand Up @@ -73,6 +83,7 @@ pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Decimal(_, _)
)
}

Expand All @@ -83,14 +94,15 @@ impl Avg {
name: impl Into<String>,
data_type: DataType,
) -> Self {
// Average is always Float64, but Avg::new() has a data_type
// parameter to keep a consistent signature with the other
// Aggregate expressions.
assert_eq!(data_type, DataType::Float64);

// the result of avg just support FLOAT64 and Decimal data type.
assert!(matches!(
data_type,
DataType::Float64 | DataType::Decimal(_, _)
));
Self {
name: name.into(),
expr,
data_type,
}
}
}
Expand All @@ -102,7 +114,14 @@ impl AggregateExpr for Avg {
}

fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, DataType::Float64, true))
Ok(Field::new(&self.name, self.data_type.clone(), true))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(
// avg is f64 or decimal
&self.data_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a sum of decimal(10,2) can be decimal(20,2) shouldn't the accumulator state also be decimal(20,2) to avoid overflow?

I think handling overflow is probably fine to for a later date / PR, but it is strange to me that there is a discrepancy between the type for sum and the accumulator type for computing avg

Copy link
Contributor Author

@liukun4515 liukun4515 Dec 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result type of phy expr(sum/avg) is same with each Accumulator, and it was decided by sum_return_type and avg_return_type.

If the column is decimal(8,2), the avg of this column must be less than 10^8-1, but we need more digits to represent the decimal part. For example, The avg of 3,4,6 is 4.333333....., we should increase the scale part.

For the sum agg, we just should increase the precision part, and the rule of adding 10 to precision is spark coercion rule for sum decimal. We can have our rules for decimal if we want.
@alamb
We can just follow the spark now, and change the rules if we want to define own rules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add issue to track the overflow.
#1460

)?))
}

fn state_fields(&self) -> Result<Vec<Field>> {
Expand All @@ -114,19 +133,12 @@ impl AggregateExpr for Avg {
),
Field::new(
&format_state_name(&self.name, "sum"),
DataType::Float64,
self.data_type.clone(),
true,
),
])
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(
// avg is f64
&DataType::Float64,
)?))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}
Expand Down Expand Up @@ -205,6 +217,17 @@ impl Accumulator for AvgAccumulator {
ScalarValue::Float64(e) => {
Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64)))
}
ScalarValue::Decimal128(value, precision, scale) => {
Ok(match value {
None => ScalarValue::Decimal128(None, precision, scale),
// TODO add the checker for overflow the precision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Some(v) => ScalarValue::Decimal128(
Some(v / self.count as i128),
precision,
scale,
),
})
}
_ => Err(DataFusionError::Internal(
"Sum should be f64 on average".to_string(),
)),
Expand All @@ -220,6 +243,73 @@ mod tests {
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};

#[test]
fn test_avg_return_data_type() -> Result<()> {
let data_type = DataType::Decimal(10, 5);
let result_type = avg_return_type(&data_type)?;
assert_eq!(DataType::Decimal(14, 9), result_type);

let data_type = DataType::Decimal(36, 10);
let result_type = avg_return_type(&data_type)?;
assert_eq!(DataType::Decimal(38, 14), result_type);
Ok(())
}

#[test]
fn avg_decimal() -> Result<()> {
// test agg
let mut decimal_builder = DecimalBuilder::new(6, 10, 0);
for i in 1..7 {
decimal_builder.append_value(i as i128)?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());

generic_test_op!(
array,
DataType::Decimal(10, 0),
Avg,
ScalarValue::Decimal128(Some(35000), 14, 4),
DataType::Decimal(14, 4)
)
}

#[test]
fn avg_decimal_with_nulls() -> Result<()> {
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for i in 1..6 {
if i == 2 {
decimal_builder.append_null()?;
} else {
decimal_builder.append_value(i)?;
}
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Avg,
ScalarValue::Decimal128(Some(32500), 14, 4),
DataType::Decimal(14, 4)
)
}

#[test]
fn avg_decimal_all_nulls() -> Result<()> {
// test agg
let mut decimal_builder = DecimalBuilder::new(5, 10, 0);
for _i in 1..6 {
decimal_builder.append_null()?;
}
let array: ArrayRef = Arc::new(decimal_builder.finish());
generic_test_op!(
array,
DataType::Decimal(10, 0),
Avg,
ScalarValue::Decimal128(None, 14, 4),
DataType::Decimal(14, 4)
)
}

#[test]
fn avg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
Expand Down
Loading