/Users/andrewlamb/Software/arrow-rs/arrow-select/src/filter.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines filter kernels |
19 | | |
20 | | use std::ops::AddAssign; |
21 | | use std::sync::Arc; |
22 | | |
23 | | use arrow_array::builder::BooleanBufferBuilder; |
24 | | use arrow_array::cast::AsArray; |
25 | | use arrow_array::types::{ |
26 | | ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType, |
27 | | }; |
28 | | use arrow_array::*; |
29 | | use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util}; |
30 | | use arrow_buffer::{Buffer, MutableBuffer}; |
31 | | use arrow_data::ArrayDataBuilder; |
32 | | use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator}; |
33 | | use arrow_data::transform::MutableArrayData; |
34 | | use arrow_schema::*; |
35 | | |
36 | | /// If the filter selects more than this fraction of rows, use |
37 | | /// [`SlicesIterator`] to copy ranges of values. Otherwise iterate |
38 | | /// over individual rows using [`IndexIterator`] |
39 | | /// |
40 | | /// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009> |
41 | | /// |
42 | | const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; |
43 | | |
44 | | /// An iterator of `(usize, usize)` each representing an interval |
45 | | /// `[start, end)` whose slots of a bitmap [Buffer] are true. |
46 | | /// |
47 | | /// Each interval corresponds to a contiguous region of memory to be |
48 | | /// "taken" from an array to be filtered. |
49 | | /// |
50 | | /// ## Notes: |
51 | | /// |
52 | | /// 1. Ignores the validity bitmap (ignores nulls) |
53 | | /// |
54 | | /// 2. Only performant for filters that copy across long contiguous runs |
55 | | #[derive(Debug)] |
56 | | pub struct SlicesIterator<'a>(BitSliceIterator<'a>); |
57 | | |
58 | | impl<'a> SlicesIterator<'a> { |
59 | | /// Creates a new iterator from a [BooleanArray] |
60 | 234 | pub fn new(filter: &'a BooleanArray) -> Self { |
61 | 234 | filter.values().into() |
62 | 234 | } |
63 | | } |
64 | | |
65 | | impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> { |
66 | 281 | fn from(filter: &'a BooleanBuffer) -> Self { |
67 | 281 | Self(filter.set_slices()) |
68 | 281 | } |
69 | | } |
70 | | |
71 | | impl Iterator for SlicesIterator<'_> { |
72 | | type Item = (usize, usize); |
73 | | |
74 | 14.7k | fn next(&mut self) -> Option<Self::Item> { |
75 | 14.7k | self.0.next() |
76 | 14.7k | } |
77 | | } |
78 | | |
79 | | /// An iterator of `usize` whose index in [`BooleanArray`] is true |
80 | | /// |
81 | | /// This provides the best performance on most predicates, apart from those which keep |
82 | | /// large runs and therefore favour [`SlicesIterator`] |
83 | | struct IndexIterator<'a> { |
84 | | remaining: usize, |
85 | | iter: BitIndexIterator<'a>, |
86 | | } |
87 | | |
88 | | impl<'a> IndexIterator<'a> { |
89 | 757 | fn new(filter: &'a BooleanArray, remaining: usize) -> Self { |
90 | 757 | assert_eq!(filter.null_count(), 0); |
91 | 757 | let iter = filter.values().set_indices(); |
92 | 757 | Self { remaining, iter } |
93 | 757 | } |
94 | | } |
95 | | |
96 | | impl Iterator for IndexIterator<'_> { |
97 | | type Item = usize; |
98 | | |
99 | 70.4k | fn next(&mut self) -> Option<Self::Item> { |
100 | 70.4k | if self.remaining != 0 { |
101 | | // Fascinatingly swapping these two lines around results in a 50% |
102 | | // performance regression for some benchmarks |
103 | 69.9k | let next = self.iter.next().expect("IndexIterator exhausted early"); |
104 | 69.9k | self.remaining -= 1; |
105 | | // Must panic if exhausted early as trusted length iterator |
106 | 69.9k | return Some(next); |
107 | 523 | } |
108 | 523 | None |
109 | 70.4k | } |
110 | | |
111 | 593 | fn size_hint(&self) -> (usize, Option<usize>) { |
112 | 593 | (self.remaining, Some(self.remaining)) |
113 | 593 | } |
114 | | } |
115 | | |
116 | | /// Counts the number of set bits in `filter` |
117 | 567 | fn filter_count(filter: &BooleanArray) -> usize { |
118 | 567 | filter.values().count_set_bits() |
119 | 567 | } |
120 | | |
121 | | /// Remove null values by do a bitmask AND operation with null bits and the boolean bits. |
122 | 6 | pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { |
123 | 6 | let nulls = filter.nulls().unwrap(); |
124 | 6 | let mask = filter.values() & nulls.inner(); |
125 | 6 | BooleanArray::new(mask, None) |
126 | 6 | } |
127 | | |
128 | | /// Returns a filtered `values` [`Array`] where the corresponding elements of |
129 | | /// `predicate` are `true`. |
130 | | /// |
131 | | /// If multiple arrays (or record batches) need to be filtered using the same predicate array, |
132 | | /// consider using [FilterBuilder] to create a single [FilterPredicate] and then |
133 | | /// calling [FilterPredicate::filter_record_batch]. |
134 | | /// |
135 | | /// In contrast to this function, it is then the responsibility of the caller |
136 | | /// to use [FilterBuilder::optimize] if appropriate. |
137 | | /// |
138 | | /// # See also |
139 | | /// * [`FilterBuilder`] for more control over the filtering process. |
140 | | /// * [`filter_record_batch`] to filter a [`RecordBatch`] |
141 | | /// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce |
142 | | /// the results into a single array. |
143 | | /// |
144 | | /// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer |
145 | | /// |
146 | | /// # Example |
147 | | /// ```rust |
148 | | /// # use arrow_array::{Int32Array, BooleanArray}; |
149 | | /// # use arrow_select::filter::filter; |
150 | | /// let array = Int32Array::from(vec![5, 6, 7, 8, 9]); |
151 | | /// let filter_array = BooleanArray::from(vec![true, false, false, true, false]); |
152 | | /// let c = filter(&array, &filter_array).unwrap(); |
153 | | /// let c = c.as_any().downcast_ref::<Int32Array>().unwrap(); |
154 | | /// assert_eq!(c, &Int32Array::from(vec![5, 8])); |
155 | | /// ``` |
156 | 375 | pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> { |
157 | 375 | let mut filter_builder = FilterBuilder::new(predicate); |
158 | | |
159 | 375 | if FilterBuilder::is_optimize_beneficial(values.data_type()) { |
160 | 6 | // Only optimize if filtering more than one array |
161 | 6 | // Otherwise, the overhead of optimization can be more than the benefit |
162 | 6 | filter_builder = filter_builder.optimize(); |
163 | 369 | } |
164 | | |
165 | 375 | let predicate = filter_builder.build(); |
166 | | |
167 | 375 | filter_array(values, &predicate) |
168 | 375 | } |
169 | | |
170 | | /// Returns a filtered [RecordBatch] where the corresponding elements of |
171 | | /// `predicate` are true. |
172 | | /// |
173 | | /// This is the equivalent of calling [filter] on each column of the [RecordBatch]. |
174 | | /// |
175 | | /// If multiple record batches (or arrays) need to be filtered using the same predicate array, |
176 | | /// consider using [FilterBuilder] to create a single [FilterPredicate] and then |
177 | | /// calling [FilterPredicate::filter_record_batch]. |
178 | | /// In contrast to this function, it is then the responsibility of the caller |
179 | | /// to use [FilterBuilder::optimize] if appropriate. |
180 | 82 | pub fn filter_record_batch( |
181 | 82 | record_batch: &RecordBatch, |
182 | 82 | predicate: &BooleanArray, |
183 | 82 | ) -> Result<RecordBatch, ArrowError> { |
184 | 82 | let mut filter_builder = FilterBuilder::new(predicate); |
185 | 82 | let num_cols = record_batch.num_columns(); |
186 | 82 | if num_cols > 1 |
187 | 2 | || (num_cols > 0 |
188 | 1 | && FilterBuilder::is_optimize_beneficial( |
189 | 1 | record_batch.schema_ref().field(0).data_type(), |
190 | | )) |
191 | 81 | { |
192 | 81 | // Only optimize if filtering more than one column or if the column contains multiple internal arrays |
193 | 81 | // Otherwise, the overhead of optimization can be more than the benefit |
194 | 81 | filter_builder = filter_builder.optimize(); |
195 | 81 | }1 |
196 | 82 | let filter = filter_builder.build(); |
197 | | |
198 | 82 | filter.filter_record_batch(record_batch) |
199 | 82 | } |
200 | | |
201 | | /// A builder to construct [`FilterPredicate`] |
202 | | #[derive(Debug)] |
203 | | pub struct FilterBuilder { |
204 | | filter: BooleanArray, |
205 | | count: usize, |
206 | | strategy: IterationStrategy, |
207 | | } |
208 | | |
209 | | impl FilterBuilder { |
210 | | /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`] |
211 | 459 | pub fn new(filter: &BooleanArray) -> Self { |
212 | 459 | let filter = match filter.null_count() { |
213 | 457 | 0 => filter.clone(), |
214 | 2 | _ => prep_null_mask_filter(filter), |
215 | | }; |
216 | | |
217 | 459 | let count = filter_count(&filter); |
218 | 459 | let strategy = IterationStrategy::default_strategy(filter.len(), count); |
219 | | |
220 | 459 | Self { |
221 | 459 | filter, |
222 | 459 | count, |
223 | 459 | strategy, |
224 | 459 | } |
225 | 459 | } |
226 | | |
227 | | /// Compute an optimized representation of the provided `filter` mask that can be |
228 | | /// applied to an array more quickly. |
229 | | /// |
230 | | /// When filtering multiple arrays (e.g. a [`RecordBatch`] or a |
231 | | /// [`StructArray`] with multiple fields), optimizing the filter can provide |
232 | | /// significant performance benefits. |
233 | | /// |
234 | | /// However, optimization takes time and can have a larger memory footprint |
235 | | /// than the original mask, so it is often faster to filter a single array, |
236 | | /// without filter optimization. |
237 | 89 | pub fn optimize(mut self) -> Self { |
238 | 89 | match self.strategy { |
239 | | IterationStrategy::SlicesIterator => { |
240 | 20 | let slices = SlicesIterator::new(&self.filter).collect(); |
241 | 20 | self.strategy = IterationStrategy::Slices(slices) |
242 | | } |
243 | | IterationStrategy::IndexIterator => { |
244 | 69 | let indices = IndexIterator::new(&self.filter, self.count).collect(); |
245 | 69 | self.strategy = IterationStrategy::Indices(indices) |
246 | | } |
247 | 0 | _ => {} |
248 | | } |
249 | 89 | self |
250 | 89 | } |
251 | | |
252 | | /// Determines if calling [FilterBuilder::optimize] is beneficial for the |
253 | | /// given type even when filtering just a single array. |
254 | | /// |
255 | | /// See [`FilterBuilder::optimize`] for more details. |
256 | 378 | pub fn is_optimize_beneficial(data_type: &DataType) -> bool { |
257 | 10 | match data_type { |
258 | 5 | DataType::Struct(fields) => { |
259 | 5 | fields.len() > 1 |
260 | 2 | || fields.len() == 1 |
261 | 2 | && FilterBuilder::is_optimize_beneficial(fields[0].data_type()) |
262 | | } |
263 | 4 | DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(), |
264 | 369 | _ => false, |
265 | | } |
266 | 378 | } |
267 | | |
268 | | /// Construct the final `FilterPredicate` |
269 | 459 | pub fn build(self) -> FilterPredicate { |
270 | 459 | FilterPredicate { |
271 | 459 | filter: self.filter, |
272 | 459 | count: self.count, |
273 | 459 | strategy: self.strategy, |
274 | 459 | } |
275 | 459 | } |
276 | | } |
277 | | |
278 | | /// The iteration strategy used to evaluate [`FilterPredicate`] |
279 | | #[derive(Debug)] |
280 | | enum IterationStrategy { |
281 | | /// A lazily evaluated iterator of ranges |
282 | | SlicesIterator, |
283 | | /// A lazily evaluated iterator of indices |
284 | | IndexIterator, |
285 | | /// A precomputed list of indices |
286 | | Indices(Vec<usize>), |
287 | | /// A precomputed array of ranges |
288 | | Slices(Vec<(usize, usize)>), |
289 | | /// Select all rows |
290 | | All, |
291 | | /// Select no rows |
292 | | None, |
293 | | } |
294 | | |
295 | | impl IterationStrategy { |
296 | | /// The default [`IterationStrategy`] for a filter of length `filter_length` |
297 | | /// and selecting `filter_count` rows |
298 | 459 | fn default_strategy(filter_length: usize, filter_count: usize) -> Self { |
299 | 459 | if filter_length == 0 || filter_count == 0 { |
300 | 31 | return IterationStrategy::None; |
301 | 428 | } |
302 | | |
303 | 428 | if filter_count == filter_length { |
304 | 19 | return IterationStrategy::All; |
305 | 409 | } |
306 | | |
307 | | // Compute the selectivity of the predicate by dividing the number of true |
308 | | // bits in the predicate by the predicate's total length |
309 | | // |
310 | | // This can then be used as a heuristic for the optimal iteration strategy |
311 | 409 | let selectivity_frac = filter_count as f64 / filter_length as f64; |
312 | 409 | if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD { |
313 | 60 | return IterationStrategy::SlicesIterator; |
314 | 349 | } |
315 | 349 | IterationStrategy::IndexIterator |
316 | 459 | } |
317 | | } |
318 | | |
319 | | /// A filtering predicate that can be applied to an [`Array`] |
320 | | #[derive(Debug)] |
321 | | pub struct FilterPredicate { |
322 | | filter: BooleanArray, |
323 | | count: usize, |
324 | | strategy: IterationStrategy, |
325 | | } |
326 | | |
327 | | impl FilterPredicate { |
328 | | /// Selects rows from `values` based on this [`FilterPredicate`] |
329 | 2 | pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> { |
330 | 2 | filter_array(values, self) |
331 | 2 | } |
332 | | |
333 | | /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this |
334 | | /// [`FilterPredicate`]. |
335 | | /// |
336 | | /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`]. |
337 | 82 | pub fn filter_record_batch( |
338 | 82 | &self, |
339 | 82 | record_batch: &RecordBatch, |
340 | 82 | ) -> Result<RecordBatch, ArrowError> { |
341 | 82 | let filtered_arrays = record_batch |
342 | 82 | .columns() |
343 | 82 | .iter() |
344 | 321 | .map82 (|a| filter_array(a, self)) |
345 | 82 | .collect::<Result<Vec<_>, _>>()?0 ; |
346 | | |
347 | | // SAFETY: we know that the set of filtered arrays will match the schema of the original |
348 | | // record batch |
349 | | unsafe { |
350 | 82 | Ok(RecordBatch::new_unchecked( |
351 | 82 | record_batch.schema(), |
352 | 82 | filtered_arrays, |
353 | 82 | self.count, |
354 | 82 | )) |
355 | | } |
356 | 82 | } |
357 | | |
358 | | /// Number of rows being selected based on this [`FilterPredicate`] |
359 | 6 | pub fn count(&self) -> usize { |
360 | 6 | self.count |
361 | 6 | } |
362 | | } |
363 | | |
364 | 714 | fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> { |
365 | 714 | if predicate.filter.len() > values.len() { |
366 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
367 | 0 | "Filter predicate of length {} is larger than target array of length {}", |
368 | 0 | predicate.filter.len(), |
369 | 0 | values.len() |
370 | 0 | ))); |
371 | 714 | } |
372 | | |
373 | 714 | match predicate.strategy { |
374 | 31 | IterationStrategy::None => Ok(new_empty_array(values.data_type())), |
375 | 19 | IterationStrategy::All => Ok(values.slice(0, predicate.count)), |
376 | | // actually filter |
377 | 664 | _ => downcast_primitive_array! { |
378 | 0 | values => Ok(Arc::new(filter_primitive(values, predicate))), |
379 | | DataType::Boolean => { |
380 | 0 | let values = values.as_any().downcast_ref::<BooleanArray>().unwrap(); |
381 | 0 | Ok(Arc::new(filter_boolean(values, predicate))) |
382 | | } |
383 | | DataType::Utf8 => { |
384 | 175 | Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate))) |
385 | | } |
386 | | DataType::LargeUtf8 => { |
387 | 0 | Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate))) |
388 | | } |
389 | | DataType::Utf8View => { |
390 | 82 | Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate))) |
391 | | } |
392 | | DataType::Binary => { |
393 | 1 | Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate))) |
394 | | } |
395 | | DataType::LargeBinary => { |
396 | 0 | Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate))) |
397 | | } |
398 | | DataType::BinaryView => { |
399 | 2 | Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate))) |
400 | | } |
401 | | DataType::FixedSizeBinary(_) => { |
402 | 4 | Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate))) |
403 | | } |
404 | | DataType::ListView(_) => { |
405 | 2 | Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate))) |
406 | | } |
407 | | DataType::LargeListView(_) => { |
408 | 2 | Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate))) |
409 | | } |
410 | | DataType::RunEndEncoded(_, _) => { |
411 | 4 | downcast_run_array!{ |
412 | 1 | values => Ok(Arc::new(filter_run_end_array(values, predicate)?0 )), |
413 | 0 | t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t) |
414 | | } |
415 | | } |
416 | 87 | DataType::Dictionary(_, _) => downcast_dictionary_array! { |
417 | 1 | values => Ok(Arc::new(filter_dict(values, predicate))), |
418 | 0 | t => unimplemented!("Filter not supported for dictionary type {:?}", t) |
419 | | } |
420 | | DataType::Struct(_) => { |
421 | 6 | Ok(Arc::new(filter_struct(values.as_struct(), predicate)?0 )) |
422 | | } |
423 | | DataType::Union(_, UnionMode::Sparse) => { |
424 | 4 | Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?0 )) |
425 | | } |
426 | | _ => { |
427 | 11 | let data = values.to_data(); |
428 | | // fallback to using MutableArrayData |
429 | 11 | let mut mutable = MutableArrayData::new( |
430 | 11 | vec![&data], |
431 | | false, |
432 | 11 | predicate.count, |
433 | | ); |
434 | | |
435 | 11 | match &predicate.strategy { |
436 | 0 | IterationStrategy::Slices(slices) => { |
437 | 0 | slices |
438 | 0 | .iter() |
439 | 0 | .for_each(|(start, end)| mutable.extend(0, *start, *end)); |
440 | | } |
441 | | _ => { |
442 | 11 | let iter = SlicesIterator::new(&predicate.filter); |
443 | 17 | iter11 .for_each11 (|(start, end)| mutable.extend(0, start, end)); |
444 | | } |
445 | | } |
446 | | |
447 | 11 | let data = mutable.freeze(); |
448 | 11 | Ok(make_array(data)) |
449 | | } |
450 | | }, |
451 | | } |
452 | 714 | } |
453 | | |
454 | | /// Filter any supported [`RunArray`] based on a [`FilterPredicate`] |
455 | 4 | fn filter_run_end_array<R: RunEndIndexType>( |
456 | 4 | array: &RunArray<R>, |
457 | 4 | predicate: &FilterPredicate, |
458 | 4 | ) -> Result<RunArray<R>, ArrowError> |
459 | 4 | where |
460 | 4 | R::Native: Into<i64> + From<bool>, |
461 | 4 | R::Native: AddAssign, |
462 | | { |
463 | 4 | let run_ends: &RunEndBuffer<R::Native> = array.run_ends(); |
464 | 4 | let mut new_run_ends = vec![R::default_value(); run_ends.len()]; |
465 | | |
466 | 4 | let mut start = 0u64; |
467 | 4 | let mut j = 0; |
468 | 4 | let mut count = R::default_value(); |
469 | 4 | let filter_values = predicate.filter.values(); |
470 | 4 | let run_ends = run_ends.inner(); |
471 | | |
472 | 15 | let pred4 : BooleanArray4 = BooleanBuffer::collect_bool4 (run_ends4 .len4 (), |i| { |
473 | 15 | let mut keep = false; |
474 | 15 | let mut end = run_ends[i].into() as u64; |
475 | 15 | let difference = end.saturating_sub(filter_values.len() as u64); |
476 | 15 | end -= difference; |
477 | | |
478 | | // Safety: we subtract the difference off `end` so we are always within bounds |
479 | 31 | for pred in (start..end)15 .map15 (|i| unsafe { filter_values.value_unchecked(i as usize) }) { |
480 | 31 | count += R::Native::from(pred); |
481 | 31 | keep |= pred |
482 | | } |
483 | | // this is to avoid branching |
484 | 15 | new_run_ends[j] = count; |
485 | 15 | j += keep as usize; |
486 | | |
487 | 15 | start = end; |
488 | 15 | keep |
489 | 15 | }) |
490 | 4 | .into(); |
491 | | |
492 | 4 | new_run_ends.truncate(j); |
493 | | |
494 | 4 | let values = array.values(); |
495 | 4 | let values = filter(&values, &pred)?0 ; |
496 | | |
497 | 4 | let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?0 ; |
498 | 4 | RunArray::try_new(&run_ends, &values) |
499 | 4 | } |
500 | | |
501 | | /// Computes a new null mask for `data` based on `predicate` |
502 | | /// |
503 | | /// If the predicate selected no null-rows, returns `None`, otherwise returns |
504 | | /// `Some((null_count, null_buffer))` where `null_count` is the number of nulls |
505 | | /// in the filtered output, and `null_buffer` is the filtered null buffer |
506 | | /// |
507 | 649 | fn filter_null_mask( |
508 | 649 | nulls: Option<&NullBuffer>, |
509 | 649 | predicate: &FilterPredicate, |
510 | 649 | ) -> Option<(usize, Buffer)> { |
511 | 649 | let nulls605 = nulls?44 ; |
512 | 605 | if nulls.null_count() == 0 { |
513 | 1 | return None; |
514 | 604 | } |
515 | | |
516 | 604 | let nulls = filter_bits(nulls.inner(), predicate); |
517 | | // The filtered `nulls` has a length of `predicate.count` bits and |
518 | | // therefore the null count is this minus the number of valid bits |
519 | 604 | let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count); |
520 | | |
521 | 604 | if null_count == 0 { |
522 | 57 | return None; |
523 | 547 | } |
524 | | |
525 | 547 | Some((null_count, nulls)) |
526 | 649 | } |
527 | | |
528 | | /// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset` |
529 | 604 | fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer { |
530 | 604 | let src = buffer.values(); |
531 | 604 | let offset = buffer.offset(); |
532 | | |
533 | 604 | match &predicate.strategy { |
534 | | IterationStrategy::IndexIterator => { |
535 | 234 | let bits = IndexIterator::new(&predicate.filter, predicate.count) |
536 | 12.6k | .map234 (|src_idx| bit_util::get_bit(src, src_idx + offset)); |
537 | | |
538 | | // SAFETY: `IndexIterator` reports its size correctly |
539 | 234 | unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } |
540 | | } |
541 | 250 | IterationStrategy::Indices(indices) => { |
542 | 250 | let bits = indices |
543 | 250 | .iter() |
544 | 70.8k | .map250 (|src_idx| bit_util::get_bit(src, *src_idx + offset)); |
545 | | |
546 | | // SAFETY: `Vec::iter()` reports its size correctly |
547 | 250 | unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } |
548 | | } |
549 | | IterationStrategy::SlicesIterator => { |
550 | 40 | let mut builder = BooleanBufferBuilder::new(predicate.count); |
551 | 659 | for (start, end) in SlicesIterator::new40 (&predicate.filter40 ) { |
552 | 659 | builder.append_packed_range(start + offset..end + offset, src) |
553 | | } |
554 | 40 | builder.into() |
555 | | } |
556 | 80 | IterationStrategy::Slices(slices) => { |
557 | 80 | let mut builder = BooleanBufferBuilder::new(predicate.count); |
558 | 6.04k | for (start5.96k , end5.96k ) in slices { |
559 | 5.96k | builder.append_packed_range(*start + offset..*end + offset, src) |
560 | | } |
561 | 80 | builder.into() |
562 | | } |
563 | 0 | IterationStrategy::All | IterationStrategy::None => unreachable!(), |
564 | | } |
565 | 604 | } |
566 | | |
567 | | /// `filter` implementation for boolean buffers |
568 | 0 | fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray { |
569 | 0 | let values = filter_bits(array.values(), predicate); |
570 | | |
571 | 0 | let mut builder = ArrayDataBuilder::new(DataType::Boolean) |
572 | 0 | .len(predicate.count) |
573 | 0 | .add_buffer(values); |
574 | | |
575 | 0 | if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { |
576 | 0 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
577 | 0 | } |
578 | | |
579 | 0 | let data = unsafe { builder.build_unchecked() }; |
580 | 0 | BooleanArray::from(data) |
581 | 0 | } |
582 | | |
583 | | #[inline(never)] |
584 | 467 | fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer { |
585 | 467 | assert!(values.len() >= predicate.filter.len()); |
586 | | |
587 | 467 | match &predicate.strategy { |
588 | | IterationStrategy::SlicesIterator => { |
589 | 27 | let mut buffer = Vec::with_capacity(predicate.count); |
590 | 440 | for (start, end) in SlicesIterator::new27 (&predicate.filter27 ) { |
591 | 440 | buffer.extend_from_slice(&values[start..end]); |
592 | 440 | } |
593 | 27 | buffer.into() |
594 | | } |
595 | 60 | IterationStrategy::Slices(slices) => { |
596 | 60 | let mut buffer = Vec::with_capacity(predicate.count); |
597 | 4.53k | for (start4.47k , end4.47k ) in slices { |
598 | 4.47k | buffer.extend_from_slice(&values[*start..*end]); |
599 | 4.47k | } |
600 | 60 | buffer.into() |
601 | | } |
602 | | IterationStrategy::IndexIterator => { |
603 | 8.46k | let iter185 = IndexIterator::new185 (&predicate.filter185 , predicate.count185 ).map185 (|x| values[x]); |
604 | | |
605 | | // SAFETY: IndexIterator is trusted length |
606 | 185 | unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into() |
607 | | } |
608 | 195 | IterationStrategy::Indices(indices) => { |
609 | 53.1k | let iter195 = indices.iter()195 .map195 (|x| values[*x]); |
610 | 195 | iter.collect::<Vec<_>>().into() |
611 | | } |
612 | 0 | IterationStrategy::All | IterationStrategy::None => unreachable!(), |
613 | | } |
614 | 467 | } |
615 | | |
616 | | /// `filter` implementation for primitive arrays |
617 | 375 | fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T> |
618 | 375 | where |
619 | 375 | T: ArrowPrimitiveType, |
620 | | { |
621 | 375 | let values = array.values(); |
622 | 375 | let buffer = filter_native(values, predicate); |
623 | 375 | let mut builder = ArrayDataBuilder::new(array.data_type().clone()) |
624 | 375 | .len(predicate.count) |
625 | 375 | .add_buffer(buffer); |
626 | | |
627 | 375 | if let Some((null_count322 , nulls322 )) = filter_null_mask(array.nulls(), predicate) { |
628 | 322 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
629 | 322 | }53 |
630 | | |
631 | 375 | let data = unsafe { builder.build_unchecked() }; |
632 | 375 | PrimitiveArray::from(data) |
633 | 375 | } |
634 | | |
635 | | /// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be |
636 | | /// used to build a new [`GenericByteArray`] by copying values from the source |
637 | | /// |
638 | | /// TODO(raphael): Could this be used for the take kernel as well? |
639 | | struct FilterBytes<'a, OffsetSize> { |
640 | | src_offsets: &'a [OffsetSize], |
641 | | src_values: &'a [u8], |
642 | | dst_offsets: Vec<OffsetSize>, |
643 | | dst_values: Vec<u8>, |
644 | | cur_offset: OffsetSize, |
645 | | } |
646 | | |
647 | | impl<'a, OffsetSize> FilterBytes<'a, OffsetSize> |
648 | | where |
649 | | OffsetSize: OffsetSizeTrait, |
650 | | { |
651 | 176 | fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self |
652 | 176 | where |
653 | 176 | T: ByteArrayType<Offset = OffsetSize>, |
654 | | { |
655 | 176 | let dst_values = Vec::new(); |
656 | 176 | let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1); |
657 | 176 | let cur_offset = OffsetSize::from_usize(0).unwrap(); |
658 | | |
659 | 176 | dst_offsets.push(cur_offset); |
660 | | |
661 | 176 | Self { |
662 | 176 | src_offsets: array.value_offsets(), |
663 | 176 | src_values: array.value_data(), |
664 | 176 | dst_offsets, |
665 | 176 | dst_values, |
666 | 176 | cur_offset, |
667 | 176 | } |
668 | 176 | } |
669 | | |
670 | | /// Returns the byte offset at `idx` |
671 | | #[inline] |
672 | 36.0k | fn get_value_offset(&self, idx: usize) -> usize { |
673 | 36.0k | self.src_offsets[idx].as_usize() |
674 | 36.0k | } |
675 | | |
676 | | /// Returns the start and end of the value at index `idx` along with its length |
677 | | #[inline] |
678 | 16.3k | fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) { |
679 | | // These can only fail if `array` contains invalid data |
680 | 16.3k | let start = self.get_value_offset(idx); |
681 | 16.3k | let end = self.get_value_offset(idx + 1); |
682 | 16.3k | let len = OffsetSize::from_usize(end - start).expect("illegal offset range"); |
683 | 16.3k | (start, end, len) |
684 | 16.3k | } |
685 | | |
686 | 143 | fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) { |
687 | 21.9k | self.dst_offsets143 .extend143 (iter143 .map143 (|idx| { |
688 | 21.9k | let start = self.src_offsets[idx].as_usize(); |
689 | 21.9k | let end = self.src_offsets[idx + 1].as_usize(); |
690 | 21.9k | let len = OffsetSize::from_usize(end - start).expect("illegal offset range"); |
691 | 21.9k | self.cur_offset += len; |
692 | | |
693 | 21.9k | self.cur_offset |
694 | 21.9k | })); |
695 | 143 | } |
696 | | |
697 | | /// Extends the in-progress array by the indexes in the provided iterator |
698 | 143 | fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) { |
699 | 143 | self.dst_values.reserve_exact(self.cur_offset.as_usize()); |
700 | | |
701 | 22.0k | for idx21.9k in iter { |
702 | 21.9k | let start = self.src_offsets[idx].as_usize(); |
703 | 21.9k | let end = self.src_offsets[idx + 1].as_usize(); |
704 | 21.9k | self.dst_values |
705 | 21.9k | .extend_from_slice(&self.src_values[start..end]); |
706 | 21.9k | } |
707 | 143 | } |
708 | | |
709 | 33 | fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) { |
710 | 33 | self.dst_offsets.reserve_exact(count); |
711 | 1.74k | for (start1.71k , end1.71k ) in iter { |
712 | | // These can only fail if `array` contains invalid data |
713 | 16.3k | for idx in start1.71k ..end1.71k { |
714 | 16.3k | let (_, _, len) = self.get_value_range(idx); |
715 | 16.3k | self.cur_offset += len; |
716 | 16.3k | self.dst_offsets.push(self.cur_offset); |
717 | 16.3k | } |
718 | | } |
719 | 33 | } |
720 | | |
721 | | /// Extends the in-progress array by the ranges in the provided iterator |
722 | 33 | fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) { |
723 | 33 | self.dst_values.reserve_exact(self.cur_offset.as_usize()); |
724 | | |
725 | 1.74k | for (start1.71k , end1.71k ) in iter { |
726 | 1.71k | let value_start = self.get_value_offset(start); |
727 | 1.71k | let value_end = self.get_value_offset(end); |
728 | 1.71k | self.dst_values |
729 | 1.71k | .extend_from_slice(&self.src_values[value_start..value_end]); |
730 | 1.71k | } |
731 | 33 | } |
732 | | } |
733 | | |
734 | | /// `filter` implementation for byte arrays |
735 | | /// |
736 | | /// Note: NULLs with a non-zero slot length in `array` will have the corresponding |
737 | | /// data copied across. This allows handling the null mask separately from the data |
738 | 176 | fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T> |
739 | 176 | where |
740 | 176 | T: ByteArrayType, |
741 | | { |
742 | 176 | let mut filter = FilterBytes::new(predicate.count, array); |
743 | | |
744 | 176 | match &predicate.strategy { |
745 | | IterationStrategy::SlicesIterator => { |
746 | 13 | filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count); |
747 | 13 | filter.extend_slices(SlicesIterator::new(&predicate.filter)) |
748 | | } |
749 | 20 | IterationStrategy::Slices(slices) => { |
750 | 20 | filter.extend_offsets_slices(slices.iter().cloned(), predicate.count); |
751 | 20 | filter.extend_slices(slices.iter().cloned()) |
752 | | } |
753 | | IterationStrategy::IndexIterator => { |
754 | 81 | filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count)); |
755 | 81 | filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count)) |
756 | | } |
757 | 62 | IterationStrategy::Indices(indices) => { |
758 | 62 | filter.extend_offsets_idx(indices.iter().cloned()); |
759 | 62 | filter.extend_idx(indices.iter().cloned()) |
760 | | } |
761 | 0 | IterationStrategy::All | IterationStrategy::None => unreachable!(), |
762 | | } |
763 | | |
764 | 176 | let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) |
765 | 176 | .len(predicate.count) |
766 | 176 | .add_buffer(filter.dst_offsets.into()) |
767 | 176 | .add_buffer(filter.dst_values.into()); |
768 | | |
769 | 176 | if let Some((null_count147 , nulls147 )) = filter_null_mask(array.nulls(), predicate) { |
770 | 147 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
771 | 147 | }29 |
772 | | |
773 | 176 | let data = unsafe { builder.build_unchecked() }; |
774 | 176 | GenericByteArray::from(data) |
775 | 176 | } |
776 | | |
777 | | /// `filter` implementation for byte view arrays. |
778 | 84 | fn filter_byte_view<T: ByteViewType>( |
779 | 84 | array: &GenericByteViewArray<T>, |
780 | 84 | predicate: &FilterPredicate, |
781 | 84 | ) -> GenericByteViewArray<T> { |
782 | 84 | let new_view_buffer = filter_native(array.views(), predicate); |
783 | | |
784 | 84 | let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) |
785 | 84 | .len(predicate.count) |
786 | 84 | .add_buffer(new_view_buffer) |
787 | 84 | .add_buffers(array.data_buffers().to_vec()); |
788 | | |
789 | 84 | if let Some((null_count76 , nulls76 )) = filter_null_mask(array.nulls(), predicate) { |
790 | 76 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
791 | 76 | }8 |
792 | | |
793 | 84 | GenericByteViewArray::from(unsafe { builder.build_unchecked() }) |
794 | 84 | } |
795 | | |
796 | 4 | fn filter_fixed_size_binary( |
797 | 4 | array: &FixedSizeBinaryArray, |
798 | 4 | predicate: &FilterPredicate, |
799 | 4 | ) -> FixedSizeBinaryArray { |
800 | 4 | let values: &[u8] = array.values(); |
801 | 4 | let value_length = array.value_length() as usize; |
802 | 12 | let calculate_offset_from_index4 = |index: usize| index * value_length; |
803 | 4 | let buffer = match &predicate.strategy { |
804 | | IterationStrategy::SlicesIterator => { |
805 | 0 | let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length); |
806 | 0 | for (start, end) in SlicesIterator::new(&predicate.filter) { |
807 | 0 | buffer.extend_from_slice( |
808 | 0 | &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)], |
809 | 0 | ); |
810 | 0 | } |
811 | 0 | buffer |
812 | | } |
813 | 0 | IterationStrategy::Slices(slices) => { |
814 | 0 | let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length); |
815 | 0 | for (start, end) in slices { |
816 | 0 | buffer.extend_from_slice( |
817 | 0 | &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)], |
818 | 0 | ); |
819 | 0 | } |
820 | 0 | buffer |
821 | | } |
822 | | IterationStrategy::IndexIterator => { |
823 | 3 | let iter2 = IndexIterator::new2 (&predicate.filter2 , predicate.count2 ).map2 (|x| { |
824 | 3 | &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)] |
825 | 3 | }); |
826 | | |
827 | 2 | let mut buffer = MutableBuffer::new(predicate.count * value_length); |
828 | 3 | iter2 .for_each2 (|item| buffer.extend_from_slice(item)); |
829 | 2 | buffer |
830 | | } |
831 | 2 | IterationStrategy::Indices(indices) => { |
832 | 3 | let iter2 = indices.iter()2 .map2 (|x| { |
833 | 3 | &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)] |
834 | 3 | }); |
835 | | |
836 | 2 | let mut buffer = MutableBuffer::new(predicate.count * value_length); |
837 | 3 | iter2 .for_each2 (|item| buffer.extend_from_slice(item)); |
838 | 2 | buffer |
839 | | } |
840 | 0 | IterationStrategy::All | IterationStrategy::None => unreachable!(), |
841 | | }; |
842 | 4 | let mut builder = ArrayDataBuilder::new(array.data_type().clone()) |
843 | 4 | .len(predicate.count) |
844 | 4 | .add_buffer(buffer.into()); |
845 | | |
846 | 4 | if let Some((null_count0 , nulls0 )) = filter_null_mask(array.nulls(), predicate) { |
847 | 0 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
848 | 4 | } |
849 | | |
850 | 4 | let data = unsafe { builder.build_unchecked() }; |
851 | 4 | FixedSizeBinaryArray::from(data) |
852 | 4 | } |
853 | | |
854 | | /// `filter` implementation for dictionaries |
855 | 87 | fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T> |
856 | 87 | where |
857 | 87 | T: ArrowDictionaryKeyType, |
858 | 87 | T::Native: num_traits::Num, |
859 | | { |
860 | 87 | let builder = filter_primitive::<T>(array.keys(), predicate) |
861 | 87 | .into_data() |
862 | 87 | .into_builder() |
863 | 87 | .data_type(array.data_type().clone()) |
864 | 87 | .child_data(vec![array.values().to_data()]); |
865 | | |
866 | | // SAFETY: |
867 | | // Keys were valid before, filtered subset is therefore still valid |
868 | 87 | DictionaryArray::from(unsafe { builder.build_unchecked() }) |
869 | 87 | } |
870 | | |
871 | | /// `filter` implementation for structs |
872 | 6 | fn filter_struct( |
873 | 6 | array: &StructArray, |
874 | 6 | predicate: &FilterPredicate, |
875 | 6 | ) -> Result<StructArray, ArrowError> { |
876 | 6 | let columns = array |
877 | 6 | .columns() |
878 | 6 | .iter() |
879 | 8 | .map6 (|column| filter_array(column, predicate)) |
880 | 6 | .collect::<Result<_, _>>()?0 ; |
881 | | |
882 | 6 | let nulls = if let Some((null_count2 , nulls2 )) = filter_null_mask(array.nulls(), predicate) { |
883 | 2 | let buffer = BooleanBuffer::new(nulls, 0, predicate.count); |
884 | | |
885 | 2 | Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) |
886 | | } else { |
887 | 4 | None |
888 | | }; |
889 | | |
890 | 6 | Ok(unsafe { |
891 | 6 | StructArray::new_unchecked_with_length( |
892 | 6 | array.fields().clone(), |
893 | 6 | columns, |
894 | 6 | nulls, |
895 | 6 | predicate.count(), |
896 | 6 | ) |
897 | 6 | }) |
898 | 6 | } |
899 | | |
900 | | /// `filter` implementation for sparse unions |
901 | 4 | fn filter_sparse_union( |
902 | 4 | array: &UnionArray, |
903 | 4 | predicate: &FilterPredicate, |
904 | 4 | ) -> Result<UnionArray, ArrowError> { |
905 | 4 | let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else { |
906 | 0 | unreachable!() |
907 | | }; |
908 | | |
909 | 4 | let type_ids = filter_primitive( |
910 | 4 | &Int8Array::try_new(array.type_ids().clone(), None)?0 , |
911 | 4 | predicate, |
912 | | ); |
913 | | |
914 | 4 | let children = fields |
915 | 4 | .iter() |
916 | 8 | .map4 (|(child_type_id, _)| filter_array(array.child(child_type_id), predicate)) |
917 | 4 | .collect::<Result<_, _>>()?0 ; |
918 | | |
919 | 4 | Ok(unsafe { |
920 | 4 | UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children) |
921 | 4 | }) |
922 | 4 | } |
923 | | |
924 | | /// `filter` implementation for list views |
925 | 4 | fn filter_list_view<OffsetType: OffsetSizeTrait>( |
926 | 4 | array: &GenericListViewArray<OffsetType>, |
927 | 4 | predicate: &FilterPredicate, |
928 | 4 | ) -> GenericListViewArray<OffsetType> { |
929 | 4 | let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate); |
930 | 4 | let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate); |
931 | | |
932 | | // Filter the nulls |
933 | 4 | let nulls = if let Some((null_count0 , nulls0 )) = filter_null_mask(array.nulls(), predicate) { |
934 | 0 | let buffer = BooleanBuffer::new(nulls, 0, predicate.count); |
935 | | |
936 | 0 | Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) |
937 | | } else { |
938 | 4 | None |
939 | | }; |
940 | | |
941 | 4 | let list_data = ArrayDataBuilder::new(array.data_type().clone()) |
942 | 4 | .nulls(nulls) |
943 | 4 | .buffers(vec![filtered_offsets, filtered_sizes]) |
944 | 4 | .child_data(vec![array.values().to_data()]) |
945 | 4 | .len(predicate.count); |
946 | | |
947 | 4 | let list_data = unsafe { list_data.build_unchecked() }; |
948 | | |
949 | 4 | GenericListViewArray::from(list_data) |
950 | 4 | } |
951 | | |
952 | | #[cfg(test)] |
953 | | mod tests { |
954 | | use super::*; |
955 | | use arrow_array::builder::*; |
956 | | use arrow_array::cast::as_run_array; |
957 | | use arrow_array::types::*; |
958 | | use arrow_data::ArrayData; |
959 | | use rand::distr::uniform::{UniformSampler, UniformUsize}; |
960 | | use rand::distr::{Alphanumeric, StandardUniform}; |
961 | | use rand::prelude::*; |
962 | | use rand::rng; |
963 | | |
964 | | macro_rules! def_temporal_test { |
965 | | ($test:ident, $array_type: ident, $data: expr) => { |
966 | | #[test] |
967 | 14 | fn $test() { |
968 | 14 | let a = $data; |
969 | 14 | let b = BooleanArray::from(vec![true, false, true, false]); |
970 | 14 | let c = filter(&a, &b).unwrap(); |
971 | 14 | let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap(); |
972 | 14 | assert_eq!(2, d.len()); |
973 | 14 | assert_eq!(1, d.value(0)); |
974 | 14 | assert_eq!(3, d.value(1)); |
975 | 14 | } |
976 | | }; |
977 | | } |
978 | | |
979 | | def_temporal_test!( |
980 | | test_filter_date32, |
981 | | Date32Array, |
982 | | Date32Array::from(vec![1, 2, 3, 4]) |
983 | | ); |
984 | | def_temporal_test!( |
985 | | test_filter_date64, |
986 | | Date64Array, |
987 | | Date64Array::from(vec![1, 2, 3, 4]) |
988 | | ); |
989 | | def_temporal_test!( |
990 | | test_filter_time32_second, |
991 | | Time32SecondArray, |
992 | | Time32SecondArray::from(vec![1, 2, 3, 4]) |
993 | | ); |
994 | | def_temporal_test!( |
995 | | test_filter_time32_millisecond, |
996 | | Time32MillisecondArray, |
997 | | Time32MillisecondArray::from(vec![1, 2, 3, 4]) |
998 | | ); |
999 | | def_temporal_test!( |
1000 | | test_filter_time64_microsecond, |
1001 | | Time64MicrosecondArray, |
1002 | | Time64MicrosecondArray::from(vec![1, 2, 3, 4]) |
1003 | | ); |
1004 | | def_temporal_test!( |
1005 | | test_filter_time64_nanosecond, |
1006 | | Time64NanosecondArray, |
1007 | | Time64NanosecondArray::from(vec![1, 2, 3, 4]) |
1008 | | ); |
1009 | | def_temporal_test!( |
1010 | | test_filter_duration_second, |
1011 | | DurationSecondArray, |
1012 | | DurationSecondArray::from(vec![1, 2, 3, 4]) |
1013 | | ); |
1014 | | def_temporal_test!( |
1015 | | test_filter_duration_millisecond, |
1016 | | DurationMillisecondArray, |
1017 | | DurationMillisecondArray::from(vec![1, 2, 3, 4]) |
1018 | | ); |
1019 | | def_temporal_test!( |
1020 | | test_filter_duration_microsecond, |
1021 | | DurationMicrosecondArray, |
1022 | | DurationMicrosecondArray::from(vec![1, 2, 3, 4]) |
1023 | | ); |
1024 | | def_temporal_test!( |
1025 | | test_filter_duration_nanosecond, |
1026 | | DurationNanosecondArray, |
1027 | | DurationNanosecondArray::from(vec![1, 2, 3, 4]) |
1028 | | ); |
1029 | | def_temporal_test!( |
1030 | | test_filter_timestamp_second, |
1031 | | TimestampSecondArray, |
1032 | | TimestampSecondArray::from(vec![1, 2, 3, 4]) |
1033 | | ); |
1034 | | def_temporal_test!( |
1035 | | test_filter_timestamp_millisecond, |
1036 | | TimestampMillisecondArray, |
1037 | | TimestampMillisecondArray::from(vec![1, 2, 3, 4]) |
1038 | | ); |
1039 | | def_temporal_test!( |
1040 | | test_filter_timestamp_microsecond, |
1041 | | TimestampMicrosecondArray, |
1042 | | TimestampMicrosecondArray::from(vec![1, 2, 3, 4]) |
1043 | | ); |
1044 | | def_temporal_test!( |
1045 | | test_filter_timestamp_nanosecond, |
1046 | | TimestampNanosecondArray, |
1047 | | TimestampNanosecondArray::from(vec![1, 2, 3, 4]) |
1048 | | ); |
1049 | | |
1050 | | #[test] |
1051 | 1 | fn test_filter_array_slice() { |
1052 | 1 | let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4); |
1053 | 1 | let b = BooleanArray::from(vec![true, false, false, true]); |
1054 | | // filtering with sliced filter array is not currently supported |
1055 | | // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4); |
1056 | | // let b = b_slice.as_any().downcast_ref().unwrap(); |
1057 | 1 | let c = filter(&a, &b).unwrap(); |
1058 | 1 | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
1059 | 1 | assert_eq!(2, d.len()); |
1060 | 1 | assert_eq!(6, d.value(0)); |
1061 | 1 | assert_eq!(9, d.value(1)); |
1062 | 1 | } |
1063 | | |
1064 | | #[test] |
1065 | 1 | fn test_filter_array_low_density() { |
1066 | | // this test exercises the all 0's branch of the filter algorithm |
1067 | 1 | let mut data_values = (1..=65).collect::<Vec<i32>>(); |
1068 | 65 | let mut filter_values1 = (1..=651 ).map1 (|i| matches!(i % 65, 0)).collect1 ::<Vec<bool>>(); |
1069 | | // set up two more values after the batch |
1070 | 1 | data_values.extend_from_slice(&[66, 67]); |
1071 | 1 | filter_values.extend_from_slice(&[false, true]); |
1072 | 1 | let a = Int32Array::from(data_values); |
1073 | 1 | let b = BooleanArray::from(filter_values); |
1074 | 1 | let c = filter(&a, &b).unwrap(); |
1075 | 1 | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
1076 | 1 | assert_eq!(2, d.len()); |
1077 | 1 | assert_eq!(65, d.value(0)); |
1078 | 1 | assert_eq!(67, d.value(1)); |
1079 | 1 | } |
1080 | | |
1081 | | #[test] |
1082 | 1 | fn test_filter_array_high_density() { |
1083 | | // this test exercises the all 1's branch of the filter algorithm |
1084 | 1 | let mut data_values = (1..=65).map(Some).collect::<Vec<_>>(); |
1085 | 1 | let mut filter_values = (1..=65) |
1086 | 65 | .map1 (|i| !matches!64 (i % 65, 0)) |
1087 | 1 | .collect::<Vec<bool>>(); |
1088 | | // set second data value to null |
1089 | 1 | data_values[1] = None; |
1090 | | // set up two more values after the batch |
1091 | 1 | data_values.extend_from_slice(&[Some(66), None, Some(67), None]); |
1092 | 1 | filter_values.extend_from_slice(&[false, true, true, true]); |
1093 | 1 | let a = Int32Array::from(data_values); |
1094 | 1 | let b = BooleanArray::from(filter_values); |
1095 | 1 | let c = filter(&a, &b).unwrap(); |
1096 | 1 | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
1097 | 1 | assert_eq!(67, d.len()); |
1098 | 1 | assert_eq!(3, d.null_count()); |
1099 | 1 | assert_eq!(1, d.value(0)); |
1100 | 1 | assert!(d.is_null(1)); |
1101 | 1 | assert_eq!(64, d.value(63)); |
1102 | 1 | assert!(d.is_null(64)); |
1103 | 1 | assert_eq!(67, d.value(65)); |
1104 | 1 | } |
1105 | | |
1106 | | #[test] |
1107 | 1 | fn test_filter_string_array_simple() { |
1108 | 1 | let a = StringArray::from(vec!["hello", " ", "world", "!"]); |
1109 | 1 | let b = BooleanArray::from(vec![true, false, true, false]); |
1110 | 1 | let c = filter(&a, &b).unwrap(); |
1111 | 1 | let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap(); |
1112 | 1 | assert_eq!(2, d.len()); |
1113 | 1 | assert_eq!("hello", d.value(0)); |
1114 | 1 | assert_eq!("world", d.value(1)); |
1115 | 1 | } |
1116 | | |
1117 | | #[test] |
1118 | 1 | fn test_filter_primitive_array_with_null() { |
1119 | 1 | let a = Int32Array::from(vec![Some(5), None]); |
1120 | 1 | let b = BooleanArray::from(vec![false, true]); |
1121 | 1 | let c = filter(&a, &b).unwrap(); |
1122 | 1 | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
1123 | 1 | assert_eq!(1, d.len()); |
1124 | 1 | assert!(d.is_null(0)); |
1125 | 1 | } |
1126 | | |
1127 | | #[test] |
1128 | 1 | fn test_filter_string_array_with_null() { |
1129 | 1 | let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]); |
1130 | 1 | let b = BooleanArray::from(vec![true, false, false, true]); |
1131 | 1 | let c = filter(&a, &b).unwrap(); |
1132 | 1 | let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap(); |
1133 | 1 | assert_eq!(2, d.len()); |
1134 | 1 | assert_eq!("hello", d.value(0)); |
1135 | 1 | assert!(!d.is_null(0)); |
1136 | 1 | assert!(d.is_null(1)); |
1137 | 1 | } |
1138 | | |
1139 | | #[test] |
1140 | 1 | fn test_filter_binary_array_with_null() { |
1141 | 1 | let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None]; |
1142 | 1 | let a = BinaryArray::from(data); |
1143 | 1 | let b = BooleanArray::from(vec![true, false, false, true]); |
1144 | 1 | let c = filter(&a, &b).unwrap(); |
1145 | 1 | let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap(); |
1146 | 1 | assert_eq!(2, d.len()); |
1147 | 1 | assert_eq!(b"hello", d.value(0)); |
1148 | 1 | assert!(!d.is_null(0)); |
1149 | 1 | assert!(d.is_null(1)); |
1150 | 1 | } |
1151 | | |
1152 | 2 | fn _test_filter_byte_view<T>() |
1153 | 2 | where |
1154 | 2 | T: ByteViewType, |
1155 | 2 | str: AsRef<T::Native>, |
1156 | 2 | T::Native: PartialEq, |
1157 | | { |
1158 | 2 | let array = { |
1159 | | // ["hello", "world", null, "large payload over 12 bytes", "lulu"] |
1160 | 2 | let mut builder = GenericByteViewBuilder::<T>::new(); |
1161 | 2 | builder.append_value("hello"); |
1162 | 2 | builder.append_value("world"); |
1163 | 2 | builder.append_null(); |
1164 | 2 | builder.append_value("large payload over 12 bytes"); |
1165 | 2 | builder.append_value("lulu"); |
1166 | 2 | builder.finish() |
1167 | | }; |
1168 | | |
1169 | | { |
1170 | 2 | let predicate = BooleanArray::from(vec![true, false, true, true, false]); |
1171 | 2 | let actual = filter(&array, &predicate).unwrap(); |
1172 | | |
1173 | 2 | assert_eq!(actual.len(), 3); |
1174 | | |
1175 | 2 | let expected = { |
1176 | | // ["hello", null, "large payload over 12 bytes"] |
1177 | 2 | let mut builder = GenericByteViewBuilder::<T>::new(); |
1178 | 2 | builder.append_value("hello"); |
1179 | 2 | builder.append_null(); |
1180 | 2 | builder.append_value("large payload over 12 bytes"); |
1181 | 2 | builder.finish() |
1182 | | }; |
1183 | | |
1184 | 2 | assert_eq!(actual.as_ref(), &expected); |
1185 | | } |
1186 | | |
1187 | | { |
1188 | 2 | let predicate = BooleanArray::from(vec![true, false, false, false, true]); |
1189 | 2 | let actual = filter(&array, &predicate).unwrap(); |
1190 | | |
1191 | 2 | assert_eq!(actual.len(), 2); |
1192 | | |
1193 | 2 | let expected = { |
1194 | | // ["hello", "lulu"] |
1195 | 2 | let mut builder = GenericByteViewBuilder::<T>::new(); |
1196 | 2 | builder.append_value("hello"); |
1197 | 2 | builder.append_value("lulu"); |
1198 | 2 | builder.finish() |
1199 | | }; |
1200 | | |
1201 | 2 | assert_eq!(actual.as_ref(), &expected); |
1202 | | } |
1203 | 2 | } |
1204 | | |
1205 | | #[test] |
1206 | 1 | fn test_filter_string_view() { |
1207 | 1 | _test_filter_byte_view::<StringViewType>() |
1208 | 1 | } |
1209 | | |
1210 | | #[test] |
1211 | 1 | fn test_filter_binary_view() { |
1212 | 1 | _test_filter_byte_view::<BinaryViewType>() |
1213 | 1 | } |
1214 | | |
1215 | | #[test] |
1216 | 1 | fn test_filter_fixed_binary() { |
1217 | 1 | let v1 = [1_u8, 2]; |
1218 | 1 | let v2 = [3_u8, 4]; |
1219 | 1 | let v3 = [5_u8, 6]; |
1220 | 1 | let v = vec![&v1, &v2, &v3]; |
1221 | 1 | let a = FixedSizeBinaryArray::from(v); |
1222 | 1 | let b = BooleanArray::from(vec![true, false, true]); |
1223 | 1 | let c = filter(&a, &b).unwrap(); |
1224 | 1 | let d = c |
1225 | 1 | .as_ref() |
1226 | 1 | .as_any() |
1227 | 1 | .downcast_ref::<FixedSizeBinaryArray>() |
1228 | 1 | .unwrap(); |
1229 | 1 | assert_eq!(d.len(), 2); |
1230 | 1 | assert_eq!(d.value(0), &v1); |
1231 | 1 | assert_eq!(d.value(1), &v3); |
1232 | 1 | let c2 = FilterBuilder::new(&b) |
1233 | 1 | .optimize() |
1234 | 1 | .build() |
1235 | 1 | .filter(&a) |
1236 | 1 | .unwrap(); |
1237 | 1 | let d2 = c2 |
1238 | 1 | .as_ref() |
1239 | 1 | .as_any() |
1240 | 1 | .downcast_ref::<FixedSizeBinaryArray>() |
1241 | 1 | .unwrap(); |
1242 | 1 | assert_eq!(d, d2); |
1243 | | |
1244 | 1 | let b = BooleanArray::from(vec![false, false, false]); |
1245 | 1 | let c = filter(&a, &b).unwrap(); |
1246 | 1 | let d = c |
1247 | 1 | .as_ref() |
1248 | 1 | .as_any() |
1249 | 1 | .downcast_ref::<FixedSizeBinaryArray>() |
1250 | 1 | .unwrap(); |
1251 | 1 | assert_eq!(d.len(), 0); |
1252 | | |
1253 | 1 | let b = BooleanArray::from(vec![true, true, true]); |
1254 | 1 | let c = filter(&a, &b).unwrap(); |
1255 | 1 | let d = c |
1256 | 1 | .as_ref() |
1257 | 1 | .as_any() |
1258 | 1 | .downcast_ref::<FixedSizeBinaryArray>() |
1259 | 1 | .unwrap(); |
1260 | 1 | assert_eq!(d.len(), 3); |
1261 | 1 | assert_eq!(d.value(0), &v1); |
1262 | 1 | assert_eq!(d.value(1), &v2); |
1263 | 1 | assert_eq!(d.value(2), &v3); |
1264 | | |
1265 | 1 | let b = BooleanArray::from(vec![false, false, true]); |
1266 | 1 | let c = filter(&a, &b).unwrap(); |
1267 | 1 | let d = c |
1268 | 1 | .as_ref() |
1269 | 1 | .as_any() |
1270 | 1 | .downcast_ref::<FixedSizeBinaryArray>() |
1271 | 1 | .unwrap(); |
1272 | 1 | assert_eq!(d.len(), 1); |
1273 | 1 | assert_eq!(d.value(0), &v3); |
1274 | 1 | let c2 = FilterBuilder::new(&b) |
1275 | 1 | .optimize() |
1276 | 1 | .build() |
1277 | 1 | .filter(&a) |
1278 | 1 | .unwrap(); |
1279 | 1 | let d2 = c2 |
1280 | 1 | .as_ref() |
1281 | 1 | .as_any() |
1282 | 1 | .downcast_ref::<FixedSizeBinaryArray>() |
1283 | 1 | .unwrap(); |
1284 | 1 | assert_eq!(d, d2); |
1285 | 1 | } |
1286 | | |
1287 | | #[test] |
1288 | 1 | fn test_filter_array_slice_with_null() { |
1289 | 1 | let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4); |
1290 | 1 | let b = BooleanArray::from(vec![true, false, false, true]); |
1291 | | // filtering with sliced filter array is not currently supported |
1292 | | // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4); |
1293 | | // let b = b_slice.as_any().downcast_ref().unwrap(); |
1294 | 1 | let c = filter(&a, &b).unwrap(); |
1295 | 1 | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
1296 | 1 | assert_eq!(2, d.len()); |
1297 | 1 | assert!(d.is_null(0)); |
1298 | 1 | assert!(!d.is_null(1)); |
1299 | 1 | assert_eq!(9, d.value(1)); |
1300 | 1 | } |
1301 | | |
1302 | | #[test] |
1303 | 1 | fn test_filter_run_end_encoding_array() { |
1304 | 1 | let run_ends = Int64Array::from(vec![2, 3, 8]); |
1305 | 1 | let values = Int64Array::from(vec![7, -2, 9]); |
1306 | 1 | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1307 | 1 | let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]); |
1308 | 1 | let c = filter(&a, &b).unwrap(); |
1309 | 1 | let actual: &RunArray<Int64Type> = as_run_array(&c); |
1310 | 1 | assert_eq!(4, actual.len()); |
1311 | | |
1312 | 1 | let expected = RunArray::try_new( |
1313 | 1 | &Int64Array::from(vec![1, 2, 4]), |
1314 | 1 | &Int64Array::from(vec![7, -2, 9]), |
1315 | | ) |
1316 | 1 | .expect("Failed to make expected RunArray test is broken"); |
1317 | | |
1318 | 1 | assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
1319 | 1 | assert_eq!(actual.values(), expected.values()) |
1320 | 1 | } |
1321 | | |
1322 | | #[test] |
1323 | 1 | fn test_filter_run_end_encoding_array_remove_value() { |
1324 | 1 | let run_ends = Int32Array::from(vec![2, 3, 8, 10]); |
1325 | 1 | let values = Int32Array::from(vec![7, -2, 9, -8]); |
1326 | 1 | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1327 | 1 | let b = BooleanArray::from(vec![ |
1328 | | false, true, false, false, true, false, true, false, false, false, |
1329 | | ]); |
1330 | 1 | let c = filter(&a, &b).unwrap(); |
1331 | 1 | let actual: &RunArray<Int32Type> = as_run_array(&c); |
1332 | 1 | assert_eq!(3, actual.len()); |
1333 | | |
1334 | 1 | let expected = |
1335 | 1 | RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9])) |
1336 | 1 | .expect("Failed to make expected RunArray test is broken"); |
1337 | | |
1338 | 1 | assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
1339 | 1 | assert_eq!(actual.values(), expected.values()) |
1340 | 1 | } |
1341 | | |
1342 | | #[test] |
1343 | 1 | fn test_filter_run_end_encoding_array_remove_all_but_one() { |
1344 | 1 | let run_ends = Int16Array::from(vec![2, 3, 8, 10]); |
1345 | 1 | let values = Int16Array::from(vec![7, -2, 9, -8]); |
1346 | 1 | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1347 | 1 | let b = BooleanArray::from(vec![ |
1348 | | false, false, false, false, false, false, true, false, false, false, |
1349 | | ]); |
1350 | 1 | let c = filter(&a, &b).unwrap(); |
1351 | 1 | let actual: &RunArray<Int16Type> = as_run_array(&c); |
1352 | 1 | assert_eq!(1, actual.len()); |
1353 | | |
1354 | 1 | let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9])) |
1355 | 1 | .expect("Failed to make expected RunArray test is broken"); |
1356 | | |
1357 | 1 | assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
1358 | 1 | assert_eq!(actual.values(), expected.values()) |
1359 | 1 | } |
1360 | | |
1361 | | #[test] |
1362 | 1 | fn test_filter_run_end_encoding_array_empty() { |
1363 | 1 | let run_ends = Int64Array::from(vec![2, 3, 8, 10]); |
1364 | 1 | let values = Int64Array::from(vec![7, -2, 9, -8]); |
1365 | 1 | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1366 | 1 | let b = BooleanArray::from(vec![ |
1367 | | false, false, false, false, false, false, false, false, false, false, |
1368 | | ]); |
1369 | 1 | let c = filter(&a, &b).unwrap(); |
1370 | 1 | let actual: &RunArray<Int64Type> = as_run_array(&c); |
1371 | 1 | assert_eq!(0, actual.len()); |
1372 | 1 | } |
1373 | | |
1374 | | #[test] |
1375 | 1 | fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() { |
1376 | 1 | let run_ends = Int64Array::from(vec![2, 3, 8, 10]); |
1377 | 1 | let values = Int64Array::from(vec![7, -2, 9, -8]); |
1378 | 1 | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1379 | 1 | let b = BooleanArray::from(vec![false, true, true]); |
1380 | 1 | let c = filter(&a, &b).unwrap(); |
1381 | 1 | let actual: &RunArray<Int64Type> = as_run_array(&c); |
1382 | 1 | assert_eq!(2, actual.len()); |
1383 | | |
1384 | 1 | let expected = RunArray::try_new( |
1385 | 1 | &Int64Array::from(vec![1, 2]), |
1386 | 1 | &Int64Array::from(vec![7, -2]), |
1387 | | ) |
1388 | 1 | .expect("Failed to make expected RunArray test is broken"); |
1389 | | |
1390 | 1 | assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
1391 | 1 | assert_eq!(actual.values(), expected.values()) |
1392 | 1 | } |
1393 | | |
1394 | | #[test] |
1395 | 1 | fn test_filter_dictionary_array() { |
1396 | 1 | let values = [Some("hello"), None, Some("world"), Some("!")]; |
1397 | 1 | let a: Int8DictionaryArray = values.iter().copied().collect(); |
1398 | 1 | let b = BooleanArray::from(vec![false, true, true, false]); |
1399 | 1 | let c = filter(&a, &b).unwrap(); |
1400 | 1 | let d = c |
1401 | 1 | .as_ref() |
1402 | 1 | .as_any() |
1403 | 1 | .downcast_ref::<Int8DictionaryArray>() |
1404 | 1 | .unwrap(); |
1405 | 1 | let value_array = d.values(); |
1406 | 1 | let values = value_array.as_any().downcast_ref::<StringArray>().unwrap(); |
1407 | | // values are cloned in the filtered dictionary array |
1408 | 1 | assert_eq!(3, values.len()); |
1409 | | // but keys are filtered |
1410 | 1 | assert_eq!(2, d.len()); |
1411 | 1 | assert!(d.is_null(0)); |
1412 | 1 | assert_eq!("world", values.value(d.keys().value(1) as usize)); |
1413 | 1 | } |
1414 | | |
1415 | | #[test] |
1416 | 1 | fn test_filter_list_array() { |
1417 | 1 | let value_data = ArrayData::builder(DataType::Int32) |
1418 | 1 | .len(8) |
1419 | 1 | .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) |
1420 | 1 | .build() |
1421 | 1 | .unwrap(); |
1422 | | |
1423 | 1 | let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]); |
1424 | | |
1425 | 1 | let list_data_type = |
1426 | 1 | DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false))); |
1427 | 1 | let list_data = ArrayData::builder(list_data_type) |
1428 | 1 | .len(4) |
1429 | 1 | .add_buffer(value_offsets) |
1430 | 1 | .add_child_data(value_data) |
1431 | 1 | .null_bit_buffer(Some(Buffer::from([0b00000111]))) |
1432 | 1 | .build() |
1433 | 1 | .unwrap(); |
1434 | | |
1435 | | // a = [[0, 1, 2], [3, 4, 5], [6, 7], null] |
1436 | 1 | let a = LargeListArray::from(list_data); |
1437 | 1 | let b = BooleanArray::from(vec![false, true, false, true]); |
1438 | 1 | let result = filter(&a, &b).unwrap(); |
1439 | | |
1440 | | // expected: [[3, 4, 5], null] |
1441 | 1 | let value_data = ArrayData::builder(DataType::Int32) |
1442 | 1 | .len(3) |
1443 | 1 | .add_buffer(Buffer::from_slice_ref([3, 4, 5])) |
1444 | 1 | .build() |
1445 | 1 | .unwrap(); |
1446 | | |
1447 | 1 | let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]); |
1448 | | |
1449 | 1 | let list_data_type = |
1450 | 1 | DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false))); |
1451 | 1 | let expected = ArrayData::builder(list_data_type) |
1452 | 1 | .len(2) |
1453 | 1 | .add_buffer(value_offsets) |
1454 | 1 | .add_child_data(value_data) |
1455 | 1 | .null_bit_buffer(Some(Buffer::from([0b00000001]))) |
1456 | 1 | .build() |
1457 | 1 | .unwrap(); |
1458 | | |
1459 | 1 | assert_eq!(&make_array(expected), &result); |
1460 | 1 | } |
1461 | | |
1462 | 2 | fn test_case_filter_list_view<T: OffsetSizeTrait>() { |
1463 | | // [[1, 2], null, [], [3,4]] |
1464 | 2 | let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new()); |
1465 | 2 | list_array.append_value([Some(1), Some(2)]); |
1466 | 2 | list_array.append_null(); |
1467 | 2 | list_array.append_value([]); |
1468 | 2 | list_array.append_value([Some(3), Some(4)]); |
1469 | | |
1470 | 2 | let list_array = list_array.finish(); |
1471 | 2 | let predicate = BooleanArray::from_iter([true, false, true, false]); |
1472 | | |
1473 | | // Filter result: [[1, 2], []] |
1474 | 2 | let filtered = filter(&list_array, &predicate) |
1475 | 2 | .unwrap() |
1476 | 2 | .as_list_view::<T>() |
1477 | 2 | .clone(); |
1478 | | |
1479 | 2 | let mut expected = |
1480 | 2 | GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3); |
1481 | 2 | expected.append_value([Some(1), Some(2)]); |
1482 | 2 | expected.append_value([]); |
1483 | 2 | let expected = expected.finish(); |
1484 | | |
1485 | 2 | assert_eq!(&filtered, &expected); |
1486 | 2 | } |
1487 | | |
1488 | 2 | fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() { |
1489 | | // [[1, 2], null, [], [3,4]] |
1490 | 2 | let mut list_array = |
1491 | 2 | GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4); |
1492 | 2 | list_array.append_value([Some(1), Some(2)]); |
1493 | 2 | list_array.append_null(); |
1494 | 2 | list_array.append_value([]); |
1495 | 2 | list_array.append_value([Some(3), Some(4)]); |
1496 | | |
1497 | 2 | let list_array = list_array.finish(); |
1498 | | |
1499 | | // Sliced: [null, [], [3, 4]] |
1500 | 2 | let sliced = list_array.slice(1, 3); |
1501 | 2 | let predicate = BooleanArray::from_iter([false, false, true]); |
1502 | | |
1503 | | // Filter result: [[1, 2], []] |
1504 | 2 | let filtered = filter(&sliced, &predicate) |
1505 | 2 | .unwrap() |
1506 | 2 | .as_list_view::<T>() |
1507 | 2 | .clone(); |
1508 | | |
1509 | 2 | let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new()); |
1510 | 2 | expected.append_value([Some(3), Some(4)]); |
1511 | 2 | let expected = expected.finish(); |
1512 | | |
1513 | 2 | assert_eq!(&filtered, &expected); |
1514 | 2 | } |
1515 | | |
1516 | | #[test] |
1517 | 1 | fn test_filter_list_view_array() { |
1518 | 1 | test_case_filter_list_view::<i32>(); |
1519 | 1 | test_case_filter_list_view::<i64>(); |
1520 | | |
1521 | 1 | test_case_filter_sliced_list_view::<i32>(); |
1522 | 1 | test_case_filter_sliced_list_view::<i64>(); |
1523 | 1 | } |
1524 | | |
1525 | | #[test] |
1526 | 1 | fn test_slice_iterator_bits() { |
1527 | 64 | let filter_values1 = (0..64)1 .map1 (|i| i == 1).collect1 ::<Vec<bool>>(); |
1528 | 1 | let filter = BooleanArray::from(filter_values); |
1529 | 1 | let filter_count = filter_count(&filter); |
1530 | | |
1531 | 1 | let iter = SlicesIterator::new(&filter); |
1532 | 1 | let chunks = iter.collect::<Vec<_>>(); |
1533 | | |
1534 | 1 | assert_eq!(chunks, vec![(1, 2)]); |
1535 | 1 | assert_eq!(filter_count, 1); |
1536 | 1 | } |
1537 | | |
1538 | | #[test] |
1539 | 1 | fn test_slice_iterator_bits1() { |
1540 | 64 | let filter_values1 = (0..64)1 .map1 (|i| i != 1).collect1 ::<Vec<bool>>(); |
1541 | 1 | let filter = BooleanArray::from(filter_values); |
1542 | 1 | let filter_count = filter_count(&filter); |
1543 | | |
1544 | 1 | let iter = SlicesIterator::new(&filter); |
1545 | 1 | let chunks = iter.collect::<Vec<_>>(); |
1546 | | |
1547 | 1 | assert_eq!(chunks, vec![(0, 1), (2, 64)]); |
1548 | 1 | assert_eq!(filter_count, 64 - 1); |
1549 | 1 | } |
1550 | | |
1551 | | #[test] |
1552 | 1 | fn test_slice_iterator_chunk_and_bits() { |
1553 | 130 | let filter_values1 = (0..130)1 .map1 (|i| i % 62 != 0).collect1 ::<Vec<bool>>(); |
1554 | 1 | let filter = BooleanArray::from(filter_values); |
1555 | 1 | let filter_count = filter_count(&filter); |
1556 | | |
1557 | 1 | let iter = SlicesIterator::new(&filter); |
1558 | 1 | let chunks = iter.collect::<Vec<_>>(); |
1559 | | |
1560 | 1 | assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]); |
1561 | 1 | assert_eq!(filter_count, 61 + 61 + 5); |
1562 | 1 | } |
1563 | | |
1564 | | #[test] |
1565 | 1 | fn test_null_mask() { |
1566 | 1 | let a = Int64Array::from(vec![Some(1), Some(2), None]); |
1567 | | |
1568 | 1 | let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]); |
1569 | 1 | let out = filter(&a, &mask1).unwrap(); |
1570 | 1 | assert_eq!(out.as_ref(), &a.slice(0, 2)); |
1571 | 1 | } |
1572 | | |
1573 | | #[test] |
1574 | 1 | fn test_filter_record_batch_no_columns() { |
1575 | 1 | let pred = BooleanArray::from(vec![Some(true), Some(true), None]); |
1576 | 1 | let options = RecordBatchOptions::default().with_row_count(Some(100)); |
1577 | 1 | let record_batch = |
1578 | 1 | RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap(); |
1579 | 1 | let out = filter_record_batch(&record_batch, &pred).unwrap(); |
1580 | | |
1581 | 1 | assert_eq!(out.num_rows(), 2); |
1582 | 1 | } |
1583 | | |
1584 | | #[test] |
1585 | 1 | fn test_fast_path() { |
1586 | 1 | let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]); |
1587 | | |
1588 | | // all true |
1589 | 1 | let mask = BooleanArray::from(vec![true, true, true]); |
1590 | 1 | let out = filter(&a, &mask).unwrap(); |
1591 | 1 | let b = out |
1592 | 1 | .as_any() |
1593 | 1 | .downcast_ref::<PrimitiveArray<Int64Type>>() |
1594 | 1 | .unwrap(); |
1595 | 1 | assert_eq!(&a, b); |
1596 | | |
1597 | | // all false |
1598 | 1 | let mask = BooleanArray::from(vec![false, false, false]); |
1599 | 1 | let out = filter(&a, &mask).unwrap(); |
1600 | 1 | assert_eq!(out.len(), 0); |
1601 | 1 | assert_eq!(out.data_type(), &DataType::Int64); |
1602 | 1 | } |
1603 | | |
1604 | | #[test] |
1605 | 1 | fn test_slices() { |
1606 | | // takes up 2 u64s |
1607 | 1 | let bools = std::iter::repeat_n(true, 10) |
1608 | 1 | .chain(std::iter::repeat_n(false, 30)) |
1609 | 1 | .chain(std::iter::repeat_n(true, 20)) |
1610 | 1 | .chain(std::iter::repeat_n(false, 17)) |
1611 | 1 | .chain(std::iter::repeat_n(true, 4)); |
1612 | | |
1613 | 1 | let bool_array: BooleanArray = bools.map(Some).collect(); |
1614 | | |
1615 | 1 | let slices: Vec<_> = SlicesIterator::new(&bool_array).collect(); |
1616 | 1 | let expected = vec![(0, 10), (40, 60), (77, 81)]; |
1617 | 1 | assert_eq!(slices, expected); |
1618 | | |
1619 | | // slice with offset and truncated len |
1620 | 1 | let len = bool_array.len(); |
1621 | 1 | let sliced_array = bool_array.slice(7, len - 10); |
1622 | 1 | let sliced_array = sliced_array |
1623 | 1 | .as_any() |
1624 | 1 | .downcast_ref::<BooleanArray>() |
1625 | 1 | .unwrap(); |
1626 | 1 | let slices: Vec<_> = SlicesIterator::new(sliced_array).collect(); |
1627 | 1 | let expected = vec![(0, 3), (33, 53), (70, 71)]; |
1628 | 1 | assert_eq!(slices, expected); |
1629 | 1 | } |
1630 | | |
1631 | 105 | fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) { |
1632 | 105 | let mut rng = rng(); |
1633 | | |
1634 | 54.6k | let bools105 : Vec<bool>105 = std::iter::from_fn105 (|| Some(rng.random())) |
1635 | 105 | .take(mask_len) |
1636 | 105 | .collect(); |
1637 | | |
1638 | 105 | let buffer = Buffer::from_iter(bools.iter().cloned()); |
1639 | | |
1640 | 105 | let truncated_length = mask_len - offset - truncate; |
1641 | | |
1642 | 105 | let data = ArrayDataBuilder::new(DataType::Boolean) |
1643 | 105 | .len(truncated_length) |
1644 | 105 | .offset(offset) |
1645 | 105 | .add_buffer(buffer) |
1646 | 105 | .build() |
1647 | 105 | .unwrap(); |
1648 | | |
1649 | 105 | let filter = BooleanArray::from(data); |
1650 | | |
1651 | 105 | let slice_bits: Vec<_> = SlicesIterator::new(&filter) |
1652 | 11.3k | .flat_map105 (|(start, end)| start..end) |
1653 | 105 | .collect(); |
1654 | | |
1655 | 105 | let count = filter_count(&filter); |
1656 | 105 | let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect(); |
1657 | | |
1658 | 105 | let expected_bits: Vec<_> = bools |
1659 | 105 | .iter() |
1660 | 105 | .skip(offset) |
1661 | 105 | .take(truncated_length) |
1662 | 105 | .enumerate() |
1663 | 45.4k | .flat_map105 (|(idx, v)| v.then(|| idx)) |
1664 | 105 | .collect(); |
1665 | | |
1666 | 105 | assert_eq!(slice_bits, expected_bits); |
1667 | 105 | assert_eq!(index_bits, expected_bits); |
1668 | 105 | } |
1669 | | |
1670 | | #[test] |
1671 | | #[cfg_attr(miri, ignore)] |
1672 | 1 | fn fuzz_test_slices_iterator() { |
1673 | 1 | let mut rng = rng(); |
1674 | | |
1675 | 1 | let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap(); |
1676 | 101 | for _ in 0..100 { |
1677 | 100 | let mask_len = rng.random_range(0..1024); |
1678 | 100 | let max_offset = 64.min(mask_len); |
1679 | 100 | let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0); |
1680 | 100 | |
1681 | 100 | let max_truncate = 128.min(mask_len - offset); |
1682 | 100 | let truncate = uusize |
1683 | 100 | .sample(&mut rng) |
1684 | 100 | .checked_rem(max_truncate) |
1685 | 100 | .unwrap_or(0); |
1686 | 100 | |
1687 | 100 | test_slices_fuzz(mask_len, offset, truncate); |
1688 | 100 | } |
1689 | | |
1690 | 1 | test_slices_fuzz(64, 0, 0); |
1691 | 1 | test_slices_fuzz(64, 8, 0); |
1692 | 1 | test_slices_fuzz(64, 8, 8); |
1693 | 1 | test_slices_fuzz(32, 8, 8); |
1694 | 1 | test_slices_fuzz(32, 5, 9); |
1695 | 1 | } |
1696 | | |
1697 | | /// Filters `values` by `predicate` using standard rust iterators |
1698 | 200 | fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> { |
1699 | 200 | values |
1700 | 200 | .into_iter() |
1701 | 200 | .zip(predicate) |
1702 | 200 | .filter(|(_, x)| **x) |
1703 | 200 | .map(|(a, _)| a) |
1704 | 200 | .collect() |
1705 | 200 | } |
1706 | | |
1707 | | /// Generates an array of length `len` with `valid_percent` non-null values |
1708 | 100 | fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>> |
1709 | 100 | where |
1710 | 100 | StandardUniform: Distribution<T>, |
1711 | | { |
1712 | 100 | let mut rng = rng(); |
1713 | 100 | (0..len) |
1714 | 15.1k | .map100 (|_| rng.random_bool(valid_percent).then(|| rng8.23k .random8.23k ())) |
1715 | 100 | .collect() |
1716 | 100 | } |
1717 | | |
1718 | | /// Generates an array of length `len` with `valid_percent` non-null values |
1719 | 100 | fn gen_strings( |
1720 | 100 | len: usize, |
1721 | 100 | valid_percent: f64, |
1722 | 100 | str_len_range: std::ops::Range<usize>, |
1723 | 100 | ) -> Vec<Option<String>> { |
1724 | 100 | let mut rng = rng(); |
1725 | 100 | (0..len) |
1726 | 15.1k | .map100 (|_| { |
1727 | 15.1k | rng.random_bool(valid_percent).then(|| {8.15k |
1728 | 8.15k | let len = rng.random_range(str_len_range.clone()); |
1729 | 8.15k | (0..len) |
1730 | 77.8k | .map8.15k (|_| char::from(rng.sample(Alphanumeric))) |
1731 | 8.15k | .collect() |
1732 | 8.15k | }) |
1733 | 15.1k | }) |
1734 | 100 | .collect() |
1735 | 100 | } |
1736 | | |
1737 | | /// Returns an iterator that calls `Option::as_deref` on each item |
1738 | 300 | fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> { |
1739 | 44.5k | src300 .iter300 ().map300 (|x| x.as_deref()) |
1740 | 300 | } |
1741 | | |
1742 | | #[test] |
1743 | | #[cfg_attr(miri, ignore)] |
1744 | 1 | fn fuzz_filter() { |
1745 | 1 | let mut rng = rng(); |
1746 | | |
1747 | 101 | for i100 in 0..100 { |
1748 | 100 | let filter_percent = match i { |
1749 | 100 | 0..=4 => 1.5 , |
1750 | 95 | 5..=10 => 0.6 , |
1751 | 89 | _ => rng.random_range(0.0..1.0), |
1752 | | }; |
1753 | | |
1754 | 100 | let valid_percent = rng.random_range(0.0..1.0); |
1755 | | |
1756 | 100 | let array_len = rng.random_range(32..256); |
1757 | 100 | let array_offset = rng.random_range(0..10); |
1758 | | |
1759 | | // Construct a predicate |
1760 | 100 | let filter_offset = rng.random_range(0..10); |
1761 | 100 | let filter_truncate = rng.random_range(0..10); |
1762 | 14.7k | let bools100 : Vec<_>100 = std::iter::from_fn100 (|| Some(rng.random_bool(filter_percent))) |
1763 | 100 | .take(array_len + filter_offset - filter_truncate) |
1764 | 100 | .collect(); |
1765 | | |
1766 | 100 | let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some)); |
1767 | | |
1768 | | // Offset predicate |
1769 | 100 | let predicate = predicate.slice(filter_offset, array_len - filter_truncate); |
1770 | 100 | let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap(); |
1771 | 100 | let bools = &bools[filter_offset..]; |
1772 | | |
1773 | | // Test i32 |
1774 | 100 | let values = gen_primitive(array_len + array_offset, valid_percent); |
1775 | 100 | let src = Int32Array::from_iter(values.iter().cloned()); |
1776 | | |
1777 | 100 | let src = src.slice(array_offset, array_len); |
1778 | 100 | let src = src.as_any().downcast_ref::<Int32Array>().unwrap(); |
1779 | 100 | let values = &values[array_offset..]; |
1780 | | |
1781 | 100 | let filtered = filter(src, predicate).unwrap(); |
1782 | 100 | let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap(); |
1783 | 100 | let actual: Vec<_> = array.iter().collect(); |
1784 | | |
1785 | 100 | assert_eq!(actual, filter_rust(values.iter().cloned(), bools)); |
1786 | | |
1787 | | // Test string |
1788 | 100 | let strings = gen_strings(array_len + array_offset, valid_percent, 0..20); |
1789 | 100 | let src = StringArray::from_iter(as_deref(&strings)); |
1790 | | |
1791 | 100 | let src = src.slice(array_offset, array_len); |
1792 | 100 | let src = src.as_any().downcast_ref::<StringArray>().unwrap(); |
1793 | | |
1794 | 100 | let filtered = filter(src, predicate).unwrap(); |
1795 | 100 | let array = filtered.as_any().downcast_ref::<StringArray>().unwrap(); |
1796 | 100 | let actual: Vec<_> = array.iter().collect(); |
1797 | | |
1798 | 100 | let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools); |
1799 | 100 | assert_eq!(actual, expected_strings); |
1800 | | |
1801 | | // Test string dictionary |
1802 | 100 | let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings)); |
1803 | | |
1804 | 100 | let src = src.slice(array_offset, array_len); |
1805 | 100 | let src = src |
1806 | 100 | .as_any() |
1807 | 100 | .downcast_ref::<DictionaryArray<Int32Type>>() |
1808 | 100 | .unwrap(); |
1809 | | |
1810 | 100 | let filtered = filter(src, predicate).unwrap(); |
1811 | | |
1812 | 100 | let array = filtered |
1813 | 100 | .as_any() |
1814 | 100 | .downcast_ref::<DictionaryArray<Int32Type>>() |
1815 | 100 | .unwrap(); |
1816 | | |
1817 | 100 | let values = array |
1818 | 100 | .values() |
1819 | 100 | .as_any() |
1820 | 100 | .downcast_ref::<StringArray>() |
1821 | 100 | .unwrap(); |
1822 | | |
1823 | 100 | let actual: Vec<_> = array |
1824 | 100 | .keys() |
1825 | 100 | .iter() |
1826 | 6.64k | .map100 (|key| key.map(|key| values3.44k .value3.44k (key as usize3.44k ))) |
1827 | 100 | .collect(); |
1828 | | |
1829 | 100 | assert_eq!(actual, expected_strings); |
1830 | | } |
1831 | 1 | } |
1832 | | |
1833 | | #[test] |
1834 | 1 | fn test_filter_map() { |
1835 | 1 | let mut builder = |
1836 | 1 | MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4)); |
1837 | | // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1} |
1838 | 1 | builder.keys().append_value("key1"); |
1839 | 1 | builder.values().append_value(1); |
1840 | 1 | builder.append(true).unwrap(); |
1841 | 1 | builder.keys().append_value("key2"); |
1842 | 1 | builder.keys().append_value("key3"); |
1843 | 1 | builder.values().append_value(2); |
1844 | 1 | builder.values().append_value(3); |
1845 | 1 | builder.append(true).unwrap(); |
1846 | 1 | builder.append(false).unwrap(); |
1847 | 1 | builder.keys().append_value("key1"); |
1848 | 1 | builder.values().append_value(1); |
1849 | 1 | builder.append(true).unwrap(); |
1850 | 1 | let maparray = Arc::new(builder.finish()) as ArrayRef; |
1851 | | |
1852 | 1 | let indices = vec![Some(true), Some(false), Some(false), Some(true)] |
1853 | 1 | .into_iter() |
1854 | 1 | .collect::<BooleanArray>(); |
1855 | 1 | let got = filter(&maparray, &indices).unwrap(); |
1856 | | |
1857 | 1 | let mut builder = |
1858 | 1 | MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2)); |
1859 | 1 | builder.keys().append_value("key1"); |
1860 | 1 | builder.values().append_value(1); |
1861 | 1 | builder.append(true).unwrap(); |
1862 | 1 | builder.keys().append_value("key1"); |
1863 | 1 | builder.values().append_value(1); |
1864 | 1 | builder.append(true).unwrap(); |
1865 | 1 | let expected = Arc::new(builder.finish()) as ArrayRef; |
1866 | | |
1867 | 1 | assert_eq!(&expected, &got); |
1868 | 1 | } |
1869 | | |
1870 | | #[test] |
1871 | 1 | fn test_filter_fixed_size_list_arrays() { |
1872 | 1 | let value_data = ArrayData::builder(DataType::Int32) |
1873 | 1 | .len(9) |
1874 | 1 | .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8])) |
1875 | 1 | .build() |
1876 | 1 | .unwrap(); |
1877 | 1 | let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false); |
1878 | 1 | let list_data = ArrayData::builder(list_data_type) |
1879 | 1 | .len(3) |
1880 | 1 | .add_child_data(value_data) |
1881 | 1 | .build() |
1882 | 1 | .unwrap(); |
1883 | 1 | let array = FixedSizeListArray::from(list_data); |
1884 | | |
1885 | 1 | let filter_array = BooleanArray::from(vec![true, false, false]); |
1886 | | |
1887 | 1 | let c = filter(&array, &filter_array).unwrap(); |
1888 | 1 | let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap(); |
1889 | | |
1890 | 1 | assert_eq!(filtered.len(), 1); |
1891 | | |
1892 | 1 | let list = filtered.value(0); |
1893 | 1 | assert_eq!( |
1894 | | &[0, 1, 2], |
1895 | 1 | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1896 | | ); |
1897 | | |
1898 | 1 | let filter_array = BooleanArray::from(vec![true, false, true]); |
1899 | | |
1900 | 1 | let c = filter(&array, &filter_array).unwrap(); |
1901 | 1 | let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap(); |
1902 | | |
1903 | 1 | assert_eq!(filtered.len(), 2); |
1904 | | |
1905 | 1 | let list = filtered.value(0); |
1906 | 1 | assert_eq!( |
1907 | | &[0, 1, 2], |
1908 | 1 | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1909 | | ); |
1910 | 1 | let list = filtered.value(1); |
1911 | 1 | assert_eq!( |
1912 | | &[6, 7, 8], |
1913 | 1 | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1914 | | ); |
1915 | 1 | } |
1916 | | |
1917 | | #[test] |
1918 | 1 | fn test_filter_fixed_size_list_arrays_with_null() { |
1919 | 1 | let value_data = ArrayData::builder(DataType::Int32) |
1920 | 1 | .len(10) |
1921 | 1 | .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) |
1922 | 1 | .build() |
1923 | 1 | .unwrap(); |
1924 | | |
1925 | | // Set null buts for the nested array: |
1926 | | // [[0, 1], null, null, [6, 7], [8, 9]] |
1927 | | // 01011001 00000001 |
1928 | 1 | let mut null_bits: [u8; 1] = [0; 1]; |
1929 | 1 | bit_util::set_bit(&mut null_bits, 0); |
1930 | 1 | bit_util::set_bit(&mut null_bits, 3); |
1931 | 1 | bit_util::set_bit(&mut null_bits, 4); |
1932 | | |
1933 | 1 | let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false); |
1934 | 1 | let list_data = ArrayData::builder(list_data_type) |
1935 | 1 | .len(5) |
1936 | 1 | .add_child_data(value_data) |
1937 | 1 | .null_bit_buffer(Some(Buffer::from(null_bits))) |
1938 | 1 | .build() |
1939 | 1 | .unwrap(); |
1940 | 1 | let array = FixedSizeListArray::from(list_data); |
1941 | | |
1942 | 1 | let filter_array = BooleanArray::from(vec![true, true, false, true, false]); |
1943 | | |
1944 | 1 | let c = filter(&array, &filter_array).unwrap(); |
1945 | 1 | let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap(); |
1946 | | |
1947 | 1 | assert_eq!(filtered.len(), 3); |
1948 | | |
1949 | 1 | let list = filtered.value(0); |
1950 | 1 | assert_eq!( |
1951 | | &[0, 1], |
1952 | 1 | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1953 | | ); |
1954 | 1 | assert!(filtered.is_null(1)); |
1955 | 1 | let list = filtered.value(2); |
1956 | 1 | assert_eq!( |
1957 | | &[6, 7], |
1958 | 1 | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1959 | | ); |
1960 | 1 | } |
1961 | | |
1962 | 2 | fn test_filter_union_array(array: UnionArray) { |
1963 | 2 | let filter_array = BooleanArray::from(vec![true, false, false]); |
1964 | 2 | let c = filter(&array, &filter_array).unwrap(); |
1965 | 2 | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1966 | | |
1967 | 2 | let mut builder = UnionBuilder::new_dense(); |
1968 | 2 | builder.append::<Int32Type>("A", 1).unwrap(); |
1969 | 2 | let expected_array = builder.build().unwrap(); |
1970 | | |
1971 | 2 | compare_union_arrays(filtered, &expected_array); |
1972 | | |
1973 | 2 | let filter_array = BooleanArray::from(vec![true, false, true]); |
1974 | 2 | let c = filter(&array, &filter_array).unwrap(); |
1975 | 2 | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1976 | | |
1977 | 2 | let mut builder = UnionBuilder::new_dense(); |
1978 | 2 | builder.append::<Int32Type>("A", 1).unwrap(); |
1979 | 2 | builder.append::<Int32Type>("A", 34).unwrap(); |
1980 | 2 | let expected_array = builder.build().unwrap(); |
1981 | | |
1982 | 2 | compare_union_arrays(filtered, &expected_array); |
1983 | | |
1984 | 2 | let filter_array = BooleanArray::from(vec![true, true, false]); |
1985 | 2 | let c = filter(&array, &filter_array).unwrap(); |
1986 | 2 | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1987 | | |
1988 | 2 | let mut builder = UnionBuilder::new_dense(); |
1989 | 2 | builder.append::<Int32Type>("A", 1).unwrap(); |
1990 | 2 | builder.append::<Float64Type>("B", 3.2).unwrap(); |
1991 | 2 | let expected_array = builder.build().unwrap(); |
1992 | | |
1993 | 2 | compare_union_arrays(filtered, &expected_array); |
1994 | 2 | } |
1995 | | |
1996 | | #[test] |
1997 | 1 | fn test_filter_union_array_dense() { |
1998 | 1 | let mut builder = UnionBuilder::new_dense(); |
1999 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2000 | 1 | builder.append::<Float64Type>("B", 3.2).unwrap(); |
2001 | 1 | builder.append::<Int32Type>("A", 34).unwrap(); |
2002 | 1 | let array = builder.build().unwrap(); |
2003 | | |
2004 | 1 | test_filter_union_array(array); |
2005 | 1 | } |
2006 | | |
2007 | | #[test] |
2008 | 1 | fn test_filter_run_union_array_dense() { |
2009 | 1 | let mut builder = UnionBuilder::new_dense(); |
2010 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2011 | 1 | builder.append::<Int32Type>("A", 3).unwrap(); |
2012 | 1 | builder.append::<Int32Type>("A", 34).unwrap(); |
2013 | 1 | let array = builder.build().unwrap(); |
2014 | | |
2015 | 1 | let filter_array = BooleanArray::from(vec![true, true, false]); |
2016 | 1 | let c = filter(&array, &filter_array).unwrap(); |
2017 | 1 | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
2018 | | |
2019 | 1 | let mut builder = UnionBuilder::new_dense(); |
2020 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2021 | 1 | builder.append::<Int32Type>("A", 3).unwrap(); |
2022 | 1 | let expected = builder.build().unwrap(); |
2023 | | |
2024 | 1 | assert_eq!(filtered.to_data(), expected.to_data()); |
2025 | 1 | } |
2026 | | |
2027 | | #[test] |
2028 | 1 | fn test_filter_union_array_dense_with_nulls() { |
2029 | 1 | let mut builder = UnionBuilder::new_dense(); |
2030 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2031 | 1 | builder.append::<Float64Type>("B", 3.2).unwrap(); |
2032 | 1 | builder.append_null::<Float64Type>("B").unwrap(); |
2033 | 1 | builder.append::<Int32Type>("A", 34).unwrap(); |
2034 | 1 | let array = builder.build().unwrap(); |
2035 | | |
2036 | 1 | let filter_array = BooleanArray::from(vec![true, true, false, false]); |
2037 | 1 | let c = filter(&array, &filter_array).unwrap(); |
2038 | 1 | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
2039 | | |
2040 | 1 | let mut builder = UnionBuilder::new_dense(); |
2041 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2042 | 1 | builder.append::<Float64Type>("B", 3.2).unwrap(); |
2043 | 1 | let expected_array = builder.build().unwrap(); |
2044 | | |
2045 | 1 | compare_union_arrays(filtered, &expected_array); |
2046 | | |
2047 | 1 | let filter_array = BooleanArray::from(vec![true, false, true, false]); |
2048 | 1 | let c = filter(&array, &filter_array).unwrap(); |
2049 | 1 | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
2050 | | |
2051 | 1 | let mut builder = UnionBuilder::new_dense(); |
2052 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2053 | 1 | builder.append_null::<Float64Type>("B").unwrap(); |
2054 | 1 | let expected_array = builder.build().unwrap(); |
2055 | | |
2056 | 1 | compare_union_arrays(filtered, &expected_array); |
2057 | 1 | } |
2058 | | |
2059 | | #[test] |
2060 | 1 | fn test_filter_union_array_sparse() { |
2061 | 1 | let mut builder = UnionBuilder::new_sparse(); |
2062 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2063 | 1 | builder.append::<Float64Type>("B", 3.2).unwrap(); |
2064 | 1 | builder.append::<Int32Type>("A", 34).unwrap(); |
2065 | 1 | let array = builder.build().unwrap(); |
2066 | | |
2067 | 1 | test_filter_union_array(array); |
2068 | 1 | } |
2069 | | |
2070 | | #[test] |
2071 | 1 | fn test_filter_union_array_sparse_with_nulls() { |
2072 | 1 | let mut builder = UnionBuilder::new_sparse(); |
2073 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2074 | 1 | builder.append::<Float64Type>("B", 3.2).unwrap(); |
2075 | 1 | builder.append_null::<Float64Type>("B").unwrap(); |
2076 | 1 | builder.append::<Int32Type>("A", 34).unwrap(); |
2077 | 1 | let array = builder.build().unwrap(); |
2078 | | |
2079 | 1 | let filter_array = BooleanArray::from(vec![true, false, true, false]); |
2080 | 1 | let c = filter(&array, &filter_array).unwrap(); |
2081 | 1 | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
2082 | | |
2083 | 1 | let mut builder = UnionBuilder::new_sparse(); |
2084 | 1 | builder.append::<Int32Type>("A", 1).unwrap(); |
2085 | 1 | builder.append_null::<Float64Type>("B").unwrap(); |
2086 | 1 | let expected_array = builder.build().unwrap(); |
2087 | | |
2088 | 1 | compare_union_arrays(filtered, &expected_array); |
2089 | 1 | } |
2090 | | |
2091 | 9 | fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) { |
2092 | 9 | assert_eq!(union1.len(), union2.len()); |
2093 | | |
2094 | 16 | for i in 0..union19 .len9 () { |
2095 | 16 | let type_id = union1.type_id(i); |
2096 | | |
2097 | 16 | let slot1 = union1.value(i); |
2098 | 16 | let slot2 = union2.value(i); |
2099 | | |
2100 | 16 | assert_eq!(slot1.is_null(0), slot2.is_null(0)); |
2101 | | |
2102 | 16 | if !slot1.is_null(0) && !slot2.is_null(0)14 { |
2103 | 14 | match type_id { |
2104 | | 0 => { |
2105 | 11 | let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap(); |
2106 | 11 | assert_eq!(slot1.len(), 1); |
2107 | 11 | let value1 = slot1.value(0); |
2108 | | |
2109 | 11 | let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap(); |
2110 | 11 | assert_eq!(slot2.len(), 1); |
2111 | 11 | let value2 = slot2.value(0); |
2112 | 11 | assert_eq!(value1, value2); |
2113 | | } |
2114 | | 1 => { |
2115 | 3 | let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap(); |
2116 | 3 | assert_eq!(slot1.len(), 1); |
2117 | 3 | let value1 = slot1.value(0); |
2118 | | |
2119 | 3 | let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap(); |
2120 | 3 | assert_eq!(slot2.len(), 1); |
2121 | 3 | let value2 = slot2.value(0); |
2122 | 3 | assert_eq!(value1, value2); |
2123 | | } |
2124 | 0 | _ => unreachable!(), |
2125 | | } |
2126 | 2 | } |
2127 | | } |
2128 | 9 | } |
2129 | | |
2130 | | #[test] |
2131 | 1 | fn test_filter_struct() { |
2132 | 1 | let predicate = BooleanArray::from(vec![true, false, true, false]); |
2133 | | |
2134 | 1 | let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"])); |
2135 | 1 | let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"])); |
2136 | | |
2137 | 1 | let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); |
2138 | 1 | let b_filtered = Arc::new(Int32Array::from(vec![5, 7])); |
2139 | | |
2140 | 1 | let null_mask = NullBuffer::from(vec![true, false, false, true]); |
2141 | 1 | let null_mask_filtered = NullBuffer::from(vec![true, false]); |
2142 | | |
2143 | 1 | let a_field = Field::new("a", DataType::Utf8, false); |
2144 | 1 | let b_field = Field::new("b", DataType::Int32, false); |
2145 | | |
2146 | 1 | let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None); |
2147 | 1 | let expected = |
2148 | 1 | StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None); |
2149 | | |
2150 | 1 | let result = filter(&array, &predicate).unwrap(); |
2151 | | |
2152 | 1 | assert_eq!(result.to_data(), expected.to_data()); |
2153 | | |
2154 | 1 | let array = StructArray::new( |
2155 | 1 | vec![a_field.clone()].into(), |
2156 | 1 | vec![a.clone()], |
2157 | 1 | Some(null_mask.clone()), |
2158 | | ); |
2159 | 1 | let expected = StructArray::new( |
2160 | 1 | vec![a_field.clone()].into(), |
2161 | 1 | vec![a_filtered.clone()], |
2162 | 1 | Some(null_mask_filtered.clone()), |
2163 | | ); |
2164 | | |
2165 | 1 | let result = filter(&array, &predicate).unwrap(); |
2166 | | |
2167 | 1 | assert_eq!(result.to_data(), expected.to_data()); |
2168 | | |
2169 | 1 | let array = StructArray::new( |
2170 | 1 | vec![a_field.clone(), b_field.clone()].into(), |
2171 | 1 | vec![a.clone(), b.clone()], |
2172 | 1 | None, |
2173 | | ); |
2174 | 1 | let expected = StructArray::new( |
2175 | 1 | vec![a_field.clone(), b_field.clone()].into(), |
2176 | 1 | vec![a_filtered.clone(), b_filtered.clone()], |
2177 | 1 | None, |
2178 | | ); |
2179 | | |
2180 | 1 | let result = filter(&array, &predicate).unwrap(); |
2181 | | |
2182 | 1 | assert_eq!(result.to_data(), expected.to_data()); |
2183 | | |
2184 | 1 | let array = StructArray::new( |
2185 | 1 | vec![a_field.clone(), b_field.clone()].into(), |
2186 | 1 | vec![a.clone(), b.clone()], |
2187 | 1 | Some(null_mask.clone()), |
2188 | | ); |
2189 | | |
2190 | 1 | let expected = StructArray::new( |
2191 | 1 | vec![a_field.clone(), b_field.clone()].into(), |
2192 | 1 | vec![a_filtered.clone(), b_filtered.clone()], |
2193 | 1 | Some(null_mask_filtered.clone()), |
2194 | | ); |
2195 | | |
2196 | 1 | let result = filter(&array, &predicate).unwrap(); |
2197 | | |
2198 | 1 | assert_eq!(result.to_data(), expected.to_data()); |
2199 | 1 | } |
2200 | | |
2201 | | #[test] |
2202 | 1 | fn test_filter_empty_struct() { |
2203 | | /* |
2204 | | "a": { |
2205 | | "b": int64, |
2206 | | "c": {} |
2207 | | }, |
2208 | | */ |
2209 | 1 | let fields = arrow_schema::Field::new( |
2210 | | "a", |
2211 | 1 | arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![ |
2212 | 1 | arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true), |
2213 | 1 | arrow_schema::Field::new( |
2214 | 1 | "c", |
2215 | 1 | arrow_schema::DataType::Struct(arrow_schema::Fields::empty()), |
2216 | 1 | true, |
2217 | 1 | ), |
2218 | 1 | ])), |
2219 | | true, |
2220 | | ); |
2221 | | |
2222 | | /* Test record |
2223 | | {"a":{"c": {}}} |
2224 | | {"a":{"c": {}}} |
2225 | | {"a":{"c": {}}} |
2226 | | */ |
2227 | | |
2228 | | // Create the record batch with the nested struct array |
2229 | 1 | let schema = Arc::new(Schema::new(vec![fields])); |
2230 | | |
2231 | 1 | let b = Arc::new(Int64Array::from(vec![None, None, None])); |
2232 | 1 | let c = Arc::new(StructArray::new_empty_fields( |
2233 | | 3, |
2234 | 1 | Some(NullBuffer::from(vec![true, true, true])), |
2235 | | )); |
2236 | 1 | let a = StructArray::new( |
2237 | 1 | vec![ |
2238 | 1 | Field::new("b", DataType::Int64, true), |
2239 | 1 | Field::new("c", DataType::Struct(Fields::empty()), true), |
2240 | | ] |
2241 | 1 | .into(), |
2242 | 1 | vec![b.clone(), c.clone()], |
2243 | 1 | Some(NullBuffer::from(vec![true, true, true])), |
2244 | | ); |
2245 | 1 | let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap(); |
2246 | 1 | println!("{record_batch:?}"); |
2247 | | |
2248 | | // Apply the filter |
2249 | 1 | let predicate = BooleanArray::from(vec![true, false, true]); |
2250 | 1 | let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap(); |
2251 | | |
2252 | | // The filtered batch should have 2 rows (the 1st and 3rd) |
2253 | 1 | assert_eq!(filtered_batch.num_rows(), 2); |
2254 | 1 | } |
2255 | | } |