@@ -19,18 +19,18 @@ use std::any::Any;
1919use std:: cmp:: max;
2020use std:: sync:: Arc ;
2121
22+ use crate :: utils:: { make_scalar_function, utf8_to_str_type} ;
2223use arrow:: array:: {
23- ArrayAccessor , ArrayIter , ArrayRef , AsArray , GenericStringArray , OffsetSizeTrait ,
24+ make_view, Array , ArrayAccessor , ArrayIter , ArrayRef , AsArray , ByteView ,
25+ GenericStringArray , OffsetSizeTrait , StringViewArray ,
2426} ;
2527use arrow:: datatypes:: DataType ;
26-
28+ use arrow_buffer :: { NullBufferBuilder , ScalarBuffer } ;
2729use datafusion_common:: cast:: as_int64_array;
2830use datafusion_common:: { exec_datafusion_err, exec_err, Result } ;
2931use datafusion_expr:: TypeSignature :: Exact ;
3032use datafusion_expr:: { ColumnarValue , ScalarUDFImpl , Signature , Volatility } ;
3133
32- use crate :: utils:: { make_scalar_function, utf8_to_str_type} ;
33-
3434#[ derive( Debug ) ]
3535pub struct SubstrFunc {
3636 signature : Signature ,
@@ -77,7 +77,11 @@ impl ScalarUDFImpl for SubstrFunc {
7777 }
7878
7979 fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
80- utf8_to_str_type ( & arg_types[ 0 ] , "substr" )
80+ if arg_types[ 0 ] == DataType :: Utf8View {
81+ Ok ( DataType :: Utf8View )
82+ } else {
83+ utf8_to_str_type ( & arg_types[ 0 ] , "substr" )
84+ }
8185 }
8286
8387 fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
@@ -89,29 +93,188 @@ impl ScalarUDFImpl for SubstrFunc {
8993 }
9094}
9195
96+ /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
97+ /// substr('alphabet', 3) = 'phabet'
98+ /// substr('alphabet', 3, 2) = 'ph'
99+ /// The implementation uses UTF-8 code points as characters
92100pub fn substr ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
93101 match args[ 0 ] . data_type ( ) {
94102 DataType :: Utf8 => {
95103 let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
96- calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
104+ string_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
97105 }
98106 DataType :: LargeUtf8 => {
99107 let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
100- calculate_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
108+ string_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
101109 }
102110 DataType :: Utf8View => {
103111 let string_array = args[ 0 ] . as_string_view ( ) ;
104- calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
112+ string_view_substr ( string_array, & args[ 1 ..] )
105113 }
106- other => exec_err ! ( "Unsupported data type {other:?} for function substr" ) ,
114+ other => exec_err ! (
115+ "Unsupported data type {other:?} for function substr,\
116+ expected Utf8View, Utf8 or LargeUtf8."
117+ ) ,
107118 }
108119}
109120
110- /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
111- /// substr('alphabet', 3) = 'phabet'
112- /// substr('alphabet', 3, 2) = 'ph'
113- /// The implementation uses UTF-8 code points as characters
114- fn calculate_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
121+ // Return the exact byte index for [start, end), set count to -1 to ignore count
122+ fn get_true_start_end ( input : & str , start : usize , count : i64 ) -> ( usize , usize ) {
123+ let ( mut st, mut ed) = ( input. len ( ) , input. len ( ) ) ;
124+ let mut start_counting = false ;
125+ let mut cnt = 0 ;
126+ for ( char_cnt, ( byte_cnt, _) ) in input. char_indices ( ) . enumerate ( ) {
127+ if char_cnt == start {
128+ st = byte_cnt;
129+ if count != -1 {
130+ start_counting = true ;
131+ } else {
132+ break ;
133+ }
134+ }
135+ if start_counting {
136+ if cnt == count {
137+ ed = byte_cnt;
138+ break ;
139+ }
140+ cnt += 1 ;
141+ }
142+ }
143+ ( st, ed)
144+ }
145+
146+ /// Make a `u128` based on the given substr, start(offset to view.offset), and
147+ /// push into to the given buffers
148+ fn make_and_append_view (
149+ views_buffer : & mut Vec < u128 > ,
150+ null_builder : & mut NullBufferBuilder ,
151+ raw : & u128 ,
152+ substr : & str ,
153+ start : u32 ,
154+ ) {
155+ let substr_len = substr. len ( ) ;
156+ if substr_len == 0 {
157+ null_builder. append_null ( ) ;
158+ views_buffer. push ( 0 ) ;
159+ } else {
160+ let sub_view = if substr_len > 12 {
161+ let view = ByteView :: from ( * raw) ;
162+ make_view ( substr. as_bytes ( ) , view. buffer_index , view. offset + start)
163+ } else {
164+ // inline value does not need block id or offset
165+ make_view ( substr. as_bytes ( ) , 0 , 0 )
166+ } ;
167+ views_buffer. push ( sub_view) ;
168+ null_builder. append_non_null ( ) ;
169+ }
170+ }
171+
172+ // The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
173+ // From<u128> for ByteView
174+ fn string_view_substr (
175+ string_view_array : & StringViewArray ,
176+ args : & [ ArrayRef ] ,
177+ ) -> Result < ArrayRef > {
178+ let mut views_buf = Vec :: with_capacity ( string_view_array. len ( ) ) ;
179+ let mut null_builder = NullBufferBuilder :: new ( string_view_array. len ( ) ) ;
180+
181+ let start_array = as_int64_array ( & args[ 0 ] ) ?;
182+
183+ match args. len ( ) {
184+ 1 => {
185+ for ( idx, ( raw, start) ) in string_view_array
186+ . views ( )
187+ . iter ( )
188+ . zip ( start_array. iter ( ) )
189+ . enumerate ( )
190+ {
191+ if let Some ( start) = start {
192+ let start = ( start - 1 ) . max ( 0 ) as usize ;
193+
194+ // Safety:
195+ // idx is always smaller or equal to string_view_array.views.len()
196+ unsafe {
197+ let str = string_view_array. value_unchecked ( idx) ;
198+ let ( start, end) = get_true_start_end ( str, start, -1 ) ;
199+ let substr = & str[ start..end] ;
200+
201+ make_and_append_view (
202+ & mut views_buf,
203+ & mut null_builder,
204+ raw,
205+ substr,
206+ start as u32 ,
207+ ) ;
208+ }
209+ } else {
210+ null_builder. append_null ( ) ;
211+ views_buf. push ( 0 ) ;
212+ }
213+ }
214+ }
215+ 2 => {
216+ let count_array = as_int64_array ( & args[ 1 ] ) ?;
217+ for ( idx, ( ( raw, start) , count) ) in string_view_array
218+ . views ( )
219+ . iter ( )
220+ . zip ( start_array. iter ( ) )
221+ . zip ( count_array. iter ( ) )
222+ . enumerate ( )
223+ {
224+ if let ( Some ( start) , Some ( count) ) = ( start, count) {
225+ let start = ( start - 1 ) . max ( 0 ) as usize ;
226+ if count < 0 {
227+ return exec_err ! (
228+ "negative substring length not allowed: substr(<str>, {start}, {count})"
229+ ) ;
230+ } else {
231+ // Safety:
232+ // idx is always smaller or equal to string_view_array.views.len()
233+ unsafe {
234+ let str = string_view_array. value_unchecked ( idx) ;
235+ let ( start, end) = get_true_start_end ( str, start, count) ;
236+ let substr = & str[ start..end] ;
237+
238+ make_and_append_view (
239+ & mut views_buf,
240+ & mut null_builder,
241+ raw,
242+ substr,
243+ start as u32 ,
244+ ) ;
245+ }
246+ }
247+ } else {
248+ null_builder. append_null ( ) ;
249+ views_buf. push ( 0 ) ;
250+ }
251+ }
252+ }
253+ other => {
254+ return exec_err ! (
255+ "substr was called with {other} arguments. It requires 2 or 3."
256+ )
257+ }
258+ }
259+
260+ let views_buf = ScalarBuffer :: from ( views_buf) ;
261+ let nulls_buf = null_builder. finish ( ) ;
262+
263+ // Safety:
264+ // (1) The blocks of the given views are all provided
265+ // (2) Each of the range `view.offset+start..end` of view in views_buf is within
266+ // the bounds of each of the blocks
267+ unsafe {
268+ let array = StringViewArray :: new_unchecked (
269+ views_buf,
270+ string_view_array. data_buffers ( ) . to_vec ( ) ,
271+ nulls_buf,
272+ ) ;
273+ Ok ( Arc :: new ( array) as ArrayRef )
274+ }
275+ }
276+
277+ fn string_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
115278where
116279 V : ArrayAccessor < Item = & ' a str > ,
117280 T : OffsetSizeTrait ,
@@ -174,8 +337,8 @@ where
174337
175338#[ cfg( test) ]
176339mod tests {
177- use arrow:: array:: { Array , StringArray } ;
178- use arrow:: datatypes:: DataType :: Utf8 ;
340+ use arrow:: array:: { Array , StringArray , StringViewArray } ;
341+ use arrow:: datatypes:: DataType :: { Utf8 , Utf8View } ;
179342
180343 use datafusion_common:: { exec_err, Result , ScalarValue } ;
181344 use datafusion_expr:: { ColumnarValue , ScalarUDFImpl } ;
@@ -193,8 +356,8 @@ mod tests {
193356 ] ,
194357 Ok ( None ) ,
195358 & str ,
196- Utf8 ,
197- StringArray
359+ Utf8View ,
360+ StringViewArray
198361 ) ;
199362 test_function ! (
200363 SubstrFunc :: new( ) ,
@@ -206,8 +369,35 @@ mod tests {
206369 ] ,
207370 Ok ( Some ( "alphabet" ) ) ,
208371 & str ,
209- Utf8 ,
210- StringArray
372+ Utf8View ,
373+ StringViewArray
374+ ) ;
375+ test_function ! (
376+ SubstrFunc :: new( ) ,
377+ & [
378+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
379+ "this és longer than 12B"
380+ ) ) ) ) ,
381+ ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
382+ ColumnarValue :: Scalar ( ScalarValue :: from( 2i64 ) ) ,
383+ ] ,
384+ Ok ( Some ( " é" ) ) ,
385+ & str ,
386+ Utf8View ,
387+ StringViewArray
388+ ) ;
389+ test_function ! (
390+ SubstrFunc :: new( ) ,
391+ & [
392+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
393+ "this is longer than 12B"
394+ ) ) ) ) ,
395+ ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
396+ ] ,
397+ Ok ( Some ( " is longer than 12B" ) ) ,
398+ & str ,
399+ Utf8View ,
400+ StringViewArray
211401 ) ;
212402 test_function ! (
213403 SubstrFunc :: new( ) ,
@@ -219,8 +409,8 @@ mod tests {
219409 ] ,
220410 Ok ( Some ( "ésoj" ) ) ,
221411 & str ,
222- Utf8 ,
223- StringArray
412+ Utf8View ,
413+ StringViewArray
224414 ) ;
225415 test_function ! (
226416 SubstrFunc :: new( ) ,
@@ -233,8 +423,8 @@ mod tests {
233423 ] ,
234424 Ok ( Some ( "ph" ) ) ,
235425 & str ,
236- Utf8 ,
237- StringArray
426+ Utf8View ,
427+ StringViewArray
238428 ) ;
239429 test_function ! (
240430 SubstrFunc :: new( ) ,
@@ -247,8 +437,8 @@ mod tests {
247437 ] ,
248438 Ok ( Some ( "phabet" ) ) ,
249439 & str ,
250- Utf8 ,
251- StringArray
440+ Utf8View ,
441+ StringViewArray
252442 ) ;
253443 test_function ! (
254444 SubstrFunc :: new( ) ,
0 commit comments