/Users/andrewlamb/Software/arrow-rs/arrow-string/src/binary_like.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Provide SQL's CONTAINS, STARTS_WITH, ENDS_WITH operators for Arrow's binary arrays |
19 | | |
20 | | use crate::binary_predicate::BinaryPredicate; |
21 | | |
22 | | use arrow_array::cast::AsArray; |
23 | | use arrow_array::*; |
24 | | use arrow_schema::*; |
25 | | use arrow_select::take::take; |
26 | | |
27 | | #[derive(Debug)] |
28 | | pub(crate) enum Op { |
29 | | Contains, |
30 | | StartsWith, |
31 | | EndsWith, |
32 | | } |
33 | | |
34 | | impl TryFrom<crate::like::Op> for Op { |
35 | | type Error = ArrowError; |
36 | | |
37 | 0 | fn try_from(value: crate::like::Op) -> Result<Self, Self::Error> { |
38 | 0 | match value { |
39 | 0 | crate::like::Op::Contains => Ok(Op::Contains), |
40 | 0 | crate::like::Op::StartsWith => Ok(Op::StartsWith), |
41 | 0 | crate::like::Op::EndsWith => Ok(Op::EndsWith), |
42 | 0 | _ => Err(ArrowError::InvalidArgumentError(format!( |
43 | 0 | "Invalid binary operation: {value}" |
44 | 0 | ))), |
45 | | } |
46 | 0 | } |
47 | | } |
48 | | |
49 | | impl std::fmt::Display for Op { |
50 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
51 | 0 | match self { |
52 | 0 | Op::Contains => write!(f, "CONTAINS"), |
53 | 0 | Op::StartsWith => write!(f, "STARTS_WITH"), |
54 | 0 | Op::EndsWith => write!(f, "ENDS_WITH"), |
55 | | } |
56 | 0 | } |
57 | | } |
58 | | |
59 | 0 | pub(crate) fn binary_apply<'a, 'i, T: BinaryArrayType<'a> + 'a>( |
60 | 0 | op: Op, |
61 | 0 | l: T, |
62 | 0 | l_s: bool, |
63 | 0 | l_v: Option<&'a dyn AnyDictionaryArray>, |
64 | 0 | r: T, |
65 | 0 | r_s: bool, |
66 | 0 | r_v: Option<&'a dyn AnyDictionaryArray>, |
67 | 0 | ) -> Result<BooleanArray, ArrowError> { |
68 | 0 | let l_len = l_v.map(|l| l.len()).unwrap_or(l.len()); |
69 | 0 | if r_s { |
70 | 0 | let idx = match r_v { |
71 | 0 | Some(dict) if dict.null_count() != 0 => return Ok(BooleanArray::new_null(l_len)), |
72 | 0 | Some(dict) => dict.normalized_keys()[0], |
73 | 0 | None => 0, |
74 | | }; |
75 | 0 | if r.is_null(idx) { |
76 | 0 | return Ok(BooleanArray::new_null(l_len)); |
77 | 0 | } |
78 | 0 | op_scalar::<T>(op, l, l_v, r.value(idx)) |
79 | | } else { |
80 | 0 | match (l_s, l_v, r_v) { |
81 | | (true, None, None) => { |
82 | 0 | let v = l.is_valid(0).then(|| l.value(0)); |
83 | 0 | op_binary(op, std::iter::repeat(v), r.iter()) |
84 | | } |
85 | 0 | (true, Some(l_v), None) => { |
86 | 0 | let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]); |
87 | 0 | let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx))); |
88 | 0 | op_binary(op, std::iter::repeat(v), r.iter()) |
89 | | } |
90 | 0 | (true, None, Some(r_v)) => { |
91 | 0 | let v = l.is_valid(0).then(|| l.value(0)); |
92 | 0 | op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v)) |
93 | | } |
94 | 0 | (true, Some(l_v), Some(r_v)) => { |
95 | 0 | let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]); |
96 | 0 | let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx))); |
97 | 0 | op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v)) |
98 | | } |
99 | 0 | (false, None, None) => op_binary(op, l.iter(), r.iter()), |
100 | 0 | (false, Some(l_v), None) => op_binary(op, vectored_iter(l, l_v), r.iter()), |
101 | 0 | (false, None, Some(r_v)) => op_binary(op, l.iter(), vectored_iter(r, r_v)), |
102 | 0 | (false, Some(l_v), Some(r_v)) => { |
103 | 0 | op_binary(op, vectored_iter(l, l_v), vectored_iter(r, r_v)) |
104 | | } |
105 | | } |
106 | | } |
107 | 0 | } |
108 | | |
109 | | #[inline(never)] |
110 | 0 | fn op_scalar<'a, T: BinaryArrayType<'a>>( |
111 | 0 | op: Op, |
112 | 0 | l: T, |
113 | 0 | l_v: Option<&dyn AnyDictionaryArray>, |
114 | 0 | r: &[u8], |
115 | 0 | ) -> Result<BooleanArray, ArrowError> { |
116 | 0 | let r = match op { |
117 | 0 | Op::Contains => BinaryPredicate::contains(r).evaluate_array(l, false), |
118 | 0 | Op::StartsWith => BinaryPredicate::StartsWith(r).evaluate_array(l, false), |
119 | 0 | Op::EndsWith => BinaryPredicate::EndsWith(r).evaluate_array(l, false), |
120 | | }; |
121 | | |
122 | 0 | Ok(match l_v { |
123 | 0 | Some(v) => take(&r, v.keys(), None)?.as_boolean().clone(), |
124 | 0 | None => r, |
125 | | }) |
126 | 0 | } |
127 | | |
128 | 0 | fn vectored_iter<'a, T: BinaryArrayType<'a> + 'a>( |
129 | 0 | a: T, |
130 | 0 | a_v: &'a dyn AnyDictionaryArray, |
131 | 0 | ) -> impl Iterator<Item = Option<&'a [u8]>> + 'a { |
132 | 0 | let nulls = a_v.nulls(); |
133 | 0 | let keys = a_v.normalized_keys(); |
134 | 0 | keys.into_iter().enumerate().map(move |(idx, key)| { |
135 | 0 | if nulls.map(|n| n.is_null(idx)).unwrap_or_default() || a.is_null(key) { |
136 | 0 | return None; |
137 | 0 | } |
138 | 0 | Some(a.value(key)) |
139 | 0 | }) |
140 | 0 | } |
141 | | |
142 | | #[inline(never)] |
143 | 0 | fn op_binary<'a>( |
144 | 0 | op: Op, |
145 | 0 | l: impl Iterator<Item = Option<&'a [u8]>>, |
146 | 0 | r: impl Iterator<Item = Option<&'a [u8]>>, |
147 | 0 | ) -> Result<BooleanArray, ArrowError> { |
148 | 0 | match op { |
149 | 0 | Op::Contains => Ok(l |
150 | 0 | .zip(r) |
151 | 0 | .map(|(l, r)| Some(bytes_contains(l?, r?))) |
152 | 0 | .collect()), |
153 | 0 | Op::StartsWith => Ok(l |
154 | 0 | .zip(r) |
155 | 0 | .map(|(l, r)| Some(BinaryPredicate::StartsWith(r?).evaluate(l?))) |
156 | 0 | .collect()), |
157 | 0 | Op::EndsWith => Ok(l |
158 | 0 | .zip(r) |
159 | 0 | .map(|(l, r)| Some(BinaryPredicate::EndsWith(r?).evaluate(l?))) |
160 | 0 | .collect()), |
161 | | } |
162 | 0 | } |
163 | | |
164 | 0 | fn bytes_contains(haystack: &[u8], needle: &[u8]) -> bool { |
165 | 0 | memchr::memmem::find(haystack, needle).is_some() |
166 | 0 | } |