1818//! This module contains end to end demonstrations of creating
1919//! user defined aggregate functions
2020
21+ use std:: any:: Any ;
22+ use std:: collections:: HashMap ;
2123use std:: hash:: { DefaultHasher , Hash , Hasher } ;
2224use std:: mem:: { size_of, size_of_val} ;
2325use std:: sync:: {
@@ -26,10 +28,10 @@ use std::sync::{
2628} ;
2729
2830use arrow:: array:: {
29- types:: UInt64Type , AsArray , Int32Array , PrimitiveArray , StringArray , StructArray ,
31+ record_batch, types:: UInt64Type , Array , AsArray , Int32Array , PrimitiveArray ,
32+ StringArray , StructArray , UInt64Array ,
3033} ;
3134use arrow:: datatypes:: { Fields , Schema } ;
32-
3335use datafusion:: common:: test_util:: batches_to_string;
3436use datafusion:: dataframe:: DataFrame ;
3537use datafusion:: datasource:: MemTable ;
@@ -48,11 +50,12 @@ use datafusion::{
4850 prelude:: SessionContext ,
4951 scalar:: ScalarValue ,
5052} ;
51- use datafusion_common:: assert_contains;
53+ use datafusion_common:: { assert_contains, exec_datafusion_err } ;
5254use datafusion_common:: { cast:: as_primitive_array, exec_err} ;
55+ use datafusion_expr:: expr:: WindowFunction ;
5356use datafusion_expr:: {
54- col, create_udaf, function:: AccumulatorArgs , AggregateUDFImpl , GroupsAccumulator ,
55- LogicalPlanBuilder , SimpleAggregateUDF ,
57+ col, create_udaf, function:: AccumulatorArgs , AggregateUDFImpl , Expr ,
58+ GroupsAccumulator , LogicalPlanBuilder , SimpleAggregateUDF , WindowFunctionDefinition ,
5659} ;
5760use datafusion_functions_aggregate:: average:: AvgAccumulator ;
5861
@@ -781,7 +784,7 @@ struct TestGroupsAccumulator {
781784}
782785
783786impl AggregateUDFImpl for TestGroupsAccumulator {
784- fn as_any ( & self ) -> & dyn std :: any :: Any {
787+ fn as_any ( & self ) -> & dyn Any {
785788 self
786789 }
787790
@@ -890,3 +893,263 @@ impl GroupsAccumulator for TestGroupsAccumulator {
890893 size_of :: < u64 > ( )
891894 }
892895}
896+
897+ #[ derive( Debug ) ]
898+ struct MetadataBasedAggregateUdf {
899+ name : String ,
900+ signature : Signature ,
901+ metadata : HashMap < String , String > ,
902+ }
903+
904+ impl MetadataBasedAggregateUdf {
905+ fn new ( metadata : HashMap < String , String > ) -> Self {
906+ // The name we return must be unique. Otherwise we will not call distinct
907+ // instances of this UDF. This is a small hack for the unit tests to get unique
908+ // names, but you could do something more elegant with the metadata.
909+ let name = format ! ( "metadata_based_udf_{}" , metadata. len( ) ) ;
910+ Self {
911+ name,
912+ signature : Signature :: exact ( vec ! [ DataType :: UInt64 ] , Volatility :: Immutable ) ,
913+ metadata,
914+ }
915+ }
916+ }
917+
918+ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
919+ fn as_any ( & self ) -> & dyn Any {
920+ self
921+ }
922+
923+ fn name ( & self ) -> & str {
924+ & self . name
925+ }
926+
927+ fn signature ( & self ) -> & Signature {
928+ & self . signature
929+ }
930+
931+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
932+ unimplemented ! ( "this should never be called since return_field is implemented" ) ;
933+ }
934+
935+ fn return_field ( & self , _arg_fields : & [ Field ] ) -> Result < Field > {
936+ Ok ( Field :: new ( self . name ( ) , DataType :: UInt64 , true )
937+ . with_metadata ( self . metadata . clone ( ) ) )
938+ }
939+
940+ fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
941+ let input_expr = acc_args
942+ . exprs
943+ . first ( )
944+ . ok_or ( exec_datafusion_err ! ( "Expected one argument" ) ) ?;
945+ let input_field = input_expr. return_field ( acc_args. schema ) ?;
946+
947+ let double_output = input_field
948+ . metadata ( )
949+ . get ( "modify_values" )
950+ . map ( |v| v == "double_output" )
951+ . unwrap_or ( false ) ;
952+
953+ Ok ( Box :: new ( MetadataBasedAccumulator {
954+ double_output,
955+ curr_sum : 0 ,
956+ } ) )
957+ }
958+ }
959+
960+ #[ derive( Debug ) ]
961+ struct MetadataBasedAccumulator {
962+ double_output : bool ,
963+ curr_sum : u64 ,
964+ }
965+
966+ impl Accumulator for MetadataBasedAccumulator {
967+ fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
968+ let arr = values[ 0 ]
969+ . as_any ( )
970+ . downcast_ref :: < UInt64Array > ( )
971+ . ok_or ( exec_datafusion_err ! ( "Expected UInt64Array" ) ) ?;
972+
973+ self . curr_sum = arr. iter ( ) . fold ( self . curr_sum , |a, b| a + b. unwrap_or ( 0 ) ) ;
974+
975+ Ok ( ( ) )
976+ }
977+
978+ fn evaluate ( & mut self ) -> Result < ScalarValue > {
979+ let v = match self . double_output {
980+ true => self . curr_sum * 2 ,
981+ false => self . curr_sum ,
982+ } ;
983+
984+ Ok ( ScalarValue :: from ( v) )
985+ }
986+
987+ fn size ( & self ) -> usize {
988+ 9
989+ }
990+
991+ fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
992+ Ok ( vec ! [ ScalarValue :: from( self . curr_sum) ] )
993+ }
994+
995+ fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
996+ self . update_batch ( states)
997+ }
998+ }
999+
1000+ #[ tokio:: test]
1001+ async fn test_metadata_based_aggregate ( ) -> Result < ( ) > {
1002+ let data_array = Arc :: new ( UInt64Array :: from ( vec ! [ 0 , 5 , 10 , 15 , 20 ] ) ) as ArrayRef ;
1003+ let schema = Arc :: new ( Schema :: new ( vec ! [
1004+ Field :: new( "no_metadata" , DataType :: UInt64 , true ) ,
1005+ Field :: new( "with_metadata" , DataType :: UInt64 , true ) . with_metadata(
1006+ [ ( "modify_values" . to_string( ) , "double_output" . to_string( ) ) ]
1007+ . into_iter( )
1008+ . collect( ) ,
1009+ ) ,
1010+ ] ) ) ;
1011+
1012+ let batch = RecordBatch :: try_new (
1013+ schema,
1014+ vec ! [ Arc :: clone( & data_array) , Arc :: clone( & data_array) ] ,
1015+ ) ?;
1016+
1017+ let ctx = SessionContext :: new ( ) ;
1018+ ctx. register_batch ( "t" , batch) ?;
1019+ let df = ctx. table ( "t" ) . await ?;
1020+
1021+ let no_output_meta_udf =
1022+ AggregateUDF :: from ( MetadataBasedAggregateUdf :: new ( HashMap :: new ( ) ) ) ;
1023+ let with_output_meta_udf = AggregateUDF :: from ( MetadataBasedAggregateUdf :: new (
1024+ [ ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ]
1025+ . into_iter ( )
1026+ . collect ( ) ,
1027+ ) ) ;
1028+
1029+ let df = df. aggregate (
1030+ vec ! [ ] ,
1031+ vec ! [
1032+ no_output_meta_udf
1033+ . call( vec![ col( "no_metadata" ) ] )
1034+ . alias( "meta_no_in_no_out" ) ,
1035+ no_output_meta_udf
1036+ . call( vec![ col( "with_metadata" ) ] )
1037+ . alias( "meta_with_in_no_out" ) ,
1038+ with_output_meta_udf
1039+ . call( vec![ col( "no_metadata" ) ] )
1040+ . alias( "meta_no_in_with_out" ) ,
1041+ with_output_meta_udf
1042+ . call( vec![ col( "with_metadata" ) ] )
1043+ . alias( "meta_with_in_with_out" ) ,
1044+ ] ,
1045+ ) ?;
1046+
1047+ let actual = df. collect ( ) . await ?;
1048+
1049+ // To test for output metadata handling, we set the expected values on the result
1050+ // To test for input metadata handling, we check the numbers returned
1051+ let mut output_meta = HashMap :: new ( ) ;
1052+ let _ = output_meta. insert ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ;
1053+ let expected_schema = Schema :: new ( vec ! [
1054+ Field :: new( "meta_no_in_no_out" , DataType :: UInt64 , true ) ,
1055+ Field :: new( "meta_with_in_no_out" , DataType :: UInt64 , true ) ,
1056+ Field :: new( "meta_no_in_with_out" , DataType :: UInt64 , true )
1057+ . with_metadata( output_meta. clone( ) ) ,
1058+ Field :: new( "meta_with_in_with_out" , DataType :: UInt64 , true )
1059+ . with_metadata( output_meta. clone( ) ) ,
1060+ ] ) ;
1061+
1062+ let expected = record_batch ! (
1063+ ( "meta_no_in_no_out" , UInt64 , [ 50 ] ) ,
1064+ ( "meta_with_in_no_out" , UInt64 , [ 100 ] ) ,
1065+ ( "meta_no_in_with_out" , UInt64 , [ 50 ] ) ,
1066+ ( "meta_with_in_with_out" , UInt64 , [ 100 ] )
1067+ ) ?
1068+ . with_schema ( Arc :: new ( expected_schema) ) ?;
1069+
1070+ assert_eq ! ( expected, actual[ 0 ] ) ;
1071+
1072+ Ok ( ( ) )
1073+ }
1074+
1075+ #[ tokio:: test]
1076+ async fn test_metadata_based_aggregate_as_window ( ) -> Result < ( ) > {
1077+ let data_array = Arc :: new ( UInt64Array :: from ( vec ! [ 0 , 5 , 10 , 15 , 20 ] ) ) as ArrayRef ;
1078+ let schema = Arc :: new ( Schema :: new ( vec ! [
1079+ Field :: new( "no_metadata" , DataType :: UInt64 , true ) ,
1080+ Field :: new( "with_metadata" , DataType :: UInt64 , true ) . with_metadata(
1081+ [ ( "modify_values" . to_string( ) , "double_output" . to_string( ) ) ]
1082+ . into_iter( )
1083+ . collect( ) ,
1084+ ) ,
1085+ ] ) ) ;
1086+
1087+ let batch = RecordBatch :: try_new (
1088+ schema,
1089+ vec ! [ Arc :: clone( & data_array) , Arc :: clone( & data_array) ] ,
1090+ ) ?;
1091+
1092+ let ctx = SessionContext :: new ( ) ;
1093+ ctx. register_batch ( "t" , batch) ?;
1094+ let df = ctx. table ( "t" ) . await ?;
1095+
1096+ let no_output_meta_udf = Arc :: new ( AggregateUDF :: from (
1097+ MetadataBasedAggregateUdf :: new ( HashMap :: new ( ) ) ,
1098+ ) ) ;
1099+ let with_output_meta_udf =
1100+ Arc :: new ( AggregateUDF :: from ( MetadataBasedAggregateUdf :: new (
1101+ [ ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ]
1102+ . into_iter ( )
1103+ . collect ( ) ,
1104+ ) ) ) ;
1105+
1106+ let df = df. select ( vec ! [
1107+ Expr :: WindowFunction ( WindowFunction :: new(
1108+ WindowFunctionDefinition :: AggregateUDF ( Arc :: clone( & no_output_meta_udf) ) ,
1109+ vec![ col( "no_metadata" ) ] ,
1110+ ) )
1111+ . alias( "meta_no_in_no_out" ) ,
1112+ Expr :: WindowFunction ( WindowFunction :: new(
1113+ WindowFunctionDefinition :: AggregateUDF ( no_output_meta_udf) ,
1114+ vec![ col( "with_metadata" ) ] ,
1115+ ) )
1116+ . alias( "meta_with_in_no_out" ) ,
1117+ Expr :: WindowFunction ( WindowFunction :: new(
1118+ WindowFunctionDefinition :: AggregateUDF ( Arc :: clone( & with_output_meta_udf) ) ,
1119+ vec![ col( "no_metadata" ) ] ,
1120+ ) )
1121+ . alias( "meta_no_in_with_out" ) ,
1122+ Expr :: WindowFunction ( WindowFunction :: new(
1123+ WindowFunctionDefinition :: AggregateUDF ( with_output_meta_udf) ,
1124+ vec![ col( "with_metadata" ) ] ,
1125+ ) )
1126+ . alias( "meta_with_in_with_out" ) ,
1127+ ] ) ?;
1128+
1129+ let actual = df. collect ( ) . await ?;
1130+
1131+ // To test for output metadata handling, we set the expected values on the result
1132+ // To test for input metadata handling, we check the numbers returned
1133+ let mut output_meta = HashMap :: new ( ) ;
1134+ let _ = output_meta. insert ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ;
1135+ let expected_schema = Schema :: new ( vec ! [
1136+ Field :: new( "meta_no_in_no_out" , DataType :: UInt64 , true ) ,
1137+ Field :: new( "meta_with_in_no_out" , DataType :: UInt64 , true ) ,
1138+ Field :: new( "meta_no_in_with_out" , DataType :: UInt64 , true )
1139+ . with_metadata( output_meta. clone( ) ) ,
1140+ Field :: new( "meta_with_in_with_out" , DataType :: UInt64 , true )
1141+ . with_metadata( output_meta. clone( ) ) ,
1142+ ] ) ;
1143+
1144+ let expected = record_batch ! (
1145+ ( "meta_no_in_no_out" , UInt64 , [ 50 , 50 , 50 , 50 , 50 ] ) ,
1146+ ( "meta_with_in_no_out" , UInt64 , [ 100 , 100 , 100 , 100 , 100 ] ) ,
1147+ ( "meta_no_in_with_out" , UInt64 , [ 50 , 50 , 50 , 50 , 50 ] ) ,
1148+ ( "meta_with_in_with_out" , UInt64 , [ 100 , 100 , 100 , 100 , 100 ] )
1149+ ) ?
1150+ . with_schema ( Arc :: new ( expected_schema) ) ?;
1151+
1152+ assert_eq ! ( expected, actual[ 0 ] ) ;
1153+
1154+ Ok ( ( ) )
1155+ }
0 commit comments