diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index f67e2f49dcbf..37bbd1508c91 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{Field, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -71,6 +71,13 @@ pub struct AccumulatorArgs<'a> { pub exprs: &'a [Arc], } +impl AccumulatorArgs<'_> { + /// Returns the return type of the aggregate function. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + /// Factory that returns an accumulator for the given aggregate function. pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 15b5db2d72e0..3ca39aa31589 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -182,7 +182,7 @@ impl AggregateUDFImpl for Avg { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { matches!( args.return_field.data_type(), - DataType::Float64 | DataType::Decimal128(_, _) + DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_) ) } @@ -243,6 +243,45 @@ impl AggregateUDFImpl for Avg { ))) } + (Duration(time_unit), Duration(_result_unit)) => { + let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); + + match time_unit { + TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::< + DurationSecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMillisecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMicrosecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationNanosecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + } + } + _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", &data_type, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 19f92ed72e0b..41ce15d7942a 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5098,6 +5098,28 @@ FROM d WHERE column1 IS NOT NULL; statement ok drop table d; +statement ok +create table dn as values + (arrow_cast(10, 'Duration(Second)'), 'a', 1), + (arrow_cast(20, 'Duration(Second)'), 'a', 2), + (NULL, 'b', 1), + (arrow_cast(40, 'Duration(Second)'), 'b', 2), + (arrow_cast(50, 'Duration(Second)'), 'c', 1), + (NULL, 'c', 2); + +query T?I +SELECT column2, avg(column1), column3 FROM dn GROUP BY column2, column3 ORDER BY column2, column3; +---- +a 0 days 0 hours 0 mins 10 secs 1 +a 0 days 0 hours 0 mins 20 secs 2 +b NULL 1 +b 0 days 0 hours 0 mins 40 secs 2 +c 0 days 0 hours 0 mins 50 secs 1 +c NULL 2 + +statement ok +drop table dn; + # Prepare the table with dictionary values for testing statement ok CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2);