Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-string/src/predicate.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray};
19
use arrow_buffer::BooleanBuffer;
20
use arrow_schema::ArrowError;
21
use memchr::memchr3;
22
use memchr::memmem::Finder;
23
use regex::{Regex, RegexBuilder};
24
use std::iter::zip;
25
26
/// A string based predicate
27
#[allow(clippy::large_enum_variant)]
28
pub(crate) enum Predicate<'a> {
29
    Eq(&'a str),
30
    Contains(Finder<'a>),
31
    StartsWith(&'a str),
32
    EndsWith(&'a str),
33
34
    /// Equality ignoring ASCII case
35
    IEqAscii(&'a str),
36
    /// Starts with ignoring ASCII case
37
    IStartsWithAscii(&'a str),
38
    /// Ends with ignoring ASCII case
39
    IEndsWithAscii(&'a str),
40
41
    Regex(Regex),
42
}
43
44
impl<'a> Predicate<'a> {
45
    /// Create a predicate for the given like pattern
46
0
    pub(crate) fn like(pattern: &'a str) -> Result<Self, ArrowError> {
47
0
        if !contains_like_pattern(pattern) {
48
0
            Ok(Self::Eq(pattern))
49
0
        } else if pattern.ends_with('%') && !contains_like_pattern(&pattern[..pattern.len() - 1]) {
50
0
            Ok(Self::StartsWith(&pattern[..pattern.len() - 1]))
51
0
        } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
52
0
            Ok(Self::EndsWith(&pattern[1..]))
53
0
        } else if pattern.starts_with('%')
54
0
            && pattern.ends_with('%')
55
0
            && !contains_like_pattern(&pattern[1..pattern.len() - 1])
56
        {
57
0
            Ok(Self::contains(&pattern[1..pattern.len() - 1]))
58
        } else {
59
0
            Ok(Self::Regex(regex_like(pattern, false)?))
60
        }
61
0
    }
62
63
0
    pub(crate) fn contains(needle: &'a str) -> Self {
64
0
        Self::Contains(Finder::new(needle.as_bytes()))
65
0
    }
66
67
    /// Create a predicate for the given ilike pattern
68
0
    pub(crate) fn ilike(pattern: &'a str, is_ascii: bool) -> Result<Self, ArrowError> {
69
0
        if is_ascii && pattern.is_ascii() {
70
0
            if !contains_like_pattern(pattern) {
71
0
                return Ok(Self::IEqAscii(pattern));
72
0
            } else if pattern.ends_with('%')
73
0
                && !pattern.ends_with("\\%")
74
0
                && !contains_like_pattern(&pattern[..pattern.len() - 1])
75
            {
76
0
                return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1]));
77
0
            } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) {
78
0
                return Ok(Self::IEndsWithAscii(&pattern[1..]));
79
0
            }
80
0
        }
81
0
        Ok(Self::Regex(regex_like(pattern, true)?))
82
0
    }
83
84
    /// Evaluate this predicate against the given haystack
85
0
    pub(crate) fn evaluate(&self, haystack: &str) -> bool {
86
0
        match self {
87
0
            Predicate::Eq(v) => *v == haystack,
88
0
            Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v),
89
0
            Predicate::Contains(finder) => finder.find(haystack.as_bytes()).is_some(),
90
0
            Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel),
91
0
            Predicate::IStartsWithAscii(v) => {
92
0
                starts_with(haystack, v, equals_ignore_ascii_case_kernel)
93
            }
94
0
            Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel),
95
0
            Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel),
96
0
            Predicate::Regex(v) => v.is_match(haystack),
97
        }
98
0
    }
99
100
    /// Evaluate this predicate against the elements of `array`
101
    ///
102
    /// If `negate` is true the result of the predicate will be negated
103
    #[inline(never)]
104
0
    pub(crate) fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
105
0
    where
106
0
        T: ArrayAccessor<Item = &'i str>,
107
    {
108
0
        match self {
109
0
            Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| {
110
0
                (haystack.len() == v.len() && haystack == *v) != negate
111
0
            }),
112
0
            Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| {
113
0
                haystack.eq_ignore_ascii_case(v) != negate
114
0
            }),
115
0
            Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
116
0
                finder.find(haystack.as_bytes()).is_some() != negate
117
0
            }),
118
0
            Predicate::StartsWith(v) => {
119
0
                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
120
0
                    let nulls = string_view_array.logical_nulls();
121
0
                    let values = BooleanBuffer::from(
122
0
                        string_view_array
123
0
                            .prefix_bytes_iter(v.len())
124
0
                            .map(|haystack| {
125
0
                                equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
126
0
                            })
127
0
                            .collect::<Vec<_>>(),
128
                    );
129
0
                    BooleanArray::new(values, nulls)
130
                } else {
131
0
                    BooleanArray::from_unary(array, |haystack| {
132
0
                        starts_with(haystack, v, equals_kernel) != negate
133
0
                    })
134
                }
135
            }
136
0
            Predicate::IStartsWithAscii(v) => {
137
0
                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
138
0
                    let nulls = string_view_array.logical_nulls();
139
0
                    let values = BooleanBuffer::from(
140
0
                        string_view_array
141
0
                            .prefix_bytes_iter(v.len())
142
0
                            .map(|haystack| {
143
0
                                equals_bytes(
144
0
                                    haystack,
145
0
                                    v.as_bytes(),
146
0
                                    equals_ignore_ascii_case_kernel,
147
0
                                ) != negate
148
0
                            })
149
0
                            .collect::<Vec<_>>(),
150
                    );
151
0
                    BooleanArray::new(values, nulls)
152
                } else {
153
0
                    BooleanArray::from_unary(array, |haystack| {
154
0
                        starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
155
0
                    })
156
                }
157
            }
158
0
            Predicate::EndsWith(v) => {
159
0
                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
160
0
                    let nulls = string_view_array.logical_nulls();
161
0
                    let values = BooleanBuffer::from(
162
0
                        string_view_array
163
0
                            .suffix_bytes_iter(v.len())
164
0
                            .map(|haystack| {
165
0
                                equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate
166
0
                            })
167
0
                            .collect::<Vec<_>>(),
168
                    );
169
0
                    BooleanArray::new(values, nulls)
170
                } else {
171
0
                    BooleanArray::from_unary(array, |haystack| {
172
0
                        ends_with(haystack, v, equals_kernel) != negate
173
0
                    })
174
                }
175
            }
176
0
            Predicate::IEndsWithAscii(v) => {
177
0
                if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() {
178
0
                    let nulls = string_view_array.logical_nulls();
179
0
                    let values = BooleanBuffer::from(
180
0
                        string_view_array
181
0
                            .suffix_bytes_iter(v.len())
182
0
                            .map(|haystack| {
183
0
                                equals_bytes(
184
0
                                    haystack,
185
0
                                    v.as_bytes(),
186
0
                                    equals_ignore_ascii_case_kernel,
187
0
                                ) != negate
188
0
                            })
189
0
                            .collect::<Vec<_>>(),
190
                    );
191
0
                    BooleanArray::new(values, nulls)
192
                } else {
193
0
                    BooleanArray::from_unary(array, |haystack| {
194
0
                        ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate
195
0
                    })
196
                }
197
            }
198
0
            Predicate::Regex(v) => {
199
0
                BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate)
200
            }
201
        }
202
0
    }
203
}
204
205
0
fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
206
0
    lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
207
0
}
208
209
/// This is faster than `str::starts_with` for small strings.
210
/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
211
0
fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
212
0
    if needle.len() > haystack.len() {
213
0
        false
214
    } else {
215
0
        zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel)
216
    }
217
0
}
218
/// This is faster than `str::ends_with` for small strings.
219
/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
220
0
fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
221
0
    if needle.len() > haystack.len() {
222
0
        false
223
    } else {
224
0
        zip(
225
0
            haystack.as_bytes().iter().rev(),
226
0
            needle.as_bytes().iter().rev(),
227
0
        )
228
0
        .all(byte_eq_kernel)
229
    }
230
0
}
231
232
0
fn equals_kernel((n, h): (&u8, &u8)) -> bool {
233
0
    n == h
234
0
}
235
236
0
fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool {
237
0
    n.eq_ignore_ascii_case(h)
238
0
}
239
240
/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does:
241
///
242
/// 1. Replace `LIKE` multi-character wildcards `%` => `.*` (unless they're at the start or end of the pattern,
243
///    where the regex is just truncated - e.g. `%foo%` => `foo` rather than `^.*foo.*$`)
244
/// 2. Replace `LIKE` single-character wildcards `_` => `.`
245
/// 3. Escape regex meta characters to match them and not be evaluated as regex special chars. e.g. `.` => `\\.`
246
/// 4. Replace escaped `LIKE` wildcards removing the escape characters to be able to match it as a regex. e.g. `\\%` => `%`
247
0
fn regex_like(pattern: &str, case_insensitive: bool) -> Result<Regex, ArrowError> {
248
0
    let mut result = String::with_capacity(pattern.len() * 2);
249
0
    let mut chars_iter = pattern.chars().peekable();
250
0
    match chars_iter.peek() {
251
        // if the pattern starts with `%`, we avoid starting the regex with a slow but meaningless `^.*`
252
0
        Some('%') => {
253
0
            chars_iter.next();
254
0
        }
255
0
        _ => result.push('^'),
256
    };
257
258
0
    while let Some(c) = chars_iter.next() {
259
0
        match c {
260
            '\\' => {
261
0
                match chars_iter.peek() {
262
0
                    Some(&next) => {
263
0
                        if regex_syntax::is_meta_character(next) {
264
0
                            result.push('\\');
265
0
                        }
266
0
                        result.push(next);
267
                        // Skipping the next char as it is already appended
268
0
                        chars_iter.next();
269
                    }
270
0
                    None => {
271
0
                        // Trailing backslash in the pattern. E.g. PostgreSQL and Trino treat it as an error, but e.g. Snowflake treats it as a literal backslash
272
0
                        result.push('\\');
273
0
                        result.push('\\');
274
0
                    }
275
                }
276
            }
277
0
            '%' => result.push_str(".*"),
278
0
            '_' => result.push('.'),
279
0
            c => {
280
0
                if regex_syntax::is_meta_character(c) {
281
0
                    result.push('\\');
282
0
                }
283
0
                result.push(c);
284
            }
285
        }
286
    }
287
    // instead of ending the regex with `.*$` and making it needlessly slow, we just end the regex
288
0
    if result.ends_with(".*") {
289
0
        result.pop();
290
0
        result.pop();
291
0
    } else {
292
0
        result.push('$');
293
0
    }
294
0
    RegexBuilder::new(&result)
295
0
        .case_insensitive(case_insensitive)
296
0
        .dot_matches_new_line(true)
297
0
        .build()
298
0
        .map_err(|e| {
299
0
            ArrowError::InvalidArgumentError(format!(
300
0
                "Unable to build regex from LIKE pattern: {e}"
301
0
            ))
302
0
        })
303
0
}
304
305
0
fn contains_like_pattern(pattern: &str) -> bool {
306
0
    memchr3(b'%', b'_', b'\\', pattern.as_bytes()).is_some()
307
0
}
308
309
#[cfg(test)]
310
mod tests {
311
    use super::*;
312
313
    #[test]
314
    fn test_regex_like() {
315
        let test_cases = [
316
            // %..%
317
            (r"%foobar%", r"foobar"),
318
            // ..%..
319
            (r"foo%bar", r"^foo.*bar$"),
320
            // .._..
321
            (r"foo_bar", r"^foo.bar$"),
322
            // escaped wildcards
323
            (r"\%\_", r"^%_$"),
324
            // escaped non-wildcard
325
            (r"\a", r"^a$"),
326
            // escaped escape and wildcard
327
            (r"\\%", r"^\\"),
328
            // escaped escape and non-wildcard
329
            (r"\\a", r"^\\a$"),
330
            // regex meta character
331
            (r".", r"^\.$"),
332
            (r"$", r"^\$$"),
333
            (r"\\", r"^\\$"),
334
        ];
335
336
        for (like_pattern, expected_regexp) in test_cases {
337
            let r = regex_like(like_pattern, false).unwrap();
338
            assert_eq!(r.to_string(), expected_regexp);
339
        }
340
    }
341
342
    #[test]
343
    fn test_contains() {
344
        assert!(Predicate::contains("hay").evaluate("haystack"));
345
        assert!(Predicate::contains("haystack").evaluate("haystack"));
346
        assert!(Predicate::contains("h").evaluate("haystack"));
347
        assert!(Predicate::contains("k").evaluate("haystack"));
348
        assert!(Predicate::contains("stack").evaluate("haystack"));
349
        assert!(Predicate::contains("sta").evaluate("haystack"));
350
        assert!(Predicate::contains("stack").evaluate("hay£stack"));
351
        assert!(Predicate::contains("y£s").evaluate("hay£stack"));
352
        assert!(Predicate::contains("£").evaluate("hay£stack"));
353
        assert!(Predicate::contains("a").evaluate("a"));
354
        // not matching
355
        assert!(!Predicate::contains("hy").evaluate("haystack"));
356
        assert!(!Predicate::contains("stackx").evaluate("haystack"));
357
        assert!(!Predicate::contains("x").evaluate("haystack"));
358
        assert!(!Predicate::contains("haystack haystack").evaluate("haystack"));
359
    }
360
361
    #[test]
362
    fn test_starts_with() {
363
        assert!(Predicate::StartsWith("hay").evaluate("haystack"));
364
        assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack"));
365
        assert!(Predicate::StartsWith("haystack").evaluate("haystack"));
366
        assert!(Predicate::StartsWith("ha").evaluate("haystack"));
367
        assert!(Predicate::StartsWith("h").evaluate("haystack"));
368
        assert!(Predicate::StartsWith("").evaluate("haystack"));
369
370
        assert!(!Predicate::StartsWith("stack").evaluate("haystack"));
371
        assert!(!Predicate::StartsWith("haystacks").evaluate("haystack"));
372
        assert!(!Predicate::StartsWith("HAY").evaluate("haystack"));
373
        assert!(!Predicate::StartsWith("h£ay").evaluate("haystack"));
374
        assert!(!Predicate::StartsWith("hay").evaluate("h£aystack"));
375
    }
376
377
    #[test]
378
    fn test_ends_with() {
379
        assert!(Predicate::EndsWith("stack").evaluate("haystack"));
380
        assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack"));
381
        assert!(Predicate::EndsWith("haystack").evaluate("haystack"));
382
        assert!(Predicate::EndsWith("ck").evaluate("haystack"));
383
        assert!(Predicate::EndsWith("k").evaluate("haystack"));
384
        assert!(Predicate::EndsWith("").evaluate("haystack"));
385
386
        assert!(!Predicate::EndsWith("hay").evaluate("haystack"));
387
        assert!(!Predicate::EndsWith("STACK").evaluate("haystack"));
388
        assert!(!Predicate::EndsWith("haystacks").evaluate("haystack"));
389
        assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack"));
390
        assert!(!Predicate::EndsWith("st£ack").evaluate("haystack"));
391
        assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack"));
392
    }
393
394
    #[test]
395
    fn test_istarts_with() {
396
        assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack"));
397
        assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK"));
398
        assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack"));
399
        assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack"));
400
        assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk"));
401
        assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk"));
402
        assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk"));
403
        assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk"));
404
        assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk"));
405
406
        assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack"));
407
        assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack"));
408
        assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack"));
409
        assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack"));
410
    }
411
412
    #[test]
413
    fn test_iends_with() {
414
        assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack"));
415
        assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack"));
416
        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack"));
417
        assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK"));
418
        assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK"));
419
        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK"));
420
        assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk"));
421
        assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk"));
422
        assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk"));
423
        assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack"));
424
        assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack"));
425
        assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK"));
426
        assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack"));
427
        assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack"));
428
        assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK"));
429
        assert!(Predicate::IEndsWithAscii("").evaluate("haystack"));
430
431
        assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack"));
432
        assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK"));
433
        assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack"));
434
        assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k"));
435
        assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack"));
436
    }
437
}