/Users/andrewlamb/Software/arrow-rs/arrow-string/src/regexp.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines kernel to extract substrings based on a regular |
19 | | //! expression of a \[Large\]StringArray |
20 | | |
21 | | use crate::like::StringArrayType; |
22 | | |
23 | | use arrow_array::builder::{ |
24 | | BooleanBufferBuilder, GenericStringBuilder, ListBuilder, StringViewBuilder, |
25 | | }; |
26 | | use arrow_array::cast::AsArray; |
27 | | use arrow_array::*; |
28 | | use arrow_buffer::NullBuffer; |
29 | | use arrow_data::{ArrayData, ArrayDataBuilder}; |
30 | | use arrow_schema::{ArrowError, DataType, Field}; |
31 | | use regex::Regex; |
32 | | |
33 | | use std::collections::HashMap; |
34 | | use std::sync::Arc; |
35 | | |
36 | | /// Return BooleanArray indicating which strings in an array match an array of |
37 | | /// regular expressions. |
38 | | /// |
39 | | /// This is equivalent to the SQL `array ~ regex_array`, supporting |
40 | | /// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. |
41 | | /// |
42 | | /// If `regex_array` element has an empty value, the corresponding result value is always true. |
43 | | /// |
44 | | /// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, |
45 | | /// which allow special search modes, such as case-insensitive and multi-line mode. |
46 | | /// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) |
47 | | /// for more information. |
48 | | /// |
49 | | /// # See Also |
50 | | /// * [`regexp_is_match_scalar`] for matching a single regular expression against an array of strings |
51 | | /// * [`regexp_match`] for extracting groups from a string array based on a regular expression |
52 | | /// |
53 | | /// # Example |
54 | | /// ``` |
55 | | /// # use arrow_array::{StringArray, BooleanArray}; |
56 | | /// # use arrow_string::regexp::regexp_is_match; |
57 | | /// // First array is the array of strings to match |
58 | | /// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]); |
59 | | /// // Second array is the array of regular expressions to match against |
60 | | /// let regex_array = StringArray::from(vec!["^Foo", "^Foo", "Bar$", "Baz"]); |
61 | | /// // Third array is the array of flags to use for each regular expression, if desired |
62 | | /// // (the type must be provided to satisfy type inference for the third parameter) |
63 | | /// let flags_array: Option<&StringArray> = None; |
64 | | /// // The result is a BooleanArray indicating when each string in `array` |
65 | | /// // matches the corresponding regular expression in `regex_array` |
66 | | /// let result = regexp_is_match(&array, ®ex_array, flags_array).unwrap(); |
67 | | /// assert_eq!(result, BooleanArray::from(vec![true, false, true, true])); |
68 | | /// ``` |
69 | | pub fn regexp_is_match<'a, S1, S2, S3>( |
70 | | array: &'a S1, |
71 | | regex_array: &'a S2, |
72 | | flags_array: Option<&'a S3>, |
73 | | ) -> Result<BooleanArray, ArrowError> |
74 | | where |
75 | | &'a S1: StringArrayType<'a>, |
76 | | &'a S2: StringArrayType<'a>, |
77 | | &'a S3: StringArrayType<'a>, |
78 | | { |
79 | | if array.len() != regex_array.len() { |
80 | | return Err(ArrowError::ComputeError( |
81 | | "Cannot perform comparison operation on arrays of different length".to_string(), |
82 | | )); |
83 | | } |
84 | | |
85 | | let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); |
86 | | |
87 | | let mut patterns: HashMap<String, Regex> = HashMap::new(); |
88 | | let mut result = BooleanBufferBuilder::new(array.len()); |
89 | | |
90 | | let complete_pattern = match flags_array { |
91 | | Some(flags) => Box::new( |
92 | | regex_array |
93 | | .iter() |
94 | | .zip(flags.iter()) |
95 | | .map(|(pattern, flags)| { |
96 | | pattern.map(|pattern| match flags { |
97 | | Some(flag) => format!("(?{flag}){pattern}"), |
98 | | None => pattern.to_string(), |
99 | | }) |
100 | | }), |
101 | | ) as Box<dyn Iterator<Item = Option<String>>>, |
102 | | None => Box::new( |
103 | | regex_array |
104 | | .iter() |
105 | | .map(|pattern| pattern.map(|pattern| pattern.to_string())), |
106 | | ), |
107 | | }; |
108 | | |
109 | | array |
110 | | .iter() |
111 | | .zip(complete_pattern) |
112 | | .map(|(value, pattern)| { |
113 | | match (value, pattern) { |
114 | | // Required for Postgres compatibility: |
115 | | // SELECT 'foobarbequebaz' ~ ''); = true |
116 | | (Some(_), Some(pattern)) if pattern == *"" => { |
117 | | result.append(true); |
118 | | } |
119 | | (Some(value), Some(pattern)) => { |
120 | | let existing_pattern = patterns.get(&pattern); |
121 | | let re = match existing_pattern { |
122 | | Some(re) => re, |
123 | | None => { |
124 | | let re = Regex::new(pattern.as_str()).map_err(|e| { |
125 | | ArrowError::ComputeError(format!( |
126 | | "Regular expression did not compile: {e:?}" |
127 | | )) |
128 | | })?; |
129 | | patterns.entry(pattern).or_insert(re) |
130 | | } |
131 | | }; |
132 | | result.append(re.is_match(value)); |
133 | | } |
134 | | _ => result.append(false), |
135 | | } |
136 | | Ok(()) |
137 | | }) |
138 | | .collect::<Result<Vec<()>, ArrowError>>()?; |
139 | | |
140 | | let data = unsafe { |
141 | | ArrayDataBuilder::new(DataType::Boolean) |
142 | | .len(array.len()) |
143 | | .buffers(vec![result.into()]) |
144 | | .nulls(nulls) |
145 | | .build_unchecked() |
146 | | }; |
147 | | |
148 | | Ok(BooleanArray::from(data)) |
149 | | } |
150 | | |
151 | | /// Return BooleanArray indicating which strings in an array match a single regular expression. |
152 | | /// |
153 | | /// This is equivalent to the SQL `array ~ regex_array`, supporting |
154 | | /// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar. |
155 | | /// |
156 | | /// See the documentation on [`regexp_is_match`] for more details on arguments |
157 | | /// |
158 | | /// # See Also |
159 | | /// * [`regexp_is_match`] for matching an array of regular expression against an array of strings |
160 | | /// * [`regexp_match`] for extracting groups from a string array based on a regular expression |
161 | | /// |
162 | | /// # Example |
163 | | /// ``` |
164 | | /// # use arrow_array::{StringArray, BooleanArray}; |
165 | | /// # use arrow_string::regexp::regexp_is_match_scalar; |
166 | | /// // array of strings to match |
167 | | /// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]); |
168 | | /// let regexp = "^Foo"; // regular expression to match against |
169 | | /// let flags: Option<&str> = None; // flags can control the matching behavior |
170 | | /// // The result is a BooleanArray indicating when each string in `array` |
171 | | /// // matches the regular expression `regexp` |
172 | | /// let result = regexp_is_match_scalar(&array, regexp, None).unwrap(); |
173 | | /// assert_eq!(result, BooleanArray::from(vec![true, false, true, false])); |
174 | | /// ``` |
175 | | pub fn regexp_is_match_scalar<'a, S>( |
176 | | array: &'a S, |
177 | | regex: &str, |
178 | | flag: Option<&str>, |
179 | | ) -> Result<BooleanArray, ArrowError> |
180 | | where |
181 | | &'a S: StringArrayType<'a>, |
182 | | { |
183 | | let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); |
184 | | let mut result = BooleanBufferBuilder::new(array.len()); |
185 | | |
186 | | let pattern = match flag { |
187 | | Some(flag) => format!("(?{flag}){regex}"), |
188 | | None => regex.to_string(), |
189 | | }; |
190 | | |
191 | | if pattern.is_empty() { |
192 | | result.append_n(array.len(), true); |
193 | | } else { |
194 | | let re = Regex::new(pattern.as_str()).map_err(|e| { |
195 | | ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) |
196 | | })?; |
197 | | for i in 0..array.len() { |
198 | | let value = array.value(i); |
199 | | result.append(re.is_match(value)); |
200 | | } |
201 | | } |
202 | | |
203 | | let buffer = result.into(); |
204 | | let data = unsafe { |
205 | | ArrayData::new_unchecked( |
206 | | DataType::Boolean, |
207 | | array.len(), |
208 | | None, |
209 | | null_bit_buffer, |
210 | | 0, |
211 | | vec![buffer], |
212 | | vec![], |
213 | | ) |
214 | | }; |
215 | | |
216 | | Ok(BooleanArray::from(data)) |
217 | | } |
218 | | |
219 | | macro_rules! process_regexp_array_match { |
220 | | ($array:expr, $regex_array:expr, $flags_array:expr, $list_builder:expr) => { |
221 | | let mut patterns: HashMap<String, Regex> = HashMap::new(); |
222 | | |
223 | | let complete_pattern = match $flags_array { |
224 | | Some(flags) => Box::new($regex_array.iter().zip(flags.iter()).map( |
225 | 0 | |(pattern, flags)| { |
226 | 0 | pattern.map(|pattern| match flags { |
227 | 0 | Some(value) => format!("(?{value}){pattern}"), |
228 | 0 | None => pattern.to_string(), |
229 | 0 | }) |
230 | 0 | }, |
231 | | )) as Box<dyn Iterator<Item = Option<String>>>, |
232 | | None => Box::new( |
233 | | $regex_array |
234 | | .iter() |
235 | 0 | .map(|pattern| pattern.map(|pattern| pattern.to_string())), |
236 | | ), |
237 | | }; |
238 | | |
239 | | $array |
240 | | .iter() |
241 | | .zip(complete_pattern) |
242 | 0 | .map(|(value, pattern)| { |
243 | 0 | match (value, pattern) { |
244 | | // Required for Postgres compatibility: |
245 | | // SELECT regexp_match('foobarbequebaz', ''); = {""} |
246 | 0 | (Some(_), Some(pattern)) if pattern == *"" => { |
247 | 0 | $list_builder.values().append_value(""); |
248 | 0 | $list_builder.append(true); |
249 | 0 | } |
250 | 0 | (Some(value), Some(pattern)) => { |
251 | 0 | let existing_pattern = patterns.get(&pattern); |
252 | 0 | let re = match existing_pattern { |
253 | 0 | Some(re) => re, |
254 | | None => { |
255 | 0 | let re = Regex::new(pattern.as_str()).map_err(|e| { |
256 | 0 | ArrowError::ComputeError(format!( |
257 | 0 | "Regular expression did not compile: {e:?}" |
258 | 0 | )) |
259 | 0 | })?; |
260 | 0 | patterns.entry(pattern).or_insert(re) |
261 | | } |
262 | | }; |
263 | 0 | match re.captures(value) { |
264 | 0 | Some(caps) => { |
265 | 0 | let mut iter = caps.iter(); |
266 | 0 | if caps.len() > 1 { |
267 | 0 | iter.next(); |
268 | 0 | } |
269 | 0 | for m in iter.flatten() { |
270 | 0 | $list_builder.values().append_value(m.as_str()); |
271 | 0 | } |
272 | | |
273 | 0 | $list_builder.append(true); |
274 | | } |
275 | 0 | None => $list_builder.append(false), |
276 | | } |
277 | | } |
278 | 0 | _ => $list_builder.append(false), |
279 | | } |
280 | 0 | Ok(()) |
281 | 0 | }) |
282 | | .collect::<Result<Vec<()>, ArrowError>>()?; |
283 | | }; |
284 | | } |
285 | | |
286 | 0 | fn regexp_array_match<OffsetSize: OffsetSizeTrait>( |
287 | 0 | array: &GenericStringArray<OffsetSize>, |
288 | 0 | regex_array: &GenericStringArray<OffsetSize>, |
289 | 0 | flags_array: Option<&GenericStringArray<OffsetSize>>, |
290 | 0 | ) -> Result<ArrayRef, ArrowError> { |
291 | 0 | let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0); |
292 | 0 | let mut list_builder = ListBuilder::new(builder); |
293 | | |
294 | 0 | process_regexp_array_match!(array, regex_array, flags_array, list_builder); |
295 | | |
296 | 0 | Ok(Arc::new(list_builder.finish())) |
297 | 0 | } |
298 | | |
299 | | fn regexp_array_match_utf8view( |
300 | | array: &StringViewArray, |
301 | | regex_array: &StringViewArray, |
302 | | flags_array: Option<&StringViewArray>, |
303 | | ) -> Result<ArrayRef, ArrowError> { |
304 | | let builder = StringViewBuilder::with_capacity(0); |
305 | | let mut list_builder = ListBuilder::new(builder); |
306 | | |
307 | | process_regexp_array_match!(array, regex_array, flags_array, list_builder); |
308 | | |
309 | | Ok(Arc::new(list_builder.finish())) |
310 | | } |
311 | | |
312 | 0 | fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>( |
313 | 0 | regex_array: &'a dyn Array, |
314 | 0 | flag_array: Option<&'a dyn Array>, |
315 | 0 | ) -> (Option<&'a str>, Option<&'a str>) { |
316 | 0 | let regex = regex_array.as_string::<OffsetSize>(); |
317 | 0 | let regex = regex.is_valid(0).then(|| regex.value(0)); |
318 | | |
319 | 0 | if let Some(flag_array) = flag_array { |
320 | 0 | let flag = flag_array.as_string::<OffsetSize>(); |
321 | 0 | (regex, flag.is_valid(0).then(|| flag.value(0))) |
322 | | } else { |
323 | 0 | (regex, None) |
324 | | } |
325 | 0 | } |
326 | | |
327 | | fn get_scalar_pattern_flag_utf8view<'a>( |
328 | | regex_array: &'a dyn Array, |
329 | | flag_array: Option<&'a dyn Array>, |
330 | | ) -> (Option<&'a str>, Option<&'a str>) { |
331 | | let regex = regex_array.as_string_view(); |
332 | 0 | let regex = regex.is_valid(0).then(|| regex.value(0)); |
333 | | |
334 | | if let Some(flag_array) = flag_array { |
335 | | let flag = flag_array.as_string_view(); |
336 | 0 | (regex, flag.is_valid(0).then(|| flag.value(0))) |
337 | | } else { |
338 | | (regex, None) |
339 | | } |
340 | | } |
341 | | |
342 | | macro_rules! process_regexp_match { |
343 | | ($array:expr, $regex:expr, $list_builder:expr) => { |
344 | | $array |
345 | | .iter() |
346 | 0 | .map(|value| { |
347 | 0 | match value { |
348 | | // Required for Postgres compatibility: |
349 | | // SELECT regexp_match('foobarbequebaz', ''); = {""} |
350 | 0 | Some(_) if $regex.as_str().is_empty() => { |
351 | 0 | $list_builder.values().append_value(""); |
352 | 0 | $list_builder.append(true); |
353 | 0 | } |
354 | 0 | Some(value) => match $regex.captures(value) { |
355 | 0 | Some(caps) => { |
356 | 0 | let mut iter = caps.iter(); |
357 | 0 | if caps.len() > 1 { |
358 | 0 | iter.next(); |
359 | 0 | } |
360 | 0 | for m in iter.flatten() { |
361 | 0 | $list_builder.values().append_value(m.as_str()); |
362 | 0 | } |
363 | 0 | $list_builder.append(true); |
364 | | } |
365 | 0 | None => $list_builder.append(false), |
366 | | }, |
367 | 0 | None => $list_builder.append(false), |
368 | | } |
369 | 0 | Ok(()) |
370 | 0 | }) |
371 | | .collect::<Result<Vec<()>, ArrowError>>()? |
372 | | }; |
373 | | } |
374 | | |
375 | 0 | fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>( |
376 | 0 | array: &GenericStringArray<OffsetSize>, |
377 | 0 | regex: &Regex, |
378 | 0 | ) -> Result<ArrayRef, ArrowError> { |
379 | 0 | let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0); |
380 | 0 | let mut list_builder = ListBuilder::new(builder); |
381 | | |
382 | 0 | process_regexp_match!(array, regex, list_builder); |
383 | | |
384 | 0 | Ok(Arc::new(list_builder.finish())) |
385 | 0 | } |
386 | | |
387 | | fn regexp_scalar_match_utf8view( |
388 | | array: &StringViewArray, |
389 | | regex: &Regex, |
390 | | ) -> Result<ArrayRef, ArrowError> { |
391 | | let builder = StringViewBuilder::with_capacity(0); |
392 | | let mut list_builder = ListBuilder::new(builder); |
393 | | |
394 | | process_regexp_match!(array, regex, list_builder); |
395 | | |
396 | | Ok(Arc::new(list_builder.finish())) |
397 | | } |
398 | | |
399 | | /// Extract all groups matched by a regular expression for a given String array. |
400 | | /// |
401 | | /// Modelled after the Postgres [regexp_match]. |
402 | | /// |
403 | | /// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first |
404 | | /// match of the corresponding index in `regex_array` to string in `array` |
405 | | /// |
406 | | /// If there is no match, the list element is NULL. |
407 | | /// |
408 | | /// If a match is found, and the pattern contains no capturing parenthesized subexpressions, |
409 | | /// then the list element is a single-element [`GenericStringArray`] containing the substring |
410 | | /// matching the whole pattern. |
411 | | /// |
412 | | /// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the |
413 | | /// list element is a [`GenericStringArray`] whose n'th element is the substring matching |
414 | | /// the n'th capturing parenthesized subexpression of the pattern. |
415 | | /// |
416 | | /// The flags parameter is an optional text string containing zero or more single-letter flags |
417 | | /// that change the function's behavior. |
418 | | /// |
419 | | /// # See Also |
420 | | /// * [`regexp_is_match`] for matching (rather than extracting) a regular expression against an array of strings |
421 | | /// |
422 | | /// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP |
423 | | pub fn regexp_match( |
424 | | array: &dyn Array, |
425 | | regex_array: &dyn Datum, |
426 | | flags_array: Option<&dyn Datum>, |
427 | | ) -> Result<ArrayRef, ArrowError> { |
428 | | let (rhs, is_rhs_scalar) = regex_array.get(); |
429 | | |
430 | | if array.data_type() != rhs.data_type() { |
431 | | return Err(ArrowError::ComputeError( |
432 | | "regexp_match() requires both array and pattern to be either Utf8, Utf8View or LargeUtf8" |
433 | | .to_string(), |
434 | | )); |
435 | | } |
436 | | |
437 | | let (flags, is_flags_scalar) = match flags_array { |
438 | | Some(flags) => { |
439 | | let (flags, is_flags_scalar) = flags.get(); |
440 | | (Some(flags), Some(is_flags_scalar)) |
441 | | } |
442 | | None => (None, None), |
443 | | }; |
444 | | |
445 | | if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() { |
446 | | return Err(ArrowError::ComputeError( |
447 | | "regexp_match() requires both pattern and flags to be either scalar or array" |
448 | | .to_string(), |
449 | | )); |
450 | | } |
451 | | |
452 | | if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() { |
453 | | return Err(ArrowError::ComputeError( |
454 | | "regexp_match() requires both pattern and flags to be either Utf8, Utf8View or LargeUtf8" |
455 | | .to_string(), |
456 | | )); |
457 | | } |
458 | | |
459 | | if is_rhs_scalar { |
460 | | // Regex and flag is scalars |
461 | | let (regex, flag) = match rhs.data_type() { |
462 | | DataType::Utf8View => get_scalar_pattern_flag_utf8view(rhs, flags), |
463 | | DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags), |
464 | | DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags), |
465 | | _ => { |
466 | | return Err(ArrowError::ComputeError( |
467 | | "regexp_match() requires pattern to be either Utf8, Utf8View or LargeUtf8" |
468 | | .to_string(), |
469 | | )); |
470 | | } |
471 | | }; |
472 | | |
473 | | if regex.is_none() { |
474 | | return Ok(new_null_array( |
475 | | &DataType::List(Arc::new(Field::new_list_field( |
476 | | array.data_type().clone(), |
477 | | true, |
478 | | ))), |
479 | | array.len(), |
480 | | )); |
481 | | } |
482 | | |
483 | | let regex = regex.unwrap(); |
484 | | |
485 | | let pattern = if let Some(flag) = flag { |
486 | | format!("(?{flag}){regex}") |
487 | | } else { |
488 | | regex.to_string() |
489 | | }; |
490 | | |
491 | 0 | let re = Regex::new(pattern.as_str()).map_err(|e| { |
492 | 0 | ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}")) |
493 | 0 | })?; |
494 | | |
495 | | match array.data_type() { |
496 | | DataType::Utf8View => regexp_scalar_match_utf8view(array.as_string_view(), &re), |
497 | | DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), &re), |
498 | | DataType::LargeUtf8 => regexp_scalar_match(array.as_string::<i64>(), &re), |
499 | | _ => Err(ArrowError::ComputeError( |
500 | | "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8" |
501 | | .to_string(), |
502 | | )), |
503 | | } |
504 | | } else { |
505 | | match array.data_type() { |
506 | | DataType::Utf8View => { |
507 | | let regex_array = rhs.as_string_view(); |
508 | 0 | let flags_array = flags.map(|flags| flags.as_string_view()); |
509 | | regexp_array_match_utf8view(array.as_string_view(), regex_array, flags_array) |
510 | | } |
511 | | DataType::Utf8 => { |
512 | | let regex_array = rhs.as_string(); |
513 | 0 | let flags_array = flags.map(|flags| flags.as_string()); |
514 | | regexp_array_match(array.as_string::<i32>(), regex_array, flags_array) |
515 | | } |
516 | | DataType::LargeUtf8 => { |
517 | | let regex_array = rhs.as_string(); |
518 | 0 | let flags_array = flags.map(|flags| flags.as_string()); |
519 | | regexp_array_match(array.as_string::<i64>(), regex_array, flags_array) |
520 | | } |
521 | | _ => Err(ArrowError::ComputeError( |
522 | | "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8" |
523 | | .to_string(), |
524 | | )), |
525 | | } |
526 | | } |
527 | | } |
528 | | |
529 | | #[cfg(test)] |
530 | | mod tests { |
531 | | use super::*; |
532 | | |
533 | | macro_rules! test_match_single_group { |
534 | | ($test_name:ident, $values:expr, $patterns:expr, $arr_type:ty, $builder_type:ty, $expected:expr) => { |
535 | | #[test] |
536 | | fn $test_name() { |
537 | | let array: $arr_type = <$arr_type>::from($values); |
538 | | let pattern: $arr_type = <$arr_type>::from($patterns); |
539 | | |
540 | | let actual = regexp_match(&array, &pattern, None).unwrap(); |
541 | | |
542 | | let elem_builder: $builder_type = <$builder_type>::new(); |
543 | | let mut expected_builder = ListBuilder::new(elem_builder); |
544 | | |
545 | | for val in $expected { |
546 | | match val { |
547 | | Some(v) => { |
548 | | expected_builder.values().append_value(v); |
549 | | expected_builder.append(true); |
550 | | } |
551 | | None => expected_builder.append(false), |
552 | | } |
553 | | } |
554 | | |
555 | | let expected = expected_builder.finish(); |
556 | | let result = actual.as_any().downcast_ref::<ListArray>().unwrap(); |
557 | | assert_eq!(&expected, result); |
558 | | } |
559 | | }; |
560 | | } |
561 | | |
562 | | test_match_single_group!( |
563 | | match_single_group_string, |
564 | | vec![ |
565 | | Some("abc-005-def"), |
566 | | Some("X-7-5"), |
567 | | Some("X545"), |
568 | | None, |
569 | | Some("foobarbequebaz"), |
570 | | Some("foobarbequebaz"), |
571 | | ], |
572 | | vec![ |
573 | | r".*-(\d*)-.*", |
574 | | r".*-(\d*)-.*", |
575 | | r".*-(\d*)-.*", |
576 | | r".*-(\d*)-.*", |
577 | | r"(bar)(bequ1e)", |
578 | | "" |
579 | | ], |
580 | | StringArray, |
581 | | GenericStringBuilder<i32>, |
582 | | [Some("005"), Some("7"), None, None, None, Some("")] |
583 | | ); |
584 | | test_match_single_group!( |
585 | | match_single_group_string_view, |
586 | | vec![ |
587 | | Some("abc-005-def"), |
588 | | Some("X-7-5"), |
589 | | Some("X545"), |
590 | | None, |
591 | | Some("foobarbequebaz"), |
592 | | Some("foobarbequebaz"), |
593 | | ], |
594 | | vec![ |
595 | | r".*-(\d*)-.*", |
596 | | r".*-(\d*)-.*", |
597 | | r".*-(\d*)-.*", |
598 | | r".*-(\d*)-.*", |
599 | | r"(bar)(bequ1e)", |
600 | | "" |
601 | | ], |
602 | | StringViewArray, |
603 | | StringViewBuilder, |
604 | | [Some("005"), Some("7"), None, None, None, Some("")] |
605 | | ); |
606 | | |
607 | | macro_rules! test_match_single_group_with_flags { |
608 | | ($test_name:ident, $values:expr, $patterns:expr, $flags:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { |
609 | | #[test] |
610 | | fn $test_name() { |
611 | | let array: $array_type = <$array_type>::from($values); |
612 | | let pattern: $array_type = <$array_type>::from($patterns); |
613 | | let flags: $array_type = <$array_type>::from($flags); |
614 | | |
615 | | let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); |
616 | | |
617 | | let elem_builder: $builder_type = <$builder_type>::new(); |
618 | | let mut expected_builder = ListBuilder::new(elem_builder); |
619 | | |
620 | | for val in $expected { |
621 | | match val { |
622 | | Some(v) => { |
623 | | expected_builder.values().append_value(v); |
624 | | expected_builder.append(true); |
625 | | } |
626 | | None => { |
627 | | expected_builder.append(false); |
628 | | } |
629 | | } |
630 | | } |
631 | | |
632 | | let expected = expected_builder.finish(); |
633 | | let result = actual.as_any().downcast_ref::<ListArray>().unwrap(); |
634 | | assert_eq!(&expected, result); |
635 | | } |
636 | | }; |
637 | | } |
638 | | |
639 | | test_match_single_group_with_flags!( |
640 | | match_single_group_with_flags_string, |
641 | | vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], |
642 | | vec![r"x.*-(\d*)-.*"; 4], |
643 | | vec!["i"; 4], |
644 | | StringArray, |
645 | | GenericStringBuilder<i32>, |
646 | | [None, Some("7"), None, None] |
647 | | ); |
648 | | test_match_single_group_with_flags!( |
649 | | match_single_group_with_flags_stringview, |
650 | | vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], |
651 | | vec![r"x.*-(\d*)-.*"; 4], |
652 | | vec!["i"; 4], |
653 | | StringViewArray, |
654 | | StringViewBuilder, |
655 | | [None, Some("7"), None, None] |
656 | | ); |
657 | | |
658 | | macro_rules! test_match_scalar_pattern { |
659 | | ($test_name:ident, $values:expr, $pattern:expr, $flag:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { |
660 | | #[test] |
661 | | fn $test_name() { |
662 | | let array: $array_type = <$array_type>::from($values); |
663 | | |
664 | | let pattern_scalar = Scalar::new(<$array_type>::from(vec![$pattern; 1])); |
665 | | let flag_scalar = Scalar::new(<$array_type>::from(vec![$flag; 1])); |
666 | | |
667 | | let actual = regexp_match(&array, &pattern_scalar, Some(&flag_scalar)).unwrap(); |
668 | | |
669 | | let elem_builder: $builder_type = <$builder_type>::new(); |
670 | | let mut expected_builder = ListBuilder::new(elem_builder); |
671 | | |
672 | | for val in $expected { |
673 | | match val { |
674 | | Some(v) => { |
675 | | expected_builder.values().append_value(v); |
676 | | expected_builder.append(true); |
677 | | } |
678 | | None => expected_builder.append(false), |
679 | | } |
680 | | } |
681 | | |
682 | | let expected = expected_builder.finish(); |
683 | | let result = actual.as_any().downcast_ref::<ListArray>().unwrap(); |
684 | | assert_eq!(&expected, result); |
685 | | } |
686 | | }; |
687 | | } |
688 | | |
689 | | test_match_scalar_pattern!( |
690 | | match_scalar_pattern_string_with_flags, |
691 | | vec![ |
692 | | Some("abc-005-def"), |
693 | | Some("x-7-5"), |
694 | | Some("X-0-Y"), |
695 | | Some("X545"), |
696 | | None |
697 | | ], |
698 | | r"x.*-(\d*)-.*", |
699 | | Some("i"), |
700 | | StringArray, |
701 | | GenericStringBuilder<i32>, |
702 | | [None, Some("7"), Some("0"), None, None] |
703 | | ); |
704 | | test_match_scalar_pattern!( |
705 | | match_scalar_pattern_stringview_with_flags, |
706 | | vec![ |
707 | | Some("abc-005-def"), |
708 | | Some("x-7-5"), |
709 | | Some("X-0-Y"), |
710 | | Some("X545"), |
711 | | None |
712 | | ], |
713 | | r"x.*-(\d*)-.*", |
714 | | Some("i"), |
715 | | StringViewArray, |
716 | | StringViewBuilder, |
717 | | [None, Some("7"), Some("0"), None, None] |
718 | | ); |
719 | | |
720 | | test_match_scalar_pattern!( |
721 | | match_scalar_pattern_string_no_flags, |
722 | | vec![ |
723 | | Some("abc-005-def"), |
724 | | Some("x-7-5"), |
725 | | Some("X-0-Y"), |
726 | | Some("X545"), |
727 | | None |
728 | | ], |
729 | | r"x.*-(\d*)-.*", |
730 | | None::<&str>, |
731 | | StringArray, |
732 | | GenericStringBuilder<i32>, |
733 | | [None, Some("7"), None, None, None] |
734 | | ); |
735 | | test_match_scalar_pattern!( |
736 | | match_scalar_pattern_stringview_no_flags, |
737 | | vec![ |
738 | | Some("abc-005-def"), |
739 | | Some("x-7-5"), |
740 | | Some("X-0-Y"), |
741 | | Some("X545"), |
742 | | None |
743 | | ], |
744 | | r"x.*-(\d*)-.*", |
745 | | None::<&str>, |
746 | | StringViewArray, |
747 | | StringViewBuilder, |
748 | | [None, Some("7"), None, None, None] |
749 | | ); |
750 | | |
751 | | macro_rules! test_match_scalar_no_pattern { |
752 | | ($test_name:ident, $values:expr, $array_type:ty, $pattern_type:expr, $builder_type:ty, $expected:expr) => { |
753 | | #[test] |
754 | | fn $test_name() { |
755 | | let array: $array_type = <$array_type>::from($values); |
756 | | let pattern = Scalar::new(new_null_array(&$pattern_type, 1)); |
757 | | |
758 | | let actual = regexp_match(&array, &pattern, None).unwrap(); |
759 | | |
760 | | let elem_builder: $builder_type = <$builder_type>::new(); |
761 | | let mut expected_builder = ListBuilder::new(elem_builder); |
762 | | |
763 | | for val in $expected { |
764 | | match val { |
765 | | Some(v) => { |
766 | | expected_builder.values().append_value(v); |
767 | | expected_builder.append(true); |
768 | | } |
769 | | None => expected_builder.append(false), |
770 | | } |
771 | | } |
772 | | |
773 | | let expected = expected_builder.finish(); |
774 | | let result = actual.as_any().downcast_ref::<ListArray>().unwrap(); |
775 | | assert_eq!(&expected, result); |
776 | | } |
777 | | }; |
778 | | } |
779 | | |
780 | | test_match_scalar_no_pattern!( |
781 | | match_scalar_no_pattern_string, |
782 | | vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], |
783 | | StringArray, |
784 | | DataType::Utf8, |
785 | | GenericStringBuilder<i32>, |
786 | | [None::<&str>, None, None, None] |
787 | | ); |
788 | | test_match_scalar_no_pattern!( |
789 | | match_scalar_no_pattern_stringview, |
790 | | vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None], |
791 | | StringViewArray, |
792 | | DataType::Utf8View, |
793 | | StringViewBuilder, |
794 | | [None::<&str>, None, None, None] |
795 | | ); |
796 | | |
797 | | macro_rules! test_match_single_group_not_skip { |
798 | | ($test_name:ident, $values:expr, $pattern:expr, $array_type:ty, $builder_type:ty, $expected:expr) => { |
799 | | #[test] |
800 | | fn $test_name() { |
801 | | let array: $array_type = <$array_type>::from($values); |
802 | | let pattern: $array_type = <$array_type>::from(vec![$pattern]); |
803 | | |
804 | | let actual = regexp_match(&array, &pattern, None).unwrap(); |
805 | | |
806 | | let elem_builder: $builder_type = <$builder_type>::new(); |
807 | | let mut expected_builder = ListBuilder::new(elem_builder); |
808 | | |
809 | | for val in $expected { |
810 | | match val { |
811 | | Some(v) => { |
812 | | expected_builder.values().append_value(v); |
813 | | expected_builder.append(true); |
814 | | } |
815 | | None => expected_builder.append(false), |
816 | | } |
817 | | } |
818 | | |
819 | | let expected = expected_builder.finish(); |
820 | | let result = actual.as_any().downcast_ref::<ListArray>().unwrap(); |
821 | | assert_eq!(&expected, result); |
822 | | } |
823 | | }; |
824 | | } |
825 | | |
826 | | test_match_single_group_not_skip!( |
827 | | match_single_group_not_skip_string, |
828 | | vec![Some("foo"), Some("bar")], |
829 | | r"foo", |
830 | | StringArray, |
831 | | GenericStringBuilder<i32>, |
832 | | [Some("foo")] |
833 | | ); |
834 | | test_match_single_group_not_skip!( |
835 | | match_single_group_not_skip_stringview, |
836 | | vec![Some("foo"), Some("bar")], |
837 | | r"foo", |
838 | | StringViewArray, |
839 | | StringViewBuilder, |
840 | | [Some("foo")] |
841 | | ); |
842 | | |
843 | | macro_rules! test_flag_utf8 { |
844 | | ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { |
845 | | #[test] |
846 | | fn $test_name() { |
847 | | let left = $left; |
848 | | let right = $right; |
849 | | let res = $op(&left, &right, None).unwrap(); |
850 | | let expected = $expected; |
851 | | assert_eq!(expected.len(), res.len()); |
852 | | for i in 0..res.len() { |
853 | | let v = res.value(i); |
854 | | assert_eq!(v, expected[i]); |
855 | | } |
856 | | } |
857 | | }; |
858 | | ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => { |
859 | | #[test] |
860 | | fn $test_name() { |
861 | | let left = $left; |
862 | | let right = $right; |
863 | | let flag = Some($flag); |
864 | | let res = $op(&left, &right, flag.as_ref()).unwrap(); |
865 | | let expected = $expected; |
866 | | assert_eq!(expected.len(), res.len()); |
867 | | for i in 0..res.len() { |
868 | | let v = res.value(i); |
869 | | assert_eq!(v, expected[i]); |
870 | | } |
871 | | } |
872 | | }; |
873 | | } |
874 | | |
875 | | macro_rules! test_flag_utf8_scalar { |
876 | | ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { |
877 | | #[test] |
878 | | fn $test_name() { |
879 | | let left = $left; |
880 | | let res = $op(&left, $right, None).unwrap(); |
881 | | let expected = $expected; |
882 | | assert_eq!(expected.len(), res.len()); |
883 | | for i in 0..res.len() { |
884 | | let v = res.value(i); |
885 | | assert_eq!( |
886 | | v, |
887 | | expected[i], |
888 | | "unexpected result when comparing {} at position {} to {} ", |
889 | | left.value(i), |
890 | | i, |
891 | | $right |
892 | | ); |
893 | | } |
894 | | } |
895 | | }; |
896 | | ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => { |
897 | | #[test] |
898 | | fn $test_name() { |
899 | | let left = $left; |
900 | | let flag = Some($flag); |
901 | | let res = $op(&left, $right, flag).unwrap(); |
902 | | let expected = $expected; |
903 | | assert_eq!(expected.len(), res.len()); |
904 | | for i in 0..res.len() { |
905 | | let v = res.value(i); |
906 | | assert_eq!( |
907 | | v, |
908 | | expected[i], |
909 | | "unexpected result when comparing {} at position {} to {} ", |
910 | | left.value(i), |
911 | | i, |
912 | | $right |
913 | | ); |
914 | | } |
915 | | } |
916 | | }; |
917 | | } |
918 | | |
919 | | test_flag_utf8!( |
920 | | test_array_regexp_is_match_utf8, |
921 | | StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), |
922 | | StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), |
923 | | regexp_is_match::<StringArray, StringArray, StringArray>, |
924 | | [true, false, true, false, false, true] |
925 | | ); |
926 | | test_flag_utf8!( |
927 | | test_array_regexp_is_match_utf8_insensitive, |
928 | | StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), |
929 | | StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), |
930 | | StringArray::from(vec!["i"; 6]), |
931 | | regexp_is_match, |
932 | | [true, true, true, true, false, true] |
933 | | ); |
934 | | |
935 | | test_flag_utf8_scalar!( |
936 | | test_array_regexp_is_match_utf8_scalar, |
937 | | StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), |
938 | | "^ar", |
939 | | regexp_is_match_scalar, |
940 | | [true, false, false, false] |
941 | | ); |
942 | | test_flag_utf8_scalar!( |
943 | | test_array_regexp_is_match_utf8_scalar_empty, |
944 | | StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), |
945 | | "", |
946 | | regexp_is_match_scalar, |
947 | | [true, true, true, true] |
948 | | ); |
949 | | test_flag_utf8_scalar!( |
950 | | test_array_regexp_is_match_utf8_scalar_insensitive, |
951 | | StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]), |
952 | | "^ar", |
953 | | "i", |
954 | | regexp_is_match_scalar, |
955 | | [true, true, false, false] |
956 | | ); |
957 | | |
958 | | test_flag_utf8!( |
959 | | tes_array_regexp_is_match, |
960 | | StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), |
961 | | StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), |
962 | | regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>, |
963 | | [true, false, true, false, false, true] |
964 | | ); |
965 | | test_flag_utf8!( |
966 | | test_array_regexp_is_match_2, |
967 | | StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), |
968 | | StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), |
969 | | regexp_is_match::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>, |
970 | | [true, false, true, false, false, true] |
971 | | ); |
972 | | test_flag_utf8!( |
973 | | test_array_regexp_is_match_insensitive, |
974 | | StringViewArray::from(vec![ |
975 | | "Official Rust implementation of Apache Arrow", |
976 | | "apache/arrow-rs", |
977 | | "apache/arrow-rs", |
978 | | "parquet", |
979 | | "parquet", |
980 | | "row", |
981 | | "row", |
982 | | ]), |
983 | | StringViewArray::from(vec![ |
984 | | ".*rust implement.*", |
985 | | "^ap", |
986 | | "^AP", |
987 | | "et$", |
988 | | "ET$", |
989 | | "foo", |
990 | | "" |
991 | | ]), |
992 | | StringViewArray::from(vec!["i"; 7]), |
993 | | regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>, |
994 | | [true, true, true, true, true, false, true] |
995 | | ); |
996 | | test_flag_utf8!( |
997 | | test_array_regexp_is_match_insensitive_2, |
998 | | LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]), |
999 | | StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]), |
1000 | | StringArray::from(vec!["i"; 6]), |
1001 | | regexp_is_match::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>, |
1002 | | [true, true, true, true, false, true] |
1003 | | ); |
1004 | | |
1005 | | test_flag_utf8_scalar!( |
1006 | | test_array_regexp_is_match_scalar, |
1007 | | StringViewArray::from(vec![ |
1008 | | "apache/arrow-rs", |
1009 | | "APACHE/ARROW-RS", |
1010 | | "parquet", |
1011 | | "PARQUET", |
1012 | | ]), |
1013 | | "^ap", |
1014 | | regexp_is_match_scalar::<StringViewArray>, |
1015 | | [true, false, false, false] |
1016 | | ); |
1017 | | test_flag_utf8_scalar!( |
1018 | | test_array_regexp_is_match_scalar_empty, |
1019 | | StringViewArray::from(vec![ |
1020 | | "apache/arrow-rs", |
1021 | | "APACHE/ARROW-RS", |
1022 | | "parquet", |
1023 | | "PARQUET", |
1024 | | ]), |
1025 | | "", |
1026 | | regexp_is_match_scalar::<StringViewArray>, |
1027 | | [true, true, true, true] |
1028 | | ); |
1029 | | test_flag_utf8_scalar!( |
1030 | | test_array_regexp_is_match_scalar_insensitive, |
1031 | | StringViewArray::from(vec![ |
1032 | | "apache/arrow-rs", |
1033 | | "APACHE/ARROW-RS", |
1034 | | "parquet", |
1035 | | "PARQUET", |
1036 | | ]), |
1037 | | "^ap", |
1038 | | "i", |
1039 | | regexp_is_match_scalar::<StringViewArray>, |
1040 | | [true, true, false, false] |
1041 | | ); |
1042 | | } |