/Users/andrewlamb/Software/arrow-rs/arrow-string/src/predicate.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use arrow_array::{Array, ArrayAccessor, BooleanArray, StringViewArray}; |
19 | | use arrow_buffer::BooleanBuffer; |
20 | | use arrow_schema::ArrowError; |
21 | | use memchr::memchr3; |
22 | | use memchr::memmem::Finder; |
23 | | use regex::{Regex, RegexBuilder}; |
24 | | use std::iter::zip; |
25 | | |
26 | | /// A string based predicate |
27 | | #[allow(clippy::large_enum_variant)] |
28 | | pub(crate) enum Predicate<'a> { |
29 | | Eq(&'a str), |
30 | | Contains(Finder<'a>), |
31 | | StartsWith(&'a str), |
32 | | EndsWith(&'a str), |
33 | | |
34 | | /// Equality ignoring ASCII case |
35 | | IEqAscii(&'a str), |
36 | | /// Starts with ignoring ASCII case |
37 | | IStartsWithAscii(&'a str), |
38 | | /// Ends with ignoring ASCII case |
39 | | IEndsWithAscii(&'a str), |
40 | | |
41 | | Regex(Regex), |
42 | | } |
43 | | |
44 | | impl<'a> Predicate<'a> { |
45 | | /// Create a predicate for the given like pattern |
46 | 0 | pub(crate) fn like(pattern: &'a str) -> Result<Self, ArrowError> { |
47 | 0 | if !contains_like_pattern(pattern) { |
48 | 0 | Ok(Self::Eq(pattern)) |
49 | 0 | } else if pattern.ends_with('%') && !contains_like_pattern(&pattern[..pattern.len() - 1]) { |
50 | 0 | Ok(Self::StartsWith(&pattern[..pattern.len() - 1])) |
51 | 0 | } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) { |
52 | 0 | Ok(Self::EndsWith(&pattern[1..])) |
53 | 0 | } else if pattern.starts_with('%') |
54 | 0 | && pattern.ends_with('%') |
55 | 0 | && !contains_like_pattern(&pattern[1..pattern.len() - 1]) |
56 | | { |
57 | 0 | Ok(Self::contains(&pattern[1..pattern.len() - 1])) |
58 | | } else { |
59 | 0 | Ok(Self::Regex(regex_like(pattern, false)?)) |
60 | | } |
61 | 0 | } |
62 | | |
63 | 0 | pub(crate) fn contains(needle: &'a str) -> Self { |
64 | 0 | Self::Contains(Finder::new(needle.as_bytes())) |
65 | 0 | } |
66 | | |
67 | | /// Create a predicate for the given ilike pattern |
68 | 0 | pub(crate) fn ilike(pattern: &'a str, is_ascii: bool) -> Result<Self, ArrowError> { |
69 | 0 | if is_ascii && pattern.is_ascii() { |
70 | 0 | if !contains_like_pattern(pattern) { |
71 | 0 | return Ok(Self::IEqAscii(pattern)); |
72 | 0 | } else if pattern.ends_with('%') |
73 | 0 | && !pattern.ends_with("\\%") |
74 | 0 | && !contains_like_pattern(&pattern[..pattern.len() - 1]) |
75 | | { |
76 | 0 | return Ok(Self::IStartsWithAscii(&pattern[..pattern.len() - 1])); |
77 | 0 | } else if pattern.starts_with('%') && !contains_like_pattern(&pattern[1..]) { |
78 | 0 | return Ok(Self::IEndsWithAscii(&pattern[1..])); |
79 | 0 | } |
80 | 0 | } |
81 | 0 | Ok(Self::Regex(regex_like(pattern, true)?)) |
82 | 0 | } |
83 | | |
84 | | /// Evaluate this predicate against the given haystack |
85 | 0 | pub(crate) fn evaluate(&self, haystack: &str) -> bool { |
86 | 0 | match self { |
87 | 0 | Predicate::Eq(v) => *v == haystack, |
88 | 0 | Predicate::IEqAscii(v) => haystack.eq_ignore_ascii_case(v), |
89 | 0 | Predicate::Contains(finder) => finder.find(haystack.as_bytes()).is_some(), |
90 | 0 | Predicate::StartsWith(v) => starts_with(haystack, v, equals_kernel), |
91 | 0 | Predicate::IStartsWithAscii(v) => { |
92 | 0 | starts_with(haystack, v, equals_ignore_ascii_case_kernel) |
93 | | } |
94 | 0 | Predicate::EndsWith(v) => ends_with(haystack, v, equals_kernel), |
95 | 0 | Predicate::IEndsWithAscii(v) => ends_with(haystack, v, equals_ignore_ascii_case_kernel), |
96 | 0 | Predicate::Regex(v) => v.is_match(haystack), |
97 | | } |
98 | 0 | } |
99 | | |
100 | | /// Evaluate this predicate against the elements of `array` |
101 | | /// |
102 | | /// If `negate` is true the result of the predicate will be negated |
103 | | #[inline(never)] |
104 | 0 | pub(crate) fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray |
105 | 0 | where |
106 | 0 | T: ArrayAccessor<Item = &'i str>, |
107 | | { |
108 | 0 | match self { |
109 | 0 | Predicate::Eq(v) => BooleanArray::from_unary(array, |haystack| { |
110 | 0 | (haystack.len() == v.len() && haystack == *v) != negate |
111 | 0 | }), |
112 | 0 | Predicate::IEqAscii(v) => BooleanArray::from_unary(array, |haystack| { |
113 | 0 | haystack.eq_ignore_ascii_case(v) != negate |
114 | 0 | }), |
115 | 0 | Predicate::Contains(finder) => BooleanArray::from_unary(array, |haystack| { |
116 | 0 | finder.find(haystack.as_bytes()).is_some() != negate |
117 | 0 | }), |
118 | 0 | Predicate::StartsWith(v) => { |
119 | 0 | if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() { |
120 | 0 | let nulls = string_view_array.logical_nulls(); |
121 | 0 | let values = BooleanBuffer::from( |
122 | 0 | string_view_array |
123 | 0 | .prefix_bytes_iter(v.len()) |
124 | 0 | .map(|haystack| { |
125 | 0 | equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate |
126 | 0 | }) |
127 | 0 | .collect::<Vec<_>>(), |
128 | | ); |
129 | 0 | BooleanArray::new(values, nulls) |
130 | | } else { |
131 | 0 | BooleanArray::from_unary(array, |haystack| { |
132 | 0 | starts_with(haystack, v, equals_kernel) != negate |
133 | 0 | }) |
134 | | } |
135 | | } |
136 | 0 | Predicate::IStartsWithAscii(v) => { |
137 | 0 | if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() { |
138 | 0 | let nulls = string_view_array.logical_nulls(); |
139 | 0 | let values = BooleanBuffer::from( |
140 | 0 | string_view_array |
141 | 0 | .prefix_bytes_iter(v.len()) |
142 | 0 | .map(|haystack| { |
143 | 0 | equals_bytes( |
144 | 0 | haystack, |
145 | 0 | v.as_bytes(), |
146 | 0 | equals_ignore_ascii_case_kernel, |
147 | 0 | ) != negate |
148 | 0 | }) |
149 | 0 | .collect::<Vec<_>>(), |
150 | | ); |
151 | 0 | BooleanArray::new(values, nulls) |
152 | | } else { |
153 | 0 | BooleanArray::from_unary(array, |haystack| { |
154 | 0 | starts_with(haystack, v, equals_ignore_ascii_case_kernel) != negate |
155 | 0 | }) |
156 | | } |
157 | | } |
158 | 0 | Predicate::EndsWith(v) => { |
159 | 0 | if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() { |
160 | 0 | let nulls = string_view_array.logical_nulls(); |
161 | 0 | let values = BooleanBuffer::from( |
162 | 0 | string_view_array |
163 | 0 | .suffix_bytes_iter(v.len()) |
164 | 0 | .map(|haystack| { |
165 | 0 | equals_bytes(haystack, v.as_bytes(), equals_kernel) != negate |
166 | 0 | }) |
167 | 0 | .collect::<Vec<_>>(), |
168 | | ); |
169 | 0 | BooleanArray::new(values, nulls) |
170 | | } else { |
171 | 0 | BooleanArray::from_unary(array, |haystack| { |
172 | 0 | ends_with(haystack, v, equals_kernel) != negate |
173 | 0 | }) |
174 | | } |
175 | | } |
176 | 0 | Predicate::IEndsWithAscii(v) => { |
177 | 0 | if let Some(string_view_array) = array.as_any().downcast_ref::<StringViewArray>() { |
178 | 0 | let nulls = string_view_array.logical_nulls(); |
179 | 0 | let values = BooleanBuffer::from( |
180 | 0 | string_view_array |
181 | 0 | .suffix_bytes_iter(v.len()) |
182 | 0 | .map(|haystack| { |
183 | 0 | equals_bytes( |
184 | 0 | haystack, |
185 | 0 | v.as_bytes(), |
186 | 0 | equals_ignore_ascii_case_kernel, |
187 | 0 | ) != negate |
188 | 0 | }) |
189 | 0 | .collect::<Vec<_>>(), |
190 | | ); |
191 | 0 | BooleanArray::new(values, nulls) |
192 | | } else { |
193 | 0 | BooleanArray::from_unary(array, |haystack| { |
194 | 0 | ends_with(haystack, v, equals_ignore_ascii_case_kernel) != negate |
195 | 0 | }) |
196 | | } |
197 | | } |
198 | 0 | Predicate::Regex(v) => { |
199 | 0 | BooleanArray::from_unary(array, |haystack| v.is_match(haystack) != negate) |
200 | | } |
201 | | } |
202 | 0 | } |
203 | | } |
204 | | |
205 | 0 | fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool { |
206 | 0 | lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel) |
207 | 0 | } |
208 | | |
209 | | /// This is faster than `str::starts_with` for small strings. |
210 | | /// See <https://github.com/apache/arrow-rs/issues/6107> for more details. |
211 | 0 | fn starts_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool { |
212 | 0 | if needle.len() > haystack.len() { |
213 | 0 | false |
214 | | } else { |
215 | 0 | zip(haystack.as_bytes(), needle.as_bytes()).all(byte_eq_kernel) |
216 | | } |
217 | 0 | } |
218 | | /// This is faster than `str::ends_with` for small strings. |
219 | | /// See <https://github.com/apache/arrow-rs/issues/6107> for more details. |
220 | 0 | fn ends_with(haystack: &str, needle: &str, byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool { |
221 | 0 | if needle.len() > haystack.len() { |
222 | 0 | false |
223 | | } else { |
224 | 0 | zip( |
225 | 0 | haystack.as_bytes().iter().rev(), |
226 | 0 | needle.as_bytes().iter().rev(), |
227 | 0 | ) |
228 | 0 | .all(byte_eq_kernel) |
229 | | } |
230 | 0 | } |
231 | | |
232 | 0 | fn equals_kernel((n, h): (&u8, &u8)) -> bool { |
233 | 0 | n == h |
234 | 0 | } |
235 | | |
236 | 0 | fn equals_ignore_ascii_case_kernel((n, h): (&u8, &u8)) -> bool { |
237 | 0 | n.eq_ignore_ascii_case(h) |
238 | 0 | } |
239 | | |
240 | | /// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does: |
241 | | /// |
242 | | /// 1. Replace `LIKE` multi-character wildcards `%` => `.*` (unless they're at the start or end of the pattern, |
243 | | /// where the regex is just truncated - e.g. `%foo%` => `foo` rather than `^.*foo.*$`) |
244 | | /// 2. Replace `LIKE` single-character wildcards `_` => `.` |
245 | | /// 3. Escape regex meta characters to match them and not be evaluated as regex special chars. e.g. `.` => `\\.` |
246 | | /// 4. Replace escaped `LIKE` wildcards removing the escape characters to be able to match it as a regex. e.g. `\\%` => `%` |
247 | 0 | fn regex_like(pattern: &str, case_insensitive: bool) -> Result<Regex, ArrowError> { |
248 | 0 | let mut result = String::with_capacity(pattern.len() * 2); |
249 | 0 | let mut chars_iter = pattern.chars().peekable(); |
250 | 0 | match chars_iter.peek() { |
251 | | // if the pattern starts with `%`, we avoid starting the regex with a slow but meaningless `^.*` |
252 | 0 | Some('%') => { |
253 | 0 | chars_iter.next(); |
254 | 0 | } |
255 | 0 | _ => result.push('^'), |
256 | | }; |
257 | | |
258 | 0 | while let Some(c) = chars_iter.next() { |
259 | 0 | match c { |
260 | | '\\' => { |
261 | 0 | match chars_iter.peek() { |
262 | 0 | Some(&next) => { |
263 | 0 | if regex_syntax::is_meta_character(next) { |
264 | 0 | result.push('\\'); |
265 | 0 | } |
266 | 0 | result.push(next); |
267 | | // Skipping the next char as it is already appended |
268 | 0 | chars_iter.next(); |
269 | | } |
270 | 0 | None => { |
271 | 0 | // Trailing backslash in the pattern. E.g. PostgreSQL and Trino treat it as an error, but e.g. Snowflake treats it as a literal backslash |
272 | 0 | result.push('\\'); |
273 | 0 | result.push('\\'); |
274 | 0 | } |
275 | | } |
276 | | } |
277 | 0 | '%' => result.push_str(".*"), |
278 | 0 | '_' => result.push('.'), |
279 | 0 | c => { |
280 | 0 | if regex_syntax::is_meta_character(c) { |
281 | 0 | result.push('\\'); |
282 | 0 | } |
283 | 0 | result.push(c); |
284 | | } |
285 | | } |
286 | | } |
287 | | // instead of ending the regex with `.*$` and making it needlessly slow, we just end the regex |
288 | 0 | if result.ends_with(".*") { |
289 | 0 | result.pop(); |
290 | 0 | result.pop(); |
291 | 0 | } else { |
292 | 0 | result.push('$'); |
293 | 0 | } |
294 | 0 | RegexBuilder::new(&result) |
295 | 0 | .case_insensitive(case_insensitive) |
296 | 0 | .dot_matches_new_line(true) |
297 | 0 | .build() |
298 | 0 | .map_err(|e| { |
299 | 0 | ArrowError::InvalidArgumentError(format!( |
300 | 0 | "Unable to build regex from LIKE pattern: {e}" |
301 | 0 | )) |
302 | 0 | }) |
303 | 0 | } |
304 | | |
305 | 0 | fn contains_like_pattern(pattern: &str) -> bool { |
306 | 0 | memchr3(b'%', b'_', b'\\', pattern.as_bytes()).is_some() |
307 | 0 | } |
308 | | |
309 | | #[cfg(test)] |
310 | | mod tests { |
311 | | use super::*; |
312 | | |
313 | | #[test] |
314 | | fn test_regex_like() { |
315 | | let test_cases = [ |
316 | | // %..% |
317 | | (r"%foobar%", r"foobar"), |
318 | | // ..%.. |
319 | | (r"foo%bar", r"^foo.*bar$"), |
320 | | // .._.. |
321 | | (r"foo_bar", r"^foo.bar$"), |
322 | | // escaped wildcards |
323 | | (r"\%\_", r"^%_$"), |
324 | | // escaped non-wildcard |
325 | | (r"\a", r"^a$"), |
326 | | // escaped escape and wildcard |
327 | | (r"\\%", r"^\\"), |
328 | | // escaped escape and non-wildcard |
329 | | (r"\\a", r"^\\a$"), |
330 | | // regex meta character |
331 | | (r".", r"^\.$"), |
332 | | (r"$", r"^\$$"), |
333 | | (r"\\", r"^\\$"), |
334 | | ]; |
335 | | |
336 | | for (like_pattern, expected_regexp) in test_cases { |
337 | | let r = regex_like(like_pattern, false).unwrap(); |
338 | | assert_eq!(r.to_string(), expected_regexp); |
339 | | } |
340 | | } |
341 | | |
342 | | #[test] |
343 | | fn test_contains() { |
344 | | assert!(Predicate::contains("hay").evaluate("haystack")); |
345 | | assert!(Predicate::contains("haystack").evaluate("haystack")); |
346 | | assert!(Predicate::contains("h").evaluate("haystack")); |
347 | | assert!(Predicate::contains("k").evaluate("haystack")); |
348 | | assert!(Predicate::contains("stack").evaluate("haystack")); |
349 | | assert!(Predicate::contains("sta").evaluate("haystack")); |
350 | | assert!(Predicate::contains("stack").evaluate("hay£stack")); |
351 | | assert!(Predicate::contains("y£s").evaluate("hay£stack")); |
352 | | assert!(Predicate::contains("£").evaluate("hay£stack")); |
353 | | assert!(Predicate::contains("a").evaluate("a")); |
354 | | // not matching |
355 | | assert!(!Predicate::contains("hy").evaluate("haystack")); |
356 | | assert!(!Predicate::contains("stackx").evaluate("haystack")); |
357 | | assert!(!Predicate::contains("x").evaluate("haystack")); |
358 | | assert!(!Predicate::contains("haystack haystack").evaluate("haystack")); |
359 | | } |
360 | | |
361 | | #[test] |
362 | | fn test_starts_with() { |
363 | | assert!(Predicate::StartsWith("hay").evaluate("haystack")); |
364 | | assert!(Predicate::StartsWith("h£ay").evaluate("h£aystack")); |
365 | | assert!(Predicate::StartsWith("haystack").evaluate("haystack")); |
366 | | assert!(Predicate::StartsWith("ha").evaluate("haystack")); |
367 | | assert!(Predicate::StartsWith("h").evaluate("haystack")); |
368 | | assert!(Predicate::StartsWith("").evaluate("haystack")); |
369 | | |
370 | | assert!(!Predicate::StartsWith("stack").evaluate("haystack")); |
371 | | assert!(!Predicate::StartsWith("haystacks").evaluate("haystack")); |
372 | | assert!(!Predicate::StartsWith("HAY").evaluate("haystack")); |
373 | | assert!(!Predicate::StartsWith("h£ay").evaluate("haystack")); |
374 | | assert!(!Predicate::StartsWith("hay").evaluate("h£aystack")); |
375 | | } |
376 | | |
377 | | #[test] |
378 | | fn test_ends_with() { |
379 | | assert!(Predicate::EndsWith("stack").evaluate("haystack")); |
380 | | assert!(Predicate::EndsWith("st£ack").evaluate("hayst£ack")); |
381 | | assert!(Predicate::EndsWith("haystack").evaluate("haystack")); |
382 | | assert!(Predicate::EndsWith("ck").evaluate("haystack")); |
383 | | assert!(Predicate::EndsWith("k").evaluate("haystack")); |
384 | | assert!(Predicate::EndsWith("").evaluate("haystack")); |
385 | | |
386 | | assert!(!Predicate::EndsWith("hay").evaluate("haystack")); |
387 | | assert!(!Predicate::EndsWith("STACK").evaluate("haystack")); |
388 | | assert!(!Predicate::EndsWith("haystacks").evaluate("haystack")); |
389 | | assert!(!Predicate::EndsWith("xhaystack").evaluate("haystack")); |
390 | | assert!(!Predicate::EndsWith("st£ack").evaluate("haystack")); |
391 | | assert!(!Predicate::EndsWith("stack").evaluate("hayst£ack")); |
392 | | } |
393 | | |
394 | | #[test] |
395 | | fn test_istarts_with() { |
396 | | assert!(Predicate::IStartsWithAscii("hay").evaluate("haystack")); |
397 | | assert!(Predicate::IStartsWithAscii("hay").evaluate("HAYSTACK")); |
398 | | assert!(Predicate::IStartsWithAscii("HAY").evaluate("haystack")); |
399 | | assert!(Predicate::IStartsWithAscii("HaY").evaluate("haystack")); |
400 | | assert!(Predicate::IStartsWithAscii("hay").evaluate("HaYsTaCk")); |
401 | | assert!(Predicate::IStartsWithAscii("HAY").evaluate("HaYsTaCk")); |
402 | | assert!(Predicate::IStartsWithAscii("haystack").evaluate("HaYsTaCk")); |
403 | | assert!(Predicate::IStartsWithAscii("HaYsTaCk").evaluate("HaYsTaCk")); |
404 | | assert!(Predicate::IStartsWithAscii("").evaluate("HaYsTaCk")); |
405 | | |
406 | | assert!(!Predicate::IStartsWithAscii("stack").evaluate("haystack")); |
407 | | assert!(!Predicate::IStartsWithAscii("haystacks").evaluate("haystack")); |
408 | | assert!(!Predicate::IStartsWithAscii("h.ay").evaluate("haystack")); |
409 | | assert!(!Predicate::IStartsWithAscii("hay").evaluate("h£aystack")); |
410 | | } |
411 | | |
412 | | #[test] |
413 | | fn test_iends_with() { |
414 | | assert!(Predicate::IEndsWithAscii("stack").evaluate("haystack")); |
415 | | assert!(Predicate::IEndsWithAscii("STACK").evaluate("haystack")); |
416 | | assert!(Predicate::IEndsWithAscii("StAcK").evaluate("haystack")); |
417 | | assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYSTACK")); |
418 | | assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYSTACK")); |
419 | | assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYSTACK")); |
420 | | assert!(Predicate::IEndsWithAscii("stack").evaluate("HAYsTaCk")); |
421 | | assert!(Predicate::IEndsWithAscii("STACK").evaluate("HAYsTaCk")); |
422 | | assert!(Predicate::IEndsWithAscii("StAcK").evaluate("HAYsTaCk")); |
423 | | assert!(Predicate::IEndsWithAscii("haystack").evaluate("haystack")); |
424 | | assert!(Predicate::IEndsWithAscii("HAYSTACK").evaluate("haystack")); |
425 | | assert!(Predicate::IEndsWithAscii("haystack").evaluate("HAYSTACK")); |
426 | | assert!(Predicate::IEndsWithAscii("ck").evaluate("haystack")); |
427 | | assert!(Predicate::IEndsWithAscii("cK").evaluate("haystack")); |
428 | | assert!(Predicate::IEndsWithAscii("ck").evaluate("haystacK")); |
429 | | assert!(Predicate::IEndsWithAscii("").evaluate("haystack")); |
430 | | |
431 | | assert!(!Predicate::IEndsWithAscii("hay").evaluate("haystack")); |
432 | | assert!(!Predicate::IEndsWithAscii("stac").evaluate("HAYSTACK")); |
433 | | assert!(!Predicate::IEndsWithAscii("haystacks").evaluate("haystack")); |
434 | | assert!(!Predicate::IEndsWithAscii("stack").evaluate("haystac£k")); |
435 | | assert!(!Predicate::IEndsWithAscii("xhaystack").evaluate("haystack")); |
436 | | } |
437 | | } |