@@ -37,7 +37,8 @@ use arrow::array::Array;
3737use arrow:: array:: ArrowNativeTypeOp ;
3838use arrow:: datatypes:: { ArrowNativeType , ArrowPrimitiveType } ;
3939
40- use datafusion_common:: { DataFusionError , HashSet , Result , ScalarValue } ;
40+ use datafusion_common:: { internal_err, DataFusionError , HashSet , Result , ScalarValue } ;
41+ use datafusion_doc:: DocSection ;
4142use datafusion_expr:: function:: StateFieldsArgs ;
4243use datafusion_expr:: {
4344 function:: AccumulatorArgs , utils:: format_state_name, Accumulator , AggregateUDFImpl ,
@@ -172,6 +173,45 @@ impl AggregateUDFImpl for Median {
172173 }
173174 }
174175
176+ fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
177+ !args. is_distinct
178+ }
179+
180+ fn create_groups_accumulator (
181+ & self ,
182+ args : AccumulatorArgs ,
183+ ) -> Result < Box < dyn GroupsAccumulator > > {
184+ let num_args = args. exprs . len ( ) ;
185+ if num_args != 1 {
186+ return internal_err ! (
187+ "median should only have 1 arg, but found num args:{}" ,
188+ args. exprs. len( )
189+ ) ;
190+ }
191+
192+ let dt = args. exprs [ 0 ] . data_type ( args. schema ) ?;
193+
194+ macro_rules! helper {
195+ ( $t: ty, $dt: expr) => {
196+ Ok ( Box :: new( MedianGroupsAccumulator :: <$t>:: new( $dt) ) )
197+ } ;
198+ }
199+
200+ downcast_integer ! {
201+ dt => ( helper, dt) ,
202+ DataType :: Float16 => helper!( Float16Type , dt) ,
203+ DataType :: Float32 => helper!( Float32Type , dt) ,
204+ DataType :: Float64 => helper!( Float64Type , dt) ,
205+ DataType :: Decimal128 ( _, _) => helper!( Decimal128Type , dt) ,
206+ DataType :: Decimal256 ( _, _) => helper!( Decimal256Type , dt) ,
207+ _ => Err ( DataFusionError :: NotImplemented ( format!(
208+ "MedianGroupsAccumulator not supported for {} with {}" ,
209+ args. name,
210+ dt,
211+ ) ) ) ,
212+ }
213+ }
214+
175215 fn aliases ( & self ) -> & [ String ] {
176216 & [ ]
177217 }
0 commit comments