Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-string/src/regexp.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines kernel to extract substrings based on a regular
19
//! expression of a \[Large\]StringArray
20
21
use crate::like::StringArrayType;
22
23
use arrow_array::builder::{
24
    BooleanBufferBuilder, GenericStringBuilder, ListBuilder, StringViewBuilder,
25
};
26
use arrow_array::cast::AsArray;
27
use arrow_array::*;
28
use arrow_buffer::NullBuffer;
29
use arrow_data::{ArrayData, ArrayDataBuilder};
30
use arrow_schema::{ArrowError, DataType, Field};
31
use regex::Regex;
32
33
use std::collections::HashMap;
34
use std::sync::Arc;
35
36
/// Return BooleanArray indicating which strings in an array match an array of
37
/// regular expressions.
38
///
39
/// This is equivalent to the SQL `array ~ regex_array`, supporting
40
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
41
///
42
/// If `regex_array` element has an empty value, the corresponding result value is always true.
43
///
44
/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
45
/// which allow special search modes, such as case-insensitive and multi-line mode.
46
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
47
/// for more information.
48
///
49
/// # See Also
50
/// * [`regexp_is_match_scalar`] for matching a single regular expression against an array of strings
51
/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
52
///
53
/// # Example
54
/// ```
55
/// # use arrow_array::{StringArray, BooleanArray};
56
/// # use arrow_string::regexp::regexp_is_match;
57
/// // First array is the array of strings to match
58
/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
59
/// // Second array is the array of regular expressions to match against
60
/// let regex_array = StringArray::from(vec!["^Foo", "^Foo", "Bar$", "Baz"]);
61
/// // Third array is the array of flags to use for each regular expression, if desired
62
/// // (the type must be provided to satisfy type inference for the third parameter)
63
/// let flags_array: Option<&StringArray> = None;
64
/// // The result is a BooleanArray indicating when each string in `array`
65
/// // matches the corresponding regular expression in `regex_array`
66
/// let result = regexp_is_match(&array, &regex_array, flags_array).unwrap();
67
/// assert_eq!(result, BooleanArray::from(vec![true, false, true, true]));
68
/// ```
69
pub fn regexp_is_match<'a, S1, S2, S3>(
70
    array: &'a S1,
71
    regex_array: &'a S2,
72
    flags_array: Option<&'a S3>,
73
) -> Result<BooleanArray, ArrowError>
74
where
75
    &'a S1: StringArrayType<'a>,
76
    &'a S2: StringArrayType<'a>,
77
    &'a S3: StringArrayType<'a>,
78
{
79
    if array.len() != regex_array.len() {
80
        return Err(ArrowError::ComputeError(
81
            "Cannot perform comparison operation on arrays of different length".to_string(),
82
        ));
83
    }
84
85
    let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());
86
87
    let mut patterns: HashMap<String, Regex> = HashMap::new();
88
    let mut result = BooleanBufferBuilder::new(array.len());
89
90
    let complete_pattern = match flags_array {
91
        Some(flags) => Box::new(
92
            regex_array
93
                .iter()
94
                .zip(flags.iter())
95
                .map(|(pattern, flags)| {
96
                    pattern.map(|pattern| match flags {
97
                        Some(flag) => format!("(?{flag}){pattern}"),
98
                        None => pattern.to_string(),
99
                    })
100
                }),
101
        ) as Box<dyn Iterator<Item = Option<String>>>,
102
        None => Box::new(
103
            regex_array
104
                .iter()
105
                .map(|pattern| pattern.map(|pattern| pattern.to_string())),
106
        ),
107
    };
108
109
    array
110
        .iter()
111
        .zip(complete_pattern)
112
        .map(|(value, pattern)| {
113
            match (value, pattern) {
114
                // Required for Postgres compatibility:
115
                // SELECT 'foobarbequebaz' ~ ''); = true
116
                (Some(_), Some(pattern)) if pattern == *"" => {
117
                    result.append(true);
118
                }
119
                (Some(value), Some(pattern)) => {
120
                    let existing_pattern = patterns.get(&pattern);
121
                    let re = match existing_pattern {
122
                        Some(re) => re,
123
                        None => {
124
                            let re = Regex::new(pattern.as_str()).map_err(|e| {
125
                                ArrowError::ComputeError(format!(
126
                                    "Regular expression did not compile: {e:?}"
127
                                ))
128
                            })?;
129
                            patterns.entry(pattern).or_insert(re)
130
                        }
131
                    };
132
                    result.append(re.is_match(value));
133
                }
134
                _ => result.append(false),
135
            }
136
            Ok(())
137
        })
138
        .collect::<Result<Vec<()>, ArrowError>>()?;
139
140
    let data = unsafe {
141
        ArrayDataBuilder::new(DataType::Boolean)
142
            .len(array.len())
143
            .buffers(vec![result.into()])
144
            .nulls(nulls)
145
            .build_unchecked()
146
    };
147
148
    Ok(BooleanArray::from(data))
149
}
150
151
/// Return BooleanArray indicating which strings in an array match a single regular expression.
152
///
153
/// This is equivalent to the SQL `array ~ regex_array`, supporting
154
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] and a scalar.
155
///
156
/// See the documentation on [`regexp_is_match`] for more details on arguments
157
///
158
/// # See Also
159
/// * [`regexp_is_match`] for matching an array of regular expression against an array of strings
160
/// * [`regexp_match`] for extracting groups from a string array based on a regular expression
161
///
162
/// # Example
163
/// ```
164
/// # use arrow_array::{StringArray, BooleanArray};
165
/// # use arrow_string::regexp::regexp_is_match_scalar;
166
/// // array of strings to match
167
/// let array = StringArray::from(vec!["Foo", "Bar", "FooBar", "Baz"]);
168
/// let regexp = "^Foo"; // regular expression to match against
169
/// let flags: Option<&str> = None;  // flags can control the matching behavior
170
/// // The result is a BooleanArray indicating when each string in `array`
171
/// // matches the regular expression `regexp`
172
/// let result = regexp_is_match_scalar(&array, regexp, None).unwrap();
173
/// assert_eq!(result, BooleanArray::from(vec![true, false, true, false]));
174
/// ```
175
pub fn regexp_is_match_scalar<'a, S>(
176
    array: &'a S,
177
    regex: &str,
178
    flag: Option<&str>,
179
) -> Result<BooleanArray, ArrowError>
180
where
181
    &'a S: StringArrayType<'a>,
182
{
183
    let null_bit_buffer = array.nulls().map(|x| x.inner().sliced());
184
    let mut result = BooleanBufferBuilder::new(array.len());
185
186
    let pattern = match flag {
187
        Some(flag) => format!("(?{flag}){regex}"),
188
        None => regex.to_string(),
189
    };
190
191
    if pattern.is_empty() {
192
        result.append_n(array.len(), true);
193
    } else {
194
        let re = Regex::new(pattern.as_str()).map_err(|e| {
195
            ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
196
        })?;
197
        for i in 0..array.len() {
198
            let value = array.value(i);
199
            result.append(re.is_match(value));
200
        }
201
    }
202
203
    let buffer = result.into();
204
    let data = unsafe {
205
        ArrayData::new_unchecked(
206
            DataType::Boolean,
207
            array.len(),
208
            None,
209
            null_bit_buffer,
210
            0,
211
            vec![buffer],
212
            vec![],
213
        )
214
    };
215
216
    Ok(BooleanArray::from(data))
217
}
218
219
macro_rules! process_regexp_array_match {
220
    ($array:expr, $regex_array:expr, $flags_array:expr, $list_builder:expr) => {
221
        let mut patterns: HashMap<String, Regex> = HashMap::new();
222
223
        let complete_pattern = match $flags_array {
224
            Some(flags) => Box::new($regex_array.iter().zip(flags.iter()).map(
225
0
                |(pattern, flags)| {
226
0
                    pattern.map(|pattern| match flags {
227
0
                        Some(value) => format!("(?{value}){pattern}"),
228
0
                        None => pattern.to_string(),
229
0
                    })
230
0
                },
231
            )) as Box<dyn Iterator<Item = Option<String>>>,
232
            None => Box::new(
233
                $regex_array
234
                    .iter()
235
0
                    .map(|pattern| pattern.map(|pattern| pattern.to_string())),
236
            ),
237
        };
238
239
        $array
240
            .iter()
241
            .zip(complete_pattern)
242
0
            .map(|(value, pattern)| {
243
0
                match (value, pattern) {
244
                    // Required for Postgres compatibility:
245
                    // SELECT regexp_match('foobarbequebaz', ''); = {""}
246
0
                    (Some(_), Some(pattern)) if pattern == *"" => {
247
0
                        $list_builder.values().append_value("");
248
0
                        $list_builder.append(true);
249
0
                    }
250
0
                    (Some(value), Some(pattern)) => {
251
0
                        let existing_pattern = patterns.get(&pattern);
252
0
                        let re = match existing_pattern {
253
0
                            Some(re) => re,
254
                            None => {
255
0
                                let re = Regex::new(pattern.as_str()).map_err(|e| {
256
0
                                    ArrowError::ComputeError(format!(
257
0
                                        "Regular expression did not compile: {e:?}"
258
0
                                    ))
259
0
                                })?;
260
0
                                patterns.entry(pattern).or_insert(re)
261
                            }
262
                        };
263
0
                        match re.captures(value) {
264
0
                            Some(caps) => {
265
0
                                let mut iter = caps.iter();
266
0
                                if caps.len() > 1 {
267
0
                                    iter.next();
268
0
                                }
269
0
                                for m in iter.flatten() {
270
0
                                    $list_builder.values().append_value(m.as_str());
271
0
                                }
272
273
0
                                $list_builder.append(true);
274
                            }
275
0
                            None => $list_builder.append(false),
276
                        }
277
                    }
278
0
                    _ => $list_builder.append(false),
279
                }
280
0
                Ok(())
281
0
            })
282
            .collect::<Result<Vec<()>, ArrowError>>()?;
283
    };
284
}
285
286
0
fn regexp_array_match<OffsetSize: OffsetSizeTrait>(
287
0
    array: &GenericStringArray<OffsetSize>,
288
0
    regex_array: &GenericStringArray<OffsetSize>,
289
0
    flags_array: Option<&GenericStringArray<OffsetSize>>,
290
0
) -> Result<ArrayRef, ArrowError> {
291
0
    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
292
0
    let mut list_builder = ListBuilder::new(builder);
293
294
0
    process_regexp_array_match!(array, regex_array, flags_array, list_builder);
295
296
0
    Ok(Arc::new(list_builder.finish()))
297
0
}
298
299
fn regexp_array_match_utf8view(
300
    array: &StringViewArray,
301
    regex_array: &StringViewArray,
302
    flags_array: Option<&StringViewArray>,
303
) -> Result<ArrayRef, ArrowError> {
304
    let builder = StringViewBuilder::with_capacity(0);
305
    let mut list_builder = ListBuilder::new(builder);
306
307
    process_regexp_array_match!(array, regex_array, flags_array, list_builder);
308
309
    Ok(Arc::new(list_builder.finish()))
310
}
311
312
0
fn get_scalar_pattern_flag<'a, OffsetSize: OffsetSizeTrait>(
313
0
    regex_array: &'a dyn Array,
314
0
    flag_array: Option<&'a dyn Array>,
315
0
) -> (Option<&'a str>, Option<&'a str>) {
316
0
    let regex = regex_array.as_string::<OffsetSize>();
317
0
    let regex = regex.is_valid(0).then(|| regex.value(0));
318
319
0
    if let Some(flag_array) = flag_array {
320
0
        let flag = flag_array.as_string::<OffsetSize>();
321
0
        (regex, flag.is_valid(0).then(|| flag.value(0)))
322
    } else {
323
0
        (regex, None)
324
    }
325
0
}
326
327
fn get_scalar_pattern_flag_utf8view<'a>(
328
    regex_array: &'a dyn Array,
329
    flag_array: Option<&'a dyn Array>,
330
) -> (Option<&'a str>, Option<&'a str>) {
331
    let regex = regex_array.as_string_view();
332
0
    let regex = regex.is_valid(0).then(|| regex.value(0));
333
334
    if let Some(flag_array) = flag_array {
335
        let flag = flag_array.as_string_view();
336
0
        (regex, flag.is_valid(0).then(|| flag.value(0)))
337
    } else {
338
        (regex, None)
339
    }
340
}
341
342
macro_rules! process_regexp_match {
343
    ($array:expr, $regex:expr, $list_builder:expr) => {
344
        $array
345
            .iter()
346
0
            .map(|value| {
347
0
                match value {
348
                    // Required for Postgres compatibility:
349
                    // SELECT regexp_match('foobarbequebaz', ''); = {""}
350
0
                    Some(_) if $regex.as_str().is_empty() => {
351
0
                        $list_builder.values().append_value("");
352
0
                        $list_builder.append(true);
353
0
                    }
354
0
                    Some(value) => match $regex.captures(value) {
355
0
                        Some(caps) => {
356
0
                            let mut iter = caps.iter();
357
0
                            if caps.len() > 1 {
358
0
                                iter.next();
359
0
                            }
360
0
                            for m in iter.flatten() {
361
0
                                $list_builder.values().append_value(m.as_str());
362
0
                            }
363
0
                            $list_builder.append(true);
364
                        }
365
0
                        None => $list_builder.append(false),
366
                    },
367
0
                    None => $list_builder.append(false),
368
                }
369
0
                Ok(())
370
0
            })
371
            .collect::<Result<Vec<()>, ArrowError>>()?
372
    };
373
}
374
375
0
fn regexp_scalar_match<OffsetSize: OffsetSizeTrait>(
376
0
    array: &GenericStringArray<OffsetSize>,
377
0
    regex: &Regex,
378
0
) -> Result<ArrayRef, ArrowError> {
379
0
    let builder: GenericStringBuilder<OffsetSize> = GenericStringBuilder::with_capacity(0, 0);
380
0
    let mut list_builder = ListBuilder::new(builder);
381
382
0
    process_regexp_match!(array, regex, list_builder);
383
384
0
    Ok(Arc::new(list_builder.finish()))
385
0
}
386
387
fn regexp_scalar_match_utf8view(
388
    array: &StringViewArray,
389
    regex: &Regex,
390
) -> Result<ArrayRef, ArrowError> {
391
    let builder = StringViewBuilder::with_capacity(0);
392
    let mut list_builder = ListBuilder::new(builder);
393
394
    process_regexp_match!(array, regex, list_builder);
395
396
    Ok(Arc::new(list_builder.finish()))
397
}
398
399
/// Extract all groups matched by a regular expression for a given String array.
400
///
401
/// Modelled after the Postgres [regexp_match].
402
///
403
/// Returns a ListArray of [`GenericStringArray`] with each element containing the leftmost-first
404
/// match of the corresponding index in `regex_array` to string in `array`
405
///
406
/// If there is no match, the list element is NULL.
407
///
408
/// If a match is found, and the pattern contains no capturing parenthesized subexpressions,
409
/// then the list element is a single-element [`GenericStringArray`] containing the substring
410
/// matching the whole pattern.
411
///
412
/// If a match is found, and the pattern contains capturing parenthesized subexpressions, then the
413
/// list element is a [`GenericStringArray`] whose n'th element is the substring matching
414
/// the n'th capturing parenthesized subexpression of the pattern.
415
///
416
/// The flags parameter is an optional text string containing zero or more single-letter flags
417
/// that change the function's behavior.
418
///
419
/// # See Also
420
/// * [`regexp_is_match`] for matching (rather than extracting) a regular expression against an array of strings
421
///
422
/// [regexp_match]: https://www.postgresql.org/docs/current/functions-matching.html#FUNCTIONS-POSIX-REGEXP
423
pub fn regexp_match(
424
    array: &dyn Array,
425
    regex_array: &dyn Datum,
426
    flags_array: Option<&dyn Datum>,
427
) -> Result<ArrayRef, ArrowError> {
428
    let (rhs, is_rhs_scalar) = regex_array.get();
429
430
    if array.data_type() != rhs.data_type() {
431
        return Err(ArrowError::ComputeError(
432
            "regexp_match() requires both array and pattern to be either Utf8, Utf8View or LargeUtf8"
433
                .to_string(),
434
        ));
435
    }
436
437
    let (flags, is_flags_scalar) = match flags_array {
438
        Some(flags) => {
439
            let (flags, is_flags_scalar) = flags.get();
440
            (Some(flags), Some(is_flags_scalar))
441
        }
442
        None => (None, None),
443
    };
444
445
    if is_flags_scalar.is_some() && is_rhs_scalar != is_flags_scalar.unwrap() {
446
        return Err(ArrowError::ComputeError(
447
            "regexp_match() requires both pattern and flags to be either scalar or array"
448
                .to_string(),
449
        ));
450
    }
451
452
    if flags_array.is_some() && rhs.data_type() != flags.unwrap().data_type() {
453
        return Err(ArrowError::ComputeError(
454
            "regexp_match() requires both pattern and flags to be either Utf8, Utf8View or LargeUtf8"
455
                .to_string(),
456
        ));
457
    }
458
459
    if is_rhs_scalar {
460
        // Regex and flag is scalars
461
        let (regex, flag) = match rhs.data_type() {
462
            DataType::Utf8View => get_scalar_pattern_flag_utf8view(rhs, flags),
463
            DataType::Utf8 => get_scalar_pattern_flag::<i32>(rhs, flags),
464
            DataType::LargeUtf8 => get_scalar_pattern_flag::<i64>(rhs, flags),
465
            _ => {
466
                return Err(ArrowError::ComputeError(
467
                    "regexp_match() requires pattern to be either Utf8, Utf8View or LargeUtf8"
468
                        .to_string(),
469
                ));
470
            }
471
        };
472
473
        if regex.is_none() {
474
            return Ok(new_null_array(
475
                &DataType::List(Arc::new(Field::new_list_field(
476
                    array.data_type().clone(),
477
                    true,
478
                ))),
479
                array.len(),
480
            ));
481
        }
482
483
        let regex = regex.unwrap();
484
485
        let pattern = if let Some(flag) = flag {
486
            format!("(?{flag}){regex}")
487
        } else {
488
            regex.to_string()
489
        };
490
491
0
        let re = Regex::new(pattern.as_str()).map_err(|e| {
492
0
            ArrowError::ComputeError(format!("Regular expression did not compile: {e:?}"))
493
0
        })?;
494
495
        match array.data_type() {
496
            DataType::Utf8View => regexp_scalar_match_utf8view(array.as_string_view(), &re),
497
            DataType::Utf8 => regexp_scalar_match(array.as_string::<i32>(), &re),
498
            DataType::LargeUtf8 => regexp_scalar_match(array.as_string::<i64>(), &re),
499
            _ => Err(ArrowError::ComputeError(
500
                "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8"
501
                    .to_string(),
502
            )),
503
        }
504
    } else {
505
        match array.data_type() {
506
            DataType::Utf8View => {
507
                let regex_array = rhs.as_string_view();
508
0
                let flags_array = flags.map(|flags| flags.as_string_view());
509
                regexp_array_match_utf8view(array.as_string_view(), regex_array, flags_array)
510
            }
511
            DataType::Utf8 => {
512
                let regex_array = rhs.as_string();
513
0
                let flags_array = flags.map(|flags| flags.as_string());
514
                regexp_array_match(array.as_string::<i32>(), regex_array, flags_array)
515
            }
516
            DataType::LargeUtf8 => {
517
                let regex_array = rhs.as_string();
518
0
                let flags_array = flags.map(|flags| flags.as_string());
519
                regexp_array_match(array.as_string::<i64>(), regex_array, flags_array)
520
            }
521
            _ => Err(ArrowError::ComputeError(
522
                "regexp_match() requires array to be either Utf8, Utf8View or LargeUtf8"
523
                    .to_string(),
524
            )),
525
        }
526
    }
527
}
528
529
#[cfg(test)]
530
mod tests {
531
    use super::*;
532
533
    macro_rules! test_match_single_group {
534
        ($test_name:ident, $values:expr, $patterns:expr, $arr_type:ty, $builder_type:ty, $expected:expr) => {
535
            #[test]
536
            fn $test_name() {
537
                let array: $arr_type = <$arr_type>::from($values);
538
                let pattern: $arr_type = <$arr_type>::from($patterns);
539
540
                let actual = regexp_match(&array, &pattern, None).unwrap();
541
542
                let elem_builder: $builder_type = <$builder_type>::new();
543
                let mut expected_builder = ListBuilder::new(elem_builder);
544
545
                for val in $expected {
546
                    match val {
547
                        Some(v) => {
548
                            expected_builder.values().append_value(v);
549
                            expected_builder.append(true);
550
                        }
551
                        None => expected_builder.append(false),
552
                    }
553
                }
554
555
                let expected = expected_builder.finish();
556
                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
557
                assert_eq!(&expected, result);
558
            }
559
        };
560
    }
561
562
    test_match_single_group!(
563
        match_single_group_string,
564
        vec![
565
            Some("abc-005-def"),
566
            Some("X-7-5"),
567
            Some("X545"),
568
            None,
569
            Some("foobarbequebaz"),
570
            Some("foobarbequebaz"),
571
        ],
572
        vec![
573
            r".*-(\d*)-.*",
574
            r".*-(\d*)-.*",
575
            r".*-(\d*)-.*",
576
            r".*-(\d*)-.*",
577
            r"(bar)(bequ1e)",
578
            ""
579
        ],
580
        StringArray,
581
        GenericStringBuilder<i32>,
582
        [Some("005"), Some("7"), None, None, None, Some("")]
583
    );
584
    test_match_single_group!(
585
        match_single_group_string_view,
586
        vec![
587
            Some("abc-005-def"),
588
            Some("X-7-5"),
589
            Some("X545"),
590
            None,
591
            Some("foobarbequebaz"),
592
            Some("foobarbequebaz"),
593
        ],
594
        vec![
595
            r".*-(\d*)-.*",
596
            r".*-(\d*)-.*",
597
            r".*-(\d*)-.*",
598
            r".*-(\d*)-.*",
599
            r"(bar)(bequ1e)",
600
            ""
601
        ],
602
        StringViewArray,
603
        StringViewBuilder,
604
        [Some("005"), Some("7"), None, None, None, Some("")]
605
    );
606
607
    macro_rules! test_match_single_group_with_flags {
608
        ($test_name:ident, $values:expr, $patterns:expr, $flags:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
609
            #[test]
610
            fn $test_name() {
611
                let array: $array_type = <$array_type>::from($values);
612
                let pattern: $array_type = <$array_type>::from($patterns);
613
                let flags: $array_type = <$array_type>::from($flags);
614
615
                let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap();
616
617
                let elem_builder: $builder_type = <$builder_type>::new();
618
                let mut expected_builder = ListBuilder::new(elem_builder);
619
620
                for val in $expected {
621
                    match val {
622
                        Some(v) => {
623
                            expected_builder.values().append_value(v);
624
                            expected_builder.append(true);
625
                        }
626
                        None => {
627
                            expected_builder.append(false);
628
                        }
629
                    }
630
                }
631
632
                let expected = expected_builder.finish();
633
                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
634
                assert_eq!(&expected, result);
635
            }
636
        };
637
    }
638
639
    test_match_single_group_with_flags!(
640
        match_single_group_with_flags_string,
641
        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
642
        vec![r"x.*-(\d*)-.*"; 4],
643
        vec!["i"; 4],
644
        StringArray,
645
        GenericStringBuilder<i32>,
646
        [None, Some("7"), None, None]
647
    );
648
    test_match_single_group_with_flags!(
649
        match_single_group_with_flags_stringview,
650
        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
651
        vec![r"x.*-(\d*)-.*"; 4],
652
        vec!["i"; 4],
653
        StringViewArray,
654
        StringViewBuilder,
655
        [None, Some("7"), None, None]
656
    );
657
658
    macro_rules! test_match_scalar_pattern {
659
        ($test_name:ident, $values:expr, $pattern:expr, $flag:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
660
            #[test]
661
            fn $test_name() {
662
                let array: $array_type = <$array_type>::from($values);
663
664
                let pattern_scalar = Scalar::new(<$array_type>::from(vec![$pattern; 1]));
665
                let flag_scalar = Scalar::new(<$array_type>::from(vec![$flag; 1]));
666
667
                let actual = regexp_match(&array, &pattern_scalar, Some(&flag_scalar)).unwrap();
668
669
                let elem_builder: $builder_type = <$builder_type>::new();
670
                let mut expected_builder = ListBuilder::new(elem_builder);
671
672
                for val in $expected {
673
                    match val {
674
                        Some(v) => {
675
                            expected_builder.values().append_value(v);
676
                            expected_builder.append(true);
677
                        }
678
                        None => expected_builder.append(false),
679
                    }
680
                }
681
682
                let expected = expected_builder.finish();
683
                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
684
                assert_eq!(&expected, result);
685
            }
686
        };
687
    }
688
689
    test_match_scalar_pattern!(
690
        match_scalar_pattern_string_with_flags,
691
        vec![
692
            Some("abc-005-def"),
693
            Some("x-7-5"),
694
            Some("X-0-Y"),
695
            Some("X545"),
696
            None
697
        ],
698
        r"x.*-(\d*)-.*",
699
        Some("i"),
700
        StringArray,
701
        GenericStringBuilder<i32>,
702
        [None, Some("7"), Some("0"), None, None]
703
    );
704
    test_match_scalar_pattern!(
705
        match_scalar_pattern_stringview_with_flags,
706
        vec![
707
            Some("abc-005-def"),
708
            Some("x-7-5"),
709
            Some("X-0-Y"),
710
            Some("X545"),
711
            None
712
        ],
713
        r"x.*-(\d*)-.*",
714
        Some("i"),
715
        StringViewArray,
716
        StringViewBuilder,
717
        [None, Some("7"), Some("0"), None, None]
718
    );
719
720
    test_match_scalar_pattern!(
721
        match_scalar_pattern_string_no_flags,
722
        vec![
723
            Some("abc-005-def"),
724
            Some("x-7-5"),
725
            Some("X-0-Y"),
726
            Some("X545"),
727
            None
728
        ],
729
        r"x.*-(\d*)-.*",
730
        None::<&str>,
731
        StringArray,
732
        GenericStringBuilder<i32>,
733
        [None, Some("7"), None, None, None]
734
    );
735
    test_match_scalar_pattern!(
736
        match_scalar_pattern_stringview_no_flags,
737
        vec![
738
            Some("abc-005-def"),
739
            Some("x-7-5"),
740
            Some("X-0-Y"),
741
            Some("X545"),
742
            None
743
        ],
744
        r"x.*-(\d*)-.*",
745
        None::<&str>,
746
        StringViewArray,
747
        StringViewBuilder,
748
        [None, Some("7"), None, None, None]
749
    );
750
751
    macro_rules! test_match_scalar_no_pattern {
752
        ($test_name:ident, $values:expr, $array_type:ty, $pattern_type:expr, $builder_type:ty, $expected:expr) => {
753
            #[test]
754
            fn $test_name() {
755
                let array: $array_type = <$array_type>::from($values);
756
                let pattern = Scalar::new(new_null_array(&$pattern_type, 1));
757
758
                let actual = regexp_match(&array, &pattern, None).unwrap();
759
760
                let elem_builder: $builder_type = <$builder_type>::new();
761
                let mut expected_builder = ListBuilder::new(elem_builder);
762
763
                for val in $expected {
764
                    match val {
765
                        Some(v) => {
766
                            expected_builder.values().append_value(v);
767
                            expected_builder.append(true);
768
                        }
769
                        None => expected_builder.append(false),
770
                    }
771
                }
772
773
                let expected = expected_builder.finish();
774
                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
775
                assert_eq!(&expected, result);
776
            }
777
        };
778
    }
779
780
    test_match_scalar_no_pattern!(
781
        match_scalar_no_pattern_string,
782
        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
783
        StringArray,
784
        DataType::Utf8,
785
        GenericStringBuilder<i32>,
786
        [None::<&str>, None, None, None]
787
    );
788
    test_match_scalar_no_pattern!(
789
        match_scalar_no_pattern_stringview,
790
        vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None],
791
        StringViewArray,
792
        DataType::Utf8View,
793
        StringViewBuilder,
794
        [None::<&str>, None, None, None]
795
    );
796
797
    macro_rules! test_match_single_group_not_skip {
798
        ($test_name:ident, $values:expr, $pattern:expr, $array_type:ty, $builder_type:ty, $expected:expr) => {
799
            #[test]
800
            fn $test_name() {
801
                let array: $array_type = <$array_type>::from($values);
802
                let pattern: $array_type = <$array_type>::from(vec![$pattern]);
803
804
                let actual = regexp_match(&array, &pattern, None).unwrap();
805
806
                let elem_builder: $builder_type = <$builder_type>::new();
807
                let mut expected_builder = ListBuilder::new(elem_builder);
808
809
                for val in $expected {
810
                    match val {
811
                        Some(v) => {
812
                            expected_builder.values().append_value(v);
813
                            expected_builder.append(true);
814
                        }
815
                        None => expected_builder.append(false),
816
                    }
817
                }
818
819
                let expected = expected_builder.finish();
820
                let result = actual.as_any().downcast_ref::<ListArray>().unwrap();
821
                assert_eq!(&expected, result);
822
            }
823
        };
824
    }
825
826
    test_match_single_group_not_skip!(
827
        match_single_group_not_skip_string,
828
        vec![Some("foo"), Some("bar")],
829
        r"foo",
830
        StringArray,
831
        GenericStringBuilder<i32>,
832
        [Some("foo")]
833
    );
834
    test_match_single_group_not_skip!(
835
        match_single_group_not_skip_stringview,
836
        vec![Some("foo"), Some("bar")],
837
        r"foo",
838
        StringViewArray,
839
        StringViewBuilder,
840
        [Some("foo")]
841
    );
842
843
    macro_rules! test_flag_utf8 {
844
        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
845
            #[test]
846
            fn $test_name() {
847
                let left = $left;
848
                let right = $right;
849
                let res = $op(&left, &right, None).unwrap();
850
                let expected = $expected;
851
                assert_eq!(expected.len(), res.len());
852
                for i in 0..res.len() {
853
                    let v = res.value(i);
854
                    assert_eq!(v, expected[i]);
855
                }
856
            }
857
        };
858
        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
859
            #[test]
860
            fn $test_name() {
861
                let left = $left;
862
                let right = $right;
863
                let flag = Some($flag);
864
                let res = $op(&left, &right, flag.as_ref()).unwrap();
865
                let expected = $expected;
866
                assert_eq!(expected.len(), res.len());
867
                for i in 0..res.len() {
868
                    let v = res.value(i);
869
                    assert_eq!(v, expected[i]);
870
                }
871
            }
872
        };
873
    }
874
875
    macro_rules! test_flag_utf8_scalar {
876
        ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => {
877
            #[test]
878
            fn $test_name() {
879
                let left = $left;
880
                let res = $op(&left, $right, None).unwrap();
881
                let expected = $expected;
882
                assert_eq!(expected.len(), res.len());
883
                for i in 0..res.len() {
884
                    let v = res.value(i);
885
                    assert_eq!(
886
                        v,
887
                        expected[i],
888
                        "unexpected result when comparing {} at position {} to {} ",
889
                        left.value(i),
890
                        i,
891
                        $right
892
                    );
893
                }
894
            }
895
        };
896
        ($test_name:ident, $left:expr, $right:expr, $flag:expr, $op:expr, $expected:expr) => {
897
            #[test]
898
            fn $test_name() {
899
                let left = $left;
900
                let flag = Some($flag);
901
                let res = $op(&left, $right, flag).unwrap();
902
                let expected = $expected;
903
                assert_eq!(expected.len(), res.len());
904
                for i in 0..res.len() {
905
                    let v = res.value(i);
906
                    assert_eq!(
907
                        v,
908
                        expected[i],
909
                        "unexpected result when comparing {} at position {} to {} ",
910
                        left.value(i),
911
                        i,
912
                        $right
913
                    );
914
                }
915
            }
916
        };
917
    }
918
919
    test_flag_utf8!(
920
        test_array_regexp_is_match_utf8,
921
        StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
922
        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
923
        regexp_is_match::<StringArray, StringArray, StringArray>,
924
        [true, false, true, false, false, true]
925
    );
926
    test_flag_utf8!(
927
        test_array_regexp_is_match_utf8_insensitive,
928
        StringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
929
        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
930
        StringArray::from(vec!["i"; 6]),
931
        regexp_is_match,
932
        [true, true, true, true, false, true]
933
    );
934
935
    test_flag_utf8_scalar!(
936
        test_array_regexp_is_match_utf8_scalar,
937
        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
938
        "^ar",
939
        regexp_is_match_scalar,
940
        [true, false, false, false]
941
    );
942
    test_flag_utf8_scalar!(
943
        test_array_regexp_is_match_utf8_scalar_empty,
944
        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
945
        "",
946
        regexp_is_match_scalar,
947
        [true, true, true, true]
948
    );
949
    test_flag_utf8_scalar!(
950
        test_array_regexp_is_match_utf8_scalar_insensitive,
951
        StringArray::from(vec!["arrow", "ARROW", "parquet", "PARQUET"]),
952
        "^ar",
953
        "i",
954
        regexp_is_match_scalar,
955
        [true, true, false, false]
956
    );
957
958
    test_flag_utf8!(
959
        tes_array_regexp_is_match,
960
        StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
961
        StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
962
        regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
963
        [true, false, true, false, false, true]
964
    );
965
    test_flag_utf8!(
966
        test_array_regexp_is_match_2,
967
        StringViewArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
968
        StringArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
969
        regexp_is_match::<StringViewArray, GenericStringArray<i32>, GenericStringArray<i32>>,
970
        [true, false, true, false, false, true]
971
    );
972
    test_flag_utf8!(
973
        test_array_regexp_is_match_insensitive,
974
        StringViewArray::from(vec![
975
            "Official Rust implementation of Apache Arrow",
976
            "apache/arrow-rs",
977
            "apache/arrow-rs",
978
            "parquet",
979
            "parquet",
980
            "row",
981
            "row",
982
        ]),
983
        StringViewArray::from(vec![
984
            ".*rust implement.*",
985
            "^ap",
986
            "^AP",
987
            "et$",
988
            "ET$",
989
            "foo",
990
            ""
991
        ]),
992
        StringViewArray::from(vec!["i"; 7]),
993
        regexp_is_match::<StringViewArray, StringViewArray, StringViewArray>,
994
        [true, true, true, true, true, false, true]
995
    );
996
    test_flag_utf8!(
997
        test_array_regexp_is_match_insensitive_2,
998
        LargeStringArray::from(vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrow"]),
999
        StringViewArray::from(vec!["^ar", "^AR", "ow$", "OW$", "foo", ""]),
1000
        StringArray::from(vec!["i"; 6]),
1001
        regexp_is_match::<GenericStringArray<i64>, StringViewArray, GenericStringArray<i32>>,
1002
        [true, true, true, true, false, true]
1003
    );
1004
1005
    test_flag_utf8_scalar!(
1006
        test_array_regexp_is_match_scalar,
1007
        StringViewArray::from(vec![
1008
            "apache/arrow-rs",
1009
            "APACHE/ARROW-RS",
1010
            "parquet",
1011
            "PARQUET",
1012
        ]),
1013
        "^ap",
1014
        regexp_is_match_scalar::<StringViewArray>,
1015
        [true, false, false, false]
1016
    );
1017
    test_flag_utf8_scalar!(
1018
        test_array_regexp_is_match_scalar_empty,
1019
        StringViewArray::from(vec![
1020
            "apache/arrow-rs",
1021
            "APACHE/ARROW-RS",
1022
            "parquet",
1023
            "PARQUET",
1024
        ]),
1025
        "",
1026
        regexp_is_match_scalar::<StringViewArray>,
1027
        [true, true, true, true]
1028
    );
1029
    test_flag_utf8_scalar!(
1030
        test_array_regexp_is_match_scalar_insensitive,
1031
        StringViewArray::from(vec![
1032
            "apache/arrow-rs",
1033
            "APACHE/ARROW-RS",
1034
            "parquet",
1035
            "PARQUET",
1036
        ]),
1037
        "^ap",
1038
        "i",
1039
        regexp_is_match_scalar::<StringViewArray>,
1040
        [true, true, false, false]
1041
    );
1042
}