Coverage Report

Created: 2025-11-17 14:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-select/src/filter.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines filter kernels
19
20
use std::ops::AddAssign;
21
use std::sync::Arc;
22
23
use arrow_array::builder::BooleanBufferBuilder;
24
use arrow_array::cast::AsArray;
25
use arrow_array::types::{
26
    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27
};
28
use arrow_array::*;
29
use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer, bit_util};
30
use arrow_buffer::{Buffer, MutableBuffer};
31
use arrow_data::ArrayDataBuilder;
32
use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
33
use arrow_data::transform::MutableArrayData;
34
use arrow_schema::*;
35
36
/// If the filter selects more than this fraction of rows, use
37
/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38
/// over individual rows using [`IndexIterator`]
39
///
40
/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41
///
42
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44
/// An iterator of `(usize, usize)` each representing an interval
45
/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46
///
47
/// Each interval corresponds to a contiguous region of memory to be
48
/// "taken" from an array to be filtered.
49
///
50
/// ## Notes:
51
///
52
/// 1. Ignores the validity bitmap (ignores nulls)
53
///
54
/// 2. Only performant for filters that copy across long contiguous runs
55
#[derive(Debug)]
56
pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58
impl<'a> SlicesIterator<'a> {
59
    /// Creates a new iterator from a [BooleanArray]
60
234
    pub fn new(filter: &'a BooleanArray) -> Self {
61
234
        filter.values().into()
62
234
    }
63
}
64
65
impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
66
281
    fn from(filter: &'a BooleanBuffer) -> Self {
67
281
        Self(filter.set_slices())
68
281
    }
69
}
70
71
impl Iterator for SlicesIterator<'_> {
72
    type Item = (usize, usize);
73
74
14.7k
    fn next(&mut self) -> Option<Self::Item> {
75
14.7k
        self.0.next()
76
14.7k
    }
77
}
78
79
/// An iterator of `usize` whose index in [`BooleanArray`] is true
80
///
81
/// This provides the best performance on most predicates, apart from those which keep
82
/// large runs and therefore favour [`SlicesIterator`]
83
struct IndexIterator<'a> {
84
    remaining: usize,
85
    iter: BitIndexIterator<'a>,
86
}
87
88
impl<'a> IndexIterator<'a> {
89
757
    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
90
757
        assert_eq!(filter.null_count(), 0);
91
757
        let iter = filter.values().set_indices();
92
757
        Self { remaining, iter }
93
757
    }
94
}
95
96
impl Iterator for IndexIterator<'_> {
97
    type Item = usize;
98
99
70.4k
    fn next(&mut self) -> Option<Self::Item> {
100
70.4k
        if self.remaining != 0 {
101
            // Fascinatingly swapping these two lines around results in a 50%
102
            // performance regression for some benchmarks
103
69.9k
            let next = self.iter.next().expect("IndexIterator exhausted early");
104
69.9k
            self.remaining -= 1;
105
            // Must panic if exhausted early as trusted length iterator
106
69.9k
            return Some(next);
107
523
        }
108
523
        None
109
70.4k
    }
110
111
593
    fn size_hint(&self) -> (usize, Option<usize>) {
112
593
        (self.remaining, Some(self.remaining))
113
593
    }
114
}
115
116
/// Counts the number of set bits in `filter`
117
567
fn filter_count(filter: &BooleanArray) -> usize {
118
567
    filter.values().count_set_bits()
119
567
}
120
121
/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
122
6
pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
123
6
    let nulls = filter.nulls().unwrap();
124
6
    let mask = filter.values() & nulls.inner();
125
6
    BooleanArray::new(mask, None)
126
6
}
127
128
/// Returns a filtered `values` [`Array`] where the corresponding elements of
129
/// `predicate` are `true`.
130
///
131
/// If multiple arrays (or record batches) need to be filtered using the same predicate array,
132
/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
133
/// calling [FilterPredicate::filter_record_batch].
134
///
135
/// In contrast to this function, it is then the responsibility of the caller
136
/// to use [FilterBuilder::optimize] if appropriate.
137
///
138
/// # See also
139
/// * [`FilterBuilder`] for more control over the filtering process.
140
/// * [`filter_record_batch`] to filter a [`RecordBatch`]
141
/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
142
///   the results into a single array.
143
///
144
/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
145
///
146
/// # Example
147
/// ```rust
148
/// # use arrow_array::{Int32Array, BooleanArray};
149
/// # use arrow_select::filter::filter;
150
/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
151
/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
152
/// let c = filter(&array, &filter_array).unwrap();
153
/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
154
/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
155
/// ```
156
375
pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
157
375
    let mut filter_builder = FilterBuilder::new(predicate);
158
159
375
    if FilterBuilder::is_optimize_beneficial(values.data_type()) {
160
6
        // Only optimize if filtering more than one array
161
6
        // Otherwise, the overhead of optimization can be more than the benefit
162
6
        filter_builder = filter_builder.optimize();
163
369
    }
164
165
375
    let predicate = filter_builder.build();
166
167
375
    filter_array(values, &predicate)
168
375
}
169
170
/// Returns a filtered [RecordBatch] where the corresponding elements of
171
/// `predicate` are true.
172
///
173
/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
174
///
175
/// If multiple record batches (or arrays) need to be filtered using the same predicate array,
176
/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
177
/// calling [FilterPredicate::filter_record_batch].
178
/// In contrast to this function, it is then the responsibility of the caller
179
/// to use [FilterBuilder::optimize] if appropriate.
180
82
pub fn filter_record_batch(
181
82
    record_batch: &RecordBatch,
182
82
    predicate: &BooleanArray,
183
82
) -> Result<RecordBatch, ArrowError> {
184
82
    let mut filter_builder = FilterBuilder::new(predicate);
185
82
    let num_cols = record_batch.num_columns();
186
82
    if num_cols > 1
187
2
        || (num_cols > 0
188
1
            && FilterBuilder::is_optimize_beneficial(
189
1
                record_batch.schema_ref().field(0).data_type(),
190
            ))
191
81
    {
192
81
        // Only optimize if filtering more than one column or if the column contains multiple internal arrays
193
81
        // Otherwise, the overhead of optimization can be more than the benefit
194
81
        filter_builder = filter_builder.optimize();
195
81
    
}1
196
82
    let filter = filter_builder.build();
197
198
82
    filter.filter_record_batch(record_batch)
199
82
}
200
201
/// A builder to construct [`FilterPredicate`]
202
#[derive(Debug)]
203
pub struct FilterBuilder {
204
    filter: BooleanArray,
205
    count: usize,
206
    strategy: IterationStrategy,
207
}
208
209
impl FilterBuilder {
210
    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
211
459
    pub fn new(filter: &BooleanArray) -> Self {
212
459
        let filter = match filter.null_count() {
213
457
            0 => filter.clone(),
214
2
            _ => prep_null_mask_filter(filter),
215
        };
216
217
459
        let count = filter_count(&filter);
218
459
        let strategy = IterationStrategy::default_strategy(filter.len(), count);
219
220
459
        Self {
221
459
            filter,
222
459
            count,
223
459
            strategy,
224
459
        }
225
459
    }
226
227
    /// Compute an optimized representation of the provided `filter` mask that can be
228
    /// applied to an array more quickly.
229
    ///
230
    /// When filtering multiple arrays (e.g. a [`RecordBatch`] or a
231
    /// [`StructArray`] with multiple fields), optimizing the filter can provide
232
    /// significant performance benefits.
233
    ///
234
    /// However, optimization takes time and can have a larger memory footprint
235
    /// than the original mask, so it is often faster to filter a single array,
236
    /// without filter optimization.
237
89
    pub fn optimize(mut self) -> Self {
238
89
        match self.strategy {
239
            IterationStrategy::SlicesIterator => {
240
20
                let slices = SlicesIterator::new(&self.filter).collect();
241
20
                self.strategy = IterationStrategy::Slices(slices)
242
            }
243
            IterationStrategy::IndexIterator => {
244
69
                let indices = IndexIterator::new(&self.filter, self.count).collect();
245
69
                self.strategy = IterationStrategy::Indices(indices)
246
            }
247
0
            _ => {}
248
        }
249
89
        self
250
89
    }
251
252
    /// Determines if calling [FilterBuilder::optimize] is beneficial for the
253
    /// given type even when filtering just a single array.
254
    ///
255
    /// See [`FilterBuilder::optimize`] for more details.
256
378
    pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
257
10
        match data_type {
258
5
            DataType::Struct(fields) => {
259
5
                fields.len() > 1
260
2
                    || fields.len() == 1
261
2
                        && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
262
            }
263
4
            DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
264
369
            _ => false,
265
        }
266
378
    }
267
268
    /// Construct the final `FilterPredicate`
269
459
    pub fn build(self) -> FilterPredicate {
270
459
        FilterPredicate {
271
459
            filter: self.filter,
272
459
            count: self.count,
273
459
            strategy: self.strategy,
274
459
        }
275
459
    }
276
}
277
278
/// The iteration strategy used to evaluate [`FilterPredicate`]
279
#[derive(Debug)]
280
enum IterationStrategy {
281
    /// A lazily evaluated iterator of ranges
282
    SlicesIterator,
283
    /// A lazily evaluated iterator of indices
284
    IndexIterator,
285
    /// A precomputed list of indices
286
    Indices(Vec<usize>),
287
    /// A precomputed array of ranges
288
    Slices(Vec<(usize, usize)>),
289
    /// Select all rows
290
    All,
291
    /// Select no rows
292
    None,
293
}
294
295
impl IterationStrategy {
296
    /// The default [`IterationStrategy`] for a filter of length `filter_length`
297
    /// and selecting `filter_count` rows
298
459
    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
299
459
        if filter_length == 0 || filter_count == 0 {
300
31
            return IterationStrategy::None;
301
428
        }
302
303
428
        if filter_count == filter_length {
304
19
            return IterationStrategy::All;
305
409
        }
306
307
        // Compute the selectivity of the predicate by dividing the number of true
308
        // bits in the predicate by the predicate's total length
309
        //
310
        // This can then be used as a heuristic for the optimal iteration strategy
311
409
        let selectivity_frac = filter_count as f64 / filter_length as f64;
312
409
        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
313
60
            return IterationStrategy::SlicesIterator;
314
349
        }
315
349
        IterationStrategy::IndexIterator
316
459
    }
317
}
318
319
/// A filtering predicate that can be applied to an [`Array`]
320
#[derive(Debug)]
321
pub struct FilterPredicate {
322
    filter: BooleanArray,
323
    count: usize,
324
    strategy: IterationStrategy,
325
}
326
327
impl FilterPredicate {
328
    /// Selects rows from `values` based on this [`FilterPredicate`]
329
2
    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
330
2
        filter_array(values, self)
331
2
    }
332
333
    /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this
334
    /// [`FilterPredicate`].
335
    ///
336
    /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`].
337
82
    pub fn filter_record_batch(
338
82
        &self,
339
82
        record_batch: &RecordBatch,
340
82
    ) -> Result<RecordBatch, ArrowError> {
341
82
        let filtered_arrays = record_batch
342
82
            .columns()
343
82
            .iter()
344
321
            .
map82
(|a| filter_array(a, self))
345
82
            .collect::<Result<Vec<_>, _>>()
?0
;
346
347
        // SAFETY: we know that the set of filtered arrays will match the schema of the original
348
        // record batch
349
        unsafe {
350
82
            Ok(RecordBatch::new_unchecked(
351
82
                record_batch.schema(),
352
82
                filtered_arrays,
353
82
                self.count,
354
82
            ))
355
        }
356
82
    }
357
358
    /// Number of rows being selected based on this [`FilterPredicate`]
359
6
    pub fn count(&self) -> usize {
360
6
        self.count
361
6
    }
362
}
363
364
714
fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
365
714
    if predicate.filter.len() > values.len() {
366
0
        return Err(ArrowError::InvalidArgumentError(format!(
367
0
            "Filter predicate of length {} is larger than target array of length {}",
368
0
            predicate.filter.len(),
369
0
            values.len()
370
0
        )));
371
714
    }
372
373
714
    match predicate.strategy {
374
31
        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
375
19
        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
376
        // actually filter
377
664
        _ => downcast_primitive_array! {
378
0
            values => Ok(Arc::new(filter_primitive(values, predicate))),
379
            DataType::Boolean => {
380
0
                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
381
0
                Ok(Arc::new(filter_boolean(values, predicate)))
382
            }
383
            DataType::Utf8 => {
384
175
                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
385
            }
386
            DataType::LargeUtf8 => {
387
0
                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
388
            }
389
            DataType::Utf8View => {
390
82
                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
391
            }
392
            DataType::Binary => {
393
1
                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
394
            }
395
            DataType::LargeBinary => {
396
0
                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
397
            }
398
            DataType::BinaryView => {
399
2
                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
400
            }
401
            DataType::FixedSizeBinary(_) => {
402
4
                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
403
            }
404
            DataType::ListView(_) => {
405
2
                Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
406
            }
407
            DataType::LargeListView(_) => {
408
2
                Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
409
            }
410
            DataType::RunEndEncoded(_, _) => {
411
4
                downcast_run_array!{
412
1
                    values => Ok(Arc::new(filter_run_end_array(values, predicate)
?0
)),
413
0
                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
414
                }
415
            }
416
87
            DataType::Dictionary(_, _) => downcast_dictionary_array! {
417
1
                values => Ok(Arc::new(filter_dict(values, predicate))),
418
0
                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
419
            }
420
            DataType::Struct(_) => {
421
6
                Ok(Arc::new(filter_struct(values.as_struct(), predicate)
?0
))
422
            }
423
            DataType::Union(_, UnionMode::Sparse) => {
424
4
                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)
?0
))
425
            }
426
            _ => {
427
11
                let data = values.to_data();
428
                // fallback to using MutableArrayData
429
11
                let mut mutable = MutableArrayData::new(
430
11
                    vec![&data],
431
                    false,
432
11
                    predicate.count,
433
                );
434
435
11
                match &predicate.strategy {
436
0
                    IterationStrategy::Slices(slices) => {
437
0
                        slices
438
0
                            .iter()
439
0
                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
440
                    }
441
                    _ => {
442
11
                        let iter = SlicesIterator::new(&predicate.filter);
443
17
                        
iter11
.
for_each11
(|(start, end)| mutable.extend(0, start, end));
444
                    }
445
                }
446
447
11
                let data = mutable.freeze();
448
11
                Ok(make_array(data))
449
            }
450
        },
451
    }
452
714
}
453
454
/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
455
4
fn filter_run_end_array<R: RunEndIndexType>(
456
4
    array: &RunArray<R>,
457
4
    predicate: &FilterPredicate,
458
4
) -> Result<RunArray<R>, ArrowError>
459
4
where
460
4
    R::Native: Into<i64> + From<bool>,
461
4
    R::Native: AddAssign,
462
{
463
4
    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
464
4
    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
465
466
4
    let mut start = 0u64;
467
4
    let mut j = 0;
468
4
    let mut count = R::default_value();
469
4
    let filter_values = predicate.filter.values();
470
4
    let run_ends = run_ends.inner();
471
472
15
    let 
pred4
:
BooleanArray4
=
BooleanBuffer::collect_bool4
(
run_ends4
.
len4
(), |i| {
473
15
        let mut keep = false;
474
15
        let mut end = run_ends[i].into() as u64;
475
15
        let difference = end.saturating_sub(filter_values.len() as u64);
476
15
        end -= difference;
477
478
        // Safety: we subtract the difference off `end` so we are always within bounds
479
31
        for pred in 
(start..end)15
.
map15
(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
480
31
            count += R::Native::from(pred);
481
31
            keep |= pred
482
        }
483
        // this is to avoid branching
484
15
        new_run_ends[j] = count;
485
15
        j += keep as usize;
486
487
15
        start = end;
488
15
        keep
489
15
    })
490
4
    .into();
491
492
4
    new_run_ends.truncate(j);
493
494
4
    let values = array.values();
495
4
    let values = filter(&values, &pred)
?0
;
496
497
4
    let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)
?0
;
498
4
    RunArray::try_new(&run_ends, &values)
499
4
}
500
501
/// Computes a new null mask for `data` based on `predicate`
502
///
503
/// If the predicate selected no null-rows, returns `None`, otherwise returns
504
/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
505
/// in the filtered output, and `null_buffer` is the filtered null buffer
506
///
507
649
fn filter_null_mask(
508
649
    nulls: Option<&NullBuffer>,
509
649
    predicate: &FilterPredicate,
510
649
) -> Option<(usize, Buffer)> {
511
649
    let 
nulls605
= nulls
?44
;
512
605
    if nulls.null_count() == 0 {
513
1
        return None;
514
604
    }
515
516
604
    let nulls = filter_bits(nulls.inner(), predicate);
517
    // The filtered `nulls` has a length of `predicate.count` bits and
518
    // therefore the null count is this minus the number of valid bits
519
604
    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
520
521
604
    if null_count == 0 {
522
57
        return None;
523
547
    }
524
525
547
    Some((null_count, nulls))
526
649
}
527
528
/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
529
604
fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
530
604
    let src = buffer.values();
531
604
    let offset = buffer.offset();
532
533
604
    match &predicate.strategy {
534
        IterationStrategy::IndexIterator => {
535
234
            let bits = IndexIterator::new(&predicate.filter, predicate.count)
536
12.6k
                .
map234
(|src_idx| bit_util::get_bit(src, src_idx + offset));
537
538
            // SAFETY: `IndexIterator` reports its size correctly
539
234
            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
540
        }
541
250
        IterationStrategy::Indices(indices) => {
542
250
            let bits = indices
543
250
                .iter()
544
70.8k
                .
map250
(|src_idx| bit_util::get_bit(src, *src_idx + offset));
545
546
            // SAFETY: `Vec::iter()` reports its size correctly
547
250
            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
548
        }
549
        IterationStrategy::SlicesIterator => {
550
40
            let mut builder = BooleanBufferBuilder::new(predicate.count);
551
659
            for (start, end) in 
SlicesIterator::new40
(
&predicate.filter40
) {
552
659
                builder.append_packed_range(start + offset..end + offset, src)
553
            }
554
40
            builder.into()
555
        }
556
80
        IterationStrategy::Slices(slices) => {
557
80
            let mut builder = BooleanBufferBuilder::new(predicate.count);
558
6.04k
            for (
start5.96k
,
end5.96k
) in slices {
559
5.96k
                builder.append_packed_range(*start + offset..*end + offset, src)
560
            }
561
80
            builder.into()
562
        }
563
0
        IterationStrategy::All | IterationStrategy::None => unreachable!(),
564
    }
565
604
}
566
567
/// `filter` implementation for boolean buffers
568
0
fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
569
0
    let values = filter_bits(array.values(), predicate);
570
571
0
    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
572
0
        .len(predicate.count)
573
0
        .add_buffer(values);
574
575
0
    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
576
0
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
577
0
    }
578
579
0
    let data = unsafe { builder.build_unchecked() };
580
0
    BooleanArray::from(data)
581
0
}
582
583
#[inline(never)]
584
467
fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
585
467
    assert!(values.len() >= predicate.filter.len());
586
587
467
    match &predicate.strategy {
588
        IterationStrategy::SlicesIterator => {
589
27
            let mut buffer = Vec::with_capacity(predicate.count);
590
440
            for (start, end) in 
SlicesIterator::new27
(
&predicate.filter27
) {
591
440
                buffer.extend_from_slice(&values[start..end]);
592
440
            }
593
27
            buffer.into()
594
        }
595
60
        IterationStrategy::Slices(slices) => {
596
60
            let mut buffer = Vec::with_capacity(predicate.count);
597
4.53k
            for (
start4.47k
,
end4.47k
) in slices {
598
4.47k
                buffer.extend_from_slice(&values[*start..*end]);
599
4.47k
            }
600
60
            buffer.into()
601
        }
602
        IterationStrategy::IndexIterator => {
603
8.46k
            let 
iter185
=
IndexIterator::new185
(
&predicate.filter185
,
predicate.count185
).
map185
(|x| values[x]);
604
605
            // SAFETY: IndexIterator is trusted length
606
185
            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
607
        }
608
195
        IterationStrategy::Indices(indices) => {
609
53.1k
            let 
iter195
=
indices.iter()195
.
map195
(|x| values[*x]);
610
195
            iter.collect::<Vec<_>>().into()
611
        }
612
0
        IterationStrategy::All | IterationStrategy::None => unreachable!(),
613
    }
614
467
}
615
616
/// `filter` implementation for primitive arrays
617
375
fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
618
375
where
619
375
    T: ArrowPrimitiveType,
620
{
621
375
    let values = array.values();
622
375
    let buffer = filter_native(values, predicate);
623
375
    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
624
375
        .len(predicate.count)
625
375
        .add_buffer(buffer);
626
627
375
    if let Some((
null_count322
,
nulls322
)) = filter_null_mask(array.nulls(), predicate) {
628
322
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
629
322
    
}53
630
631
375
    let data = unsafe { builder.build_unchecked() };
632
375
    PrimitiveArray::from(data)
633
375
}
634
635
/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
636
/// used to build a new [`GenericByteArray`] by copying values from the source
637
///
638
/// TODO(raphael): Could this be used for the take kernel as well?
639
struct FilterBytes<'a, OffsetSize> {
640
    src_offsets: &'a [OffsetSize],
641
    src_values: &'a [u8],
642
    dst_offsets: Vec<OffsetSize>,
643
    dst_values: Vec<u8>,
644
    cur_offset: OffsetSize,
645
}
646
647
impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
648
where
649
    OffsetSize: OffsetSizeTrait,
650
{
651
176
    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
652
176
    where
653
176
        T: ByteArrayType<Offset = OffsetSize>,
654
    {
655
176
        let dst_values = Vec::new();
656
176
        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
657
176
        let cur_offset = OffsetSize::from_usize(0).unwrap();
658
659
176
        dst_offsets.push(cur_offset);
660
661
176
        Self {
662
176
            src_offsets: array.value_offsets(),
663
176
            src_values: array.value_data(),
664
176
            dst_offsets,
665
176
            dst_values,
666
176
            cur_offset,
667
176
        }
668
176
    }
669
670
    /// Returns the byte offset at `idx`
671
    #[inline]
672
36.0k
    fn get_value_offset(&self, idx: usize) -> usize {
673
36.0k
        self.src_offsets[idx].as_usize()
674
36.0k
    }
675
676
    /// Returns the start and end of the value at index `idx` along with its length
677
    #[inline]
678
16.3k
    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
679
        // These can only fail if `array` contains invalid data
680
16.3k
        let start = self.get_value_offset(idx);
681
16.3k
        let end = self.get_value_offset(idx + 1);
682
16.3k
        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
683
16.3k
        (start, end, len)
684
16.3k
    }
685
686
143
    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
687
21.9k
        
self.dst_offsets143
.
extend143
(
iter143
.
map143
(|idx| {
688
21.9k
            let start = self.src_offsets[idx].as_usize();
689
21.9k
            let end = self.src_offsets[idx + 1].as_usize();
690
21.9k
            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
691
21.9k
            self.cur_offset += len;
692
693
21.9k
            self.cur_offset
694
21.9k
        }));
695
143
    }
696
697
    /// Extends the in-progress array by the indexes in the provided iterator
698
143
    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
699
143
        self.dst_values.reserve_exact(self.cur_offset.as_usize());
700
701
22.0k
        for 
idx21.9k
in iter {
702
21.9k
            let start = self.src_offsets[idx].as_usize();
703
21.9k
            let end = self.src_offsets[idx + 1].as_usize();
704
21.9k
            self.dst_values
705
21.9k
                .extend_from_slice(&self.src_values[start..end]);
706
21.9k
        }
707
143
    }
708
709
33
    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
710
33
        self.dst_offsets.reserve_exact(count);
711
1.74k
        for (
start1.71k
,
end1.71k
) in iter {
712
            // These can only fail if `array` contains invalid data
713
16.3k
            for idx in 
start1.71k
..
end1.71k
{
714
16.3k
                let (_, _, len) = self.get_value_range(idx);
715
16.3k
                self.cur_offset += len;
716
16.3k
                self.dst_offsets.push(self.cur_offset);
717
16.3k
            }
718
        }
719
33
    }
720
721
    /// Extends the in-progress array by the ranges in the provided iterator
722
33
    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
723
33
        self.dst_values.reserve_exact(self.cur_offset.as_usize());
724
725
1.74k
        for (
start1.71k
,
end1.71k
) in iter {
726
1.71k
            let value_start = self.get_value_offset(start);
727
1.71k
            let value_end = self.get_value_offset(end);
728
1.71k
            self.dst_values
729
1.71k
                .extend_from_slice(&self.src_values[value_start..value_end]);
730
1.71k
        }
731
33
    }
732
}
733
734
/// `filter` implementation for byte arrays
735
///
736
/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
737
/// data copied across. This allows handling the null mask separately from the data
738
176
fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
739
176
where
740
176
    T: ByteArrayType,
741
{
742
176
    let mut filter = FilterBytes::new(predicate.count, array);
743
744
176
    match &predicate.strategy {
745
        IterationStrategy::SlicesIterator => {
746
13
            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
747
13
            filter.extend_slices(SlicesIterator::new(&predicate.filter))
748
        }
749
20
        IterationStrategy::Slices(slices) => {
750
20
            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
751
20
            filter.extend_slices(slices.iter().cloned())
752
        }
753
        IterationStrategy::IndexIterator => {
754
81
            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
755
81
            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
756
        }
757
62
        IterationStrategy::Indices(indices) => {
758
62
            filter.extend_offsets_idx(indices.iter().cloned());
759
62
            filter.extend_idx(indices.iter().cloned())
760
        }
761
0
        IterationStrategy::All | IterationStrategy::None => unreachable!(),
762
    }
763
764
176
    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
765
176
        .len(predicate.count)
766
176
        .add_buffer(filter.dst_offsets.into())
767
176
        .add_buffer(filter.dst_values.into());
768
769
176
    if let Some((
null_count147
,
nulls147
)) = filter_null_mask(array.nulls(), predicate) {
770
147
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
771
147
    
}29
772
773
176
    let data = unsafe { builder.build_unchecked() };
774
176
    GenericByteArray::from(data)
775
176
}
776
777
/// `filter` implementation for byte view arrays.
778
84
fn filter_byte_view<T: ByteViewType>(
779
84
    array: &GenericByteViewArray<T>,
780
84
    predicate: &FilterPredicate,
781
84
) -> GenericByteViewArray<T> {
782
84
    let new_view_buffer = filter_native(array.views(), predicate);
783
784
84
    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
785
84
        .len(predicate.count)
786
84
        .add_buffer(new_view_buffer)
787
84
        .add_buffers(array.data_buffers().to_vec());
788
789
84
    if let Some((
null_count76
,
nulls76
)) = filter_null_mask(array.nulls(), predicate) {
790
76
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
791
76
    
}8
792
793
84
    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
794
84
}
795
796
4
fn filter_fixed_size_binary(
797
4
    array: &FixedSizeBinaryArray,
798
4
    predicate: &FilterPredicate,
799
4
) -> FixedSizeBinaryArray {
800
4
    let values: &[u8] = array.values();
801
4
    let value_length = array.value_length() as usize;
802
12
    let 
calculate_offset_from_index4
= |index: usize| index * value_length;
803
4
    let buffer = match &predicate.strategy {
804
        IterationStrategy::SlicesIterator => {
805
0
            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
806
0
            for (start, end) in SlicesIterator::new(&predicate.filter) {
807
0
                buffer.extend_from_slice(
808
0
                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
809
0
                );
810
0
            }
811
0
            buffer
812
        }
813
0
        IterationStrategy::Slices(slices) => {
814
0
            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
815
0
            for (start, end) in slices {
816
0
                buffer.extend_from_slice(
817
0
                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
818
0
                );
819
0
            }
820
0
            buffer
821
        }
822
        IterationStrategy::IndexIterator => {
823
3
            let 
iter2
=
IndexIterator::new2
(
&predicate.filter2
,
predicate.count2
).
map2
(|x| {
824
3
                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
825
3
            });
826
827
2
            let mut buffer = MutableBuffer::new(predicate.count * value_length);
828
3
            
iter2
.
for_each2
(|item| buffer.extend_from_slice(item));
829
2
            buffer
830
        }
831
2
        IterationStrategy::Indices(indices) => {
832
3
            let 
iter2
=
indices.iter()2
.
map2
(|x| {
833
3
                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
834
3
            });
835
836
2
            let mut buffer = MutableBuffer::new(predicate.count * value_length);
837
3
            
iter2
.
for_each2
(|item| buffer.extend_from_slice(item));
838
2
            buffer
839
        }
840
0
        IterationStrategy::All | IterationStrategy::None => unreachable!(),
841
    };
842
4
    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
843
4
        .len(predicate.count)
844
4
        .add_buffer(buffer.into());
845
846
4
    if let Some((
null_count0
,
nulls0
)) = filter_null_mask(array.nulls(), predicate) {
847
0
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
848
4
    }
849
850
4
    let data = unsafe { builder.build_unchecked() };
851
4
    FixedSizeBinaryArray::from(data)
852
4
}
853
854
/// `filter` implementation for dictionaries
855
87
fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
856
87
where
857
87
    T: ArrowDictionaryKeyType,
858
87
    T::Native: num_traits::Num,
859
{
860
87
    let builder = filter_primitive::<T>(array.keys(), predicate)
861
87
        .into_data()
862
87
        .into_builder()
863
87
        .data_type(array.data_type().clone())
864
87
        .child_data(vec![array.values().to_data()]);
865
866
    // SAFETY:
867
    // Keys were valid before, filtered subset is therefore still valid
868
87
    DictionaryArray::from(unsafe { builder.build_unchecked() })
869
87
}
870
871
/// `filter` implementation for structs
872
6
fn filter_struct(
873
6
    array: &StructArray,
874
6
    predicate: &FilterPredicate,
875
6
) -> Result<StructArray, ArrowError> {
876
6
    let columns = array
877
6
        .columns()
878
6
        .iter()
879
8
        .
map6
(|column| filter_array(column, predicate))
880
6
        .collect::<Result<_, _>>()
?0
;
881
882
6
    let nulls = if let Some((
null_count2
,
nulls2
)) = filter_null_mask(array.nulls(), predicate) {
883
2
        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
884
885
2
        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
886
    } else {
887
4
        None
888
    };
889
890
6
    Ok(unsafe {
891
6
        StructArray::new_unchecked_with_length(
892
6
            array.fields().clone(),
893
6
            columns,
894
6
            nulls,
895
6
            predicate.count(),
896
6
        )
897
6
    })
898
6
}
899
900
/// `filter` implementation for sparse unions
901
4
fn filter_sparse_union(
902
4
    array: &UnionArray,
903
4
    predicate: &FilterPredicate,
904
4
) -> Result<UnionArray, ArrowError> {
905
4
    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
906
0
        unreachable!()
907
    };
908
909
4
    let type_ids = filter_primitive(
910
4
        &Int8Array::try_new(array.type_ids().clone(), None)
?0
,
911
4
        predicate,
912
    );
913
914
4
    let children = fields
915
4
        .iter()
916
8
        .
map4
(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
917
4
        .collect::<Result<_, _>>()
?0
;
918
919
4
    Ok(unsafe {
920
4
        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
921
4
    })
922
4
}
923
924
/// `filter` implementation for list views
925
4
fn filter_list_view<OffsetType: OffsetSizeTrait>(
926
4
    array: &GenericListViewArray<OffsetType>,
927
4
    predicate: &FilterPredicate,
928
4
) -> GenericListViewArray<OffsetType> {
929
4
    let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
930
4
    let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
931
932
    // Filter the nulls
933
4
    let nulls = if let Some((
null_count0
,
nulls0
)) = filter_null_mask(array.nulls(), predicate) {
934
0
        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
935
936
0
        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
937
    } else {
938
4
        None
939
    };
940
941
4
    let list_data = ArrayDataBuilder::new(array.data_type().clone())
942
4
        .nulls(nulls)
943
4
        .buffers(vec![filtered_offsets, filtered_sizes])
944
4
        .child_data(vec![array.values().to_data()])
945
4
        .len(predicate.count);
946
947
4
    let list_data = unsafe { list_data.build_unchecked() };
948
949
4
    GenericListViewArray::from(list_data)
950
4
}
951
952
#[cfg(test)]
953
mod tests {
954
    use super::*;
955
    use arrow_array::builder::*;
956
    use arrow_array::cast::as_run_array;
957
    use arrow_array::types::*;
958
    use arrow_data::ArrayData;
959
    use rand::distr::uniform::{UniformSampler, UniformUsize};
960
    use rand::distr::{Alphanumeric, StandardUniform};
961
    use rand::prelude::*;
962
    use rand::rng;
963
964
    macro_rules! def_temporal_test {
965
        ($test:ident, $array_type: ident, $data: expr) => {
966
            #[test]
967
14
            fn $test() {
968
14
                let a = $data;
969
14
                let b = BooleanArray::from(vec![true, false, true, false]);
970
14
                let c = filter(&a, &b).unwrap();
971
14
                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
972
14
                assert_eq!(2, d.len());
973
14
                assert_eq!(1, d.value(0));
974
14
                assert_eq!(3, d.value(1));
975
14
            }
976
        };
977
    }
978
979
    def_temporal_test!(
980
        test_filter_date32,
981
        Date32Array,
982
        Date32Array::from(vec![1, 2, 3, 4])
983
    );
984
    def_temporal_test!(
985
        test_filter_date64,
986
        Date64Array,
987
        Date64Array::from(vec![1, 2, 3, 4])
988
    );
989
    def_temporal_test!(
990
        test_filter_time32_second,
991
        Time32SecondArray,
992
        Time32SecondArray::from(vec![1, 2, 3, 4])
993
    );
994
    def_temporal_test!(
995
        test_filter_time32_millisecond,
996
        Time32MillisecondArray,
997
        Time32MillisecondArray::from(vec![1, 2, 3, 4])
998
    );
999
    def_temporal_test!(
1000
        test_filter_time64_microsecond,
1001
        Time64MicrosecondArray,
1002
        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1003
    );
1004
    def_temporal_test!(
1005
        test_filter_time64_nanosecond,
1006
        Time64NanosecondArray,
1007
        Time64NanosecondArray::from(vec![1, 2, 3, 4])
1008
    );
1009
    def_temporal_test!(
1010
        test_filter_duration_second,
1011
        DurationSecondArray,
1012
        DurationSecondArray::from(vec![1, 2, 3, 4])
1013
    );
1014
    def_temporal_test!(
1015
        test_filter_duration_millisecond,
1016
        DurationMillisecondArray,
1017
        DurationMillisecondArray::from(vec![1, 2, 3, 4])
1018
    );
1019
    def_temporal_test!(
1020
        test_filter_duration_microsecond,
1021
        DurationMicrosecondArray,
1022
        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1023
    );
1024
    def_temporal_test!(
1025
        test_filter_duration_nanosecond,
1026
        DurationNanosecondArray,
1027
        DurationNanosecondArray::from(vec![1, 2, 3, 4])
1028
    );
1029
    def_temporal_test!(
1030
        test_filter_timestamp_second,
1031
        TimestampSecondArray,
1032
        TimestampSecondArray::from(vec![1, 2, 3, 4])
1033
    );
1034
    def_temporal_test!(
1035
        test_filter_timestamp_millisecond,
1036
        TimestampMillisecondArray,
1037
        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1038
    );
1039
    def_temporal_test!(
1040
        test_filter_timestamp_microsecond,
1041
        TimestampMicrosecondArray,
1042
        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1043
    );
1044
    def_temporal_test!(
1045
        test_filter_timestamp_nanosecond,
1046
        TimestampNanosecondArray,
1047
        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1048
    );
1049
1050
    #[test]
1051
1
    fn test_filter_array_slice() {
1052
1
        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1053
1
        let b = BooleanArray::from(vec![true, false, false, true]);
1054
        // filtering with sliced filter array is not currently supported
1055
        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1056
        // let b = b_slice.as_any().downcast_ref().unwrap();
1057
1
        let c = filter(&a, &b).unwrap();
1058
1
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1059
1
        assert_eq!(2, d.len());
1060
1
        assert_eq!(6, d.value(0));
1061
1
        assert_eq!(9, d.value(1));
1062
1
    }
1063
1064
    #[test]
1065
1
    fn test_filter_array_low_density() {
1066
        // this test exercises the all 0's branch of the filter algorithm
1067
1
        let mut data_values = (1..=65).collect::<Vec<i32>>();
1068
65
        let 
mut filter_values1
= (
1..=651
).
map1
(|i| matches!(i % 65, 0)).
collect1
::<Vec<bool>>();
1069
        // set up two more values after the batch
1070
1
        data_values.extend_from_slice(&[66, 67]);
1071
1
        filter_values.extend_from_slice(&[false, true]);
1072
1
        let a = Int32Array::from(data_values);
1073
1
        let b = BooleanArray::from(filter_values);
1074
1
        let c = filter(&a, &b).unwrap();
1075
1
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1076
1
        assert_eq!(2, d.len());
1077
1
        assert_eq!(65, d.value(0));
1078
1
        assert_eq!(67, d.value(1));
1079
1
    }
1080
1081
    #[test]
1082
1
    fn test_filter_array_high_density() {
1083
        // this test exercises the all 1's branch of the filter algorithm
1084
1
        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1085
1
        let mut filter_values = (1..=65)
1086
65
            .
map1
(|i| !
matches!64
(i % 65, 0))
1087
1
            .collect::<Vec<bool>>();
1088
        // set second data value to null
1089
1
        data_values[1] = None;
1090
        // set up two more values after the batch
1091
1
        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1092
1
        filter_values.extend_from_slice(&[false, true, true, true]);
1093
1
        let a = Int32Array::from(data_values);
1094
1
        let b = BooleanArray::from(filter_values);
1095
1
        let c = filter(&a, &b).unwrap();
1096
1
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1097
1
        assert_eq!(67, d.len());
1098
1
        assert_eq!(3, d.null_count());
1099
1
        assert_eq!(1, d.value(0));
1100
1
        assert!(d.is_null(1));
1101
1
        assert_eq!(64, d.value(63));
1102
1
        assert!(d.is_null(64));
1103
1
        assert_eq!(67, d.value(65));
1104
1
    }
1105
1106
    #[test]
1107
1
    fn test_filter_string_array_simple() {
1108
1
        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1109
1
        let b = BooleanArray::from(vec![true, false, true, false]);
1110
1
        let c = filter(&a, &b).unwrap();
1111
1
        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1112
1
        assert_eq!(2, d.len());
1113
1
        assert_eq!("hello", d.value(0));
1114
1
        assert_eq!("world", d.value(1));
1115
1
    }
1116
1117
    #[test]
1118
1
    fn test_filter_primitive_array_with_null() {
1119
1
        let a = Int32Array::from(vec![Some(5), None]);
1120
1
        let b = BooleanArray::from(vec![false, true]);
1121
1
        let c = filter(&a, &b).unwrap();
1122
1
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1123
1
        assert_eq!(1, d.len());
1124
1
        assert!(d.is_null(0));
1125
1
    }
1126
1127
    #[test]
1128
1
    fn test_filter_string_array_with_null() {
1129
1
        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1130
1
        let b = BooleanArray::from(vec![true, false, false, true]);
1131
1
        let c = filter(&a, &b).unwrap();
1132
1
        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1133
1
        assert_eq!(2, d.len());
1134
1
        assert_eq!("hello", d.value(0));
1135
1
        assert!(!d.is_null(0));
1136
1
        assert!(d.is_null(1));
1137
1
    }
1138
1139
    #[test]
1140
1
    fn test_filter_binary_array_with_null() {
1141
1
        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1142
1
        let a = BinaryArray::from(data);
1143
1
        let b = BooleanArray::from(vec![true, false, false, true]);
1144
1
        let c = filter(&a, &b).unwrap();
1145
1
        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1146
1
        assert_eq!(2, d.len());
1147
1
        assert_eq!(b"hello", d.value(0));
1148
1
        assert!(!d.is_null(0));
1149
1
        assert!(d.is_null(1));
1150
1
    }
1151
1152
2
    fn _test_filter_byte_view<T>()
1153
2
    where
1154
2
        T: ByteViewType,
1155
2
        str: AsRef<T::Native>,
1156
2
        T::Native: PartialEq,
1157
    {
1158
2
        let array = {
1159
            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1160
2
            let mut builder = GenericByteViewBuilder::<T>::new();
1161
2
            builder.append_value("hello");
1162
2
            builder.append_value("world");
1163
2
            builder.append_null();
1164
2
            builder.append_value("large payload over 12 bytes");
1165
2
            builder.append_value("lulu");
1166
2
            builder.finish()
1167
        };
1168
1169
        {
1170
2
            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1171
2
            let actual = filter(&array, &predicate).unwrap();
1172
1173
2
            assert_eq!(actual.len(), 3);
1174
1175
2
            let expected = {
1176
                // ["hello", null, "large payload over 12 bytes"]
1177
2
                let mut builder = GenericByteViewBuilder::<T>::new();
1178
2
                builder.append_value("hello");
1179
2
                builder.append_null();
1180
2
                builder.append_value("large payload over 12 bytes");
1181
2
                builder.finish()
1182
            };
1183
1184
2
            assert_eq!(actual.as_ref(), &expected);
1185
        }
1186
1187
        {
1188
2
            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1189
2
            let actual = filter(&array, &predicate).unwrap();
1190
1191
2
            assert_eq!(actual.len(), 2);
1192
1193
2
            let expected = {
1194
                // ["hello", "lulu"]
1195
2
                let mut builder = GenericByteViewBuilder::<T>::new();
1196
2
                builder.append_value("hello");
1197
2
                builder.append_value("lulu");
1198
2
                builder.finish()
1199
            };
1200
1201
2
            assert_eq!(actual.as_ref(), &expected);
1202
        }
1203
2
    }
1204
1205
    #[test]
1206
1
    fn test_filter_string_view() {
1207
1
        _test_filter_byte_view::<StringViewType>()
1208
1
    }
1209
1210
    #[test]
1211
1
    fn test_filter_binary_view() {
1212
1
        _test_filter_byte_view::<BinaryViewType>()
1213
1
    }
1214
1215
    #[test]
1216
1
    fn test_filter_fixed_binary() {
1217
1
        let v1 = [1_u8, 2];
1218
1
        let v2 = [3_u8, 4];
1219
1
        let v3 = [5_u8, 6];
1220
1
        let v = vec![&v1, &v2, &v3];
1221
1
        let a = FixedSizeBinaryArray::from(v);
1222
1
        let b = BooleanArray::from(vec![true, false, true]);
1223
1
        let c = filter(&a, &b).unwrap();
1224
1
        let d = c
1225
1
            .as_ref()
1226
1
            .as_any()
1227
1
            .downcast_ref::<FixedSizeBinaryArray>()
1228
1
            .unwrap();
1229
1
        assert_eq!(d.len(), 2);
1230
1
        assert_eq!(d.value(0), &v1);
1231
1
        assert_eq!(d.value(1), &v3);
1232
1
        let c2 = FilterBuilder::new(&b)
1233
1
            .optimize()
1234
1
            .build()
1235
1
            .filter(&a)
1236
1
            .unwrap();
1237
1
        let d2 = c2
1238
1
            .as_ref()
1239
1
            .as_any()
1240
1
            .downcast_ref::<FixedSizeBinaryArray>()
1241
1
            .unwrap();
1242
1
        assert_eq!(d, d2);
1243
1244
1
        let b = BooleanArray::from(vec![false, false, false]);
1245
1
        let c = filter(&a, &b).unwrap();
1246
1
        let d = c
1247
1
            .as_ref()
1248
1
            .as_any()
1249
1
            .downcast_ref::<FixedSizeBinaryArray>()
1250
1
            .unwrap();
1251
1
        assert_eq!(d.len(), 0);
1252
1253
1
        let b = BooleanArray::from(vec![true, true, true]);
1254
1
        let c = filter(&a, &b).unwrap();
1255
1
        let d = c
1256
1
            .as_ref()
1257
1
            .as_any()
1258
1
            .downcast_ref::<FixedSizeBinaryArray>()
1259
1
            .unwrap();
1260
1
        assert_eq!(d.len(), 3);
1261
1
        assert_eq!(d.value(0), &v1);
1262
1
        assert_eq!(d.value(1), &v2);
1263
1
        assert_eq!(d.value(2), &v3);
1264
1265
1
        let b = BooleanArray::from(vec![false, false, true]);
1266
1
        let c = filter(&a, &b).unwrap();
1267
1
        let d = c
1268
1
            .as_ref()
1269
1
            .as_any()
1270
1
            .downcast_ref::<FixedSizeBinaryArray>()
1271
1
            .unwrap();
1272
1
        assert_eq!(d.len(), 1);
1273
1
        assert_eq!(d.value(0), &v3);
1274
1
        let c2 = FilterBuilder::new(&b)
1275
1
            .optimize()
1276
1
            .build()
1277
1
            .filter(&a)
1278
1
            .unwrap();
1279
1
        let d2 = c2
1280
1
            .as_ref()
1281
1
            .as_any()
1282
1
            .downcast_ref::<FixedSizeBinaryArray>()
1283
1
            .unwrap();
1284
1
        assert_eq!(d, d2);
1285
1
    }
1286
1287
    #[test]
1288
1
    fn test_filter_array_slice_with_null() {
1289
1
        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1290
1
        let b = BooleanArray::from(vec![true, false, false, true]);
1291
        // filtering with sliced filter array is not currently supported
1292
        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1293
        // let b = b_slice.as_any().downcast_ref().unwrap();
1294
1
        let c = filter(&a, &b).unwrap();
1295
1
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1296
1
        assert_eq!(2, d.len());
1297
1
        assert!(d.is_null(0));
1298
1
        assert!(!d.is_null(1));
1299
1
        assert_eq!(9, d.value(1));
1300
1
    }
1301
1302
    #[test]
1303
1
    fn test_filter_run_end_encoding_array() {
1304
1
        let run_ends = Int64Array::from(vec![2, 3, 8]);
1305
1
        let values = Int64Array::from(vec![7, -2, 9]);
1306
1
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1307
1
        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1308
1
        let c = filter(&a, &b).unwrap();
1309
1
        let actual: &RunArray<Int64Type> = as_run_array(&c);
1310
1
        assert_eq!(4, actual.len());
1311
1312
1
        let expected = RunArray::try_new(
1313
1
            &Int64Array::from(vec![1, 2, 4]),
1314
1
            &Int64Array::from(vec![7, -2, 9]),
1315
        )
1316
1
        .expect("Failed to make expected RunArray test is broken");
1317
1318
1
        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1319
1
        assert_eq!(actual.values(), expected.values())
1320
1
    }
1321
1322
    #[test]
1323
1
    fn test_filter_run_end_encoding_array_remove_value() {
1324
1
        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1325
1
        let values = Int32Array::from(vec![7, -2, 9, -8]);
1326
1
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1327
1
        let b = BooleanArray::from(vec![
1328
            false, true, false, false, true, false, true, false, false, false,
1329
        ]);
1330
1
        let c = filter(&a, &b).unwrap();
1331
1
        let actual: &RunArray<Int32Type> = as_run_array(&c);
1332
1
        assert_eq!(3, actual.len());
1333
1334
1
        let expected =
1335
1
            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1336
1
                .expect("Failed to make expected RunArray test is broken");
1337
1338
1
        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1339
1
        assert_eq!(actual.values(), expected.values())
1340
1
    }
1341
1342
    #[test]
1343
1
    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1344
1
        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1345
1
        let values = Int16Array::from(vec![7, -2, 9, -8]);
1346
1
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1347
1
        let b = BooleanArray::from(vec![
1348
            false, false, false, false, false, false, true, false, false, false,
1349
        ]);
1350
1
        let c = filter(&a, &b).unwrap();
1351
1
        let actual: &RunArray<Int16Type> = as_run_array(&c);
1352
1
        assert_eq!(1, actual.len());
1353
1354
1
        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1355
1
            .expect("Failed to make expected RunArray test is broken");
1356
1357
1
        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1358
1
        assert_eq!(actual.values(), expected.values())
1359
1
    }
1360
1361
    #[test]
1362
1
    fn test_filter_run_end_encoding_array_empty() {
1363
1
        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1364
1
        let values = Int64Array::from(vec![7, -2, 9, -8]);
1365
1
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1366
1
        let b = BooleanArray::from(vec![
1367
            false, false, false, false, false, false, false, false, false, false,
1368
        ]);
1369
1
        let c = filter(&a, &b).unwrap();
1370
1
        let actual: &RunArray<Int64Type> = as_run_array(&c);
1371
1
        assert_eq!(0, actual.len());
1372
1
    }
1373
1374
    #[test]
1375
1
    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1376
1
        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1377
1
        let values = Int64Array::from(vec![7, -2, 9, -8]);
1378
1
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1379
1
        let b = BooleanArray::from(vec![false, true, true]);
1380
1
        let c = filter(&a, &b).unwrap();
1381
1
        let actual: &RunArray<Int64Type> = as_run_array(&c);
1382
1
        assert_eq!(2, actual.len());
1383
1384
1
        let expected = RunArray::try_new(
1385
1
            &Int64Array::from(vec![1, 2]),
1386
1
            &Int64Array::from(vec![7, -2]),
1387
        )
1388
1
        .expect("Failed to make expected RunArray test is broken");
1389
1390
1
        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1391
1
        assert_eq!(actual.values(), expected.values())
1392
1
    }
1393
1394
    #[test]
1395
1
    fn test_filter_dictionary_array() {
1396
1
        let values = [Some("hello"), None, Some("world"), Some("!")];
1397
1
        let a: Int8DictionaryArray = values.iter().copied().collect();
1398
1
        let b = BooleanArray::from(vec![false, true, true, false]);
1399
1
        let c = filter(&a, &b).unwrap();
1400
1
        let d = c
1401
1
            .as_ref()
1402
1
            .as_any()
1403
1
            .downcast_ref::<Int8DictionaryArray>()
1404
1
            .unwrap();
1405
1
        let value_array = d.values();
1406
1
        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1407
        // values are cloned in the filtered dictionary array
1408
1
        assert_eq!(3, values.len());
1409
        // but keys are filtered
1410
1
        assert_eq!(2, d.len());
1411
1
        assert!(d.is_null(0));
1412
1
        assert_eq!("world", values.value(d.keys().value(1) as usize));
1413
1
    }
1414
1415
    #[test]
1416
1
    fn test_filter_list_array() {
1417
1
        let value_data = ArrayData::builder(DataType::Int32)
1418
1
            .len(8)
1419
1
            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1420
1
            .build()
1421
1
            .unwrap();
1422
1423
1
        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1424
1425
1
        let list_data_type =
1426
1
            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1427
1
        let list_data = ArrayData::builder(list_data_type)
1428
1
            .len(4)
1429
1
            .add_buffer(value_offsets)
1430
1
            .add_child_data(value_data)
1431
1
            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1432
1
            .build()
1433
1
            .unwrap();
1434
1435
        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1436
1
        let a = LargeListArray::from(list_data);
1437
1
        let b = BooleanArray::from(vec![false, true, false, true]);
1438
1
        let result = filter(&a, &b).unwrap();
1439
1440
        // expected: [[3, 4, 5], null]
1441
1
        let value_data = ArrayData::builder(DataType::Int32)
1442
1
            .len(3)
1443
1
            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1444
1
            .build()
1445
1
            .unwrap();
1446
1447
1
        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1448
1449
1
        let list_data_type =
1450
1
            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1451
1
        let expected = ArrayData::builder(list_data_type)
1452
1
            .len(2)
1453
1
            .add_buffer(value_offsets)
1454
1
            .add_child_data(value_data)
1455
1
            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1456
1
            .build()
1457
1
            .unwrap();
1458
1459
1
        assert_eq!(&make_array(expected), &result);
1460
1
    }
1461
1462
2
    fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1463
        // [[1, 2], null, [], [3,4]]
1464
2
        let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1465
2
        list_array.append_value([Some(1), Some(2)]);
1466
2
        list_array.append_null();
1467
2
        list_array.append_value([]);
1468
2
        list_array.append_value([Some(3), Some(4)]);
1469
1470
2
        let list_array = list_array.finish();
1471
2
        let predicate = BooleanArray::from_iter([true, false, true, false]);
1472
1473
        // Filter result: [[1, 2], []]
1474
2
        let filtered = filter(&list_array, &predicate)
1475
2
            .unwrap()
1476
2
            .as_list_view::<T>()
1477
2
            .clone();
1478
1479
2
        let mut expected =
1480
2
            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1481
2
        expected.append_value([Some(1), Some(2)]);
1482
2
        expected.append_value([]);
1483
2
        let expected = expected.finish();
1484
1485
2
        assert_eq!(&filtered, &expected);
1486
2
    }
1487
1488
2
    fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1489
        // [[1, 2], null, [], [3,4]]
1490
2
        let mut list_array =
1491
2
            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1492
2
        list_array.append_value([Some(1), Some(2)]);
1493
2
        list_array.append_null();
1494
2
        list_array.append_value([]);
1495
2
        list_array.append_value([Some(3), Some(4)]);
1496
1497
2
        let list_array = list_array.finish();
1498
1499
        // Sliced: [null, [], [3, 4]]
1500
2
        let sliced = list_array.slice(1, 3);
1501
2
        let predicate = BooleanArray::from_iter([false, false, true]);
1502
1503
        // Filter result: [[1, 2], []]
1504
2
        let filtered = filter(&sliced, &predicate)
1505
2
            .unwrap()
1506
2
            .as_list_view::<T>()
1507
2
            .clone();
1508
1509
2
        let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1510
2
        expected.append_value([Some(3), Some(4)]);
1511
2
        let expected = expected.finish();
1512
1513
2
        assert_eq!(&filtered, &expected);
1514
2
    }
1515
1516
    #[test]
1517
1
    fn test_filter_list_view_array() {
1518
1
        test_case_filter_list_view::<i32>();
1519
1
        test_case_filter_list_view::<i64>();
1520
1521
1
        test_case_filter_sliced_list_view::<i32>();
1522
1
        test_case_filter_sliced_list_view::<i64>();
1523
1
    }
1524
1525
    #[test]
1526
1
    fn test_slice_iterator_bits() {
1527
64
        let 
filter_values1
=
(0..64)1
.
map1
(|i| i == 1).
collect1
::<Vec<bool>>();
1528
1
        let filter = BooleanArray::from(filter_values);
1529
1
        let filter_count = filter_count(&filter);
1530
1531
1
        let iter = SlicesIterator::new(&filter);
1532
1
        let chunks = iter.collect::<Vec<_>>();
1533
1534
1
        assert_eq!(chunks, vec![(1, 2)]);
1535
1
        assert_eq!(filter_count, 1);
1536
1
    }
1537
1538
    #[test]
1539
1
    fn test_slice_iterator_bits1() {
1540
64
        let 
filter_values1
=
(0..64)1
.
map1
(|i| i != 1).
collect1
::<Vec<bool>>();
1541
1
        let filter = BooleanArray::from(filter_values);
1542
1
        let filter_count = filter_count(&filter);
1543
1544
1
        let iter = SlicesIterator::new(&filter);
1545
1
        let chunks = iter.collect::<Vec<_>>();
1546
1547
1
        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1548
1
        assert_eq!(filter_count, 64 - 1);
1549
1
    }
1550
1551
    #[test]
1552
1
    fn test_slice_iterator_chunk_and_bits() {
1553
130
        let 
filter_values1
=
(0..130)1
.
map1
(|i| i % 62 != 0).
collect1
::<Vec<bool>>();
1554
1
        let filter = BooleanArray::from(filter_values);
1555
1
        let filter_count = filter_count(&filter);
1556
1557
1
        let iter = SlicesIterator::new(&filter);
1558
1
        let chunks = iter.collect::<Vec<_>>();
1559
1560
1
        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1561
1
        assert_eq!(filter_count, 61 + 61 + 5);
1562
1
    }
1563
1564
    #[test]
1565
1
    fn test_null_mask() {
1566
1
        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1567
1568
1
        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1569
1
        let out = filter(&a, &mask1).unwrap();
1570
1
        assert_eq!(out.as_ref(), &a.slice(0, 2));
1571
1
    }
1572
1573
    #[test]
1574
1
    fn test_filter_record_batch_no_columns() {
1575
1
        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1576
1
        let options = RecordBatchOptions::default().with_row_count(Some(100));
1577
1
        let record_batch =
1578
1
            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1579
1
        let out = filter_record_batch(&record_batch, &pred).unwrap();
1580
1581
1
        assert_eq!(out.num_rows(), 2);
1582
1
    }
1583
1584
    #[test]
1585
1
    fn test_fast_path() {
1586
1
        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1587
1588
        // all true
1589
1
        let mask = BooleanArray::from(vec![true, true, true]);
1590
1
        let out = filter(&a, &mask).unwrap();
1591
1
        let b = out
1592
1
            .as_any()
1593
1
            .downcast_ref::<PrimitiveArray<Int64Type>>()
1594
1
            .unwrap();
1595
1
        assert_eq!(&a, b);
1596
1597
        // all false
1598
1
        let mask = BooleanArray::from(vec![false, false, false]);
1599
1
        let out = filter(&a, &mask).unwrap();
1600
1
        assert_eq!(out.len(), 0);
1601
1
        assert_eq!(out.data_type(), &DataType::Int64);
1602
1
    }
1603
1604
    #[test]
1605
1
    fn test_slices() {
1606
        // takes up 2 u64s
1607
1
        let bools = std::iter::repeat_n(true, 10)
1608
1
            .chain(std::iter::repeat_n(false, 30))
1609
1
            .chain(std::iter::repeat_n(true, 20))
1610
1
            .chain(std::iter::repeat_n(false, 17))
1611
1
            .chain(std::iter::repeat_n(true, 4));
1612
1613
1
        let bool_array: BooleanArray = bools.map(Some).collect();
1614
1615
1
        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1616
1
        let expected = vec![(0, 10), (40, 60), (77, 81)];
1617
1
        assert_eq!(slices, expected);
1618
1619
        // slice with offset and truncated len
1620
1
        let len = bool_array.len();
1621
1
        let sliced_array = bool_array.slice(7, len - 10);
1622
1
        let sliced_array = sliced_array
1623
1
            .as_any()
1624
1
            .downcast_ref::<BooleanArray>()
1625
1
            .unwrap();
1626
1
        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1627
1
        let expected = vec![(0, 3), (33, 53), (70, 71)];
1628
1
        assert_eq!(slices, expected);
1629
1
    }
1630
1631
105
    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1632
105
        let mut rng = rng();
1633
1634
54.6k
        let 
bools105
:
Vec<bool>105
=
std::iter::from_fn105
(|| Some(rng.random()))
1635
105
            .take(mask_len)
1636
105
            .collect();
1637
1638
105
        let buffer = Buffer::from_iter(bools.iter().cloned());
1639
1640
105
        let truncated_length = mask_len - offset - truncate;
1641
1642
105
        let data = ArrayDataBuilder::new(DataType::Boolean)
1643
105
            .len(truncated_length)
1644
105
            .offset(offset)
1645
105
            .add_buffer(buffer)
1646
105
            .build()
1647
105
            .unwrap();
1648
1649
105
        let filter = BooleanArray::from(data);
1650
1651
105
        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1652
11.3k
            .
flat_map105
(|(start, end)| start..end)
1653
105
            .collect();
1654
1655
105
        let count = filter_count(&filter);
1656
105
        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1657
1658
105
        let expected_bits: Vec<_> = bools
1659
105
            .iter()
1660
105
            .skip(offset)
1661
105
            .take(truncated_length)
1662
105
            .enumerate()
1663
45.4k
            .
flat_map105
(|(idx, v)| v.then(|| idx))
1664
105
            .collect();
1665
1666
105
        assert_eq!(slice_bits, expected_bits);
1667
105
        assert_eq!(index_bits, expected_bits);
1668
105
    }
1669
1670
    #[test]
1671
    #[cfg_attr(miri, ignore)]
1672
1
    fn fuzz_test_slices_iterator() {
1673
1
        let mut rng = rng();
1674
1675
1
        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1676
101
        for _ in 0..100 {
1677
100
            let mask_len = rng.random_range(0..1024);
1678
100
            let max_offset = 64.min(mask_len);
1679
100
            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1680
100
1681
100
            let max_truncate = 128.min(mask_len - offset);
1682
100
            let truncate = uusize
1683
100
                .sample(&mut rng)
1684
100
                .checked_rem(max_truncate)
1685
100
                .unwrap_or(0);
1686
100
1687
100
            test_slices_fuzz(mask_len, offset, truncate);
1688
100
        }
1689
1690
1
        test_slices_fuzz(64, 0, 0);
1691
1
        test_slices_fuzz(64, 8, 0);
1692
1
        test_slices_fuzz(64, 8, 8);
1693
1
        test_slices_fuzz(32, 8, 8);
1694
1
        test_slices_fuzz(32, 5, 9);
1695
1
    }
1696
1697
    /// Filters `values` by `predicate` using standard rust iterators
1698
200
    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1699
200
        values
1700
200
            .into_iter()
1701
200
            .zip(predicate)
1702
200
            .filter(|(_, x)| **x)
1703
200
            .map(|(a, _)| a)
1704
200
            .collect()
1705
200
    }
1706
1707
    /// Generates an array of length `len` with `valid_percent` non-null values
1708
100
    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1709
100
    where
1710
100
        StandardUniform: Distribution<T>,
1711
    {
1712
100
        let mut rng = rng();
1713
100
        (0..len)
1714
15.1k
            .
map100
(|_| rng.random_bool(valid_percent).then(||
rng8.23k
.
random8.23k
()))
1715
100
            .collect()
1716
100
    }
1717
1718
    /// Generates an array of length `len` with `valid_percent` non-null values
1719
100
    fn gen_strings(
1720
100
        len: usize,
1721
100
        valid_percent: f64,
1722
100
        str_len_range: std::ops::Range<usize>,
1723
100
    ) -> Vec<Option<String>> {
1724
100
        let mut rng = rng();
1725
100
        (0..len)
1726
15.1k
            .
map100
(|_| {
1727
15.1k
                rng.random_bool(valid_percent).then(|| 
{8.15k
1728
8.15k
                    let len = rng.random_range(str_len_range.clone());
1729
8.15k
                    (0..len)
1730
77.8k
                        .
map8.15k
(|_| char::from(rng.sample(Alphanumeric)))
1731
8.15k
                        .collect()
1732
8.15k
                })
1733
15.1k
            })
1734
100
            .collect()
1735
100
    }
1736
1737
    /// Returns an iterator that calls `Option::as_deref` on each item
1738
300
    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1739
44.5k
        
src300
.
iter300
().
map300
(|x| x.as_deref())
1740
300
    }
1741
1742
    #[test]
1743
    #[cfg_attr(miri, ignore)]
1744
1
    fn fuzz_filter() {
1745
1
        let mut rng = rng();
1746
1747
101
        for 
i100
in 0..100 {
1748
100
            let filter_percent = match i {
1749
100
                0..=4 => 
1.5
,
1750
95
                5..=10 => 
0.6
,
1751
89
                _ => rng.random_range(0.0..1.0),
1752
            };
1753
1754
100
            let valid_percent = rng.random_range(0.0..1.0);
1755
1756
100
            let array_len = rng.random_range(32..256);
1757
100
            let array_offset = rng.random_range(0..10);
1758
1759
            // Construct a predicate
1760
100
            let filter_offset = rng.random_range(0..10);
1761
100
            let filter_truncate = rng.random_range(0..10);
1762
14.7k
            let 
bools100
:
Vec<_>100
=
std::iter::from_fn100
(|| Some(rng.random_bool(filter_percent)))
1763
100
                .take(array_len + filter_offset - filter_truncate)
1764
100
                .collect();
1765
1766
100
            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1767
1768
            // Offset predicate
1769
100
            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1770
100
            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1771
100
            let bools = &bools[filter_offset..];
1772
1773
            // Test i32
1774
100
            let values = gen_primitive(array_len + array_offset, valid_percent);
1775
100
            let src = Int32Array::from_iter(values.iter().cloned());
1776
1777
100
            let src = src.slice(array_offset, array_len);
1778
100
            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1779
100
            let values = &values[array_offset..];
1780
1781
100
            let filtered = filter(src, predicate).unwrap();
1782
100
            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1783
100
            let actual: Vec<_> = array.iter().collect();
1784
1785
100
            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1786
1787
            // Test string
1788
100
            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1789
100
            let src = StringArray::from_iter(as_deref(&strings));
1790
1791
100
            let src = src.slice(array_offset, array_len);
1792
100
            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1793
1794
100
            let filtered = filter(src, predicate).unwrap();
1795
100
            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1796
100
            let actual: Vec<_> = array.iter().collect();
1797
1798
100
            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1799
100
            assert_eq!(actual, expected_strings);
1800
1801
            // Test string dictionary
1802
100
            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1803
1804
100
            let src = src.slice(array_offset, array_len);
1805
100
            let src = src
1806
100
                .as_any()
1807
100
                .downcast_ref::<DictionaryArray<Int32Type>>()
1808
100
                .unwrap();
1809
1810
100
            let filtered = filter(src, predicate).unwrap();
1811
1812
100
            let array = filtered
1813
100
                .as_any()
1814
100
                .downcast_ref::<DictionaryArray<Int32Type>>()
1815
100
                .unwrap();
1816
1817
100
            let values = array
1818
100
                .values()
1819
100
                .as_any()
1820
100
                .downcast_ref::<StringArray>()
1821
100
                .unwrap();
1822
1823
100
            let actual: Vec<_> = array
1824
100
                .keys()
1825
100
                .iter()
1826
6.64k
                .
map100
(|key| key.map(|key|
values3.44k
.
value3.44k
(
key as usize3.44k
)))
1827
100
                .collect();
1828
1829
100
            assert_eq!(actual, expected_strings);
1830
        }
1831
1
    }
1832
1833
    #[test]
1834
1
    fn test_filter_map() {
1835
1
        let mut builder =
1836
1
            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1837
        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1838
1
        builder.keys().append_value("key1");
1839
1
        builder.values().append_value(1);
1840
1
        builder.append(true).unwrap();
1841
1
        builder.keys().append_value("key2");
1842
1
        builder.keys().append_value("key3");
1843
1
        builder.values().append_value(2);
1844
1
        builder.values().append_value(3);
1845
1
        builder.append(true).unwrap();
1846
1
        builder.append(false).unwrap();
1847
1
        builder.keys().append_value("key1");
1848
1
        builder.values().append_value(1);
1849
1
        builder.append(true).unwrap();
1850
1
        let maparray = Arc::new(builder.finish()) as ArrayRef;
1851
1852
1
        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1853
1
            .into_iter()
1854
1
            .collect::<BooleanArray>();
1855
1
        let got = filter(&maparray, &indices).unwrap();
1856
1857
1
        let mut builder =
1858
1
            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1859
1
        builder.keys().append_value("key1");
1860
1
        builder.values().append_value(1);
1861
1
        builder.append(true).unwrap();
1862
1
        builder.keys().append_value("key1");
1863
1
        builder.values().append_value(1);
1864
1
        builder.append(true).unwrap();
1865
1
        let expected = Arc::new(builder.finish()) as ArrayRef;
1866
1867
1
        assert_eq!(&expected, &got);
1868
1
    }
1869
1870
    #[test]
1871
1
    fn test_filter_fixed_size_list_arrays() {
1872
1
        let value_data = ArrayData::builder(DataType::Int32)
1873
1
            .len(9)
1874
1
            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1875
1
            .build()
1876
1
            .unwrap();
1877
1
        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1878
1
        let list_data = ArrayData::builder(list_data_type)
1879
1
            .len(3)
1880
1
            .add_child_data(value_data)
1881
1
            .build()
1882
1
            .unwrap();
1883
1
        let array = FixedSizeListArray::from(list_data);
1884
1885
1
        let filter_array = BooleanArray::from(vec![true, false, false]);
1886
1887
1
        let c = filter(&array, &filter_array).unwrap();
1888
1
        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1889
1890
1
        assert_eq!(filtered.len(), 1);
1891
1892
1
        let list = filtered.value(0);
1893
1
        assert_eq!(
1894
            &[0, 1, 2],
1895
1
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1896
        );
1897
1898
1
        let filter_array = BooleanArray::from(vec![true, false, true]);
1899
1900
1
        let c = filter(&array, &filter_array).unwrap();
1901
1
        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1902
1903
1
        assert_eq!(filtered.len(), 2);
1904
1905
1
        let list = filtered.value(0);
1906
1
        assert_eq!(
1907
            &[0, 1, 2],
1908
1
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1909
        );
1910
1
        let list = filtered.value(1);
1911
1
        assert_eq!(
1912
            &[6, 7, 8],
1913
1
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1914
        );
1915
1
    }
1916
1917
    #[test]
1918
1
    fn test_filter_fixed_size_list_arrays_with_null() {
1919
1
        let value_data = ArrayData::builder(DataType::Int32)
1920
1
            .len(10)
1921
1
            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1922
1
            .build()
1923
1
            .unwrap();
1924
1925
        // Set null buts for the nested array:
1926
        //  [[0, 1], null, null, [6, 7], [8, 9]]
1927
        // 01011001 00000001
1928
1
        let mut null_bits: [u8; 1] = [0; 1];
1929
1
        bit_util::set_bit(&mut null_bits, 0);
1930
1
        bit_util::set_bit(&mut null_bits, 3);
1931
1
        bit_util::set_bit(&mut null_bits, 4);
1932
1933
1
        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1934
1
        let list_data = ArrayData::builder(list_data_type)
1935
1
            .len(5)
1936
1
            .add_child_data(value_data)
1937
1
            .null_bit_buffer(Some(Buffer::from(null_bits)))
1938
1
            .build()
1939
1
            .unwrap();
1940
1
        let array = FixedSizeListArray::from(list_data);
1941
1942
1
        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1943
1944
1
        let c = filter(&array, &filter_array).unwrap();
1945
1
        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1946
1947
1
        assert_eq!(filtered.len(), 3);
1948
1949
1
        let list = filtered.value(0);
1950
1
        assert_eq!(
1951
            &[0, 1],
1952
1
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1953
        );
1954
1
        assert!(filtered.is_null(1));
1955
1
        let list = filtered.value(2);
1956
1
        assert_eq!(
1957
            &[6, 7],
1958
1
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1959
        );
1960
1
    }
1961
1962
2
    fn test_filter_union_array(array: UnionArray) {
1963
2
        let filter_array = BooleanArray::from(vec![true, false, false]);
1964
2
        let c = filter(&array, &filter_array).unwrap();
1965
2
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1966
1967
2
        let mut builder = UnionBuilder::new_dense();
1968
2
        builder.append::<Int32Type>("A", 1).unwrap();
1969
2
        let expected_array = builder.build().unwrap();
1970
1971
2
        compare_union_arrays(filtered, &expected_array);
1972
1973
2
        let filter_array = BooleanArray::from(vec![true, false, true]);
1974
2
        let c = filter(&array, &filter_array).unwrap();
1975
2
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1976
1977
2
        let mut builder = UnionBuilder::new_dense();
1978
2
        builder.append::<Int32Type>("A", 1).unwrap();
1979
2
        builder.append::<Int32Type>("A", 34).unwrap();
1980
2
        let expected_array = builder.build().unwrap();
1981
1982
2
        compare_union_arrays(filtered, &expected_array);
1983
1984
2
        let filter_array = BooleanArray::from(vec![true, true, false]);
1985
2
        let c = filter(&array, &filter_array).unwrap();
1986
2
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1987
1988
2
        let mut builder = UnionBuilder::new_dense();
1989
2
        builder.append::<Int32Type>("A", 1).unwrap();
1990
2
        builder.append::<Float64Type>("B", 3.2).unwrap();
1991
2
        let expected_array = builder.build().unwrap();
1992
1993
2
        compare_union_arrays(filtered, &expected_array);
1994
2
    }
1995
1996
    #[test]
1997
1
    fn test_filter_union_array_dense() {
1998
1
        let mut builder = UnionBuilder::new_dense();
1999
1
        builder.append::<Int32Type>("A", 1).unwrap();
2000
1
        builder.append::<Float64Type>("B", 3.2).unwrap();
2001
1
        builder.append::<Int32Type>("A", 34).unwrap();
2002
1
        let array = builder.build().unwrap();
2003
2004
1
        test_filter_union_array(array);
2005
1
    }
2006
2007
    #[test]
2008
1
    fn test_filter_run_union_array_dense() {
2009
1
        let mut builder = UnionBuilder::new_dense();
2010
1
        builder.append::<Int32Type>("A", 1).unwrap();
2011
1
        builder.append::<Int32Type>("A", 3).unwrap();
2012
1
        builder.append::<Int32Type>("A", 34).unwrap();
2013
1
        let array = builder.build().unwrap();
2014
2015
1
        let filter_array = BooleanArray::from(vec![true, true, false]);
2016
1
        let c = filter(&array, &filter_array).unwrap();
2017
1
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2018
2019
1
        let mut builder = UnionBuilder::new_dense();
2020
1
        builder.append::<Int32Type>("A", 1).unwrap();
2021
1
        builder.append::<Int32Type>("A", 3).unwrap();
2022
1
        let expected = builder.build().unwrap();
2023
2024
1
        assert_eq!(filtered.to_data(), expected.to_data());
2025
1
    }
2026
2027
    #[test]
2028
1
    fn test_filter_union_array_dense_with_nulls() {
2029
1
        let mut builder = UnionBuilder::new_dense();
2030
1
        builder.append::<Int32Type>("A", 1).unwrap();
2031
1
        builder.append::<Float64Type>("B", 3.2).unwrap();
2032
1
        builder.append_null::<Float64Type>("B").unwrap();
2033
1
        builder.append::<Int32Type>("A", 34).unwrap();
2034
1
        let array = builder.build().unwrap();
2035
2036
1
        let filter_array = BooleanArray::from(vec![true, true, false, false]);
2037
1
        let c = filter(&array, &filter_array).unwrap();
2038
1
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2039
2040
1
        let mut builder = UnionBuilder::new_dense();
2041
1
        builder.append::<Int32Type>("A", 1).unwrap();
2042
1
        builder.append::<Float64Type>("B", 3.2).unwrap();
2043
1
        let expected_array = builder.build().unwrap();
2044
2045
1
        compare_union_arrays(filtered, &expected_array);
2046
2047
1
        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2048
1
        let c = filter(&array, &filter_array).unwrap();
2049
1
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2050
2051
1
        let mut builder = UnionBuilder::new_dense();
2052
1
        builder.append::<Int32Type>("A", 1).unwrap();
2053
1
        builder.append_null::<Float64Type>("B").unwrap();
2054
1
        let expected_array = builder.build().unwrap();
2055
2056
1
        compare_union_arrays(filtered, &expected_array);
2057
1
    }
2058
2059
    #[test]
2060
1
    fn test_filter_union_array_sparse() {
2061
1
        let mut builder = UnionBuilder::new_sparse();
2062
1
        builder.append::<Int32Type>("A", 1).unwrap();
2063
1
        builder.append::<Float64Type>("B", 3.2).unwrap();
2064
1
        builder.append::<Int32Type>("A", 34).unwrap();
2065
1
        let array = builder.build().unwrap();
2066
2067
1
        test_filter_union_array(array);
2068
1
    }
2069
2070
    #[test]
2071
1
    fn test_filter_union_array_sparse_with_nulls() {
2072
1
        let mut builder = UnionBuilder::new_sparse();
2073
1
        builder.append::<Int32Type>("A", 1).unwrap();
2074
1
        builder.append::<Float64Type>("B", 3.2).unwrap();
2075
1
        builder.append_null::<Float64Type>("B").unwrap();
2076
1
        builder.append::<Int32Type>("A", 34).unwrap();
2077
1
        let array = builder.build().unwrap();
2078
2079
1
        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2080
1
        let c = filter(&array, &filter_array).unwrap();
2081
1
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2082
2083
1
        let mut builder = UnionBuilder::new_sparse();
2084
1
        builder.append::<Int32Type>("A", 1).unwrap();
2085
1
        builder.append_null::<Float64Type>("B").unwrap();
2086
1
        let expected_array = builder.build().unwrap();
2087
2088
1
        compare_union_arrays(filtered, &expected_array);
2089
1
    }
2090
2091
9
    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2092
9
        assert_eq!(union1.len(), union2.len());
2093
2094
16
        for i in 0..
union19
.
len9
() {
2095
16
            let type_id = union1.type_id(i);
2096
2097
16
            let slot1 = union1.value(i);
2098
16
            let slot2 = union2.value(i);
2099
2100
16
            assert_eq!(slot1.is_null(0), slot2.is_null(0));
2101
2102
16
            if !slot1.is_null(0) && 
!slot2.is_null(0)14
{
2103
14
                match type_id {
2104
                    0 => {
2105
11
                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2106
11
                        assert_eq!(slot1.len(), 1);
2107
11
                        let value1 = slot1.value(0);
2108
2109
11
                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2110
11
                        assert_eq!(slot2.len(), 1);
2111
11
                        let value2 = slot2.value(0);
2112
11
                        assert_eq!(value1, value2);
2113
                    }
2114
                    1 => {
2115
3
                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2116
3
                        assert_eq!(slot1.len(), 1);
2117
3
                        let value1 = slot1.value(0);
2118
2119
3
                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2120
3
                        assert_eq!(slot2.len(), 1);
2121
3
                        let value2 = slot2.value(0);
2122
3
                        assert_eq!(value1, value2);
2123
                    }
2124
0
                    _ => unreachable!(),
2125
                }
2126
2
            }
2127
        }
2128
9
    }
2129
2130
    #[test]
2131
1
    fn test_filter_struct() {
2132
1
        let predicate = BooleanArray::from(vec![true, false, true, false]);
2133
2134
1
        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2135
1
        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2136
2137
1
        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2138
1
        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2139
2140
1
        let null_mask = NullBuffer::from(vec![true, false, false, true]);
2141
1
        let null_mask_filtered = NullBuffer::from(vec![true, false]);
2142
2143
1
        let a_field = Field::new("a", DataType::Utf8, false);
2144
1
        let b_field = Field::new("b", DataType::Int32, false);
2145
2146
1
        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2147
1
        let expected =
2148
1
            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2149
2150
1
        let result = filter(&array, &predicate).unwrap();
2151
2152
1
        assert_eq!(result.to_data(), expected.to_data());
2153
2154
1
        let array = StructArray::new(
2155
1
            vec![a_field.clone()].into(),
2156
1
            vec![a.clone()],
2157
1
            Some(null_mask.clone()),
2158
        );
2159
1
        let expected = StructArray::new(
2160
1
            vec![a_field.clone()].into(),
2161
1
            vec![a_filtered.clone()],
2162
1
            Some(null_mask_filtered.clone()),
2163
        );
2164
2165
1
        let result = filter(&array, &predicate).unwrap();
2166
2167
1
        assert_eq!(result.to_data(), expected.to_data());
2168
2169
1
        let array = StructArray::new(
2170
1
            vec![a_field.clone(), b_field.clone()].into(),
2171
1
            vec![a.clone(), b.clone()],
2172
1
            None,
2173
        );
2174
1
        let expected = StructArray::new(
2175
1
            vec![a_field.clone(), b_field.clone()].into(),
2176
1
            vec![a_filtered.clone(), b_filtered.clone()],
2177
1
            None,
2178
        );
2179
2180
1
        let result = filter(&array, &predicate).unwrap();
2181
2182
1
        assert_eq!(result.to_data(), expected.to_data());
2183
2184
1
        let array = StructArray::new(
2185
1
            vec![a_field.clone(), b_field.clone()].into(),
2186
1
            vec![a.clone(), b.clone()],
2187
1
            Some(null_mask.clone()),
2188
        );
2189
2190
1
        let expected = StructArray::new(
2191
1
            vec![a_field.clone(), b_field.clone()].into(),
2192
1
            vec![a_filtered.clone(), b_filtered.clone()],
2193
1
            Some(null_mask_filtered.clone()),
2194
        );
2195
2196
1
        let result = filter(&array, &predicate).unwrap();
2197
2198
1
        assert_eq!(result.to_data(), expected.to_data());
2199
1
    }
2200
2201
    #[test]
2202
1
    fn test_filter_empty_struct() {
2203
        /*
2204
            "a": {
2205
                "b": int64,
2206
                "c": {}
2207
            },
2208
        */
2209
1
        let fields = arrow_schema::Field::new(
2210
            "a",
2211
1
            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2212
1
                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2213
1
                arrow_schema::Field::new(
2214
1
                    "c",
2215
1
                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2216
1
                    true,
2217
1
                ),
2218
1
            ])),
2219
            true,
2220
        );
2221
2222
        /* Test record
2223
            {"a":{"c": {}}}
2224
            {"a":{"c": {}}}
2225
            {"a":{"c": {}}}
2226
        */
2227
2228
        // Create the record batch with the nested struct array
2229
1
        let schema = Arc::new(Schema::new(vec![fields]));
2230
2231
1
        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2232
1
        let c = Arc::new(StructArray::new_empty_fields(
2233
            3,
2234
1
            Some(NullBuffer::from(vec![true, true, true])),
2235
        ));
2236
1
        let a = StructArray::new(
2237
1
            vec![
2238
1
                Field::new("b", DataType::Int64, true),
2239
1
                Field::new("c", DataType::Struct(Fields::empty()), true),
2240
            ]
2241
1
            .into(),
2242
1
            vec![b.clone(), c.clone()],
2243
1
            Some(NullBuffer::from(vec![true, true, true])),
2244
        );
2245
1
        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2246
1
        println!("{record_batch:?}");
2247
2248
        // Apply the filter
2249
1
        let predicate = BooleanArray::from(vec![true, false, true]);
2250
1
        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2251
2252
        // The filtered batch should have 2 rows (the 1st and 3rd)
2253
1
        assert_eq!(filtered_batch.num_rows(), 2);
2254
1
    }
2255
}