/Users/andrewlamb/Software/arrow-rs/arrow-ord/src/cmp.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Comparison kernels for `Array`s. |
19 | | //! |
20 | | //! These kernels can leverage SIMD if available on your system. Currently no runtime |
21 | | //! detection is provided, you should enable the specific SIMD intrinsics using |
22 | | //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation |
23 | | //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. |
24 | | //! |
25 | | |
26 | | use arrow_array::cast::AsArray; |
27 | | use arrow_array::types::{ByteArrayType, ByteViewType}; |
28 | | use arrow_array::{ |
29 | | downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, Datum, |
30 | | FixedSizeBinaryArray, GenericByteArray, GenericByteViewArray, |
31 | | }; |
32 | | use arrow_buffer::bit_util::ceil; |
33 | | use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; |
34 | | use arrow_schema::ArrowError; |
35 | | use arrow_select::take::take; |
36 | | use std::cmp::Ordering; |
37 | | use std::ops::Not; |
38 | | |
39 | | #[derive(Debug, Copy, Clone)] |
40 | | enum Op { |
41 | | Equal, |
42 | | NotEqual, |
43 | | Less, |
44 | | LessEqual, |
45 | | Greater, |
46 | | GreaterEqual, |
47 | | Distinct, |
48 | | NotDistinct, |
49 | | } |
50 | | |
51 | | impl std::fmt::Display for Op { |
52 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
53 | 0 | match self { |
54 | 0 | Op::Equal => write!(f, "=="), |
55 | 0 | Op::NotEqual => write!(f, "!="), |
56 | 0 | Op::Less => write!(f, "<"), |
57 | 0 | Op::LessEqual => write!(f, "<="), |
58 | 0 | Op::Greater => write!(f, ">"), |
59 | 0 | Op::GreaterEqual => write!(f, ">="), |
60 | 0 | Op::Distinct => write!(f, "IS DISTINCT FROM"), |
61 | 0 | Op::NotDistinct => write!(f, "IS NOT DISTINCT FROM"), |
62 | | } |
63 | 0 | } |
64 | | } |
65 | | |
66 | | /// Perform `left == right` operation on two [`Datum`]. |
67 | | /// |
68 | | /// Comparing null values on either side will yield a null in the corresponding |
69 | | /// slot of the resulting [`BooleanArray`]. |
70 | | /// |
71 | | /// For floating values like f32 and f64, this comparison produces an ordering in accordance to |
72 | | /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. |
73 | | /// Note that totalOrder treats positive and negative zeros as different. If it is necessary |
74 | | /// to treat them as equal, please normalize zeros before calling this kernel. See |
75 | | /// [`f32::total_cmp`] and [`f64::total_cmp`]. |
76 | | /// |
77 | | /// Nested types, such as lists, are not supported as the null semantics are not well-defined. |
78 | | /// For comparisons involving nested types see [`crate::ord::make_comparator`] |
79 | 0 | pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
80 | 0 | compare_op(Op::Equal, lhs, rhs) |
81 | 0 | } |
82 | | |
83 | | /// Perform `left != right` operation on two [`Datum`]. |
84 | | /// |
85 | | /// Comparing null values on either side will yield a null in the corresponding |
86 | | /// slot of the resulting [`BooleanArray`]. |
87 | | /// |
88 | | /// For floating values like f32 and f64, this comparison produces an ordering in accordance to |
89 | | /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. |
90 | | /// Note that totalOrder treats positive and negative zeros as different. If it is necessary |
91 | | /// to treat them as equal, please normalize zeros before calling this kernel. See |
92 | | /// [`f32::total_cmp`] and [`f64::total_cmp`]. |
93 | | /// |
94 | | /// Nested types, such as lists, are not supported as the null semantics are not well-defined. |
95 | | /// For comparisons involving nested types see [`crate::ord::make_comparator`] |
96 | 0 | pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
97 | 0 | compare_op(Op::NotEqual, lhs, rhs) |
98 | 0 | } |
99 | | |
100 | | /// Perform `left < right` operation on two [`Datum`]. |
101 | | /// |
102 | | /// Comparing null values on either side will yield a null in the corresponding |
103 | | /// slot of the resulting [`BooleanArray`]. |
104 | | /// |
105 | | /// For floating values like f32 and f64, this comparison produces an ordering in accordance to |
106 | | /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. |
107 | | /// Note that totalOrder treats positive and negative zeros as different. If it is necessary |
108 | | /// to treat them as equal, please normalize zeros before calling this kernel. See |
109 | | /// [`f32::total_cmp`] and [`f64::total_cmp`]. |
110 | | /// |
111 | | /// Nested types, such as lists, are not supported as the null semantics are not well-defined. |
112 | | /// For comparisons involving nested types see [`crate::ord::make_comparator`] |
113 | 0 | pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
114 | 0 | compare_op(Op::Less, lhs, rhs) |
115 | 0 | } |
116 | | |
117 | | /// Perform `left <= right` operation on two [`Datum`]. |
118 | | /// |
119 | | /// Comparing null values on either side will yield a null in the corresponding |
120 | | /// slot of the resulting [`BooleanArray`]. |
121 | | /// |
122 | | /// For floating values like f32 and f64, this comparison produces an ordering in accordance to |
123 | | /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. |
124 | | /// Note that totalOrder treats positive and negative zeros as different. If it is necessary |
125 | | /// to treat them as equal, please normalize zeros before calling this kernel. See |
126 | | /// [`f32::total_cmp`] and [`f64::total_cmp`]. |
127 | | /// |
128 | | /// Nested types, such as lists, are not supported as the null semantics are not well-defined. |
129 | | /// For comparisons involving nested types see [`crate::ord::make_comparator`] |
130 | 0 | pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
131 | 0 | compare_op(Op::LessEqual, lhs, rhs) |
132 | 0 | } |
133 | | |
134 | | /// Perform `left > right` operation on two [`Datum`]. |
135 | | /// |
136 | | /// Comparing null values on either side will yield a null in the corresponding |
137 | | /// slot of the resulting [`BooleanArray`]. |
138 | | /// |
139 | | /// For floating values like f32 and f64, this comparison produces an ordering in accordance to |
140 | | /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. |
141 | | /// Note that totalOrder treats positive and negative zeros as different. If it is necessary |
142 | | /// to treat them as equal, please normalize zeros before calling this kernel. See |
143 | | /// [`f32::total_cmp`] and [`f64::total_cmp`]. |
144 | | /// |
145 | | /// Nested types, such as lists, are not supported as the null semantics are not well-defined. |
146 | | /// For comparisons involving nested types see [`crate::ord::make_comparator`] |
147 | 0 | pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
148 | 0 | compare_op(Op::Greater, lhs, rhs) |
149 | 0 | } |
150 | | |
151 | | /// Perform `left >= right` operation on two [`Datum`]. |
152 | | /// |
153 | | /// Comparing null values on either side will yield a null in the corresponding |
154 | | /// slot of the resulting [`BooleanArray`]. |
155 | | /// |
156 | | /// For floating values like f32 and f64, this comparison produces an ordering in accordance to |
157 | | /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. |
158 | | /// Note that totalOrder treats positive and negative zeros as different. If it is necessary |
159 | | /// to treat them as equal, please normalize zeros before calling this kernel. See |
160 | | /// [`f32::total_cmp`] and [`f64::total_cmp`]. |
161 | | /// |
162 | | /// Nested types, such as lists, are not supported as the null semantics are not well-defined. |
163 | | /// For comparisons involving nested types see [`crate::ord::make_comparator`] |
164 | 0 | pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
165 | 0 | compare_op(Op::GreaterEqual, lhs, rhs) |
166 | 0 | } |
167 | | |
168 | | /// Perform `left IS DISTINCT FROM right` operation on two [`Datum`] |
169 | | /// |
170 | | /// [`distinct`] is similar to [`neq`], only differing in null handling. In particular, two |
171 | | /// operands are considered DISTINCT if they have a different value or if one of them is NULL |
172 | | /// and the other isn't. The result of [`distinct`] is never NULL. |
173 | | /// |
174 | | /// For floating values like f32 and f64, this comparison produces an ordering in accordance to |
175 | | /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. |
176 | | /// Note that totalOrder treats positive and negative zeros as different. If it is necessary |
177 | | /// to treat them as equal, please normalize zeros before calling this kernel. See |
178 | | /// [`f32::total_cmp`] and [`f64::total_cmp`]. |
179 | | /// |
180 | | /// Nested types, such as lists, are not supported as the null semantics are not well-defined. |
181 | | /// For comparisons involving nested types see [`crate::ord::make_comparator`] |
182 | 0 | pub fn distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
183 | 0 | compare_op(Op::Distinct, lhs, rhs) |
184 | 0 | } |
185 | | |
186 | | /// Perform `left IS NOT DISTINCT FROM right` operation on two [`Datum`] |
187 | | /// |
188 | | /// [`not_distinct`] is similar to [`eq`], only differing in null handling. In particular, two |
189 | | /// operands are considered `NOT DISTINCT` if they have the same value or if both of them |
190 | | /// is NULL. The result of [`not_distinct`] is never NULL. |
191 | | /// |
192 | | /// For floating values like f32 and f64, this comparison produces an ordering in accordance to |
193 | | /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. |
194 | | /// Note that totalOrder treats positive and negative zeros as different. If it is necessary |
195 | | /// to treat them as equal, please normalize zeros before calling this kernel. See |
196 | | /// [`f32::total_cmp`] and [`f64::total_cmp`]. |
197 | | /// |
198 | | /// Nested types, such as lists, are not supported as the null semantics are not well-defined. |
199 | | /// For comparisons involving nested types see [`crate::ord::make_comparator`] |
200 | 0 | pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
201 | 0 | compare_op(Op::NotDistinct, lhs, rhs) |
202 | 0 | } |
203 | | |
204 | | /// Perform `op` on the provided `Datum` |
205 | | #[inline(never)] |
206 | 0 | fn compare_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, ArrowError> { |
207 | | use arrow_schema::DataType::*; |
208 | 0 | let (l, l_s) = lhs.get(); |
209 | 0 | let (r, r_s) = rhs.get(); |
210 | | |
211 | 0 | let l_len = l.len(); |
212 | 0 | let r_len = r.len(); |
213 | | |
214 | 0 | if l_len != r_len && !l_s && !r_s { |
215 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
216 | 0 | "Cannot compare arrays of different lengths, got {l_len} vs {r_len}" |
217 | 0 | ))); |
218 | 0 | } |
219 | | |
220 | 0 | let len = match l_s { |
221 | 0 | true => r_len, |
222 | 0 | false => l_len, |
223 | | }; |
224 | | |
225 | 0 | let l_nulls = l.logical_nulls(); |
226 | 0 | let r_nulls = r.logical_nulls(); |
227 | | |
228 | 0 | let l_v = l.as_any_dictionary_opt(); |
229 | 0 | let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); |
230 | 0 | let l_t = l.data_type(); |
231 | | |
232 | 0 | let r_v = r.as_any_dictionary_opt(); |
233 | 0 | let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r); |
234 | 0 | let r_t = r.data_type(); |
235 | | |
236 | 0 | if r_t.is_nested() || l_t.is_nested() { |
237 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
238 | 0 | "Nested comparison: {l_t} {op} {r_t} (hint: use make_comparator instead)" |
239 | 0 | ))); |
240 | 0 | } else if l_t != r_t { |
241 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
242 | 0 | "Invalid comparison operation: {l_t} {op} {r_t}" |
243 | 0 | ))); |
244 | 0 | } |
245 | | |
246 | | // Defer computation as may not be necessary |
247 | 0 | let values = || -> BooleanBuffer { |
248 | 0 | let d = downcast_primitive_array! { |
249 | 0 | (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), |
250 | 0 | (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), |
251 | 0 | (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v, r.as_string::<i32>(), r_s, r_v), |
252 | 0 | (Utf8View, Utf8View) => apply(op, l.as_string_view(), l_s, l_v, r.as_string_view(), r_s, r_v), |
253 | 0 | (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v, r.as_string::<i64>(), r_s, r_v), |
254 | 0 | (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v, r.as_binary::<i32>(), r_s, r_v), |
255 | 0 | (BinaryView, BinaryView) => apply(op, l.as_binary_view(), l_s, l_v, r.as_binary_view(), r_s, r_v), |
256 | 0 | (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s, l_v, r.as_binary::<i64>(), r_s, r_v), |
257 | 0 | (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), |
258 | 0 | (Null, Null) => None, |
259 | 0 | _ => unreachable!(), |
260 | | }; |
261 | 0 | d.unwrap_or_else(|| BooleanBuffer::new_unset(len)) |
262 | 0 | }; |
263 | | |
264 | 0 | let l_nulls = l_nulls.filter(|n| n.null_count() > 0); |
265 | 0 | let r_nulls = r_nulls.filter(|n| n.null_count() > 0); |
266 | 0 | Ok(match (l_nulls, l_s, r_nulls, r_s) { |
267 | 0 | (Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => { |
268 | | // Either both sides are scalar or neither side is scalar |
269 | 0 | match op { |
270 | | Op::Distinct => { |
271 | 0 | let values = values(); |
272 | 0 | let l = l.inner().bit_chunks().iter_padded(); |
273 | 0 | let r = r.inner().bit_chunks().iter_padded(); |
274 | 0 | let ne = values.bit_chunks().iter_padded(); |
275 | | |
276 | 0 | let c = |((l, r), n)| (l ^ r) | (l & r & n); |
277 | 0 | let buffer = l.zip(r).zip(ne).map(c).collect(); |
278 | 0 | BooleanBuffer::new(buffer, 0, len).into() |
279 | | } |
280 | | Op::NotDistinct => { |
281 | 0 | let values = values(); |
282 | 0 | let l = l.inner().bit_chunks().iter_padded(); |
283 | 0 | let r = r.inner().bit_chunks().iter_padded(); |
284 | 0 | let e = values.bit_chunks().iter_padded(); |
285 | | |
286 | 0 | let c = |((l, r), e)| u64::not(l | r) | (l & r & e); |
287 | 0 | let buffer = l.zip(r).zip(e).map(c).collect(); |
288 | 0 | BooleanBuffer::new(buffer, 0, len).into() |
289 | | } |
290 | 0 | _ => BooleanArray::new(values(), NullBuffer::union(Some(&l), Some(&r))), |
291 | | } |
292 | | } |
293 | 0 | (Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => { |
294 | | // Scalar is null, other side is non-scalar and nullable |
295 | 0 | match op { |
296 | 0 | Op::Distinct => a.into_inner().into(), |
297 | 0 | Op::NotDistinct => a.into_inner().not().into(), |
298 | 0 | _ => BooleanArray::new_null(len), |
299 | | } |
300 | | } |
301 | 0 | (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar) => { |
302 | | // Only one side is nullable |
303 | 0 | match is_scalar { |
304 | 0 | true => match op { |
305 | | // Scalar is null, other side is not nullable |
306 | 0 | Op::Distinct => BooleanBuffer::new_set(len).into(), |
307 | 0 | Op::NotDistinct => BooleanBuffer::new_unset(len).into(), |
308 | 0 | _ => BooleanArray::new_null(len), |
309 | | }, |
310 | 0 | false => match op { |
311 | | Op::Distinct => { |
312 | 0 | let values = values(); |
313 | 0 | let l = nulls.inner().bit_chunks().iter_padded(); |
314 | 0 | let ne = values.bit_chunks().iter_padded(); |
315 | 0 | let c = |(l, n)| u64::not(l) | n; |
316 | 0 | let buffer = l.zip(ne).map(c).collect(); |
317 | 0 | BooleanBuffer::new(buffer, 0, len).into() |
318 | | } |
319 | 0 | Op::NotDistinct => (nulls.inner() & &values()).into(), |
320 | 0 | _ => BooleanArray::new(values(), Some(nulls)), |
321 | | }, |
322 | | } |
323 | | } |
324 | | // Neither side is nullable |
325 | 0 | (None, _, None, _) => BooleanArray::new(values(), None), |
326 | | }) |
327 | 0 | } |
328 | | |
329 | | /// Perform a potentially vectored `op` on the provided `ArrayOrd` |
330 | 0 | fn apply<T: ArrayOrd>( |
331 | 0 | op: Op, |
332 | 0 | l: T, |
333 | 0 | l_s: bool, |
334 | 0 | l_v: Option<&dyn AnyDictionaryArray>, |
335 | 0 | r: T, |
336 | 0 | r_s: bool, |
337 | 0 | r_v: Option<&dyn AnyDictionaryArray>, |
338 | 0 | ) -> Option<BooleanBuffer> { |
339 | 0 | if l.len() == 0 || r.len() == 0 { |
340 | 0 | return None; // Handle empty dictionaries |
341 | 0 | } |
342 | | |
343 | 0 | if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) { |
344 | | // Not scalar and at least one side has a dictionary, need to perform vectored comparison |
345 | 0 | let l_v = l_v |
346 | 0 | .map(|x| x.normalized_keys()) |
347 | 0 | .unwrap_or_else(|| (0..l.len()).collect()); |
348 | | |
349 | 0 | let r_v = r_v |
350 | 0 | .map(|x| x.normalized_keys()) |
351 | 0 | .unwrap_or_else(|| (0..r.len()).collect()); |
352 | | |
353 | 0 | assert_eq!(l_v.len(), r_v.len()); // Sanity check |
354 | | |
355 | 0 | Some(match op { |
356 | 0 | Op::Equal | Op::NotDistinct => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq), |
357 | 0 | Op::NotEqual | Op::Distinct => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_eq), |
358 | 0 | Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt), |
359 | 0 | Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true, T::is_lt), |
360 | 0 | Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false, T::is_lt), |
361 | 0 | Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_lt), |
362 | | }) |
363 | | } else { |
364 | 0 | let l_s = l_s.then(|| l_v.map(|x| x.normalized_keys()[0]).unwrap_or_default()); |
365 | 0 | let r_s = r_s.then(|| r_v.map(|x| x.normalized_keys()[0]).unwrap_or_default()); |
366 | | |
367 | 0 | let buffer = match op { |
368 | 0 | Op::Equal | Op::NotDistinct => apply_op(l, l_s, r, r_s, false, T::is_eq), |
369 | 0 | Op::NotEqual | Op::Distinct => apply_op(l, l_s, r, r_s, true, T::is_eq), |
370 | 0 | Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt), |
371 | 0 | Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt), |
372 | 0 | Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt), |
373 | 0 | Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt), |
374 | | }; |
375 | | |
376 | | // If a side had a dictionary, and was not scalar, we need to materialize this |
377 | 0 | Some(match (l_v, r_v) { |
378 | 0 | (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer), |
379 | 0 | (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer), |
380 | 0 | _ => buffer, |
381 | | }) |
382 | | } |
383 | 0 | } |
384 | | |
385 | | /// Perform a take operation on `buffer` with the given dictionary |
386 | 0 | fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer { |
387 | 0 | let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap(); |
388 | 0 | array.as_boolean().values().clone() |
389 | 0 | } |
390 | | |
391 | | /// Invokes `f` with values `0..len` collecting the boolean results into a new `BooleanBuffer` |
392 | | /// |
393 | | /// This is similar to [`MutableBuffer::collect_bool`] but with |
394 | | /// the option to efficiently negate the result |
395 | 0 | fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) -> BooleanBuffer { |
396 | 0 | let mut buffer = MutableBuffer::new(ceil(len, 64) * 8); |
397 | | |
398 | 0 | let chunks = len / 64; |
399 | 0 | let remainder = len % 64; |
400 | 0 | for chunk in 0..chunks { |
401 | 0 | let mut packed = 0; |
402 | 0 | for bit_idx in 0..64 { |
403 | 0 | let i = bit_idx + chunk * 64; |
404 | 0 | packed |= (f(i) as u64) << bit_idx; |
405 | 0 | } |
406 | 0 | if neg { |
407 | 0 | packed = !packed |
408 | 0 | } |
409 | | |
410 | | // SAFETY: Already allocated sufficient capacity |
411 | 0 | unsafe { buffer.push_unchecked(packed) } |
412 | | } |
413 | | |
414 | 0 | if remainder != 0 { |
415 | 0 | let mut packed = 0; |
416 | 0 | for bit_idx in 0..remainder { |
417 | 0 | let i = bit_idx + chunks * 64; |
418 | 0 | packed |= (f(i) as u64) << bit_idx; |
419 | 0 | } |
420 | 0 | if neg { |
421 | 0 | packed = !packed |
422 | 0 | } |
423 | | |
424 | | // SAFETY: Already allocated sufficient capacity |
425 | 0 | unsafe { buffer.push_unchecked(packed) } |
426 | 0 | } |
427 | 0 | BooleanBuffer::new(buffer.into(), 0, len) |
428 | 0 | } |
429 | | |
430 | | /// Applies `op` to possibly scalar `ArrayOrd` |
431 | | /// |
432 | | /// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the scalar value in `l` |
433 | | /// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the scalar value in `r` |
434 | | /// |
435 | | /// If `neg` is true the result of `op` will be negated |
436 | 0 | fn apply_op<T: ArrayOrd>( |
437 | 0 | l: T, |
438 | 0 | l_s: Option<usize>, |
439 | 0 | r: T, |
440 | 0 | r_s: Option<usize>, |
441 | 0 | neg: bool, |
442 | 0 | op: impl Fn(T::Item, T::Item) -> bool, |
443 | 0 | ) -> BooleanBuffer { |
444 | 0 | match (l_s, r_s) { |
445 | | (None, None) => { |
446 | 0 | assert_eq!(l.len(), r.len()); |
447 | 0 | collect_bool(l.len(), neg, |idx| unsafe { |
448 | 0 | op(l.value_unchecked(idx), r.value_unchecked(idx)) |
449 | 0 | }) |
450 | | } |
451 | 0 | (Some(l_s), Some(r_s)) => { |
452 | 0 | let a = l.value(l_s); |
453 | 0 | let b = r.value(r_s); |
454 | 0 | std::iter::once(op(a, b) ^ neg).collect() |
455 | | } |
456 | 0 | (Some(l_s), None) => { |
457 | 0 | let v = l.value(l_s); |
458 | 0 | collect_bool(r.len(), neg, |idx| op(v, unsafe { r.value_unchecked(idx) })) |
459 | | } |
460 | 0 | (None, Some(r_s)) => { |
461 | 0 | let v = r.value(r_s); |
462 | 0 | collect_bool(l.len(), neg, |idx| op(unsafe { l.value_unchecked(idx) }, v)) |
463 | | } |
464 | | } |
465 | 0 | } |
466 | | |
467 | | /// Applies `op` to possibly scalar `ArrayOrd` with the given indices |
468 | 0 | fn apply_op_vectored<T: ArrayOrd>( |
469 | 0 | l: T, |
470 | 0 | l_v: &[usize], |
471 | 0 | r: T, |
472 | 0 | r_v: &[usize], |
473 | 0 | neg: bool, |
474 | 0 | op: impl Fn(T::Item, T::Item) -> bool, |
475 | 0 | ) -> BooleanBuffer { |
476 | 0 | assert_eq!(l_v.len(), r_v.len()); |
477 | 0 | collect_bool(l_v.len(), neg, |idx| unsafe { |
478 | 0 | let l_idx = *l_v.get_unchecked(idx); |
479 | 0 | let r_idx = *r_v.get_unchecked(idx); |
480 | 0 | op(l.value_unchecked(l_idx), r.value_unchecked(r_idx)) |
481 | 0 | }) |
482 | 0 | } |
483 | | |
484 | | trait ArrayOrd { |
485 | | type Item: Copy; |
486 | | |
487 | | fn len(&self) -> usize; |
488 | | |
489 | 0 | fn value(&self, idx: usize) -> Self::Item { |
490 | 0 | assert!(idx < self.len()); |
491 | 0 | unsafe { self.value_unchecked(idx) } |
492 | 0 | } |
493 | | |
494 | | /// # Safety |
495 | | /// |
496 | | /// Safe if `idx < self.len()` |
497 | | unsafe fn value_unchecked(&self, idx: usize) -> Self::Item; |
498 | | |
499 | | fn is_eq(l: Self::Item, r: Self::Item) -> bool; |
500 | | |
501 | | fn is_lt(l: Self::Item, r: Self::Item) -> bool; |
502 | | } |
503 | | |
504 | | impl ArrayOrd for &BooleanArray { |
505 | | type Item = bool; |
506 | | |
507 | 0 | fn len(&self) -> usize { |
508 | 0 | Array::len(self) |
509 | 0 | } |
510 | | |
511 | 0 | unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { |
512 | 0 | BooleanArray::value_unchecked(self, idx) |
513 | 0 | } |
514 | | |
515 | 0 | fn is_eq(l: Self::Item, r: Self::Item) -> bool { |
516 | 0 | l == r |
517 | 0 | } |
518 | | |
519 | 0 | fn is_lt(l: Self::Item, r: Self::Item) -> bool { |
520 | 0 | !l & r |
521 | 0 | } |
522 | | } |
523 | | |
524 | | impl<T: ArrowNativeTypeOp> ArrayOrd for &[T] { |
525 | | type Item = T; |
526 | | |
527 | 0 | fn len(&self) -> usize { |
528 | 0 | (*self).len() |
529 | 0 | } |
530 | | |
531 | 0 | unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { |
532 | 0 | *self.get_unchecked(idx) |
533 | 0 | } |
534 | | |
535 | 0 | fn is_eq(l: Self::Item, r: Self::Item) -> bool { |
536 | 0 | l.is_eq(r) |
537 | 0 | } |
538 | | |
539 | 0 | fn is_lt(l: Self::Item, r: Self::Item) -> bool { |
540 | 0 | l.is_lt(r) |
541 | 0 | } |
542 | | } |
543 | | |
544 | | impl<'a, T: ByteArrayType> ArrayOrd for &'a GenericByteArray<T> { |
545 | | type Item = &'a [u8]; |
546 | | |
547 | 0 | fn len(&self) -> usize { |
548 | 0 | Array::len(self) |
549 | 0 | } |
550 | | |
551 | 0 | unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { |
552 | 0 | GenericByteArray::value_unchecked(self, idx).as_ref() |
553 | 0 | } |
554 | | |
555 | 0 | fn is_eq(l: Self::Item, r: Self::Item) -> bool { |
556 | 0 | l == r |
557 | 0 | } |
558 | | |
559 | 0 | fn is_lt(l: Self::Item, r: Self::Item) -> bool { |
560 | 0 | l < r |
561 | 0 | } |
562 | | } |
563 | | |
564 | | impl<'a, T: ByteViewType> ArrayOrd for &'a GenericByteViewArray<T> { |
565 | | /// This is the item type for the GenericByteViewArray::compare |
566 | | /// Item.0 is the array, Item.1 is the index |
567 | | type Item = (&'a GenericByteViewArray<T>, usize); |
568 | | |
569 | | #[inline(always)] |
570 | 0 | fn is_eq(l: Self::Item, r: Self::Item) -> bool { |
571 | 0 | let l_view = unsafe { l.0.views().get_unchecked(l.1) }; |
572 | 0 | let r_view = unsafe { r.0.views().get_unchecked(r.1) }; |
573 | 0 | if l.0.data_buffers().is_empty() && r.0.data_buffers().is_empty() { |
574 | | // For eq case, we can directly compare the inlined bytes |
575 | 0 | return l_view == r_view; |
576 | 0 | } |
577 | | |
578 | 0 | let l_len = *l_view as u32; |
579 | 0 | let r_len = *r_view as u32; |
580 | | // This is a fast path for equality check. |
581 | | // We don't need to look at the actual bytes to determine if they are equal. |
582 | 0 | if l_len != r_len { |
583 | 0 | return false; |
584 | 0 | } |
585 | 0 | if l_len == 0 && r_len == 0 { |
586 | 0 | return true; |
587 | 0 | } |
588 | | |
589 | | // # Safety |
590 | | // The index is within bounds as it is checked in value() |
591 | 0 | unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_eq() } |
592 | 0 | } |
593 | | |
594 | | #[inline(always)] |
595 | 0 | fn is_lt(l: Self::Item, r: Self::Item) -> bool { |
596 | | // If both arrays use only the inline buffer |
597 | 0 | if l.0.data_buffers().is_empty() && r.0.data_buffers().is_empty() { |
598 | 0 | let l_view = unsafe { l.0.views().get_unchecked(l.1) }; |
599 | 0 | let r_view = unsafe { r.0.views().get_unchecked(r.1) }; |
600 | 0 | return GenericByteViewArray::<T>::inline_key_fast(*l_view) |
601 | 0 | < GenericByteViewArray::<T>::inline_key_fast(*r_view); |
602 | 0 | } |
603 | | |
604 | | // Fallback to the generic, unchecked comparison for non-inline cases |
605 | | // # Safety |
606 | | // The index is within bounds as it is checked in value() |
607 | 0 | unsafe { GenericByteViewArray::compare_unchecked(l.0, l.1, r.0, r.1).is_lt() } |
608 | 0 | } |
609 | | |
610 | 0 | fn len(&self) -> usize { |
611 | 0 | Array::len(self) |
612 | 0 | } |
613 | | |
614 | 0 | unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { |
615 | 0 | (self, idx) |
616 | 0 | } |
617 | | } |
618 | | |
619 | | impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { |
620 | | type Item = &'a [u8]; |
621 | | |
622 | 0 | fn len(&self) -> usize { |
623 | 0 | Array::len(self) |
624 | 0 | } |
625 | | |
626 | 0 | unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { |
627 | 0 | FixedSizeBinaryArray::value_unchecked(self, idx) |
628 | 0 | } |
629 | | |
630 | 0 | fn is_eq(l: Self::Item, r: Self::Item) -> bool { |
631 | 0 | l == r |
632 | 0 | } |
633 | | |
634 | 0 | fn is_lt(l: Self::Item, r: Self::Item) -> bool { |
635 | 0 | l < r |
636 | 0 | } |
637 | | } |
638 | | |
639 | | /// Compares two [`GenericByteViewArray`] at index `left_idx` and `right_idx` |
640 | | #[inline(always)] |
641 | 0 | pub fn compare_byte_view<T: ByteViewType>( |
642 | 0 | left: &GenericByteViewArray<T>, |
643 | 0 | left_idx: usize, |
644 | 0 | right: &GenericByteViewArray<T>, |
645 | 0 | right_idx: usize, |
646 | 0 | ) -> Ordering { |
647 | 0 | assert!(left_idx < left.len()); |
648 | 0 | assert!(right_idx < right.len()); |
649 | 0 | if left.data_buffers().is_empty() && right.data_buffers().is_empty() { |
650 | 0 | let l_view = unsafe { left.views().get_unchecked(left_idx) }; |
651 | 0 | let r_view = unsafe { right.views().get_unchecked(right_idx) }; |
652 | 0 | return GenericByteViewArray::<T>::inline_key_fast(*l_view) |
653 | 0 | .cmp(&GenericByteViewArray::<T>::inline_key_fast(*r_view)); |
654 | 0 | } |
655 | 0 | unsafe { GenericByteViewArray::compare_unchecked(left, left_idx, right, right_idx) } |
656 | 0 | } |
657 | | |
658 | | #[cfg(test)] |
659 | | mod tests { |
660 | | use std::sync::Arc; |
661 | | |
662 | | use arrow_array::{DictionaryArray, Int32Array, Scalar, StringArray}; |
663 | | |
664 | | use super::*; |
665 | | |
666 | | #[test] |
667 | | fn test_null_dict() { |
668 | | let a = DictionaryArray::new(Int32Array::new_null(10), Arc::new(Int32Array::new_null(0))); |
669 | | let r = eq(&a, &a).unwrap(); |
670 | | assert_eq!(r.null_count(), 10); |
671 | | |
672 | | let a = DictionaryArray::new( |
673 | | Int32Array::from(vec![1, 2, 3, 4, 5, 6]), |
674 | | Arc::new(Int32Array::new_null(10)), |
675 | | ); |
676 | | let r = eq(&a, &a).unwrap(); |
677 | | assert_eq!(r.null_count(), 6); |
678 | | |
679 | | let scalar = |
680 | | DictionaryArray::new(Int32Array::new_null(1), Arc::new(Int32Array::new_null(0))); |
681 | | let r = eq(&a, &Scalar::new(&scalar)).unwrap(); |
682 | | assert_eq!(r.null_count(), 6); |
683 | | |
684 | | let scalar = |
685 | | DictionaryArray::new(Int32Array::new_null(1), Arc::new(Int32Array::new_null(0))); |
686 | | let r = eq(&Scalar::new(&scalar), &Scalar::new(&scalar)).unwrap(); |
687 | | assert_eq!(r.null_count(), 1); |
688 | | |
689 | | let a = DictionaryArray::new( |
690 | | Int32Array::from(vec![0, 1, 2]), |
691 | | Arc::new(Int32Array::from(vec![3, 2, 1])), |
692 | | ); |
693 | | let r = eq(&a, &Scalar::new(&scalar)).unwrap(); |
694 | | assert_eq!(r.null_count(), 3); |
695 | | } |
696 | | |
697 | | #[test] |
698 | | fn is_distinct_from_non_nulls() { |
699 | | let left_int_array = Int32Array::from(vec![0, 1, 2, 3, 4]); |
700 | | let right_int_array = Int32Array::from(vec![4, 3, 2, 1, 0]); |
701 | | |
702 | | assert_eq!( |
703 | | BooleanArray::from(vec![true, true, false, true, true,]), |
704 | | distinct(&left_int_array, &right_int_array).unwrap() |
705 | | ); |
706 | | assert_eq!( |
707 | | BooleanArray::from(vec![false, false, true, false, false,]), |
708 | | not_distinct(&left_int_array, &right_int_array).unwrap() |
709 | | ); |
710 | | } |
711 | | |
712 | | #[test] |
713 | | fn is_distinct_from_nulls() { |
714 | | // [0, 0, NULL, 0, 0, 0] |
715 | | let left_int_array = Int32Array::new( |
716 | | vec![0, 0, 1, 3, 0, 0].into(), |
717 | | Some(NullBuffer::from(vec![true, true, false, true, true, true])), |
718 | | ); |
719 | | // [0, NULL, NULL, NULL, 0, NULL] |
720 | | let right_int_array = Int32Array::new( |
721 | | vec![0; 6].into(), |
722 | | Some(NullBuffer::from(vec![ |
723 | | true, false, false, false, true, false, |
724 | | ])), |
725 | | ); |
726 | | |
727 | | assert_eq!( |
728 | | BooleanArray::from(vec![false, true, false, true, false, true,]), |
729 | | distinct(&left_int_array, &right_int_array).unwrap() |
730 | | ); |
731 | | |
732 | | assert_eq!( |
733 | | BooleanArray::from(vec![true, false, true, false, true, false,]), |
734 | | not_distinct(&left_int_array, &right_int_array).unwrap() |
735 | | ); |
736 | | } |
737 | | |
738 | | #[test] |
739 | | fn test_distinct_scalar() { |
740 | | let a = Int32Array::new_scalar(12); |
741 | | let b = Int32Array::new_scalar(12); |
742 | | assert!(!distinct(&a, &b).unwrap().value(0)); |
743 | | assert!(not_distinct(&a, &b).unwrap().value(0)); |
744 | | |
745 | | let a = Int32Array::new_scalar(12); |
746 | | let b = Int32Array::new_null(1); |
747 | | assert!(distinct(&a, &b).unwrap().value(0)); |
748 | | assert!(!not_distinct(&a, &b).unwrap().value(0)); |
749 | | assert!(distinct(&b, &a).unwrap().value(0)); |
750 | | assert!(!not_distinct(&b, &a).unwrap().value(0)); |
751 | | |
752 | | let b = Scalar::new(b); |
753 | | assert!(distinct(&a, &b).unwrap().value(0)); |
754 | | assert!(!not_distinct(&a, &b).unwrap().value(0)); |
755 | | |
756 | | assert!(!distinct(&b, &b).unwrap().value(0)); |
757 | | assert!(not_distinct(&b, &b).unwrap().value(0)); |
758 | | |
759 | | let a = Int32Array::new( |
760 | | vec![0, 1, 2, 3].into(), |
761 | | Some(vec![false, false, true, true].into()), |
762 | | ); |
763 | | let expected = BooleanArray::from(vec![false, false, true, true]); |
764 | | assert_eq!(distinct(&a, &b).unwrap(), expected); |
765 | | assert_eq!(distinct(&b, &a).unwrap(), expected); |
766 | | |
767 | | let expected = BooleanArray::from(vec![true, true, false, false]); |
768 | | assert_eq!(not_distinct(&a, &b).unwrap(), expected); |
769 | | assert_eq!(not_distinct(&b, &a).unwrap(), expected); |
770 | | |
771 | | let b = Int32Array::new_scalar(1); |
772 | | let expected = BooleanArray::from(vec![true; 4]); |
773 | | assert_eq!(distinct(&a, &b).unwrap(), expected); |
774 | | assert_eq!(distinct(&b, &a).unwrap(), expected); |
775 | | let expected = BooleanArray::from(vec![false; 4]); |
776 | | assert_eq!(not_distinct(&a, &b).unwrap(), expected); |
777 | | assert_eq!(not_distinct(&b, &a).unwrap(), expected); |
778 | | |
779 | | let b = Int32Array::new_scalar(3); |
780 | | let expected = BooleanArray::from(vec![true, true, true, false]); |
781 | | assert_eq!(distinct(&a, &b).unwrap(), expected); |
782 | | assert_eq!(distinct(&b, &a).unwrap(), expected); |
783 | | let expected = BooleanArray::from(vec![false, false, false, true]); |
784 | | assert_eq!(not_distinct(&a, &b).unwrap(), expected); |
785 | | assert_eq!(not_distinct(&b, &a).unwrap(), expected); |
786 | | } |
787 | | |
788 | | #[test] |
789 | | fn test_scalar_negation() { |
790 | | let a = Int32Array::new_scalar(54); |
791 | | let b = Int32Array::new_scalar(54); |
792 | | let r = eq(&a, &b).unwrap(); |
793 | | assert!(r.value(0)); |
794 | | |
795 | | let r = neq(&a, &b).unwrap(); |
796 | | assert!(!r.value(0)) |
797 | | } |
798 | | |
799 | | #[test] |
800 | | fn test_scalar_empty() { |
801 | | let a = Int32Array::new_null(0); |
802 | | let b = Int32Array::new_scalar(23); |
803 | | let r = eq(&a, &b).unwrap(); |
804 | | assert_eq!(r.len(), 0); |
805 | | let r = eq(&b, &a).unwrap(); |
806 | | assert_eq!(r.len(), 0); |
807 | | } |
808 | | |
809 | | #[test] |
810 | | fn test_dictionary_nulls() { |
811 | | let values = StringArray::from(vec![Some("us-west"), Some("us-east")]); |
812 | | let nulls = NullBuffer::from(vec![false, true, true]); |
813 | | |
814 | | let key_values = vec![100i32, 1i32, 0i32].into(); |
815 | | let keys = Int32Array::new(key_values, Some(nulls)); |
816 | | let col = DictionaryArray::try_new(keys, Arc::new(values)).unwrap(); |
817 | | |
818 | | neq(&col.slice(0, col.len() - 1), &col.slice(1, col.len() - 1)).unwrap(); |
819 | | } |
820 | | } |