@@ -40,6 +40,7 @@ use datafusion_common::tree_node::{
4040use datafusion_common:: {
4141 plan_err, Column , DFSchema , Result , ScalarValue , TableReference ,
4242} ;
43+ use datafusion_functions_window_common:: field:: WindowUDFFieldArgs ;
4344use sqlparser:: ast:: {
4445 display_comma_separated, ExceptSelectItem , ExcludeSelectItem , IlikeSelectItem ,
4546 NullTreatment , RenameSelectItem , ReplaceSelectElement ,
@@ -706,6 +707,7 @@ impl WindowFunctionDefinition {
706707 & self ,
707708 input_expr_types : & [ DataType ] ,
708709 _input_expr_nullable : & [ bool ] ,
710+ display_name : & str ,
709711 ) -> Result < DataType > {
710712 match self {
711713 WindowFunctionDefinition :: BuiltInWindowFunction ( fun) => {
@@ -714,12 +716,9 @@ impl WindowFunctionDefinition {
714716 WindowFunctionDefinition :: AggregateUDF ( fun) => {
715717 fun. return_type ( input_expr_types)
716718 }
717- WindowFunctionDefinition :: WindowUDF ( _) => {
718- // To get the return data type of the result from
719- // evaluating the user-defined window function instead
720- // use the `WindowUDF::field` trait method.
721- unreachable ! ( )
722- }
719+ WindowFunctionDefinition :: WindowUDF ( fun) => fun
720+ . field ( WindowUDFFieldArgs :: new ( input_expr_types, display_name) )
721+ . map ( |field| field. data_type ( ) . clone ( ) ) ,
723722 }
724723 }
725724
@@ -2558,10 +2557,10 @@ mod test {
25582557 #[ test]
25592558 fn test_first_value_return_type ( ) -> Result < ( ) > {
25602559 let fun = find_df_window_func ( "first_value" ) . unwrap ( ) ;
2561- let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] ) ?;
2560+ let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] , "" ) ?;
25622561 assert_eq ! ( DataType :: Utf8 , observed) ;
25632562
2564- let observed = fun. return_type ( & [ DataType :: UInt64 ] , & [ true ] ) ?;
2563+ let observed = fun. return_type ( & [ DataType :: UInt64 ] , & [ true ] , "" ) ?;
25652564 assert_eq ! ( DataType :: UInt64 , observed) ;
25662565
25672566 Ok ( ( ) )
@@ -2570,10 +2569,10 @@ mod test {
25702569 #[ test]
25712570 fn test_last_value_return_type ( ) -> Result < ( ) > {
25722571 let fun = find_df_window_func ( "last_value" ) . unwrap ( ) ;
2573- let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] ) ?;
2572+ let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] , "" ) ?;
25742573 assert_eq ! ( DataType :: Utf8 , observed) ;
25752574
2576- let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] ) ?;
2575+ let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] , "" ) ?;
25772576 assert_eq ! ( DataType :: Float64 , observed) ;
25782577
25792578 Ok ( ( ) )
@@ -2582,10 +2581,10 @@ mod test {
25822581 #[ test]
25832582 fn test_lead_return_type ( ) -> Result < ( ) > {
25842583 let fun = find_df_window_func ( "lead" ) . unwrap ( ) ;
2585- let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] ) ?;
2584+ let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] , "" ) ?;
25862585 assert_eq ! ( DataType :: Utf8 , observed) ;
25872586
2588- let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] ) ?;
2587+ let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] , "" ) ?;
25892588 assert_eq ! ( DataType :: Float64 , observed) ;
25902589
25912590 Ok ( ( ) )
@@ -2594,10 +2593,10 @@ mod test {
25942593 #[ test]
25952594 fn test_lag_return_type ( ) -> Result < ( ) > {
25962595 let fun = find_df_window_func ( "lag" ) . unwrap ( ) ;
2597- let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] ) ?;
2596+ let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] , "" ) ?;
25982597 assert_eq ! ( DataType :: Utf8 , observed) ;
25992598
2600- let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] ) ?;
2599+ let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] , "" ) ?;
26012600 assert_eq ! ( DataType :: Float64 , observed) ;
26022601
26032602 Ok ( ( ) )
@@ -2607,11 +2606,11 @@ mod test {
26072606 fn test_nth_value_return_type ( ) -> Result < ( ) > {
26082607 let fun = find_df_window_func ( "nth_value" ) . unwrap ( ) ;
26092608 let observed =
2610- fun. return_type ( & [ DataType :: Utf8 , DataType :: UInt64 ] , & [ true , true ] ) ?;
2609+ fun. return_type ( & [ DataType :: Utf8 , DataType :: UInt64 ] , & [ true , true ] , "" ) ?;
26112610 assert_eq ! ( DataType :: Utf8 , observed) ;
26122611
26132612 let observed =
2614- fun. return_type ( & [ DataType :: Float64 , DataType :: UInt64 ] , & [ true , true ] ) ?;
2613+ fun. return_type ( & [ DataType :: Float64 , DataType :: UInt64 ] , & [ true , true ] , "" ) ?;
26152614 assert_eq ! ( DataType :: Float64 , observed) ;
26162615
26172616 Ok ( ( ) )
@@ -2620,7 +2619,7 @@ mod test {
26202619 #[ test]
26212620 fn test_percent_rank_return_type ( ) -> Result < ( ) > {
26222621 let fun = find_df_window_func ( "percent_rank" ) . unwrap ( ) ;
2623- let observed = fun. return_type ( & [ ] , & [ ] ) ?;
2622+ let observed = fun. return_type ( & [ ] , & [ ] , "" ) ?;
26242623 assert_eq ! ( DataType :: Float64 , observed) ;
26252624
26262625 Ok ( ( ) )
@@ -2629,7 +2628,7 @@ mod test {
26292628 #[ test]
26302629 fn test_cume_dist_return_type ( ) -> Result < ( ) > {
26312630 let fun = find_df_window_func ( "cume_dist" ) . unwrap ( ) ;
2632- let observed = fun. return_type ( & [ ] , & [ ] ) ?;
2631+ let observed = fun. return_type ( & [ ] , & [ ] , "" ) ?;
26332632 assert_eq ! ( DataType :: Float64 , observed) ;
26342633
26352634 Ok ( ( ) )
@@ -2638,7 +2637,7 @@ mod test {
26382637 #[ test]
26392638 fn test_ntile_return_type ( ) -> Result < ( ) > {
26402639 let fun = find_df_window_func ( "ntile" ) . unwrap ( ) ;
2641- let observed = fun. return_type ( & [ DataType :: Int16 ] , & [ true ] ) ?;
2640+ let observed = fun. return_type ( & [ DataType :: Int16 ] , & [ true ] , "" ) ?;
26422641 assert_eq ! ( DataType :: UInt64 , observed) ;
26432642
26442643 Ok ( ( ) )
0 commit comments