/Users/andrewlamb/Software/arrow-rs/arrow-select/src/filter.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines filter kernels |
19 | | |
20 | | use std::ops::AddAssign; |
21 | | use std::sync::Arc; |
22 | | |
23 | | use arrow_array::builder::BooleanBufferBuilder; |
24 | | use arrow_array::cast::AsArray; |
25 | | use arrow_array::types::{ |
26 | | ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType, |
27 | | }; |
28 | | use arrow_array::*; |
29 | | use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer}; |
30 | | use arrow_buffer::{Buffer, MutableBuffer}; |
31 | | use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator}; |
32 | | use arrow_data::transform::MutableArrayData; |
33 | | use arrow_data::ArrayDataBuilder; |
34 | | use arrow_schema::*; |
35 | | |
36 | | /// If the filter selects more than this fraction of rows, use |
37 | | /// [`SlicesIterator`] to copy ranges of values. Otherwise iterate |
38 | | /// over individual rows using [`IndexIterator`] |
39 | | /// |
40 | | /// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009> |
41 | | /// |
42 | | const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; |
43 | | |
44 | | /// An iterator of `(usize, usize)` each representing an interval |
45 | | /// `[start, end)` whose slots of a bitmap [Buffer] are true. |
46 | | /// |
47 | | /// Each interval corresponds to a contiguous region of memory to be |
48 | | /// "taken" from an array to be filtered. |
49 | | /// |
50 | | /// ## Notes: |
51 | | /// |
52 | | /// 1. Ignores the validity bitmap (ignores nulls) |
53 | | /// |
54 | | /// 2. Only performant for filters that copy across long contiguous runs |
55 | | #[derive(Debug)] |
56 | | pub struct SlicesIterator<'a>(BitSliceIterator<'a>); |
57 | | |
58 | | impl<'a> SlicesIterator<'a> { |
59 | | /// Creates a new iterator from a [BooleanArray] |
60 | 0 | pub fn new(filter: &'a BooleanArray) -> Self { |
61 | 0 | Self(filter.values().set_slices()) |
62 | 0 | } |
63 | | } |
64 | | |
65 | | impl Iterator for SlicesIterator<'_> { |
66 | | type Item = (usize, usize); |
67 | | |
68 | 0 | fn next(&mut self) -> Option<Self::Item> { |
69 | 0 | self.0.next() |
70 | 0 | } |
71 | | } |
72 | | |
73 | | /// An iterator of `usize` whose index in [`BooleanArray`] is true |
74 | | /// |
75 | | /// This provides the best performance on most predicates, apart from those which keep |
76 | | /// large runs and therefore favour [`SlicesIterator`] |
77 | | struct IndexIterator<'a> { |
78 | | remaining: usize, |
79 | | iter: BitIndexIterator<'a>, |
80 | | } |
81 | | |
82 | | impl<'a> IndexIterator<'a> { |
83 | 0 | fn new(filter: &'a BooleanArray, remaining: usize) -> Self { |
84 | 0 | assert_eq!(filter.null_count(), 0); |
85 | 0 | let iter = filter.values().set_indices(); |
86 | 0 | Self { remaining, iter } |
87 | 0 | } |
88 | | } |
89 | | |
90 | | impl Iterator for IndexIterator<'_> { |
91 | | type Item = usize; |
92 | | |
93 | 0 | fn next(&mut self) -> Option<Self::Item> { |
94 | 0 | if self.remaining != 0 { |
95 | | // Fascinatingly swapping these two lines around results in a 50% |
96 | | // performance regression for some benchmarks |
97 | 0 | let next = self.iter.next().expect("IndexIterator exhausted early"); |
98 | 0 | self.remaining -= 1; |
99 | | // Must panic if exhausted early as trusted length iterator |
100 | 0 | return Some(next); |
101 | 0 | } |
102 | 0 | None |
103 | 0 | } |
104 | | |
105 | 0 | fn size_hint(&self) -> (usize, Option<usize>) { |
106 | 0 | (self.remaining, Some(self.remaining)) |
107 | 0 | } |
108 | | } |
109 | | |
110 | | /// Counts the number of set bits in `filter` |
111 | 0 | fn filter_count(filter: &BooleanArray) -> usize { |
112 | 0 | filter.values().count_set_bits() |
113 | 0 | } |
114 | | |
115 | | /// Remove null values by do a bitmask AND operation with null bits and the boolean bits. |
116 | 0 | pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { |
117 | 0 | let nulls = filter.nulls().unwrap(); |
118 | 0 | let mask = filter.values() & nulls.inner(); |
119 | 0 | BooleanArray::new(mask, None) |
120 | 0 | } |
121 | | |
122 | | /// Returns a filtered `values` [`Array`] where the corresponding elements of |
123 | | /// `predicate` are `true`. |
124 | | /// |
125 | | /// # See also |
126 | | /// * [`FilterBuilder`] for more control over the filtering process. |
127 | | /// * [`filter_record_batch`] to filter a [`RecordBatch`] |
128 | | /// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce |
129 | | /// the results into a single array. |
130 | | /// |
131 | | /// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer |
132 | | /// |
133 | | /// # Example |
134 | | /// ```rust |
135 | | /// # use arrow_array::{Int32Array, BooleanArray}; |
136 | | /// # use arrow_select::filter::filter; |
137 | | /// let array = Int32Array::from(vec![5, 6, 7, 8, 9]); |
138 | | /// let filter_array = BooleanArray::from(vec![true, false, false, true, false]); |
139 | | /// let c = filter(&array, &filter_array).unwrap(); |
140 | | /// let c = c.as_any().downcast_ref::<Int32Array>().unwrap(); |
141 | | /// assert_eq!(c, &Int32Array::from(vec![5, 8])); |
142 | | /// ``` |
143 | 0 | pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> { |
144 | 0 | let mut filter_builder = FilterBuilder::new(predicate); |
145 | | |
146 | 0 | if multiple_arrays(values.data_type()) { |
147 | 0 | // Only optimize if filtering more than one array |
148 | 0 | // Otherwise, the overhead of optimization can be more than the benefit |
149 | 0 | filter_builder = filter_builder.optimize(); |
150 | 0 | } |
151 | | |
152 | 0 | let predicate = filter_builder.build(); |
153 | | |
154 | 0 | filter_array(values, &predicate) |
155 | 0 | } |
156 | | |
157 | 0 | fn multiple_arrays(data_type: &DataType) -> bool { |
158 | 0 | match data_type { |
159 | 0 | DataType::Struct(fields) => { |
160 | 0 | fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type()) |
161 | | } |
162 | 0 | DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(), |
163 | 0 | _ => false, |
164 | | } |
165 | 0 | } |
166 | | |
167 | | /// Returns a filtered [RecordBatch] where the corresponding elements of |
168 | | /// `predicate` are true. |
169 | | /// |
170 | | /// This is the equivalent of calling [filter] on each column of the [RecordBatch]. |
171 | 0 | pub fn filter_record_batch( |
172 | 0 | record_batch: &RecordBatch, |
173 | 0 | predicate: &BooleanArray, |
174 | 0 | ) -> Result<RecordBatch, ArrowError> { |
175 | 0 | let mut filter_builder = FilterBuilder::new(predicate); |
176 | 0 | if record_batch.num_columns() > 1 { |
177 | 0 | // Only optimize if filtering more than one column |
178 | 0 | // Otherwise, the overhead of optimization can be more than the benefit |
179 | 0 | filter_builder = filter_builder.optimize(); |
180 | 0 | } |
181 | 0 | let filter = filter_builder.build(); |
182 | | |
183 | 0 | let filtered_arrays = record_batch |
184 | 0 | .columns() |
185 | 0 | .iter() |
186 | 0 | .map(|a| filter_array(a, &filter)) |
187 | 0 | .collect::<Result<Vec<_>, _>>()?; |
188 | 0 | let options = RecordBatchOptions::default().with_row_count(Some(filter.count())); |
189 | 0 | RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options) |
190 | 0 | } |
191 | | |
192 | | /// A builder to construct [`FilterPredicate`] |
193 | | #[derive(Debug)] |
194 | | pub struct FilterBuilder { |
195 | | filter: BooleanArray, |
196 | | count: usize, |
197 | | strategy: IterationStrategy, |
198 | | } |
199 | | |
200 | | impl FilterBuilder { |
201 | | /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`] |
202 | 0 | pub fn new(filter: &BooleanArray) -> Self { |
203 | 0 | let filter = match filter.null_count() { |
204 | 0 | 0 => filter.clone(), |
205 | 0 | _ => prep_null_mask_filter(filter), |
206 | | }; |
207 | | |
208 | 0 | let count = filter_count(&filter); |
209 | 0 | let strategy = IterationStrategy::default_strategy(filter.len(), count); |
210 | | |
211 | 0 | Self { |
212 | 0 | filter, |
213 | 0 | count, |
214 | 0 | strategy, |
215 | 0 | } |
216 | 0 | } |
217 | | |
218 | | /// Compute an optimised representation of the provided `filter` mask that can be |
219 | | /// applied to an array more quickly. |
220 | | /// |
221 | | /// Note: There is limited benefit to calling this to then filter a single array |
222 | | /// Note: This will likely have a larger memory footprint than the original mask |
223 | 0 | pub fn optimize(mut self) -> Self { |
224 | 0 | match self.strategy { |
225 | | IterationStrategy::SlicesIterator => { |
226 | 0 | let slices = SlicesIterator::new(&self.filter).collect(); |
227 | 0 | self.strategy = IterationStrategy::Slices(slices) |
228 | | } |
229 | | IterationStrategy::IndexIterator => { |
230 | 0 | let indices = IndexIterator::new(&self.filter, self.count).collect(); |
231 | 0 | self.strategy = IterationStrategy::Indices(indices) |
232 | | } |
233 | 0 | _ => {} |
234 | | } |
235 | 0 | self |
236 | 0 | } |
237 | | |
238 | | /// Construct the final `FilterPredicate` |
239 | 0 | pub fn build(self) -> FilterPredicate { |
240 | 0 | FilterPredicate { |
241 | 0 | filter: self.filter, |
242 | 0 | count: self.count, |
243 | 0 | strategy: self.strategy, |
244 | 0 | } |
245 | 0 | } |
246 | | } |
247 | | |
248 | | /// The iteration strategy used to evaluate [`FilterPredicate`] |
249 | | #[derive(Debug)] |
250 | | enum IterationStrategy { |
251 | | /// A lazily evaluated iterator of ranges |
252 | | SlicesIterator, |
253 | | /// A lazily evaluated iterator of indices |
254 | | IndexIterator, |
255 | | /// A precomputed list of indices |
256 | | Indices(Vec<usize>), |
257 | | /// A precomputed array of ranges |
258 | | Slices(Vec<(usize, usize)>), |
259 | | /// Select all rows |
260 | | All, |
261 | | /// Select no rows |
262 | | None, |
263 | | } |
264 | | |
265 | | impl IterationStrategy { |
266 | | /// The default [`IterationStrategy`] for a filter of length `filter_length` |
267 | | /// and selecting `filter_count` rows |
268 | 0 | fn default_strategy(filter_length: usize, filter_count: usize) -> Self { |
269 | 0 | if filter_length == 0 || filter_count == 0 { |
270 | 0 | return IterationStrategy::None; |
271 | 0 | } |
272 | | |
273 | 0 | if filter_count == filter_length { |
274 | 0 | return IterationStrategy::All; |
275 | 0 | } |
276 | | |
277 | | // Compute the selectivity of the predicate by dividing the number of true |
278 | | // bits in the predicate by the predicate's total length |
279 | | // |
280 | | // This can then be used as a heuristic for the optimal iteration strategy |
281 | 0 | let selectivity_frac = filter_count as f64 / filter_length as f64; |
282 | 0 | if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD { |
283 | 0 | return IterationStrategy::SlicesIterator; |
284 | 0 | } |
285 | 0 | IterationStrategy::IndexIterator |
286 | 0 | } |
287 | | } |
288 | | |
289 | | /// A filtering predicate that can be applied to an [`Array`] |
290 | | #[derive(Debug)] |
291 | | pub struct FilterPredicate { |
292 | | filter: BooleanArray, |
293 | | count: usize, |
294 | | strategy: IterationStrategy, |
295 | | } |
296 | | |
297 | | impl FilterPredicate { |
298 | | /// Selects rows from `values` based on this [`FilterPredicate`] |
299 | 0 | pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> { |
300 | 0 | filter_array(values, self) |
301 | 0 | } |
302 | | |
303 | | /// Number of rows being selected based on this [`FilterPredicate`] |
304 | 0 | pub fn count(&self) -> usize { |
305 | 0 | self.count |
306 | 0 | } |
307 | | } |
308 | | |
309 | 0 | fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> { |
310 | 0 | if predicate.filter.len() > values.len() { |
311 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
312 | 0 | "Filter predicate of length {} is larger than target array of length {}", |
313 | 0 | predicate.filter.len(), |
314 | 0 | values.len() |
315 | 0 | ))); |
316 | 0 | } |
317 | | |
318 | 0 | match predicate.strategy { |
319 | 0 | IterationStrategy::None => Ok(new_empty_array(values.data_type())), |
320 | 0 | IterationStrategy::All => Ok(values.slice(0, predicate.count)), |
321 | | // actually filter |
322 | 0 | _ => downcast_primitive_array! { |
323 | 0 | values => Ok(Arc::new(filter_primitive(values, predicate))), |
324 | | DataType::Boolean => { |
325 | 0 | let values = values.as_any().downcast_ref::<BooleanArray>().unwrap(); |
326 | 0 | Ok(Arc::new(filter_boolean(values, predicate))) |
327 | | } |
328 | | DataType::Utf8 => { |
329 | 0 | Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate))) |
330 | | } |
331 | | DataType::LargeUtf8 => { |
332 | 0 | Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate))) |
333 | | } |
334 | | DataType::Utf8View => { |
335 | 0 | Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate))) |
336 | | } |
337 | | DataType::Binary => { |
338 | 0 | Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate))) |
339 | | } |
340 | | DataType::LargeBinary => { |
341 | 0 | Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate))) |
342 | | } |
343 | | DataType::BinaryView => { |
344 | 0 | Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate))) |
345 | | } |
346 | | DataType::FixedSizeBinary(_) => { |
347 | 0 | Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate))) |
348 | | } |
349 | | DataType::RunEndEncoded(_, _) => { |
350 | 0 | downcast_run_array!{ |
351 | 0 | values => Ok(Arc::new(filter_run_end_array(values, predicate)?)), |
352 | 0 | t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t) |
353 | | } |
354 | | } |
355 | 0 | DataType::Dictionary(_, _) => downcast_dictionary_array! { |
356 | 0 | values => Ok(Arc::new(filter_dict(values, predicate))), |
357 | 0 | t => unimplemented!("Filter not supported for dictionary type {:?}", t) |
358 | | } |
359 | | DataType::Struct(_) => { |
360 | 0 | Ok(Arc::new(filter_struct(values.as_struct(), predicate)?)) |
361 | | } |
362 | | DataType::Union(_, UnionMode::Sparse) => { |
363 | 0 | Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?)) |
364 | | } |
365 | | _ => { |
366 | 0 | let data = values.to_data(); |
367 | | // fallback to using MutableArrayData |
368 | 0 | let mut mutable = MutableArrayData::new( |
369 | 0 | vec![&data], |
370 | | false, |
371 | 0 | predicate.count, |
372 | | ); |
373 | | |
374 | 0 | match &predicate.strategy { |
375 | 0 | IterationStrategy::Slices(slices) => { |
376 | 0 | slices |
377 | 0 | .iter() |
378 | 0 | .for_each(|(start, end)| mutable.extend(0, *start, *end)); |
379 | | } |
380 | | _ => { |
381 | 0 | let iter = SlicesIterator::new(&predicate.filter); |
382 | 0 | iter.for_each(|(start, end)| mutable.extend(0, start, end)); |
383 | | } |
384 | | } |
385 | | |
386 | 0 | let data = mutable.freeze(); |
387 | 0 | Ok(make_array(data)) |
388 | | } |
389 | | }, |
390 | | } |
391 | 0 | } |
392 | | |
393 | | /// Filter any supported [`RunArray`] based on a [`FilterPredicate`] |
394 | 0 | fn filter_run_end_array<R: RunEndIndexType>( |
395 | 0 | array: &RunArray<R>, |
396 | 0 | predicate: &FilterPredicate, |
397 | 0 | ) -> Result<RunArray<R>, ArrowError> |
398 | 0 | where |
399 | 0 | R::Native: Into<i64> + From<bool>, |
400 | 0 | R::Native: AddAssign, |
401 | | { |
402 | 0 | let run_ends: &RunEndBuffer<R::Native> = array.run_ends(); |
403 | 0 | let mut new_run_ends = vec![R::default_value(); run_ends.len()]; |
404 | | |
405 | 0 | let mut start = 0u64; |
406 | 0 | let mut j = 0; |
407 | 0 | let mut count = R::default_value(); |
408 | 0 | let filter_values = predicate.filter.values(); |
409 | 0 | let run_ends = run_ends.inner(); |
410 | | |
411 | 0 | let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| { |
412 | 0 | let mut keep = false; |
413 | 0 | let mut end = run_ends[i].into() as u64; |
414 | 0 | let difference = end.saturating_sub(filter_values.len() as u64); |
415 | 0 | end -= difference; |
416 | | |
417 | | // Safety: we subtract the difference off `end` so we are always within bounds |
418 | 0 | for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) { |
419 | 0 | count += R::Native::from(pred); |
420 | 0 | keep |= pred |
421 | | } |
422 | | // this is to avoid branching |
423 | 0 | new_run_ends[j] = count; |
424 | 0 | j += keep as usize; |
425 | | |
426 | 0 | start = end; |
427 | 0 | keep |
428 | 0 | }) |
429 | 0 | .into(); |
430 | | |
431 | 0 | new_run_ends.truncate(j); |
432 | | |
433 | 0 | let values = array.values(); |
434 | 0 | let values = filter(&values, &pred)?; |
435 | | |
436 | 0 | let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None); |
437 | 0 | RunArray::try_new(&run_ends, &values) |
438 | 0 | } |
439 | | |
440 | | /// Computes a new null mask for `data` based on `predicate` |
441 | | /// |
442 | | /// If the predicate selected no null-rows, returns `None`, otherwise returns |
443 | | /// `Some((null_count, null_buffer))` where `null_count` is the number of nulls |
444 | | /// in the filtered output, and `null_buffer` is the filtered null buffer |
445 | | /// |
446 | 0 | fn filter_null_mask( |
447 | 0 | nulls: Option<&NullBuffer>, |
448 | 0 | predicate: &FilterPredicate, |
449 | 0 | ) -> Option<(usize, Buffer)> { |
450 | 0 | let nulls = nulls?; |
451 | 0 | if nulls.null_count() == 0 { |
452 | 0 | return None; |
453 | 0 | } |
454 | | |
455 | 0 | let nulls = filter_bits(nulls.inner(), predicate); |
456 | | // The filtered `nulls` has a length of `predicate.count` bits and |
457 | | // therefore the null count is this minus the number of valid bits |
458 | 0 | let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count); |
459 | | |
460 | 0 | if null_count == 0 { |
461 | 0 | return None; |
462 | 0 | } |
463 | | |
464 | 0 | Some((null_count, nulls)) |
465 | 0 | } |
466 | | |
467 | | /// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset` |
468 | 0 | fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer { |
469 | 0 | let src = buffer.values(); |
470 | 0 | let offset = buffer.offset(); |
471 | | |
472 | 0 | match &predicate.strategy { |
473 | | IterationStrategy::IndexIterator => { |
474 | 0 | let bits = IndexIterator::new(&predicate.filter, predicate.count) |
475 | 0 | .map(|src_idx| bit_util::get_bit(src, src_idx + offset)); |
476 | | |
477 | | // SAFETY: `IndexIterator` reports its size correctly |
478 | 0 | unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } |
479 | | } |
480 | 0 | IterationStrategy::Indices(indices) => { |
481 | 0 | let bits = indices |
482 | 0 | .iter() |
483 | 0 | .map(|src_idx| bit_util::get_bit(src, *src_idx + offset)); |
484 | | |
485 | | // SAFETY: `Vec::iter()` reports its size correctly |
486 | 0 | unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } |
487 | | } |
488 | | IterationStrategy::SlicesIterator => { |
489 | 0 | let mut builder = BooleanBufferBuilder::new(predicate.count); |
490 | 0 | for (start, end) in SlicesIterator::new(&predicate.filter) { |
491 | 0 | builder.append_packed_range(start + offset..end + offset, src) |
492 | | } |
493 | 0 | builder.into() |
494 | | } |
495 | 0 | IterationStrategy::Slices(slices) => { |
496 | 0 | let mut builder = BooleanBufferBuilder::new(predicate.count); |
497 | 0 | for (start, end) in slices { |
498 | 0 | builder.append_packed_range(*start + offset..*end + offset, src) |
499 | | } |
500 | 0 | builder.into() |
501 | | } |
502 | 0 | IterationStrategy::All | IterationStrategy::None => unreachable!(), |
503 | | } |
504 | 0 | } |
505 | | |
506 | | /// `filter` implementation for boolean buffers |
507 | 0 | fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray { |
508 | 0 | let values = filter_bits(array.values(), predicate); |
509 | | |
510 | 0 | let mut builder = ArrayDataBuilder::new(DataType::Boolean) |
511 | 0 | .len(predicate.count) |
512 | 0 | .add_buffer(values); |
513 | | |
514 | 0 | if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { |
515 | 0 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
516 | 0 | } |
517 | | |
518 | 0 | let data = unsafe { builder.build_unchecked() }; |
519 | 0 | BooleanArray::from(data) |
520 | 0 | } |
521 | | |
522 | | #[inline(never)] |
523 | 0 | fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer { |
524 | 0 | assert!(values.len() >= predicate.filter.len()); |
525 | | |
526 | 0 | match &predicate.strategy { |
527 | | IterationStrategy::SlicesIterator => { |
528 | 0 | let mut buffer = Vec::with_capacity(predicate.count); |
529 | 0 | for (start, end) in SlicesIterator::new(&predicate.filter) { |
530 | 0 | buffer.extend_from_slice(&values[start..end]); |
531 | 0 | } |
532 | 0 | buffer.into() |
533 | | } |
534 | 0 | IterationStrategy::Slices(slices) => { |
535 | 0 | let mut buffer = Vec::with_capacity(predicate.count); |
536 | 0 | for (start, end) in slices { |
537 | 0 | buffer.extend_from_slice(&values[*start..*end]); |
538 | 0 | } |
539 | 0 | buffer.into() |
540 | | } |
541 | | IterationStrategy::IndexIterator => { |
542 | 0 | let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]); |
543 | | |
544 | | // SAFETY: IndexIterator is trusted length |
545 | 0 | unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into() |
546 | | } |
547 | 0 | IterationStrategy::Indices(indices) => { |
548 | 0 | let iter = indices.iter().map(|x| values[*x]); |
549 | 0 | iter.collect::<Vec<_>>().into() |
550 | | } |
551 | 0 | IterationStrategy::All | IterationStrategy::None => unreachable!(), |
552 | | } |
553 | 0 | } |
554 | | |
555 | | /// `filter` implementation for primitive arrays |
556 | 0 | fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T> |
557 | 0 | where |
558 | 0 | T: ArrowPrimitiveType, |
559 | | { |
560 | 0 | let values = array.values(); |
561 | 0 | let buffer = filter_native(values, predicate); |
562 | 0 | let mut builder = ArrayDataBuilder::new(array.data_type().clone()) |
563 | 0 | .len(predicate.count) |
564 | 0 | .add_buffer(buffer); |
565 | | |
566 | 0 | if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { |
567 | 0 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
568 | 0 | } |
569 | | |
570 | 0 | let data = unsafe { builder.build_unchecked() }; |
571 | 0 | PrimitiveArray::from(data) |
572 | 0 | } |
573 | | |
574 | | /// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be |
575 | | /// used to build a new [`GenericByteArray`] by copying values from the source |
576 | | /// |
577 | | /// TODO(raphael): Could this be used for the take kernel as well? |
578 | | struct FilterBytes<'a, OffsetSize> { |
579 | | src_offsets: &'a [OffsetSize], |
580 | | src_values: &'a [u8], |
581 | | dst_offsets: Vec<OffsetSize>, |
582 | | dst_values: Vec<u8>, |
583 | | cur_offset: OffsetSize, |
584 | | } |
585 | | |
586 | | impl<'a, OffsetSize> FilterBytes<'a, OffsetSize> |
587 | | where |
588 | | OffsetSize: OffsetSizeTrait, |
589 | | { |
590 | 0 | fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self |
591 | 0 | where |
592 | 0 | T: ByteArrayType<Offset = OffsetSize>, |
593 | | { |
594 | 0 | let dst_values = Vec::new(); |
595 | 0 | let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1); |
596 | 0 | let cur_offset = OffsetSize::from_usize(0).unwrap(); |
597 | | |
598 | 0 | dst_offsets.push(cur_offset); |
599 | | |
600 | 0 | Self { |
601 | 0 | src_offsets: array.value_offsets(), |
602 | 0 | src_values: array.value_data(), |
603 | 0 | dst_offsets, |
604 | 0 | dst_values, |
605 | 0 | cur_offset, |
606 | 0 | } |
607 | 0 | } |
608 | | |
609 | | /// Returns the byte offset at `idx` |
610 | | #[inline] |
611 | 0 | fn get_value_offset(&self, idx: usize) -> usize { |
612 | 0 | self.src_offsets[idx].as_usize() |
613 | 0 | } |
614 | | |
615 | | /// Returns the start and end of the value at index `idx` along with its length |
616 | | #[inline] |
617 | 0 | fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) { |
618 | | // These can only fail if `array` contains invalid data |
619 | 0 | let start = self.get_value_offset(idx); |
620 | 0 | let end = self.get_value_offset(idx + 1); |
621 | 0 | let len = OffsetSize::from_usize(end - start).expect("illegal offset range"); |
622 | 0 | (start, end, len) |
623 | 0 | } |
624 | | |
625 | 0 | fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) { |
626 | 0 | self.dst_offsets.extend(iter.map(|idx| { |
627 | 0 | let start = self.src_offsets[idx].as_usize(); |
628 | 0 | let end = self.src_offsets[idx + 1].as_usize(); |
629 | 0 | let len = OffsetSize::from_usize(end - start).expect("illegal offset range"); |
630 | 0 | self.cur_offset += len; |
631 | | |
632 | 0 | self.cur_offset |
633 | 0 | })); |
634 | 0 | } |
635 | | |
636 | | /// Extends the in-progress array by the indexes in the provided iterator |
637 | 0 | fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) { |
638 | 0 | self.dst_values.reserve_exact(self.cur_offset.as_usize()); |
639 | | |
640 | 0 | for idx in iter { |
641 | 0 | let start = self.src_offsets[idx].as_usize(); |
642 | 0 | let end = self.src_offsets[idx + 1].as_usize(); |
643 | 0 | self.dst_values |
644 | 0 | .extend_from_slice(&self.src_values[start..end]); |
645 | 0 | } |
646 | 0 | } |
647 | | |
648 | 0 | fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) { |
649 | 0 | self.dst_offsets.reserve_exact(count); |
650 | 0 | for (start, end) in iter { |
651 | | // These can only fail if `array` contains invalid data |
652 | 0 | for idx in start..end { |
653 | 0 | let (_, _, len) = self.get_value_range(idx); |
654 | 0 | self.cur_offset += len; |
655 | 0 | self.dst_offsets.push(self.cur_offset); |
656 | 0 | } |
657 | | } |
658 | 0 | } |
659 | | |
660 | | /// Extends the in-progress array by the ranges in the provided iterator |
661 | 0 | fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) { |
662 | 0 | self.dst_values.reserve_exact(self.cur_offset.as_usize()); |
663 | | |
664 | 0 | for (start, end) in iter { |
665 | 0 | let value_start = self.get_value_offset(start); |
666 | 0 | let value_end = self.get_value_offset(end); |
667 | 0 | self.dst_values |
668 | 0 | .extend_from_slice(&self.src_values[value_start..value_end]); |
669 | 0 | } |
670 | 0 | } |
671 | | } |
672 | | |
673 | | /// `filter` implementation for byte arrays |
674 | | /// |
675 | | /// Note: NULLs with a non-zero slot length in `array` will have the corresponding |
676 | | /// data copied across. This allows handling the null mask separately from the data |
677 | 0 | fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T> |
678 | 0 | where |
679 | 0 | T: ByteArrayType, |
680 | | { |
681 | 0 | let mut filter = FilterBytes::new(predicate.count, array); |
682 | | |
683 | 0 | match &predicate.strategy { |
684 | | IterationStrategy::SlicesIterator => { |
685 | 0 | filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count); |
686 | 0 | filter.extend_slices(SlicesIterator::new(&predicate.filter)) |
687 | | } |
688 | 0 | IterationStrategy::Slices(slices) => { |
689 | 0 | filter.extend_offsets_slices(slices.iter().cloned(), predicate.count); |
690 | 0 | filter.extend_slices(slices.iter().cloned()) |
691 | | } |
692 | | IterationStrategy::IndexIterator => { |
693 | 0 | filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count)); |
694 | 0 | filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count)) |
695 | | } |
696 | 0 | IterationStrategy::Indices(indices) => { |
697 | 0 | filter.extend_offsets_idx(indices.iter().cloned()); |
698 | 0 | filter.extend_idx(indices.iter().cloned()) |
699 | | } |
700 | 0 | IterationStrategy::All | IterationStrategy::None => unreachable!(), |
701 | | } |
702 | | |
703 | 0 | let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) |
704 | 0 | .len(predicate.count) |
705 | 0 | .add_buffer(filter.dst_offsets.into()) |
706 | 0 | .add_buffer(filter.dst_values.into()); |
707 | | |
708 | 0 | if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { |
709 | 0 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
710 | 0 | } |
711 | | |
712 | 0 | let data = unsafe { builder.build_unchecked() }; |
713 | 0 | GenericByteArray::from(data) |
714 | 0 | } |
715 | | |
716 | | /// `filter` implementation for byte view arrays. |
717 | 0 | fn filter_byte_view<T: ByteViewType>( |
718 | 0 | array: &GenericByteViewArray<T>, |
719 | 0 | predicate: &FilterPredicate, |
720 | 0 | ) -> GenericByteViewArray<T> { |
721 | 0 | let new_view_buffer = filter_native(array.views(), predicate); |
722 | | |
723 | 0 | let mut builder = ArrayDataBuilder::new(T::DATA_TYPE) |
724 | 0 | .len(predicate.count) |
725 | 0 | .add_buffer(new_view_buffer) |
726 | 0 | .add_buffers(array.data_buffers().to_vec()); |
727 | | |
728 | 0 | if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { |
729 | 0 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
730 | 0 | } |
731 | | |
732 | 0 | GenericByteViewArray::from(unsafe { builder.build_unchecked() }) |
733 | 0 | } |
734 | | |
735 | 0 | fn filter_fixed_size_binary( |
736 | 0 | array: &FixedSizeBinaryArray, |
737 | 0 | predicate: &FilterPredicate, |
738 | 0 | ) -> FixedSizeBinaryArray { |
739 | 0 | let values: &[u8] = array.values(); |
740 | 0 | let value_length = array.value_length() as usize; |
741 | 0 | let calculate_offset_from_index = |index: usize| index * value_length; |
742 | 0 | let buffer = match &predicate.strategy { |
743 | | IterationStrategy::SlicesIterator => { |
744 | 0 | let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length); |
745 | 0 | for (start, end) in SlicesIterator::new(&predicate.filter) { |
746 | 0 | buffer.extend_from_slice( |
747 | 0 | &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)], |
748 | 0 | ); |
749 | 0 | } |
750 | 0 | buffer |
751 | | } |
752 | 0 | IterationStrategy::Slices(slices) => { |
753 | 0 | let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length); |
754 | 0 | for (start, end) in slices { |
755 | 0 | buffer.extend_from_slice( |
756 | 0 | &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)], |
757 | 0 | ); |
758 | 0 | } |
759 | 0 | buffer |
760 | | } |
761 | | IterationStrategy::IndexIterator => { |
762 | 0 | let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| { |
763 | 0 | &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)] |
764 | 0 | }); |
765 | | |
766 | 0 | let mut buffer = MutableBuffer::new(predicate.count * value_length); |
767 | 0 | iter.for_each(|item| buffer.extend_from_slice(item)); |
768 | 0 | buffer |
769 | | } |
770 | 0 | IterationStrategy::Indices(indices) => { |
771 | 0 | let iter = indices.iter().map(|x| { |
772 | 0 | &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)] |
773 | 0 | }); |
774 | | |
775 | 0 | let mut buffer = MutableBuffer::new(predicate.count * value_length); |
776 | 0 | iter.for_each(|item| buffer.extend_from_slice(item)); |
777 | 0 | buffer |
778 | | } |
779 | 0 | IterationStrategy::All | IterationStrategy::None => unreachable!(), |
780 | | }; |
781 | 0 | let mut builder = ArrayDataBuilder::new(array.data_type().clone()) |
782 | 0 | .len(predicate.count) |
783 | 0 | .add_buffer(buffer.into()); |
784 | | |
785 | 0 | if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { |
786 | 0 | builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); |
787 | 0 | } |
788 | | |
789 | 0 | let data = unsafe { builder.build_unchecked() }; |
790 | 0 | FixedSizeBinaryArray::from(data) |
791 | 0 | } |
792 | | |
793 | | /// `filter` implementation for dictionaries |
794 | 0 | fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T> |
795 | 0 | where |
796 | 0 | T: ArrowDictionaryKeyType, |
797 | 0 | T::Native: num::Num, |
798 | | { |
799 | 0 | let builder = filter_primitive::<T>(array.keys(), predicate) |
800 | 0 | .into_data() |
801 | 0 | .into_builder() |
802 | 0 | .data_type(array.data_type().clone()) |
803 | 0 | .child_data(vec![array.values().to_data()]); |
804 | | |
805 | | // SAFETY: |
806 | | // Keys were valid before, filtered subset is therefore still valid |
807 | 0 | DictionaryArray::from(unsafe { builder.build_unchecked() }) |
808 | 0 | } |
809 | | |
810 | | /// `filter` implementation for structs |
811 | 0 | fn filter_struct( |
812 | 0 | array: &StructArray, |
813 | 0 | predicate: &FilterPredicate, |
814 | 0 | ) -> Result<StructArray, ArrowError> { |
815 | 0 | let columns = array |
816 | 0 | .columns() |
817 | 0 | .iter() |
818 | 0 | .map(|column| filter_array(column, predicate)) |
819 | 0 | .collect::<Result<_, _>>()?; |
820 | | |
821 | 0 | let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) { |
822 | 0 | let buffer = BooleanBuffer::new(nulls, 0, predicate.count); |
823 | | |
824 | 0 | Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) }) |
825 | | } else { |
826 | 0 | None |
827 | | }; |
828 | | |
829 | 0 | Ok(unsafe { |
830 | 0 | StructArray::new_unchecked_with_length( |
831 | 0 | array.fields().clone(), |
832 | 0 | columns, |
833 | 0 | nulls, |
834 | 0 | predicate.count(), |
835 | 0 | ) |
836 | 0 | }) |
837 | 0 | } |
838 | | |
839 | | /// `filter` implementation for sparse unions |
840 | 0 | fn filter_sparse_union( |
841 | 0 | array: &UnionArray, |
842 | 0 | predicate: &FilterPredicate, |
843 | 0 | ) -> Result<UnionArray, ArrowError> { |
844 | 0 | let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else { |
845 | 0 | unreachable!() |
846 | | }; |
847 | | |
848 | 0 | let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate); |
849 | | |
850 | 0 | let children = fields |
851 | 0 | .iter() |
852 | 0 | .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate)) |
853 | 0 | .collect::<Result<_, _>>()?; |
854 | | |
855 | 0 | Ok(unsafe { |
856 | 0 | UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children) |
857 | 0 | }) |
858 | 0 | } |
859 | | |
860 | | #[cfg(test)] |
861 | | mod tests { |
862 | | use super::*; |
863 | | use arrow_array::builder::*; |
864 | | use arrow_array::cast::as_run_array; |
865 | | use arrow_array::types::*; |
866 | | use arrow_data::ArrayData; |
867 | | use rand::distr::uniform::{UniformSampler, UniformUsize}; |
868 | | use rand::distr::{Alphanumeric, StandardUniform}; |
869 | | use rand::prelude::*; |
870 | | use rand::rng; |
871 | | |
872 | | macro_rules! def_temporal_test { |
873 | | ($test:ident, $array_type: ident, $data: expr) => { |
874 | | #[test] |
875 | | fn $test() { |
876 | | let a = $data; |
877 | | let b = BooleanArray::from(vec![true, false, true, false]); |
878 | | let c = filter(&a, &b).unwrap(); |
879 | | let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap(); |
880 | | assert_eq!(2, d.len()); |
881 | | assert_eq!(1, d.value(0)); |
882 | | assert_eq!(3, d.value(1)); |
883 | | } |
884 | | }; |
885 | | } |
886 | | |
887 | | def_temporal_test!( |
888 | | test_filter_date32, |
889 | | Date32Array, |
890 | | Date32Array::from(vec![1, 2, 3, 4]) |
891 | | ); |
892 | | def_temporal_test!( |
893 | | test_filter_date64, |
894 | | Date64Array, |
895 | | Date64Array::from(vec![1, 2, 3, 4]) |
896 | | ); |
897 | | def_temporal_test!( |
898 | | test_filter_time32_second, |
899 | | Time32SecondArray, |
900 | | Time32SecondArray::from(vec![1, 2, 3, 4]) |
901 | | ); |
902 | | def_temporal_test!( |
903 | | test_filter_time32_millisecond, |
904 | | Time32MillisecondArray, |
905 | | Time32MillisecondArray::from(vec![1, 2, 3, 4]) |
906 | | ); |
907 | | def_temporal_test!( |
908 | | test_filter_time64_microsecond, |
909 | | Time64MicrosecondArray, |
910 | | Time64MicrosecondArray::from(vec![1, 2, 3, 4]) |
911 | | ); |
912 | | def_temporal_test!( |
913 | | test_filter_time64_nanosecond, |
914 | | Time64NanosecondArray, |
915 | | Time64NanosecondArray::from(vec![1, 2, 3, 4]) |
916 | | ); |
917 | | def_temporal_test!( |
918 | | test_filter_duration_second, |
919 | | DurationSecondArray, |
920 | | DurationSecondArray::from(vec![1, 2, 3, 4]) |
921 | | ); |
922 | | def_temporal_test!( |
923 | | test_filter_duration_millisecond, |
924 | | DurationMillisecondArray, |
925 | | DurationMillisecondArray::from(vec![1, 2, 3, 4]) |
926 | | ); |
927 | | def_temporal_test!( |
928 | | test_filter_duration_microsecond, |
929 | | DurationMicrosecondArray, |
930 | | DurationMicrosecondArray::from(vec![1, 2, 3, 4]) |
931 | | ); |
932 | | def_temporal_test!( |
933 | | test_filter_duration_nanosecond, |
934 | | DurationNanosecondArray, |
935 | | DurationNanosecondArray::from(vec![1, 2, 3, 4]) |
936 | | ); |
937 | | def_temporal_test!( |
938 | | test_filter_timestamp_second, |
939 | | TimestampSecondArray, |
940 | | TimestampSecondArray::from(vec![1, 2, 3, 4]) |
941 | | ); |
942 | | def_temporal_test!( |
943 | | test_filter_timestamp_millisecond, |
944 | | TimestampMillisecondArray, |
945 | | TimestampMillisecondArray::from(vec![1, 2, 3, 4]) |
946 | | ); |
947 | | def_temporal_test!( |
948 | | test_filter_timestamp_microsecond, |
949 | | TimestampMicrosecondArray, |
950 | | TimestampMicrosecondArray::from(vec![1, 2, 3, 4]) |
951 | | ); |
952 | | def_temporal_test!( |
953 | | test_filter_timestamp_nanosecond, |
954 | | TimestampNanosecondArray, |
955 | | TimestampNanosecondArray::from(vec![1, 2, 3, 4]) |
956 | | ); |
957 | | |
958 | | #[test] |
959 | | fn test_filter_array_slice() { |
960 | | let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4); |
961 | | let b = BooleanArray::from(vec![true, false, false, true]); |
962 | | // filtering with sliced filter array is not currently supported |
963 | | // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4); |
964 | | // let b = b_slice.as_any().downcast_ref().unwrap(); |
965 | | let c = filter(&a, &b).unwrap(); |
966 | | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
967 | | assert_eq!(2, d.len()); |
968 | | assert_eq!(6, d.value(0)); |
969 | | assert_eq!(9, d.value(1)); |
970 | | } |
971 | | |
972 | | #[test] |
973 | | fn test_filter_array_low_density() { |
974 | | // this test exercises the all 0's branch of the filter algorithm |
975 | | let mut data_values = (1..=65).collect::<Vec<i32>>(); |
976 | | let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>(); |
977 | | // set up two more values after the batch |
978 | | data_values.extend_from_slice(&[66, 67]); |
979 | | filter_values.extend_from_slice(&[false, true]); |
980 | | let a = Int32Array::from(data_values); |
981 | | let b = BooleanArray::from(filter_values); |
982 | | let c = filter(&a, &b).unwrap(); |
983 | | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
984 | | assert_eq!(2, d.len()); |
985 | | assert_eq!(65, d.value(0)); |
986 | | assert_eq!(67, d.value(1)); |
987 | | } |
988 | | |
989 | | #[test] |
990 | | fn test_filter_array_high_density() { |
991 | | // this test exercises the all 1's branch of the filter algorithm |
992 | | let mut data_values = (1..=65).map(Some).collect::<Vec<_>>(); |
993 | | let mut filter_values = (1..=65) |
994 | | .map(|i| !matches!(i % 65, 0)) |
995 | | .collect::<Vec<bool>>(); |
996 | | // set second data value to null |
997 | | data_values[1] = None; |
998 | | // set up two more values after the batch |
999 | | data_values.extend_from_slice(&[Some(66), None, Some(67), None]); |
1000 | | filter_values.extend_from_slice(&[false, true, true, true]); |
1001 | | let a = Int32Array::from(data_values); |
1002 | | let b = BooleanArray::from(filter_values); |
1003 | | let c = filter(&a, &b).unwrap(); |
1004 | | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
1005 | | assert_eq!(67, d.len()); |
1006 | | assert_eq!(3, d.null_count()); |
1007 | | assert_eq!(1, d.value(0)); |
1008 | | assert!(d.is_null(1)); |
1009 | | assert_eq!(64, d.value(63)); |
1010 | | assert!(d.is_null(64)); |
1011 | | assert_eq!(67, d.value(65)); |
1012 | | } |
1013 | | |
1014 | | #[test] |
1015 | | fn test_filter_string_array_simple() { |
1016 | | let a = StringArray::from(vec!["hello", " ", "world", "!"]); |
1017 | | let b = BooleanArray::from(vec![true, false, true, false]); |
1018 | | let c = filter(&a, &b).unwrap(); |
1019 | | let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap(); |
1020 | | assert_eq!(2, d.len()); |
1021 | | assert_eq!("hello", d.value(0)); |
1022 | | assert_eq!("world", d.value(1)); |
1023 | | } |
1024 | | |
1025 | | #[test] |
1026 | | fn test_filter_primitive_array_with_null() { |
1027 | | let a = Int32Array::from(vec![Some(5), None]); |
1028 | | let b = BooleanArray::from(vec![false, true]); |
1029 | | let c = filter(&a, &b).unwrap(); |
1030 | | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
1031 | | assert_eq!(1, d.len()); |
1032 | | assert!(d.is_null(0)); |
1033 | | } |
1034 | | |
1035 | | #[test] |
1036 | | fn test_filter_string_array_with_null() { |
1037 | | let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]); |
1038 | | let b = BooleanArray::from(vec![true, false, false, true]); |
1039 | | let c = filter(&a, &b).unwrap(); |
1040 | | let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap(); |
1041 | | assert_eq!(2, d.len()); |
1042 | | assert_eq!("hello", d.value(0)); |
1043 | | assert!(!d.is_null(0)); |
1044 | | assert!(d.is_null(1)); |
1045 | | } |
1046 | | |
1047 | | #[test] |
1048 | | fn test_filter_binary_array_with_null() { |
1049 | | let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None]; |
1050 | | let a = BinaryArray::from(data); |
1051 | | let b = BooleanArray::from(vec![true, false, false, true]); |
1052 | | let c = filter(&a, &b).unwrap(); |
1053 | | let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap(); |
1054 | | assert_eq!(2, d.len()); |
1055 | | assert_eq!(b"hello", d.value(0)); |
1056 | | assert!(!d.is_null(0)); |
1057 | | assert!(d.is_null(1)); |
1058 | | } |
1059 | | |
1060 | | fn _test_filter_byte_view<T>() |
1061 | | where |
1062 | | T: ByteViewType, |
1063 | | str: AsRef<T::Native>, |
1064 | | T::Native: PartialEq, |
1065 | | { |
1066 | | let array = { |
1067 | | // ["hello", "world", null, "large payload over 12 bytes", "lulu"] |
1068 | | let mut builder = GenericByteViewBuilder::<T>::new(); |
1069 | | builder.append_value("hello"); |
1070 | | builder.append_value("world"); |
1071 | | builder.append_null(); |
1072 | | builder.append_value("large payload over 12 bytes"); |
1073 | | builder.append_value("lulu"); |
1074 | | builder.finish() |
1075 | | }; |
1076 | | |
1077 | | { |
1078 | | let predicate = BooleanArray::from(vec![true, false, true, true, false]); |
1079 | | let actual = filter(&array, &predicate).unwrap(); |
1080 | | |
1081 | | assert_eq!(actual.len(), 3); |
1082 | | |
1083 | | let expected = { |
1084 | | // ["hello", null, "large payload over 12 bytes"] |
1085 | | let mut builder = GenericByteViewBuilder::<T>::new(); |
1086 | | builder.append_value("hello"); |
1087 | | builder.append_null(); |
1088 | | builder.append_value("large payload over 12 bytes"); |
1089 | | builder.finish() |
1090 | | }; |
1091 | | |
1092 | | assert_eq!(actual.as_ref(), &expected); |
1093 | | } |
1094 | | |
1095 | | { |
1096 | | let predicate = BooleanArray::from(vec![true, false, false, false, true]); |
1097 | | let actual = filter(&array, &predicate).unwrap(); |
1098 | | |
1099 | | assert_eq!(actual.len(), 2); |
1100 | | |
1101 | | let expected = { |
1102 | | // ["hello", "lulu"] |
1103 | | let mut builder = GenericByteViewBuilder::<T>::new(); |
1104 | | builder.append_value("hello"); |
1105 | | builder.append_value("lulu"); |
1106 | | builder.finish() |
1107 | | }; |
1108 | | |
1109 | | assert_eq!(actual.as_ref(), &expected); |
1110 | | } |
1111 | | } |
1112 | | |
1113 | | #[test] |
1114 | | fn test_filter_string_view() { |
1115 | | _test_filter_byte_view::<StringViewType>() |
1116 | | } |
1117 | | |
1118 | | #[test] |
1119 | | fn test_filter_binary_view() { |
1120 | | _test_filter_byte_view::<BinaryViewType>() |
1121 | | } |
1122 | | |
1123 | | #[test] |
1124 | | fn test_filter_fixed_binary() { |
1125 | | let v1 = [1_u8, 2]; |
1126 | | let v2 = [3_u8, 4]; |
1127 | | let v3 = [5_u8, 6]; |
1128 | | let v = vec![&v1, &v2, &v3]; |
1129 | | let a = FixedSizeBinaryArray::from(v); |
1130 | | let b = BooleanArray::from(vec![true, false, true]); |
1131 | | let c = filter(&a, &b).unwrap(); |
1132 | | let d = c |
1133 | | .as_ref() |
1134 | | .as_any() |
1135 | | .downcast_ref::<FixedSizeBinaryArray>() |
1136 | | .unwrap(); |
1137 | | assert_eq!(d.len(), 2); |
1138 | | assert_eq!(d.value(0), &v1); |
1139 | | assert_eq!(d.value(1), &v3); |
1140 | | let c2 = FilterBuilder::new(&b) |
1141 | | .optimize() |
1142 | | .build() |
1143 | | .filter(&a) |
1144 | | .unwrap(); |
1145 | | let d2 = c2 |
1146 | | .as_ref() |
1147 | | .as_any() |
1148 | | .downcast_ref::<FixedSizeBinaryArray>() |
1149 | | .unwrap(); |
1150 | | assert_eq!(d, d2); |
1151 | | |
1152 | | let b = BooleanArray::from(vec![false, false, false]); |
1153 | | let c = filter(&a, &b).unwrap(); |
1154 | | let d = c |
1155 | | .as_ref() |
1156 | | .as_any() |
1157 | | .downcast_ref::<FixedSizeBinaryArray>() |
1158 | | .unwrap(); |
1159 | | assert_eq!(d.len(), 0); |
1160 | | |
1161 | | let b = BooleanArray::from(vec![true, true, true]); |
1162 | | let c = filter(&a, &b).unwrap(); |
1163 | | let d = c |
1164 | | .as_ref() |
1165 | | .as_any() |
1166 | | .downcast_ref::<FixedSizeBinaryArray>() |
1167 | | .unwrap(); |
1168 | | assert_eq!(d.len(), 3); |
1169 | | assert_eq!(d.value(0), &v1); |
1170 | | assert_eq!(d.value(1), &v2); |
1171 | | assert_eq!(d.value(2), &v3); |
1172 | | |
1173 | | let b = BooleanArray::from(vec![false, false, true]); |
1174 | | let c = filter(&a, &b).unwrap(); |
1175 | | let d = c |
1176 | | .as_ref() |
1177 | | .as_any() |
1178 | | .downcast_ref::<FixedSizeBinaryArray>() |
1179 | | .unwrap(); |
1180 | | assert_eq!(d.len(), 1); |
1181 | | assert_eq!(d.value(0), &v3); |
1182 | | let c2 = FilterBuilder::new(&b) |
1183 | | .optimize() |
1184 | | .build() |
1185 | | .filter(&a) |
1186 | | .unwrap(); |
1187 | | let d2 = c2 |
1188 | | .as_ref() |
1189 | | .as_any() |
1190 | | .downcast_ref::<FixedSizeBinaryArray>() |
1191 | | .unwrap(); |
1192 | | assert_eq!(d, d2); |
1193 | | } |
1194 | | |
1195 | | #[test] |
1196 | | fn test_filter_array_slice_with_null() { |
1197 | | let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4); |
1198 | | let b = BooleanArray::from(vec![true, false, false, true]); |
1199 | | // filtering with sliced filter array is not currently supported |
1200 | | // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4); |
1201 | | // let b = b_slice.as_any().downcast_ref().unwrap(); |
1202 | | let c = filter(&a, &b).unwrap(); |
1203 | | let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap(); |
1204 | | assert_eq!(2, d.len()); |
1205 | | assert!(d.is_null(0)); |
1206 | | assert!(!d.is_null(1)); |
1207 | | assert_eq!(9, d.value(1)); |
1208 | | } |
1209 | | |
1210 | | #[test] |
1211 | | fn test_filter_run_end_encoding_array() { |
1212 | | let run_ends = Int64Array::from(vec![2, 3, 8]); |
1213 | | let values = Int64Array::from(vec![7, -2, 9]); |
1214 | | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1215 | | let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]); |
1216 | | let c = filter(&a, &b).unwrap(); |
1217 | | let actual: &RunArray<Int64Type> = as_run_array(&c); |
1218 | | assert_eq!(4, actual.len()); |
1219 | | |
1220 | | let expected = RunArray::try_new( |
1221 | | &Int64Array::from(vec![1, 2, 4]), |
1222 | | &Int64Array::from(vec![7, -2, 9]), |
1223 | | ) |
1224 | | .expect("Failed to make expected RunArray test is broken"); |
1225 | | |
1226 | | assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
1227 | | assert_eq!(actual.values(), expected.values()) |
1228 | | } |
1229 | | |
1230 | | #[test] |
1231 | | fn test_filter_run_end_encoding_array_remove_value() { |
1232 | | let run_ends = Int32Array::from(vec![2, 3, 8, 10]); |
1233 | | let values = Int32Array::from(vec![7, -2, 9, -8]); |
1234 | | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1235 | | let b = BooleanArray::from(vec![ |
1236 | | false, true, false, false, true, false, true, false, false, false, |
1237 | | ]); |
1238 | | let c = filter(&a, &b).unwrap(); |
1239 | | let actual: &RunArray<Int32Type> = as_run_array(&c); |
1240 | | assert_eq!(3, actual.len()); |
1241 | | |
1242 | | let expected = |
1243 | | RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9])) |
1244 | | .expect("Failed to make expected RunArray test is broken"); |
1245 | | |
1246 | | assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
1247 | | assert_eq!(actual.values(), expected.values()) |
1248 | | } |
1249 | | |
1250 | | #[test] |
1251 | | fn test_filter_run_end_encoding_array_remove_all_but_one() { |
1252 | | let run_ends = Int16Array::from(vec![2, 3, 8, 10]); |
1253 | | let values = Int16Array::from(vec![7, -2, 9, -8]); |
1254 | | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1255 | | let b = BooleanArray::from(vec![ |
1256 | | false, false, false, false, false, false, true, false, false, false, |
1257 | | ]); |
1258 | | let c = filter(&a, &b).unwrap(); |
1259 | | let actual: &RunArray<Int16Type> = as_run_array(&c); |
1260 | | assert_eq!(1, actual.len()); |
1261 | | |
1262 | | let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9])) |
1263 | | .expect("Failed to make expected RunArray test is broken"); |
1264 | | |
1265 | | assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
1266 | | assert_eq!(actual.values(), expected.values()) |
1267 | | } |
1268 | | |
1269 | | #[test] |
1270 | | fn test_filter_run_end_encoding_array_empty() { |
1271 | | let run_ends = Int64Array::from(vec![2, 3, 8, 10]); |
1272 | | let values = Int64Array::from(vec![7, -2, 9, -8]); |
1273 | | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1274 | | let b = BooleanArray::from(vec![ |
1275 | | false, false, false, false, false, false, false, false, false, false, |
1276 | | ]); |
1277 | | let c = filter(&a, &b).unwrap(); |
1278 | | let actual: &RunArray<Int64Type> = as_run_array(&c); |
1279 | | assert_eq!(0, actual.len()); |
1280 | | } |
1281 | | |
1282 | | #[test] |
1283 | | fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() { |
1284 | | let run_ends = Int64Array::from(vec![2, 3, 8, 10]); |
1285 | | let values = Int64Array::from(vec![7, -2, 9, -8]); |
1286 | | let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray"); |
1287 | | let b = BooleanArray::from(vec![false, true, true]); |
1288 | | let c = filter(&a, &b).unwrap(); |
1289 | | let actual: &RunArray<Int64Type> = as_run_array(&c); |
1290 | | assert_eq!(2, actual.len()); |
1291 | | |
1292 | | let expected = RunArray::try_new( |
1293 | | &Int64Array::from(vec![1, 2]), |
1294 | | &Int64Array::from(vec![7, -2]), |
1295 | | ) |
1296 | | .expect("Failed to make expected RunArray test is broken"); |
1297 | | |
1298 | | assert_eq!(&actual.run_ends().values(), &expected.run_ends().values()); |
1299 | | assert_eq!(actual.values(), expected.values()) |
1300 | | } |
1301 | | |
1302 | | #[test] |
1303 | | fn test_filter_dictionary_array() { |
1304 | | let values = [Some("hello"), None, Some("world"), Some("!")]; |
1305 | | let a: Int8DictionaryArray = values.iter().copied().collect(); |
1306 | | let b = BooleanArray::from(vec![false, true, true, false]); |
1307 | | let c = filter(&a, &b).unwrap(); |
1308 | | let d = c |
1309 | | .as_ref() |
1310 | | .as_any() |
1311 | | .downcast_ref::<Int8DictionaryArray>() |
1312 | | .unwrap(); |
1313 | | let value_array = d.values(); |
1314 | | let values = value_array.as_any().downcast_ref::<StringArray>().unwrap(); |
1315 | | // values are cloned in the filtered dictionary array |
1316 | | assert_eq!(3, values.len()); |
1317 | | // but keys are filtered |
1318 | | assert_eq!(2, d.len()); |
1319 | | assert!(d.is_null(0)); |
1320 | | assert_eq!("world", values.value(d.keys().value(1) as usize)); |
1321 | | } |
1322 | | |
1323 | | #[test] |
1324 | | fn test_filter_list_array() { |
1325 | | let value_data = ArrayData::builder(DataType::Int32) |
1326 | | .len(8) |
1327 | | .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7])) |
1328 | | .build() |
1329 | | .unwrap(); |
1330 | | |
1331 | | let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]); |
1332 | | |
1333 | | let list_data_type = |
1334 | | DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false))); |
1335 | | let list_data = ArrayData::builder(list_data_type) |
1336 | | .len(4) |
1337 | | .add_buffer(value_offsets) |
1338 | | .add_child_data(value_data) |
1339 | | .null_bit_buffer(Some(Buffer::from([0b00000111]))) |
1340 | | .build() |
1341 | | .unwrap(); |
1342 | | |
1343 | | // a = [[0, 1, 2], [3, 4, 5], [6, 7], null] |
1344 | | let a = LargeListArray::from(list_data); |
1345 | | let b = BooleanArray::from(vec![false, true, false, true]); |
1346 | | let result = filter(&a, &b).unwrap(); |
1347 | | |
1348 | | // expected: [[3, 4, 5], null] |
1349 | | let value_data = ArrayData::builder(DataType::Int32) |
1350 | | .len(3) |
1351 | | .add_buffer(Buffer::from_slice_ref([3, 4, 5])) |
1352 | | .build() |
1353 | | .unwrap(); |
1354 | | |
1355 | | let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]); |
1356 | | |
1357 | | let list_data_type = |
1358 | | DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false))); |
1359 | | let expected = ArrayData::builder(list_data_type) |
1360 | | .len(2) |
1361 | | .add_buffer(value_offsets) |
1362 | | .add_child_data(value_data) |
1363 | | .null_bit_buffer(Some(Buffer::from([0b00000001]))) |
1364 | | .build() |
1365 | | .unwrap(); |
1366 | | |
1367 | | assert_eq!(&make_array(expected), &result); |
1368 | | } |
1369 | | |
1370 | | #[test] |
1371 | | fn test_slice_iterator_bits() { |
1372 | | let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>(); |
1373 | | let filter = BooleanArray::from(filter_values); |
1374 | | let filter_count = filter_count(&filter); |
1375 | | |
1376 | | let iter = SlicesIterator::new(&filter); |
1377 | | let chunks = iter.collect::<Vec<_>>(); |
1378 | | |
1379 | | assert_eq!(chunks, vec![(1, 2)]); |
1380 | | assert_eq!(filter_count, 1); |
1381 | | } |
1382 | | |
1383 | | #[test] |
1384 | | fn test_slice_iterator_bits1() { |
1385 | | let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>(); |
1386 | | let filter = BooleanArray::from(filter_values); |
1387 | | let filter_count = filter_count(&filter); |
1388 | | |
1389 | | let iter = SlicesIterator::new(&filter); |
1390 | | let chunks = iter.collect::<Vec<_>>(); |
1391 | | |
1392 | | assert_eq!(chunks, vec![(0, 1), (2, 64)]); |
1393 | | assert_eq!(filter_count, 64 - 1); |
1394 | | } |
1395 | | |
1396 | | #[test] |
1397 | | fn test_slice_iterator_chunk_and_bits() { |
1398 | | let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>(); |
1399 | | let filter = BooleanArray::from(filter_values); |
1400 | | let filter_count = filter_count(&filter); |
1401 | | |
1402 | | let iter = SlicesIterator::new(&filter); |
1403 | | let chunks = iter.collect::<Vec<_>>(); |
1404 | | |
1405 | | assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]); |
1406 | | assert_eq!(filter_count, 61 + 61 + 5); |
1407 | | } |
1408 | | |
1409 | | #[test] |
1410 | | fn test_null_mask() { |
1411 | | let a = Int64Array::from(vec![Some(1), Some(2), None]); |
1412 | | |
1413 | | let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]); |
1414 | | let out = filter(&a, &mask1).unwrap(); |
1415 | | assert_eq!(out.as_ref(), &a.slice(0, 2)); |
1416 | | } |
1417 | | |
1418 | | #[test] |
1419 | | fn test_filter_record_batch_no_columns() { |
1420 | | let pred = BooleanArray::from(vec![Some(true), Some(true), None]); |
1421 | | let options = RecordBatchOptions::default().with_row_count(Some(100)); |
1422 | | let record_batch = |
1423 | | RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap(); |
1424 | | let out = filter_record_batch(&record_batch, &pred).unwrap(); |
1425 | | |
1426 | | assert_eq!(out.num_rows(), 2); |
1427 | | } |
1428 | | |
1429 | | #[test] |
1430 | | fn test_fast_path() { |
1431 | | let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]); |
1432 | | |
1433 | | // all true |
1434 | | let mask = BooleanArray::from(vec![true, true, true]); |
1435 | | let out = filter(&a, &mask).unwrap(); |
1436 | | let b = out |
1437 | | .as_any() |
1438 | | .downcast_ref::<PrimitiveArray<Int64Type>>() |
1439 | | .unwrap(); |
1440 | | assert_eq!(&a, b); |
1441 | | |
1442 | | // all false |
1443 | | let mask = BooleanArray::from(vec![false, false, false]); |
1444 | | let out = filter(&a, &mask).unwrap(); |
1445 | | assert_eq!(out.len(), 0); |
1446 | | assert_eq!(out.data_type(), &DataType::Int64); |
1447 | | } |
1448 | | |
1449 | | #[test] |
1450 | | fn test_slices() { |
1451 | | // takes up 2 u64s |
1452 | | let bools = std::iter::repeat_n(true, 10) |
1453 | | .chain(std::iter::repeat_n(false, 30)) |
1454 | | .chain(std::iter::repeat_n(true, 20)) |
1455 | | .chain(std::iter::repeat_n(false, 17)) |
1456 | | .chain(std::iter::repeat_n(true, 4)); |
1457 | | |
1458 | | let bool_array: BooleanArray = bools.map(Some).collect(); |
1459 | | |
1460 | | let slices: Vec<_> = SlicesIterator::new(&bool_array).collect(); |
1461 | | let expected = vec![(0, 10), (40, 60), (77, 81)]; |
1462 | | assert_eq!(slices, expected); |
1463 | | |
1464 | | // slice with offset and truncated len |
1465 | | let len = bool_array.len(); |
1466 | | let sliced_array = bool_array.slice(7, len - 10); |
1467 | | let sliced_array = sliced_array |
1468 | | .as_any() |
1469 | | .downcast_ref::<BooleanArray>() |
1470 | | .unwrap(); |
1471 | | let slices: Vec<_> = SlicesIterator::new(sliced_array).collect(); |
1472 | | let expected = vec![(0, 3), (33, 53), (70, 71)]; |
1473 | | assert_eq!(slices, expected); |
1474 | | } |
1475 | | |
1476 | | fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) { |
1477 | | let mut rng = rng(); |
1478 | | |
1479 | | let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random())) |
1480 | | .take(mask_len) |
1481 | | .collect(); |
1482 | | |
1483 | | let buffer = Buffer::from_iter(bools.iter().cloned()); |
1484 | | |
1485 | | let truncated_length = mask_len - offset - truncate; |
1486 | | |
1487 | | let data = ArrayDataBuilder::new(DataType::Boolean) |
1488 | | .len(truncated_length) |
1489 | | .offset(offset) |
1490 | | .add_buffer(buffer) |
1491 | | .build() |
1492 | | .unwrap(); |
1493 | | |
1494 | | let filter = BooleanArray::from(data); |
1495 | | |
1496 | | let slice_bits: Vec<_> = SlicesIterator::new(&filter) |
1497 | | .flat_map(|(start, end)| start..end) |
1498 | | .collect(); |
1499 | | |
1500 | | let count = filter_count(&filter); |
1501 | | let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect(); |
1502 | | |
1503 | | let expected_bits: Vec<_> = bools |
1504 | | .iter() |
1505 | | .skip(offset) |
1506 | | .take(truncated_length) |
1507 | | .enumerate() |
1508 | | .flat_map(|(idx, v)| v.then(|| idx)) |
1509 | | .collect(); |
1510 | | |
1511 | | assert_eq!(slice_bits, expected_bits); |
1512 | | assert_eq!(index_bits, expected_bits); |
1513 | | } |
1514 | | |
1515 | | #[test] |
1516 | | #[cfg_attr(miri, ignore)] |
1517 | | fn fuzz_test_slices_iterator() { |
1518 | | let mut rng = rng(); |
1519 | | |
1520 | | let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap(); |
1521 | | for _ in 0..100 { |
1522 | | let mask_len = rng.random_range(0..1024); |
1523 | | let max_offset = 64.min(mask_len); |
1524 | | let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0); |
1525 | | |
1526 | | let max_truncate = 128.min(mask_len - offset); |
1527 | | let truncate = uusize |
1528 | | .sample(&mut rng) |
1529 | | .checked_rem(max_truncate) |
1530 | | .unwrap_or(0); |
1531 | | |
1532 | | test_slices_fuzz(mask_len, offset, truncate); |
1533 | | } |
1534 | | |
1535 | | test_slices_fuzz(64, 0, 0); |
1536 | | test_slices_fuzz(64, 8, 0); |
1537 | | test_slices_fuzz(64, 8, 8); |
1538 | | test_slices_fuzz(32, 8, 8); |
1539 | | test_slices_fuzz(32, 5, 9); |
1540 | | } |
1541 | | |
1542 | | /// Filters `values` by `predicate` using standard rust iterators |
1543 | | fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> { |
1544 | | values |
1545 | | .into_iter() |
1546 | | .zip(predicate) |
1547 | | .filter(|(_, x)| **x) |
1548 | | .map(|(a, _)| a) |
1549 | | .collect() |
1550 | | } |
1551 | | |
1552 | | /// Generates an array of length `len` with `valid_percent` non-null values |
1553 | | fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>> |
1554 | | where |
1555 | | StandardUniform: Distribution<T>, |
1556 | | { |
1557 | | let mut rng = rng(); |
1558 | | (0..len) |
1559 | | .map(|_| rng.random_bool(valid_percent).then(|| rng.random())) |
1560 | | .collect() |
1561 | | } |
1562 | | |
1563 | | /// Generates an array of length `len` with `valid_percent` non-null values |
1564 | | fn gen_strings( |
1565 | | len: usize, |
1566 | | valid_percent: f64, |
1567 | | str_len_range: std::ops::Range<usize>, |
1568 | | ) -> Vec<Option<String>> { |
1569 | | let mut rng = rng(); |
1570 | | (0..len) |
1571 | | .map(|_| { |
1572 | | rng.random_bool(valid_percent).then(|| { |
1573 | | let len = rng.random_range(str_len_range.clone()); |
1574 | | (0..len) |
1575 | | .map(|_| char::from(rng.sample(Alphanumeric))) |
1576 | | .collect() |
1577 | | }) |
1578 | | }) |
1579 | | .collect() |
1580 | | } |
1581 | | |
1582 | | /// Returns an iterator that calls `Option::as_deref` on each item |
1583 | | fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> { |
1584 | | src.iter().map(|x| x.as_deref()) |
1585 | | } |
1586 | | |
1587 | | #[test] |
1588 | | #[cfg_attr(miri, ignore)] |
1589 | | fn fuzz_filter() { |
1590 | | let mut rng = rng(); |
1591 | | |
1592 | | for i in 0..100 { |
1593 | | let filter_percent = match i { |
1594 | | 0..=4 => 1., |
1595 | | 5..=10 => 0., |
1596 | | _ => rng.random_range(0.0..1.0), |
1597 | | }; |
1598 | | |
1599 | | let valid_percent = rng.random_range(0.0..1.0); |
1600 | | |
1601 | | let array_len = rng.random_range(32..256); |
1602 | | let array_offset = rng.random_range(0..10); |
1603 | | |
1604 | | // Construct a predicate |
1605 | | let filter_offset = rng.random_range(0..10); |
1606 | | let filter_truncate = rng.random_range(0..10); |
1607 | | let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent))) |
1608 | | .take(array_len + filter_offset - filter_truncate) |
1609 | | .collect(); |
1610 | | |
1611 | | let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some)); |
1612 | | |
1613 | | // Offset predicate |
1614 | | let predicate = predicate.slice(filter_offset, array_len - filter_truncate); |
1615 | | let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap(); |
1616 | | let bools = &bools[filter_offset..]; |
1617 | | |
1618 | | // Test i32 |
1619 | | let values = gen_primitive(array_len + array_offset, valid_percent); |
1620 | | let src = Int32Array::from_iter(values.iter().cloned()); |
1621 | | |
1622 | | let src = src.slice(array_offset, array_len); |
1623 | | let src = src.as_any().downcast_ref::<Int32Array>().unwrap(); |
1624 | | let values = &values[array_offset..]; |
1625 | | |
1626 | | let filtered = filter(src, predicate).unwrap(); |
1627 | | let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap(); |
1628 | | let actual: Vec<_> = array.iter().collect(); |
1629 | | |
1630 | | assert_eq!(actual, filter_rust(values.iter().cloned(), bools)); |
1631 | | |
1632 | | // Test string |
1633 | | let strings = gen_strings(array_len + array_offset, valid_percent, 0..20); |
1634 | | let src = StringArray::from_iter(as_deref(&strings)); |
1635 | | |
1636 | | let src = src.slice(array_offset, array_len); |
1637 | | let src = src.as_any().downcast_ref::<StringArray>().unwrap(); |
1638 | | |
1639 | | let filtered = filter(src, predicate).unwrap(); |
1640 | | let array = filtered.as_any().downcast_ref::<StringArray>().unwrap(); |
1641 | | let actual: Vec<_> = array.iter().collect(); |
1642 | | |
1643 | | let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools); |
1644 | | assert_eq!(actual, expected_strings); |
1645 | | |
1646 | | // Test string dictionary |
1647 | | let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings)); |
1648 | | |
1649 | | let src = src.slice(array_offset, array_len); |
1650 | | let src = src |
1651 | | .as_any() |
1652 | | .downcast_ref::<DictionaryArray<Int32Type>>() |
1653 | | .unwrap(); |
1654 | | |
1655 | | let filtered = filter(src, predicate).unwrap(); |
1656 | | |
1657 | | let array = filtered |
1658 | | .as_any() |
1659 | | .downcast_ref::<DictionaryArray<Int32Type>>() |
1660 | | .unwrap(); |
1661 | | |
1662 | | let values = array |
1663 | | .values() |
1664 | | .as_any() |
1665 | | .downcast_ref::<StringArray>() |
1666 | | .unwrap(); |
1667 | | |
1668 | | let actual: Vec<_> = array |
1669 | | .keys() |
1670 | | .iter() |
1671 | | .map(|key| key.map(|key| values.value(key as usize))) |
1672 | | .collect(); |
1673 | | |
1674 | | assert_eq!(actual, expected_strings); |
1675 | | } |
1676 | | } |
1677 | | |
1678 | | #[test] |
1679 | | fn test_filter_map() { |
1680 | | let mut builder = |
1681 | | MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4)); |
1682 | | // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1} |
1683 | | builder.keys().append_value("key1"); |
1684 | | builder.values().append_value(1); |
1685 | | builder.append(true).unwrap(); |
1686 | | builder.keys().append_value("key2"); |
1687 | | builder.keys().append_value("key3"); |
1688 | | builder.values().append_value(2); |
1689 | | builder.values().append_value(3); |
1690 | | builder.append(true).unwrap(); |
1691 | | builder.append(false).unwrap(); |
1692 | | builder.keys().append_value("key1"); |
1693 | | builder.values().append_value(1); |
1694 | | builder.append(true).unwrap(); |
1695 | | let maparray = Arc::new(builder.finish()) as ArrayRef; |
1696 | | |
1697 | | let indices = vec![Some(true), Some(false), Some(false), Some(true)] |
1698 | | .into_iter() |
1699 | | .collect::<BooleanArray>(); |
1700 | | let got = filter(&maparray, &indices).unwrap(); |
1701 | | |
1702 | | let mut builder = |
1703 | | MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2)); |
1704 | | builder.keys().append_value("key1"); |
1705 | | builder.values().append_value(1); |
1706 | | builder.append(true).unwrap(); |
1707 | | builder.keys().append_value("key1"); |
1708 | | builder.values().append_value(1); |
1709 | | builder.append(true).unwrap(); |
1710 | | let expected = Arc::new(builder.finish()) as ArrayRef; |
1711 | | |
1712 | | assert_eq!(&expected, &got); |
1713 | | } |
1714 | | |
1715 | | #[test] |
1716 | | fn test_filter_fixed_size_list_arrays() { |
1717 | | let value_data = ArrayData::builder(DataType::Int32) |
1718 | | .len(9) |
1719 | | .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8])) |
1720 | | .build() |
1721 | | .unwrap(); |
1722 | | let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false); |
1723 | | let list_data = ArrayData::builder(list_data_type) |
1724 | | .len(3) |
1725 | | .add_child_data(value_data) |
1726 | | .build() |
1727 | | .unwrap(); |
1728 | | let array = FixedSizeListArray::from(list_data); |
1729 | | |
1730 | | let filter_array = BooleanArray::from(vec![true, false, false]); |
1731 | | |
1732 | | let c = filter(&array, &filter_array).unwrap(); |
1733 | | let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap(); |
1734 | | |
1735 | | assert_eq!(filtered.len(), 1); |
1736 | | |
1737 | | let list = filtered.value(0); |
1738 | | assert_eq!( |
1739 | | &[0, 1, 2], |
1740 | | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1741 | | ); |
1742 | | |
1743 | | let filter_array = BooleanArray::from(vec![true, false, true]); |
1744 | | |
1745 | | let c = filter(&array, &filter_array).unwrap(); |
1746 | | let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap(); |
1747 | | |
1748 | | assert_eq!(filtered.len(), 2); |
1749 | | |
1750 | | let list = filtered.value(0); |
1751 | | assert_eq!( |
1752 | | &[0, 1, 2], |
1753 | | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1754 | | ); |
1755 | | let list = filtered.value(1); |
1756 | | assert_eq!( |
1757 | | &[6, 7, 8], |
1758 | | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1759 | | ); |
1760 | | } |
1761 | | |
1762 | | #[test] |
1763 | | fn test_filter_fixed_size_list_arrays_with_null() { |
1764 | | let value_data = ArrayData::builder(DataType::Int32) |
1765 | | .len(10) |
1766 | | .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) |
1767 | | .build() |
1768 | | .unwrap(); |
1769 | | |
1770 | | // Set null buts for the nested array: |
1771 | | // [[0, 1], null, null, [6, 7], [8, 9]] |
1772 | | // 01011001 00000001 |
1773 | | let mut null_bits: [u8; 1] = [0; 1]; |
1774 | | bit_util::set_bit(&mut null_bits, 0); |
1775 | | bit_util::set_bit(&mut null_bits, 3); |
1776 | | bit_util::set_bit(&mut null_bits, 4); |
1777 | | |
1778 | | let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false); |
1779 | | let list_data = ArrayData::builder(list_data_type) |
1780 | | .len(5) |
1781 | | .add_child_data(value_data) |
1782 | | .null_bit_buffer(Some(Buffer::from(null_bits))) |
1783 | | .build() |
1784 | | .unwrap(); |
1785 | | let array = FixedSizeListArray::from(list_data); |
1786 | | |
1787 | | let filter_array = BooleanArray::from(vec![true, true, false, true, false]); |
1788 | | |
1789 | | let c = filter(&array, &filter_array).unwrap(); |
1790 | | let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap(); |
1791 | | |
1792 | | assert_eq!(filtered.len(), 3); |
1793 | | |
1794 | | let list = filtered.value(0); |
1795 | | assert_eq!( |
1796 | | &[0, 1], |
1797 | | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1798 | | ); |
1799 | | assert!(filtered.is_null(1)); |
1800 | | let list = filtered.value(2); |
1801 | | assert_eq!( |
1802 | | &[6, 7], |
1803 | | list.as_any().downcast_ref::<Int32Array>().unwrap().values() |
1804 | | ); |
1805 | | } |
1806 | | |
1807 | | fn test_filter_union_array(array: UnionArray) { |
1808 | | let filter_array = BooleanArray::from(vec![true, false, false]); |
1809 | | let c = filter(&array, &filter_array).unwrap(); |
1810 | | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1811 | | |
1812 | | let mut builder = UnionBuilder::new_dense(); |
1813 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1814 | | let expected_array = builder.build().unwrap(); |
1815 | | |
1816 | | compare_union_arrays(filtered, &expected_array); |
1817 | | |
1818 | | let filter_array = BooleanArray::from(vec![true, false, true]); |
1819 | | let c = filter(&array, &filter_array).unwrap(); |
1820 | | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1821 | | |
1822 | | let mut builder = UnionBuilder::new_dense(); |
1823 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1824 | | builder.append::<Int32Type>("A", 34).unwrap(); |
1825 | | let expected_array = builder.build().unwrap(); |
1826 | | |
1827 | | compare_union_arrays(filtered, &expected_array); |
1828 | | |
1829 | | let filter_array = BooleanArray::from(vec![true, true, false]); |
1830 | | let c = filter(&array, &filter_array).unwrap(); |
1831 | | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1832 | | |
1833 | | let mut builder = UnionBuilder::new_dense(); |
1834 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1835 | | builder.append::<Float64Type>("B", 3.2).unwrap(); |
1836 | | let expected_array = builder.build().unwrap(); |
1837 | | |
1838 | | compare_union_arrays(filtered, &expected_array); |
1839 | | } |
1840 | | |
1841 | | #[test] |
1842 | | fn test_filter_union_array_dense() { |
1843 | | let mut builder = UnionBuilder::new_dense(); |
1844 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1845 | | builder.append::<Float64Type>("B", 3.2).unwrap(); |
1846 | | builder.append::<Int32Type>("A", 34).unwrap(); |
1847 | | let array = builder.build().unwrap(); |
1848 | | |
1849 | | test_filter_union_array(array); |
1850 | | } |
1851 | | |
1852 | | #[test] |
1853 | | fn test_filter_run_union_array_dense() { |
1854 | | let mut builder = UnionBuilder::new_dense(); |
1855 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1856 | | builder.append::<Int32Type>("A", 3).unwrap(); |
1857 | | builder.append::<Int32Type>("A", 34).unwrap(); |
1858 | | let array = builder.build().unwrap(); |
1859 | | |
1860 | | let filter_array = BooleanArray::from(vec![true, true, false]); |
1861 | | let c = filter(&array, &filter_array).unwrap(); |
1862 | | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1863 | | |
1864 | | let mut builder = UnionBuilder::new_dense(); |
1865 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1866 | | builder.append::<Int32Type>("A", 3).unwrap(); |
1867 | | let expected = builder.build().unwrap(); |
1868 | | |
1869 | | assert_eq!(filtered.to_data(), expected.to_data()); |
1870 | | } |
1871 | | |
1872 | | #[test] |
1873 | | fn test_filter_union_array_dense_with_nulls() { |
1874 | | let mut builder = UnionBuilder::new_dense(); |
1875 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1876 | | builder.append::<Float64Type>("B", 3.2).unwrap(); |
1877 | | builder.append_null::<Float64Type>("B").unwrap(); |
1878 | | builder.append::<Int32Type>("A", 34).unwrap(); |
1879 | | let array = builder.build().unwrap(); |
1880 | | |
1881 | | let filter_array = BooleanArray::from(vec![true, true, false, false]); |
1882 | | let c = filter(&array, &filter_array).unwrap(); |
1883 | | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1884 | | |
1885 | | let mut builder = UnionBuilder::new_dense(); |
1886 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1887 | | builder.append::<Float64Type>("B", 3.2).unwrap(); |
1888 | | let expected_array = builder.build().unwrap(); |
1889 | | |
1890 | | compare_union_arrays(filtered, &expected_array); |
1891 | | |
1892 | | let filter_array = BooleanArray::from(vec![true, false, true, false]); |
1893 | | let c = filter(&array, &filter_array).unwrap(); |
1894 | | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1895 | | |
1896 | | let mut builder = UnionBuilder::new_dense(); |
1897 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1898 | | builder.append_null::<Float64Type>("B").unwrap(); |
1899 | | let expected_array = builder.build().unwrap(); |
1900 | | |
1901 | | compare_union_arrays(filtered, &expected_array); |
1902 | | } |
1903 | | |
1904 | | #[test] |
1905 | | fn test_filter_union_array_sparse() { |
1906 | | let mut builder = UnionBuilder::new_sparse(); |
1907 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1908 | | builder.append::<Float64Type>("B", 3.2).unwrap(); |
1909 | | builder.append::<Int32Type>("A", 34).unwrap(); |
1910 | | let array = builder.build().unwrap(); |
1911 | | |
1912 | | test_filter_union_array(array); |
1913 | | } |
1914 | | |
1915 | | #[test] |
1916 | | fn test_filter_union_array_sparse_with_nulls() { |
1917 | | let mut builder = UnionBuilder::new_sparse(); |
1918 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1919 | | builder.append::<Float64Type>("B", 3.2).unwrap(); |
1920 | | builder.append_null::<Float64Type>("B").unwrap(); |
1921 | | builder.append::<Int32Type>("A", 34).unwrap(); |
1922 | | let array = builder.build().unwrap(); |
1923 | | |
1924 | | let filter_array = BooleanArray::from(vec![true, false, true, false]); |
1925 | | let c = filter(&array, &filter_array).unwrap(); |
1926 | | let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap(); |
1927 | | |
1928 | | let mut builder = UnionBuilder::new_sparse(); |
1929 | | builder.append::<Int32Type>("A", 1).unwrap(); |
1930 | | builder.append_null::<Float64Type>("B").unwrap(); |
1931 | | let expected_array = builder.build().unwrap(); |
1932 | | |
1933 | | compare_union_arrays(filtered, &expected_array); |
1934 | | } |
1935 | | |
1936 | | fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) { |
1937 | | assert_eq!(union1.len(), union2.len()); |
1938 | | |
1939 | | for i in 0..union1.len() { |
1940 | | let type_id = union1.type_id(i); |
1941 | | |
1942 | | let slot1 = union1.value(i); |
1943 | | let slot2 = union2.value(i); |
1944 | | |
1945 | | assert_eq!(slot1.is_null(0), slot2.is_null(0)); |
1946 | | |
1947 | | if !slot1.is_null(0) && !slot2.is_null(0) { |
1948 | | match type_id { |
1949 | | 0 => { |
1950 | | let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap(); |
1951 | | assert_eq!(slot1.len(), 1); |
1952 | | let value1 = slot1.value(0); |
1953 | | |
1954 | | let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap(); |
1955 | | assert_eq!(slot2.len(), 1); |
1956 | | let value2 = slot2.value(0); |
1957 | | assert_eq!(value1, value2); |
1958 | | } |
1959 | | 1 => { |
1960 | | let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap(); |
1961 | | assert_eq!(slot1.len(), 1); |
1962 | | let value1 = slot1.value(0); |
1963 | | |
1964 | | let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap(); |
1965 | | assert_eq!(slot2.len(), 1); |
1966 | | let value2 = slot2.value(0); |
1967 | | assert_eq!(value1, value2); |
1968 | | } |
1969 | | _ => unreachable!(), |
1970 | | } |
1971 | | } |
1972 | | } |
1973 | | } |
1974 | | |
1975 | | #[test] |
1976 | | fn test_filter_struct() { |
1977 | | let predicate = BooleanArray::from(vec![true, false, true, false]); |
1978 | | |
1979 | | let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"])); |
1980 | | let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"])); |
1981 | | |
1982 | | let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); |
1983 | | let b_filtered = Arc::new(Int32Array::from(vec![5, 7])); |
1984 | | |
1985 | | let null_mask = NullBuffer::from(vec![true, false, false, true]); |
1986 | | let null_mask_filtered = NullBuffer::from(vec![true, false]); |
1987 | | |
1988 | | let a_field = Field::new("a", DataType::Utf8, false); |
1989 | | let b_field = Field::new("b", DataType::Int32, false); |
1990 | | |
1991 | | let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None); |
1992 | | let expected = |
1993 | | StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None); |
1994 | | |
1995 | | let result = filter(&array, &predicate).unwrap(); |
1996 | | |
1997 | | assert_eq!(result.to_data(), expected.to_data()); |
1998 | | |
1999 | | let array = StructArray::new( |
2000 | | vec![a_field.clone()].into(), |
2001 | | vec![a.clone()], |
2002 | | Some(null_mask.clone()), |
2003 | | ); |
2004 | | let expected = StructArray::new( |
2005 | | vec![a_field.clone()].into(), |
2006 | | vec![a_filtered.clone()], |
2007 | | Some(null_mask_filtered.clone()), |
2008 | | ); |
2009 | | |
2010 | | let result = filter(&array, &predicate).unwrap(); |
2011 | | |
2012 | | assert_eq!(result.to_data(), expected.to_data()); |
2013 | | |
2014 | | let array = StructArray::new( |
2015 | | vec![a_field.clone(), b_field.clone()].into(), |
2016 | | vec![a.clone(), b.clone()], |
2017 | | None, |
2018 | | ); |
2019 | | let expected = StructArray::new( |
2020 | | vec![a_field.clone(), b_field.clone()].into(), |
2021 | | vec![a_filtered.clone(), b_filtered.clone()], |
2022 | | None, |
2023 | | ); |
2024 | | |
2025 | | let result = filter(&array, &predicate).unwrap(); |
2026 | | |
2027 | | assert_eq!(result.to_data(), expected.to_data()); |
2028 | | |
2029 | | let array = StructArray::new( |
2030 | | vec![a_field.clone(), b_field.clone()].into(), |
2031 | | vec![a.clone(), b.clone()], |
2032 | | Some(null_mask.clone()), |
2033 | | ); |
2034 | | |
2035 | | let expected = StructArray::new( |
2036 | | vec![a_field.clone(), b_field.clone()].into(), |
2037 | | vec![a_filtered.clone(), b_filtered.clone()], |
2038 | | Some(null_mask_filtered.clone()), |
2039 | | ); |
2040 | | |
2041 | | let result = filter(&array, &predicate).unwrap(); |
2042 | | |
2043 | | assert_eq!(result.to_data(), expected.to_data()); |
2044 | | } |
2045 | | |
2046 | | #[test] |
2047 | | fn test_filter_empty_struct() { |
2048 | | /* |
2049 | | "a": { |
2050 | | "b": int64, |
2051 | | "c": {} |
2052 | | }, |
2053 | | */ |
2054 | | let fields = arrow_schema::Field::new( |
2055 | | "a", |
2056 | | arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![ |
2057 | | arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true), |
2058 | | arrow_schema::Field::new( |
2059 | | "c", |
2060 | | arrow_schema::DataType::Struct(arrow_schema::Fields::empty()), |
2061 | | true, |
2062 | | ), |
2063 | | ])), |
2064 | | true, |
2065 | | ); |
2066 | | |
2067 | | /* Test record |
2068 | | {"a":{"c": {}}} |
2069 | | {"a":{"c": {}}} |
2070 | | {"a":{"c": {}}} |
2071 | | */ |
2072 | | |
2073 | | // Create the record batch with the nested struct array |
2074 | | let schema = Arc::new(Schema::new(vec![fields])); |
2075 | | |
2076 | | let b = Arc::new(Int64Array::from(vec![None, None, None])); |
2077 | | let c = Arc::new(StructArray::new_empty_fields( |
2078 | | 3, |
2079 | | Some(NullBuffer::from(vec![true, true, true])), |
2080 | | )); |
2081 | | let a = StructArray::new( |
2082 | | vec![ |
2083 | | Field::new("b", DataType::Int64, true), |
2084 | | Field::new("c", DataType::Struct(Fields::empty()), true), |
2085 | | ] |
2086 | | .into(), |
2087 | | vec![b.clone(), c.clone()], |
2088 | | Some(NullBuffer::from(vec![true, true, true])), |
2089 | | ); |
2090 | | let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap(); |
2091 | | println!("{record_batch:?}"); |
2092 | | |
2093 | | // Apply the filter |
2094 | | let predicate = BooleanArray::from(vec![true, false, true]); |
2095 | | let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap(); |
2096 | | |
2097 | | // The filtered batch should have 2 rows (the 1st and 3rd) |
2098 | | assert_eq!(filtered_batch.num_rows(), 2); |
2099 | | } |
2100 | | } |