Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-select/src/filter.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines filter kernels
19
20
use std::ops::AddAssign;
21
use std::sync::Arc;
22
23
use arrow_array::builder::BooleanBufferBuilder;
24
use arrow_array::cast::AsArray;
25
use arrow_array::types::{
26
    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27
};
28
use arrow_array::*;
29
use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, NullBuffer, RunEndBuffer};
30
use arrow_buffer::{Buffer, MutableBuffer};
31
use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
32
use arrow_data::transform::MutableArrayData;
33
use arrow_data::ArrayDataBuilder;
34
use arrow_schema::*;
35
36
/// If the filter selects more than this fraction of rows, use
37
/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
38
/// over individual rows using [`IndexIterator`]
39
///
40
/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
41
///
42
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
43
44
/// An iterator of `(usize, usize)` each representing an interval
45
/// `[start, end)` whose slots of a bitmap [Buffer] are true.
46
///
47
/// Each interval corresponds to a contiguous region of memory to be
48
/// "taken" from an array to be filtered.
49
///
50
/// ## Notes:
51
///
52
/// 1. Ignores the validity bitmap (ignores nulls)
53
///
54
/// 2. Only performant for filters that copy across long contiguous runs
55
#[derive(Debug)]
56
pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
57
58
impl<'a> SlicesIterator<'a> {
59
    /// Creates a new iterator from a [BooleanArray]
60
0
    pub fn new(filter: &'a BooleanArray) -> Self {
61
0
        Self(filter.values().set_slices())
62
0
    }
63
}
64
65
impl Iterator for SlicesIterator<'_> {
66
    type Item = (usize, usize);
67
68
0
    fn next(&mut self) -> Option<Self::Item> {
69
0
        self.0.next()
70
0
    }
71
}
72
73
/// An iterator of `usize` whose index in [`BooleanArray`] is true
74
///
75
/// This provides the best performance on most predicates, apart from those which keep
76
/// large runs and therefore favour [`SlicesIterator`]
77
struct IndexIterator<'a> {
78
    remaining: usize,
79
    iter: BitIndexIterator<'a>,
80
}
81
82
impl<'a> IndexIterator<'a> {
83
0
    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
84
0
        assert_eq!(filter.null_count(), 0);
85
0
        let iter = filter.values().set_indices();
86
0
        Self { remaining, iter }
87
0
    }
88
}
89
90
impl Iterator for IndexIterator<'_> {
91
    type Item = usize;
92
93
0
    fn next(&mut self) -> Option<Self::Item> {
94
0
        if self.remaining != 0 {
95
            // Fascinatingly swapping these two lines around results in a 50%
96
            // performance regression for some benchmarks
97
0
            let next = self.iter.next().expect("IndexIterator exhausted early");
98
0
            self.remaining -= 1;
99
            // Must panic if exhausted early as trusted length iterator
100
0
            return Some(next);
101
0
        }
102
0
        None
103
0
    }
104
105
0
    fn size_hint(&self) -> (usize, Option<usize>) {
106
0
        (self.remaining, Some(self.remaining))
107
0
    }
108
}
109
110
/// Counts the number of set bits in `filter`
111
0
fn filter_count(filter: &BooleanArray) -> usize {
112
0
    filter.values().count_set_bits()
113
0
}
114
115
/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
116
0
pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
117
0
    let nulls = filter.nulls().unwrap();
118
0
    let mask = filter.values() & nulls.inner();
119
0
    BooleanArray::new(mask, None)
120
0
}
121
122
/// Returns a filtered `values` [`Array`] where the corresponding elements of
123
/// `predicate` are `true`.
124
///
125
/// # See also
126
/// * [`FilterBuilder`] for more control over the filtering process.
127
/// * [`filter_record_batch`] to filter a [`RecordBatch`]
128
/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
129
///   the results into a single array.
130
///
131
/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
132
///
133
/// # Example
134
/// ```rust
135
/// # use arrow_array::{Int32Array, BooleanArray};
136
/// # use arrow_select::filter::filter;
137
/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
138
/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
139
/// let c = filter(&array, &filter_array).unwrap();
140
/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
141
/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
142
/// ```
143
0
pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
144
0
    let mut filter_builder = FilterBuilder::new(predicate);
145
146
0
    if multiple_arrays(values.data_type()) {
147
0
        // Only optimize if filtering more than one array
148
0
        // Otherwise, the overhead of optimization can be more than the benefit
149
0
        filter_builder = filter_builder.optimize();
150
0
    }
151
152
0
    let predicate = filter_builder.build();
153
154
0
    filter_array(values, &predicate)
155
0
}
156
157
0
fn multiple_arrays(data_type: &DataType) -> bool {
158
0
    match data_type {
159
0
        DataType::Struct(fields) => {
160
0
            fields.len() > 1 || fields.len() == 1 && multiple_arrays(fields[0].data_type())
161
        }
162
0
        DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
163
0
        _ => false,
164
    }
165
0
}
166
167
/// Returns a filtered [RecordBatch] where the corresponding elements of
168
/// `predicate` are true.
169
///
170
/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
171
0
pub fn filter_record_batch(
172
0
    record_batch: &RecordBatch,
173
0
    predicate: &BooleanArray,
174
0
) -> Result<RecordBatch, ArrowError> {
175
0
    let mut filter_builder = FilterBuilder::new(predicate);
176
0
    if record_batch.num_columns() > 1 {
177
0
        // Only optimize if filtering more than one column
178
0
        // Otherwise, the overhead of optimization can be more than the benefit
179
0
        filter_builder = filter_builder.optimize();
180
0
    }
181
0
    let filter = filter_builder.build();
182
183
0
    let filtered_arrays = record_batch
184
0
        .columns()
185
0
        .iter()
186
0
        .map(|a| filter_array(a, &filter))
187
0
        .collect::<Result<Vec<_>, _>>()?;
188
0
    let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
189
0
    RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
190
0
}
191
192
/// A builder to construct [`FilterPredicate`]
193
#[derive(Debug)]
194
pub struct FilterBuilder {
195
    filter: BooleanArray,
196
    count: usize,
197
    strategy: IterationStrategy,
198
}
199
200
impl FilterBuilder {
201
    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
202
0
    pub fn new(filter: &BooleanArray) -> Self {
203
0
        let filter = match filter.null_count() {
204
0
            0 => filter.clone(),
205
0
            _ => prep_null_mask_filter(filter),
206
        };
207
208
0
        let count = filter_count(&filter);
209
0
        let strategy = IterationStrategy::default_strategy(filter.len(), count);
210
211
0
        Self {
212
0
            filter,
213
0
            count,
214
0
            strategy,
215
0
        }
216
0
    }
217
218
    /// Compute an optimised representation of the provided `filter` mask that can be
219
    /// applied to an array more quickly.
220
    ///
221
    /// Note: There is limited benefit to calling this to then filter a single array
222
    /// Note: This will likely have a larger memory footprint than the original mask
223
0
    pub fn optimize(mut self) -> Self {
224
0
        match self.strategy {
225
            IterationStrategy::SlicesIterator => {
226
0
                let slices = SlicesIterator::new(&self.filter).collect();
227
0
                self.strategy = IterationStrategy::Slices(slices)
228
            }
229
            IterationStrategy::IndexIterator => {
230
0
                let indices = IndexIterator::new(&self.filter, self.count).collect();
231
0
                self.strategy = IterationStrategy::Indices(indices)
232
            }
233
0
            _ => {}
234
        }
235
0
        self
236
0
    }
237
238
    /// Construct the final `FilterPredicate`
239
0
    pub fn build(self) -> FilterPredicate {
240
0
        FilterPredicate {
241
0
            filter: self.filter,
242
0
            count: self.count,
243
0
            strategy: self.strategy,
244
0
        }
245
0
    }
246
}
247
248
/// The iteration strategy used to evaluate [`FilterPredicate`]
249
#[derive(Debug)]
250
enum IterationStrategy {
251
    /// A lazily evaluated iterator of ranges
252
    SlicesIterator,
253
    /// A lazily evaluated iterator of indices
254
    IndexIterator,
255
    /// A precomputed list of indices
256
    Indices(Vec<usize>),
257
    /// A precomputed array of ranges
258
    Slices(Vec<(usize, usize)>),
259
    /// Select all rows
260
    All,
261
    /// Select no rows
262
    None,
263
}
264
265
impl IterationStrategy {
266
    /// The default [`IterationStrategy`] for a filter of length `filter_length`
267
    /// and selecting `filter_count` rows
268
0
    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
269
0
        if filter_length == 0 || filter_count == 0 {
270
0
            return IterationStrategy::None;
271
0
        }
272
273
0
        if filter_count == filter_length {
274
0
            return IterationStrategy::All;
275
0
        }
276
277
        // Compute the selectivity of the predicate by dividing the number of true
278
        // bits in the predicate by the predicate's total length
279
        //
280
        // This can then be used as a heuristic for the optimal iteration strategy
281
0
        let selectivity_frac = filter_count as f64 / filter_length as f64;
282
0
        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
283
0
            return IterationStrategy::SlicesIterator;
284
0
        }
285
0
        IterationStrategy::IndexIterator
286
0
    }
287
}
288
289
/// A filtering predicate that can be applied to an [`Array`]
290
#[derive(Debug)]
291
pub struct FilterPredicate {
292
    filter: BooleanArray,
293
    count: usize,
294
    strategy: IterationStrategy,
295
}
296
297
impl FilterPredicate {
298
    /// Selects rows from `values` based on this [`FilterPredicate`]
299
0
    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
300
0
        filter_array(values, self)
301
0
    }
302
303
    /// Number of rows being selected based on this [`FilterPredicate`]
304
0
    pub fn count(&self) -> usize {
305
0
        self.count
306
0
    }
307
}
308
309
0
fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
310
0
    if predicate.filter.len() > values.len() {
311
0
        return Err(ArrowError::InvalidArgumentError(format!(
312
0
            "Filter predicate of length {} is larger than target array of length {}",
313
0
            predicate.filter.len(),
314
0
            values.len()
315
0
        )));
316
0
    }
317
318
0
    match predicate.strategy {
319
0
        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
320
0
        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
321
        // actually filter
322
0
        _ => downcast_primitive_array! {
323
0
            values => Ok(Arc::new(filter_primitive(values, predicate))),
324
            DataType::Boolean => {
325
0
                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
326
0
                Ok(Arc::new(filter_boolean(values, predicate)))
327
            }
328
            DataType::Utf8 => {
329
0
                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
330
            }
331
            DataType::LargeUtf8 => {
332
0
                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
333
            }
334
            DataType::Utf8View => {
335
0
                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
336
            }
337
            DataType::Binary => {
338
0
                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
339
            }
340
            DataType::LargeBinary => {
341
0
                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
342
            }
343
            DataType::BinaryView => {
344
0
                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
345
            }
346
            DataType::FixedSizeBinary(_) => {
347
0
                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
348
            }
349
            DataType::RunEndEncoded(_, _) => {
350
0
                downcast_run_array!{
351
0
                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
352
0
                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
353
                }
354
            }
355
0
            DataType::Dictionary(_, _) => downcast_dictionary_array! {
356
0
                values => Ok(Arc::new(filter_dict(values, predicate))),
357
0
                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
358
            }
359
            DataType::Struct(_) => {
360
0
                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
361
            }
362
            DataType::Union(_, UnionMode::Sparse) => {
363
0
                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
364
            }
365
            _ => {
366
0
                let data = values.to_data();
367
                // fallback to using MutableArrayData
368
0
                let mut mutable = MutableArrayData::new(
369
0
                    vec![&data],
370
                    false,
371
0
                    predicate.count,
372
                );
373
374
0
                match &predicate.strategy {
375
0
                    IterationStrategy::Slices(slices) => {
376
0
                        slices
377
0
                            .iter()
378
0
                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
379
                    }
380
                    _ => {
381
0
                        let iter = SlicesIterator::new(&predicate.filter);
382
0
                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
383
                    }
384
                }
385
386
0
                let data = mutable.freeze();
387
0
                Ok(make_array(data))
388
            }
389
        },
390
    }
391
0
}
392
393
/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
394
0
fn filter_run_end_array<R: RunEndIndexType>(
395
0
    array: &RunArray<R>,
396
0
    predicate: &FilterPredicate,
397
0
) -> Result<RunArray<R>, ArrowError>
398
0
where
399
0
    R::Native: Into<i64> + From<bool>,
400
0
    R::Native: AddAssign,
401
{
402
0
    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
403
0
    let mut new_run_ends = vec![R::default_value(); run_ends.len()];
404
405
0
    let mut start = 0u64;
406
0
    let mut j = 0;
407
0
    let mut count = R::default_value();
408
0
    let filter_values = predicate.filter.values();
409
0
    let run_ends = run_ends.inner();
410
411
0
    let pred: BooleanArray = BooleanBuffer::collect_bool(run_ends.len(), |i| {
412
0
        let mut keep = false;
413
0
        let mut end = run_ends[i].into() as u64;
414
0
        let difference = end.saturating_sub(filter_values.len() as u64);
415
0
        end -= difference;
416
417
        // Safety: we subtract the difference off `end` so we are always within bounds
418
0
        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
419
0
            count += R::Native::from(pred);
420
0
            keep |= pred
421
        }
422
        // this is to avoid branching
423
0
        new_run_ends[j] = count;
424
0
        j += keep as usize;
425
426
0
        start = end;
427
0
        keep
428
0
    })
429
0
    .into();
430
431
0
    new_run_ends.truncate(j);
432
433
0
    let values = array.values();
434
0
    let values = filter(&values, &pred)?;
435
436
0
    let run_ends = PrimitiveArray::<R>::new(new_run_ends.into(), None);
437
0
    RunArray::try_new(&run_ends, &values)
438
0
}
439
440
/// Computes a new null mask for `data` based on `predicate`
441
///
442
/// If the predicate selected no null-rows, returns `None`, otherwise returns
443
/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
444
/// in the filtered output, and `null_buffer` is the filtered null buffer
445
///
446
0
fn filter_null_mask(
447
0
    nulls: Option<&NullBuffer>,
448
0
    predicate: &FilterPredicate,
449
0
) -> Option<(usize, Buffer)> {
450
0
    let nulls = nulls?;
451
0
    if nulls.null_count() == 0 {
452
0
        return None;
453
0
    }
454
455
0
    let nulls = filter_bits(nulls.inner(), predicate);
456
    // The filtered `nulls` has a length of `predicate.count` bits and
457
    // therefore the null count is this minus the number of valid bits
458
0
    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
459
460
0
    if null_count == 0 {
461
0
        return None;
462
0
    }
463
464
0
    Some((null_count, nulls))
465
0
}
466
467
/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
468
0
fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
469
0
    let src = buffer.values();
470
0
    let offset = buffer.offset();
471
472
0
    match &predicate.strategy {
473
        IterationStrategy::IndexIterator => {
474
0
            let bits = IndexIterator::new(&predicate.filter, predicate.count)
475
0
                .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
476
477
            // SAFETY: `IndexIterator` reports its size correctly
478
0
            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
479
        }
480
0
        IterationStrategy::Indices(indices) => {
481
0
            let bits = indices
482
0
                .iter()
483
0
                .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
484
485
            // SAFETY: `Vec::iter()` reports its size correctly
486
0
            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
487
        }
488
        IterationStrategy::SlicesIterator => {
489
0
            let mut builder = BooleanBufferBuilder::new(predicate.count);
490
0
            for (start, end) in SlicesIterator::new(&predicate.filter) {
491
0
                builder.append_packed_range(start + offset..end + offset, src)
492
            }
493
0
            builder.into()
494
        }
495
0
        IterationStrategy::Slices(slices) => {
496
0
            let mut builder = BooleanBufferBuilder::new(predicate.count);
497
0
            for (start, end) in slices {
498
0
                builder.append_packed_range(*start + offset..*end + offset, src)
499
            }
500
0
            builder.into()
501
        }
502
0
        IterationStrategy::All | IterationStrategy::None => unreachable!(),
503
    }
504
0
}
505
506
/// `filter` implementation for boolean buffers
507
0
fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
508
0
    let values = filter_bits(array.values(), predicate);
509
510
0
    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
511
0
        .len(predicate.count)
512
0
        .add_buffer(values);
513
514
0
    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
515
0
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
516
0
    }
517
518
0
    let data = unsafe { builder.build_unchecked() };
519
0
    BooleanArray::from(data)
520
0
}
521
522
#[inline(never)]
523
0
fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
524
0
    assert!(values.len() >= predicate.filter.len());
525
526
0
    match &predicate.strategy {
527
        IterationStrategy::SlicesIterator => {
528
0
            let mut buffer = Vec::with_capacity(predicate.count);
529
0
            for (start, end) in SlicesIterator::new(&predicate.filter) {
530
0
                buffer.extend_from_slice(&values[start..end]);
531
0
            }
532
0
            buffer.into()
533
        }
534
0
        IterationStrategy::Slices(slices) => {
535
0
            let mut buffer = Vec::with_capacity(predicate.count);
536
0
            for (start, end) in slices {
537
0
                buffer.extend_from_slice(&values[*start..*end]);
538
0
            }
539
0
            buffer.into()
540
        }
541
        IterationStrategy::IndexIterator => {
542
0
            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
543
544
            // SAFETY: IndexIterator is trusted length
545
0
            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
546
        }
547
0
        IterationStrategy::Indices(indices) => {
548
0
            let iter = indices.iter().map(|x| values[*x]);
549
0
            iter.collect::<Vec<_>>().into()
550
        }
551
0
        IterationStrategy::All | IterationStrategy::None => unreachable!(),
552
    }
553
0
}
554
555
/// `filter` implementation for primitive arrays
556
0
fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
557
0
where
558
0
    T: ArrowPrimitiveType,
559
{
560
0
    let values = array.values();
561
0
    let buffer = filter_native(values, predicate);
562
0
    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
563
0
        .len(predicate.count)
564
0
        .add_buffer(buffer);
565
566
0
    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
567
0
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
568
0
    }
569
570
0
    let data = unsafe { builder.build_unchecked() };
571
0
    PrimitiveArray::from(data)
572
0
}
573
574
/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
575
/// used to build a new [`GenericByteArray`] by copying values from the source
576
///
577
/// TODO(raphael): Could this be used for the take kernel as well?
578
struct FilterBytes<'a, OffsetSize> {
579
    src_offsets: &'a [OffsetSize],
580
    src_values: &'a [u8],
581
    dst_offsets: Vec<OffsetSize>,
582
    dst_values: Vec<u8>,
583
    cur_offset: OffsetSize,
584
}
585
586
impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
587
where
588
    OffsetSize: OffsetSizeTrait,
589
{
590
0
    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
591
0
    where
592
0
        T: ByteArrayType<Offset = OffsetSize>,
593
    {
594
0
        let dst_values = Vec::new();
595
0
        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
596
0
        let cur_offset = OffsetSize::from_usize(0).unwrap();
597
598
0
        dst_offsets.push(cur_offset);
599
600
0
        Self {
601
0
            src_offsets: array.value_offsets(),
602
0
            src_values: array.value_data(),
603
0
            dst_offsets,
604
0
            dst_values,
605
0
            cur_offset,
606
0
        }
607
0
    }
608
609
    /// Returns the byte offset at `idx`
610
    #[inline]
611
0
    fn get_value_offset(&self, idx: usize) -> usize {
612
0
        self.src_offsets[idx].as_usize()
613
0
    }
614
615
    /// Returns the start and end of the value at index `idx` along with its length
616
    #[inline]
617
0
    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
618
        // These can only fail if `array` contains invalid data
619
0
        let start = self.get_value_offset(idx);
620
0
        let end = self.get_value_offset(idx + 1);
621
0
        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
622
0
        (start, end, len)
623
0
    }
624
625
0
    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
626
0
        self.dst_offsets.extend(iter.map(|idx| {
627
0
            let start = self.src_offsets[idx].as_usize();
628
0
            let end = self.src_offsets[idx + 1].as_usize();
629
0
            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
630
0
            self.cur_offset += len;
631
632
0
            self.cur_offset
633
0
        }));
634
0
    }
635
636
    /// Extends the in-progress array by the indexes in the provided iterator
637
0
    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
638
0
        self.dst_values.reserve_exact(self.cur_offset.as_usize());
639
640
0
        for idx in iter {
641
0
            let start = self.src_offsets[idx].as_usize();
642
0
            let end = self.src_offsets[idx + 1].as_usize();
643
0
            self.dst_values
644
0
                .extend_from_slice(&self.src_values[start..end]);
645
0
        }
646
0
    }
647
648
0
    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
649
0
        self.dst_offsets.reserve_exact(count);
650
0
        for (start, end) in iter {
651
            // These can only fail if `array` contains invalid data
652
0
            for idx in start..end {
653
0
                let (_, _, len) = self.get_value_range(idx);
654
0
                self.cur_offset += len;
655
0
                self.dst_offsets.push(self.cur_offset);
656
0
            }
657
        }
658
0
    }
659
660
    /// Extends the in-progress array by the ranges in the provided iterator
661
0
    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
662
0
        self.dst_values.reserve_exact(self.cur_offset.as_usize());
663
664
0
        for (start, end) in iter {
665
0
            let value_start = self.get_value_offset(start);
666
0
            let value_end = self.get_value_offset(end);
667
0
            self.dst_values
668
0
                .extend_from_slice(&self.src_values[value_start..value_end]);
669
0
        }
670
0
    }
671
}
672
673
/// `filter` implementation for byte arrays
674
///
675
/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
676
/// data copied across. This allows handling the null mask separately from the data
677
0
fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
678
0
where
679
0
    T: ByteArrayType,
680
{
681
0
    let mut filter = FilterBytes::new(predicate.count, array);
682
683
0
    match &predicate.strategy {
684
        IterationStrategy::SlicesIterator => {
685
0
            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
686
0
            filter.extend_slices(SlicesIterator::new(&predicate.filter))
687
        }
688
0
        IterationStrategy::Slices(slices) => {
689
0
            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
690
0
            filter.extend_slices(slices.iter().cloned())
691
        }
692
        IterationStrategy::IndexIterator => {
693
0
            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
694
0
            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
695
        }
696
0
        IterationStrategy::Indices(indices) => {
697
0
            filter.extend_offsets_idx(indices.iter().cloned());
698
0
            filter.extend_idx(indices.iter().cloned())
699
        }
700
0
        IterationStrategy::All | IterationStrategy::None => unreachable!(),
701
    }
702
703
0
    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
704
0
        .len(predicate.count)
705
0
        .add_buffer(filter.dst_offsets.into())
706
0
        .add_buffer(filter.dst_values.into());
707
708
0
    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
709
0
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
710
0
    }
711
712
0
    let data = unsafe { builder.build_unchecked() };
713
0
    GenericByteArray::from(data)
714
0
}
715
716
/// `filter` implementation for byte view arrays.
717
0
fn filter_byte_view<T: ByteViewType>(
718
0
    array: &GenericByteViewArray<T>,
719
0
    predicate: &FilterPredicate,
720
0
) -> GenericByteViewArray<T> {
721
0
    let new_view_buffer = filter_native(array.views(), predicate);
722
723
0
    let mut builder = ArrayDataBuilder::new(T::DATA_TYPE)
724
0
        .len(predicate.count)
725
0
        .add_buffer(new_view_buffer)
726
0
        .add_buffers(array.data_buffers().to_vec());
727
728
0
    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
729
0
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
730
0
    }
731
732
0
    GenericByteViewArray::from(unsafe { builder.build_unchecked() })
733
0
}
734
735
0
fn filter_fixed_size_binary(
736
0
    array: &FixedSizeBinaryArray,
737
0
    predicate: &FilterPredicate,
738
0
) -> FixedSizeBinaryArray {
739
0
    let values: &[u8] = array.values();
740
0
    let value_length = array.value_length() as usize;
741
0
    let calculate_offset_from_index = |index: usize| index * value_length;
742
0
    let buffer = match &predicate.strategy {
743
        IterationStrategy::SlicesIterator => {
744
0
            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
745
0
            for (start, end) in SlicesIterator::new(&predicate.filter) {
746
0
                buffer.extend_from_slice(
747
0
                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
748
0
                );
749
0
            }
750
0
            buffer
751
        }
752
0
        IterationStrategy::Slices(slices) => {
753
0
            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
754
0
            for (start, end) in slices {
755
0
                buffer.extend_from_slice(
756
0
                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
757
0
                );
758
0
            }
759
0
            buffer
760
        }
761
        IterationStrategy::IndexIterator => {
762
0
            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
763
0
                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
764
0
            });
765
766
0
            let mut buffer = MutableBuffer::new(predicate.count * value_length);
767
0
            iter.for_each(|item| buffer.extend_from_slice(item));
768
0
            buffer
769
        }
770
0
        IterationStrategy::Indices(indices) => {
771
0
            let iter = indices.iter().map(|x| {
772
0
                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
773
0
            });
774
775
0
            let mut buffer = MutableBuffer::new(predicate.count * value_length);
776
0
            iter.for_each(|item| buffer.extend_from_slice(item));
777
0
            buffer
778
        }
779
0
        IterationStrategy::All | IterationStrategy::None => unreachable!(),
780
    };
781
0
    let mut builder = ArrayDataBuilder::new(array.data_type().clone())
782
0
        .len(predicate.count)
783
0
        .add_buffer(buffer.into());
784
785
0
    if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
786
0
        builder = builder.null_count(null_count).null_bit_buffer(Some(nulls));
787
0
    }
788
789
0
    let data = unsafe { builder.build_unchecked() };
790
0
    FixedSizeBinaryArray::from(data)
791
0
}
792
793
/// `filter` implementation for dictionaries
794
0
fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
795
0
where
796
0
    T: ArrowDictionaryKeyType,
797
0
    T::Native: num::Num,
798
{
799
0
    let builder = filter_primitive::<T>(array.keys(), predicate)
800
0
        .into_data()
801
0
        .into_builder()
802
0
        .data_type(array.data_type().clone())
803
0
        .child_data(vec![array.values().to_data()]);
804
805
    // SAFETY:
806
    // Keys were valid before, filtered subset is therefore still valid
807
0
    DictionaryArray::from(unsafe { builder.build_unchecked() })
808
0
}
809
810
/// `filter` implementation for structs
811
0
fn filter_struct(
812
0
    array: &StructArray,
813
0
    predicate: &FilterPredicate,
814
0
) -> Result<StructArray, ArrowError> {
815
0
    let columns = array
816
0
        .columns()
817
0
        .iter()
818
0
        .map(|column| filter_array(column, predicate))
819
0
        .collect::<Result<_, _>>()?;
820
821
0
    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
822
0
        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
823
824
0
        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
825
    } else {
826
0
        None
827
    };
828
829
0
    Ok(unsafe {
830
0
        StructArray::new_unchecked_with_length(
831
0
            array.fields().clone(),
832
0
            columns,
833
0
            nulls,
834
0
            predicate.count(),
835
0
        )
836
0
    })
837
0
}
838
839
/// `filter` implementation for sparse unions
840
0
fn filter_sparse_union(
841
0
    array: &UnionArray,
842
0
    predicate: &FilterPredicate,
843
0
) -> Result<UnionArray, ArrowError> {
844
0
    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
845
0
        unreachable!()
846
    };
847
848
0
    let type_ids = filter_primitive(&Int8Array::new(array.type_ids().clone(), None), predicate);
849
850
0
    let children = fields
851
0
        .iter()
852
0
        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
853
0
        .collect::<Result<_, _>>()?;
854
855
0
    Ok(unsafe {
856
0
        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
857
0
    })
858
0
}
859
860
#[cfg(test)]
861
mod tests {
862
    use super::*;
863
    use arrow_array::builder::*;
864
    use arrow_array::cast::as_run_array;
865
    use arrow_array::types::*;
866
    use arrow_data::ArrayData;
867
    use rand::distr::uniform::{UniformSampler, UniformUsize};
868
    use rand::distr::{Alphanumeric, StandardUniform};
869
    use rand::prelude::*;
870
    use rand::rng;
871
872
    macro_rules! def_temporal_test {
873
        ($test:ident, $array_type: ident, $data: expr) => {
874
            #[test]
875
            fn $test() {
876
                let a = $data;
877
                let b = BooleanArray::from(vec![true, false, true, false]);
878
                let c = filter(&a, &b).unwrap();
879
                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
880
                assert_eq!(2, d.len());
881
                assert_eq!(1, d.value(0));
882
                assert_eq!(3, d.value(1));
883
            }
884
        };
885
    }
886
887
    def_temporal_test!(
888
        test_filter_date32,
889
        Date32Array,
890
        Date32Array::from(vec![1, 2, 3, 4])
891
    );
892
    def_temporal_test!(
893
        test_filter_date64,
894
        Date64Array,
895
        Date64Array::from(vec![1, 2, 3, 4])
896
    );
897
    def_temporal_test!(
898
        test_filter_time32_second,
899
        Time32SecondArray,
900
        Time32SecondArray::from(vec![1, 2, 3, 4])
901
    );
902
    def_temporal_test!(
903
        test_filter_time32_millisecond,
904
        Time32MillisecondArray,
905
        Time32MillisecondArray::from(vec![1, 2, 3, 4])
906
    );
907
    def_temporal_test!(
908
        test_filter_time64_microsecond,
909
        Time64MicrosecondArray,
910
        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
911
    );
912
    def_temporal_test!(
913
        test_filter_time64_nanosecond,
914
        Time64NanosecondArray,
915
        Time64NanosecondArray::from(vec![1, 2, 3, 4])
916
    );
917
    def_temporal_test!(
918
        test_filter_duration_second,
919
        DurationSecondArray,
920
        DurationSecondArray::from(vec![1, 2, 3, 4])
921
    );
922
    def_temporal_test!(
923
        test_filter_duration_millisecond,
924
        DurationMillisecondArray,
925
        DurationMillisecondArray::from(vec![1, 2, 3, 4])
926
    );
927
    def_temporal_test!(
928
        test_filter_duration_microsecond,
929
        DurationMicrosecondArray,
930
        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
931
    );
932
    def_temporal_test!(
933
        test_filter_duration_nanosecond,
934
        DurationNanosecondArray,
935
        DurationNanosecondArray::from(vec![1, 2, 3, 4])
936
    );
937
    def_temporal_test!(
938
        test_filter_timestamp_second,
939
        TimestampSecondArray,
940
        TimestampSecondArray::from(vec![1, 2, 3, 4])
941
    );
942
    def_temporal_test!(
943
        test_filter_timestamp_millisecond,
944
        TimestampMillisecondArray,
945
        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
946
    );
947
    def_temporal_test!(
948
        test_filter_timestamp_microsecond,
949
        TimestampMicrosecondArray,
950
        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
951
    );
952
    def_temporal_test!(
953
        test_filter_timestamp_nanosecond,
954
        TimestampNanosecondArray,
955
        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
956
    );
957
958
    #[test]
959
    fn test_filter_array_slice() {
960
        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
961
        let b = BooleanArray::from(vec![true, false, false, true]);
962
        // filtering with sliced filter array is not currently supported
963
        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
964
        // let b = b_slice.as_any().downcast_ref().unwrap();
965
        let c = filter(&a, &b).unwrap();
966
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
967
        assert_eq!(2, d.len());
968
        assert_eq!(6, d.value(0));
969
        assert_eq!(9, d.value(1));
970
    }
971
972
    #[test]
973
    fn test_filter_array_low_density() {
974
        // this test exercises the all 0's branch of the filter algorithm
975
        let mut data_values = (1..=65).collect::<Vec<i32>>();
976
        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
977
        // set up two more values after the batch
978
        data_values.extend_from_slice(&[66, 67]);
979
        filter_values.extend_from_slice(&[false, true]);
980
        let a = Int32Array::from(data_values);
981
        let b = BooleanArray::from(filter_values);
982
        let c = filter(&a, &b).unwrap();
983
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
984
        assert_eq!(2, d.len());
985
        assert_eq!(65, d.value(0));
986
        assert_eq!(67, d.value(1));
987
    }
988
989
    #[test]
990
    fn test_filter_array_high_density() {
991
        // this test exercises the all 1's branch of the filter algorithm
992
        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
993
        let mut filter_values = (1..=65)
994
            .map(|i| !matches!(i % 65, 0))
995
            .collect::<Vec<bool>>();
996
        // set second data value to null
997
        data_values[1] = None;
998
        // set up two more values after the batch
999
        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1000
        filter_values.extend_from_slice(&[false, true, true, true]);
1001
        let a = Int32Array::from(data_values);
1002
        let b = BooleanArray::from(filter_values);
1003
        let c = filter(&a, &b).unwrap();
1004
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1005
        assert_eq!(67, d.len());
1006
        assert_eq!(3, d.null_count());
1007
        assert_eq!(1, d.value(0));
1008
        assert!(d.is_null(1));
1009
        assert_eq!(64, d.value(63));
1010
        assert!(d.is_null(64));
1011
        assert_eq!(67, d.value(65));
1012
    }
1013
1014
    #[test]
1015
    fn test_filter_string_array_simple() {
1016
        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1017
        let b = BooleanArray::from(vec![true, false, true, false]);
1018
        let c = filter(&a, &b).unwrap();
1019
        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1020
        assert_eq!(2, d.len());
1021
        assert_eq!("hello", d.value(0));
1022
        assert_eq!("world", d.value(1));
1023
    }
1024
1025
    #[test]
1026
    fn test_filter_primitive_array_with_null() {
1027
        let a = Int32Array::from(vec![Some(5), None]);
1028
        let b = BooleanArray::from(vec![false, true]);
1029
        let c = filter(&a, &b).unwrap();
1030
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1031
        assert_eq!(1, d.len());
1032
        assert!(d.is_null(0));
1033
    }
1034
1035
    #[test]
1036
    fn test_filter_string_array_with_null() {
1037
        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1038
        let b = BooleanArray::from(vec![true, false, false, true]);
1039
        let c = filter(&a, &b).unwrap();
1040
        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1041
        assert_eq!(2, d.len());
1042
        assert_eq!("hello", d.value(0));
1043
        assert!(!d.is_null(0));
1044
        assert!(d.is_null(1));
1045
    }
1046
1047
    #[test]
1048
    fn test_filter_binary_array_with_null() {
1049
        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1050
        let a = BinaryArray::from(data);
1051
        let b = BooleanArray::from(vec![true, false, false, true]);
1052
        let c = filter(&a, &b).unwrap();
1053
        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1054
        assert_eq!(2, d.len());
1055
        assert_eq!(b"hello", d.value(0));
1056
        assert!(!d.is_null(0));
1057
        assert!(d.is_null(1));
1058
    }
1059
1060
    fn _test_filter_byte_view<T>()
1061
    where
1062
        T: ByteViewType,
1063
        str: AsRef<T::Native>,
1064
        T::Native: PartialEq,
1065
    {
1066
        let array = {
1067
            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1068
            let mut builder = GenericByteViewBuilder::<T>::new();
1069
            builder.append_value("hello");
1070
            builder.append_value("world");
1071
            builder.append_null();
1072
            builder.append_value("large payload over 12 bytes");
1073
            builder.append_value("lulu");
1074
            builder.finish()
1075
        };
1076
1077
        {
1078
            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1079
            let actual = filter(&array, &predicate).unwrap();
1080
1081
            assert_eq!(actual.len(), 3);
1082
1083
            let expected = {
1084
                // ["hello", null, "large payload over 12 bytes"]
1085
                let mut builder = GenericByteViewBuilder::<T>::new();
1086
                builder.append_value("hello");
1087
                builder.append_null();
1088
                builder.append_value("large payload over 12 bytes");
1089
                builder.finish()
1090
            };
1091
1092
            assert_eq!(actual.as_ref(), &expected);
1093
        }
1094
1095
        {
1096
            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1097
            let actual = filter(&array, &predicate).unwrap();
1098
1099
            assert_eq!(actual.len(), 2);
1100
1101
            let expected = {
1102
                // ["hello", "lulu"]
1103
                let mut builder = GenericByteViewBuilder::<T>::new();
1104
                builder.append_value("hello");
1105
                builder.append_value("lulu");
1106
                builder.finish()
1107
            };
1108
1109
            assert_eq!(actual.as_ref(), &expected);
1110
        }
1111
    }
1112
1113
    #[test]
1114
    fn test_filter_string_view() {
1115
        _test_filter_byte_view::<StringViewType>()
1116
    }
1117
1118
    #[test]
1119
    fn test_filter_binary_view() {
1120
        _test_filter_byte_view::<BinaryViewType>()
1121
    }
1122
1123
    #[test]
1124
    fn test_filter_fixed_binary() {
1125
        let v1 = [1_u8, 2];
1126
        let v2 = [3_u8, 4];
1127
        let v3 = [5_u8, 6];
1128
        let v = vec![&v1, &v2, &v3];
1129
        let a = FixedSizeBinaryArray::from(v);
1130
        let b = BooleanArray::from(vec![true, false, true]);
1131
        let c = filter(&a, &b).unwrap();
1132
        let d = c
1133
            .as_ref()
1134
            .as_any()
1135
            .downcast_ref::<FixedSizeBinaryArray>()
1136
            .unwrap();
1137
        assert_eq!(d.len(), 2);
1138
        assert_eq!(d.value(0), &v1);
1139
        assert_eq!(d.value(1), &v3);
1140
        let c2 = FilterBuilder::new(&b)
1141
            .optimize()
1142
            .build()
1143
            .filter(&a)
1144
            .unwrap();
1145
        let d2 = c2
1146
            .as_ref()
1147
            .as_any()
1148
            .downcast_ref::<FixedSizeBinaryArray>()
1149
            .unwrap();
1150
        assert_eq!(d, d2);
1151
1152
        let b = BooleanArray::from(vec![false, false, false]);
1153
        let c = filter(&a, &b).unwrap();
1154
        let d = c
1155
            .as_ref()
1156
            .as_any()
1157
            .downcast_ref::<FixedSizeBinaryArray>()
1158
            .unwrap();
1159
        assert_eq!(d.len(), 0);
1160
1161
        let b = BooleanArray::from(vec![true, true, true]);
1162
        let c = filter(&a, &b).unwrap();
1163
        let d = c
1164
            .as_ref()
1165
            .as_any()
1166
            .downcast_ref::<FixedSizeBinaryArray>()
1167
            .unwrap();
1168
        assert_eq!(d.len(), 3);
1169
        assert_eq!(d.value(0), &v1);
1170
        assert_eq!(d.value(1), &v2);
1171
        assert_eq!(d.value(2), &v3);
1172
1173
        let b = BooleanArray::from(vec![false, false, true]);
1174
        let c = filter(&a, &b).unwrap();
1175
        let d = c
1176
            .as_ref()
1177
            .as_any()
1178
            .downcast_ref::<FixedSizeBinaryArray>()
1179
            .unwrap();
1180
        assert_eq!(d.len(), 1);
1181
        assert_eq!(d.value(0), &v3);
1182
        let c2 = FilterBuilder::new(&b)
1183
            .optimize()
1184
            .build()
1185
            .filter(&a)
1186
            .unwrap();
1187
        let d2 = c2
1188
            .as_ref()
1189
            .as_any()
1190
            .downcast_ref::<FixedSizeBinaryArray>()
1191
            .unwrap();
1192
        assert_eq!(d, d2);
1193
    }
1194
1195
    #[test]
1196
    fn test_filter_array_slice_with_null() {
1197
        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1198
        let b = BooleanArray::from(vec![true, false, false, true]);
1199
        // filtering with sliced filter array is not currently supported
1200
        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1201
        // let b = b_slice.as_any().downcast_ref().unwrap();
1202
        let c = filter(&a, &b).unwrap();
1203
        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1204
        assert_eq!(2, d.len());
1205
        assert!(d.is_null(0));
1206
        assert!(!d.is_null(1));
1207
        assert_eq!(9, d.value(1));
1208
    }
1209
1210
    #[test]
1211
    fn test_filter_run_end_encoding_array() {
1212
        let run_ends = Int64Array::from(vec![2, 3, 8]);
1213
        let values = Int64Array::from(vec![7, -2, 9]);
1214
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1215
        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1216
        let c = filter(&a, &b).unwrap();
1217
        let actual: &RunArray<Int64Type> = as_run_array(&c);
1218
        assert_eq!(4, actual.len());
1219
1220
        let expected = RunArray::try_new(
1221
            &Int64Array::from(vec![1, 2, 4]),
1222
            &Int64Array::from(vec![7, -2, 9]),
1223
        )
1224
        .expect("Failed to make expected RunArray test is broken");
1225
1226
        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1227
        assert_eq!(actual.values(), expected.values())
1228
    }
1229
1230
    #[test]
1231
    fn test_filter_run_end_encoding_array_remove_value() {
1232
        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1233
        let values = Int32Array::from(vec![7, -2, 9, -8]);
1234
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1235
        let b = BooleanArray::from(vec![
1236
            false, true, false, false, true, false, true, false, false, false,
1237
        ]);
1238
        let c = filter(&a, &b).unwrap();
1239
        let actual: &RunArray<Int32Type> = as_run_array(&c);
1240
        assert_eq!(3, actual.len());
1241
1242
        let expected =
1243
            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1244
                .expect("Failed to make expected RunArray test is broken");
1245
1246
        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1247
        assert_eq!(actual.values(), expected.values())
1248
    }
1249
1250
    #[test]
1251
    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1252
        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1253
        let values = Int16Array::from(vec![7, -2, 9, -8]);
1254
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1255
        let b = BooleanArray::from(vec![
1256
            false, false, false, false, false, false, true, false, false, false,
1257
        ]);
1258
        let c = filter(&a, &b).unwrap();
1259
        let actual: &RunArray<Int16Type> = as_run_array(&c);
1260
        assert_eq!(1, actual.len());
1261
1262
        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1263
            .expect("Failed to make expected RunArray test is broken");
1264
1265
        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1266
        assert_eq!(actual.values(), expected.values())
1267
    }
1268
1269
    #[test]
1270
    fn test_filter_run_end_encoding_array_empty() {
1271
        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1272
        let values = Int64Array::from(vec![7, -2, 9, -8]);
1273
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1274
        let b = BooleanArray::from(vec![
1275
            false, false, false, false, false, false, false, false, false, false,
1276
        ]);
1277
        let c = filter(&a, &b).unwrap();
1278
        let actual: &RunArray<Int64Type> = as_run_array(&c);
1279
        assert_eq!(0, actual.len());
1280
    }
1281
1282
    #[test]
1283
    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1284
        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1285
        let values = Int64Array::from(vec![7, -2, 9, -8]);
1286
        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1287
        let b = BooleanArray::from(vec![false, true, true]);
1288
        let c = filter(&a, &b).unwrap();
1289
        let actual: &RunArray<Int64Type> = as_run_array(&c);
1290
        assert_eq!(2, actual.len());
1291
1292
        let expected = RunArray::try_new(
1293
            &Int64Array::from(vec![1, 2]),
1294
            &Int64Array::from(vec![7, -2]),
1295
        )
1296
        .expect("Failed to make expected RunArray test is broken");
1297
1298
        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1299
        assert_eq!(actual.values(), expected.values())
1300
    }
1301
1302
    #[test]
1303
    fn test_filter_dictionary_array() {
1304
        let values = [Some("hello"), None, Some("world"), Some("!")];
1305
        let a: Int8DictionaryArray = values.iter().copied().collect();
1306
        let b = BooleanArray::from(vec![false, true, true, false]);
1307
        let c = filter(&a, &b).unwrap();
1308
        let d = c
1309
            .as_ref()
1310
            .as_any()
1311
            .downcast_ref::<Int8DictionaryArray>()
1312
            .unwrap();
1313
        let value_array = d.values();
1314
        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1315
        // values are cloned in the filtered dictionary array
1316
        assert_eq!(3, values.len());
1317
        // but keys are filtered
1318
        assert_eq!(2, d.len());
1319
        assert!(d.is_null(0));
1320
        assert_eq!("world", values.value(d.keys().value(1) as usize));
1321
    }
1322
1323
    #[test]
1324
    fn test_filter_list_array() {
1325
        let value_data = ArrayData::builder(DataType::Int32)
1326
            .len(8)
1327
            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
1328
            .build()
1329
            .unwrap();
1330
1331
        let value_offsets = Buffer::from_slice_ref([0i64, 3, 6, 8, 8]);
1332
1333
        let list_data_type =
1334
            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1335
        let list_data = ArrayData::builder(list_data_type)
1336
            .len(4)
1337
            .add_buffer(value_offsets)
1338
            .add_child_data(value_data)
1339
            .null_bit_buffer(Some(Buffer::from([0b00000111])))
1340
            .build()
1341
            .unwrap();
1342
1343
        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1344
        let a = LargeListArray::from(list_data);
1345
        let b = BooleanArray::from(vec![false, true, false, true]);
1346
        let result = filter(&a, &b).unwrap();
1347
1348
        // expected: [[3, 4, 5], null]
1349
        let value_data = ArrayData::builder(DataType::Int32)
1350
            .len(3)
1351
            .add_buffer(Buffer::from_slice_ref([3, 4, 5]))
1352
            .build()
1353
            .unwrap();
1354
1355
        let value_offsets = Buffer::from_slice_ref([0i64, 3, 3]);
1356
1357
        let list_data_type =
1358
            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int32, false)));
1359
        let expected = ArrayData::builder(list_data_type)
1360
            .len(2)
1361
            .add_buffer(value_offsets)
1362
            .add_child_data(value_data)
1363
            .null_bit_buffer(Some(Buffer::from([0b00000001])))
1364
            .build()
1365
            .unwrap();
1366
1367
        assert_eq!(&make_array(expected), &result);
1368
    }
1369
1370
    #[test]
1371
    fn test_slice_iterator_bits() {
1372
        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1373
        let filter = BooleanArray::from(filter_values);
1374
        let filter_count = filter_count(&filter);
1375
1376
        let iter = SlicesIterator::new(&filter);
1377
        let chunks = iter.collect::<Vec<_>>();
1378
1379
        assert_eq!(chunks, vec![(1, 2)]);
1380
        assert_eq!(filter_count, 1);
1381
    }
1382
1383
    #[test]
1384
    fn test_slice_iterator_bits1() {
1385
        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1386
        let filter = BooleanArray::from(filter_values);
1387
        let filter_count = filter_count(&filter);
1388
1389
        let iter = SlicesIterator::new(&filter);
1390
        let chunks = iter.collect::<Vec<_>>();
1391
1392
        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1393
        assert_eq!(filter_count, 64 - 1);
1394
    }
1395
1396
    #[test]
1397
    fn test_slice_iterator_chunk_and_bits() {
1398
        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1399
        let filter = BooleanArray::from(filter_values);
1400
        let filter_count = filter_count(&filter);
1401
1402
        let iter = SlicesIterator::new(&filter);
1403
        let chunks = iter.collect::<Vec<_>>();
1404
1405
        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1406
        assert_eq!(filter_count, 61 + 61 + 5);
1407
    }
1408
1409
    #[test]
1410
    fn test_null_mask() {
1411
        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1412
1413
        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1414
        let out = filter(&a, &mask1).unwrap();
1415
        assert_eq!(out.as_ref(), &a.slice(0, 2));
1416
    }
1417
1418
    #[test]
1419
    fn test_filter_record_batch_no_columns() {
1420
        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1421
        let options = RecordBatchOptions::default().with_row_count(Some(100));
1422
        let record_batch =
1423
            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1424
        let out = filter_record_batch(&record_batch, &pred).unwrap();
1425
1426
        assert_eq!(out.num_rows(), 2);
1427
    }
1428
1429
    #[test]
1430
    fn test_fast_path() {
1431
        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1432
1433
        // all true
1434
        let mask = BooleanArray::from(vec![true, true, true]);
1435
        let out = filter(&a, &mask).unwrap();
1436
        let b = out
1437
            .as_any()
1438
            .downcast_ref::<PrimitiveArray<Int64Type>>()
1439
            .unwrap();
1440
        assert_eq!(&a, b);
1441
1442
        // all false
1443
        let mask = BooleanArray::from(vec![false, false, false]);
1444
        let out = filter(&a, &mask).unwrap();
1445
        assert_eq!(out.len(), 0);
1446
        assert_eq!(out.data_type(), &DataType::Int64);
1447
    }
1448
1449
    #[test]
1450
    fn test_slices() {
1451
        // takes up 2 u64s
1452
        let bools = std::iter::repeat_n(true, 10)
1453
            .chain(std::iter::repeat_n(false, 30))
1454
            .chain(std::iter::repeat_n(true, 20))
1455
            .chain(std::iter::repeat_n(false, 17))
1456
            .chain(std::iter::repeat_n(true, 4));
1457
1458
        let bool_array: BooleanArray = bools.map(Some).collect();
1459
1460
        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1461
        let expected = vec![(0, 10), (40, 60), (77, 81)];
1462
        assert_eq!(slices, expected);
1463
1464
        // slice with offset and truncated len
1465
        let len = bool_array.len();
1466
        let sliced_array = bool_array.slice(7, len - 10);
1467
        let sliced_array = sliced_array
1468
            .as_any()
1469
            .downcast_ref::<BooleanArray>()
1470
            .unwrap();
1471
        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1472
        let expected = vec![(0, 3), (33, 53), (70, 71)];
1473
        assert_eq!(slices, expected);
1474
    }
1475
1476
    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1477
        let mut rng = rng();
1478
1479
        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1480
            .take(mask_len)
1481
            .collect();
1482
1483
        let buffer = Buffer::from_iter(bools.iter().cloned());
1484
1485
        let truncated_length = mask_len - offset - truncate;
1486
1487
        let data = ArrayDataBuilder::new(DataType::Boolean)
1488
            .len(truncated_length)
1489
            .offset(offset)
1490
            .add_buffer(buffer)
1491
            .build()
1492
            .unwrap();
1493
1494
        let filter = BooleanArray::from(data);
1495
1496
        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1497
            .flat_map(|(start, end)| start..end)
1498
            .collect();
1499
1500
        let count = filter_count(&filter);
1501
        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1502
1503
        let expected_bits: Vec<_> = bools
1504
            .iter()
1505
            .skip(offset)
1506
            .take(truncated_length)
1507
            .enumerate()
1508
            .flat_map(|(idx, v)| v.then(|| idx))
1509
            .collect();
1510
1511
        assert_eq!(slice_bits, expected_bits);
1512
        assert_eq!(index_bits, expected_bits);
1513
    }
1514
1515
    #[test]
1516
    #[cfg_attr(miri, ignore)]
1517
    fn fuzz_test_slices_iterator() {
1518
        let mut rng = rng();
1519
1520
        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1521
        for _ in 0..100 {
1522
            let mask_len = rng.random_range(0..1024);
1523
            let max_offset = 64.min(mask_len);
1524
            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1525
1526
            let max_truncate = 128.min(mask_len - offset);
1527
            let truncate = uusize
1528
                .sample(&mut rng)
1529
                .checked_rem(max_truncate)
1530
                .unwrap_or(0);
1531
1532
            test_slices_fuzz(mask_len, offset, truncate);
1533
        }
1534
1535
        test_slices_fuzz(64, 0, 0);
1536
        test_slices_fuzz(64, 8, 0);
1537
        test_slices_fuzz(64, 8, 8);
1538
        test_slices_fuzz(32, 8, 8);
1539
        test_slices_fuzz(32, 5, 9);
1540
    }
1541
1542
    /// Filters `values` by `predicate` using standard rust iterators
1543
    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1544
        values
1545
            .into_iter()
1546
            .zip(predicate)
1547
            .filter(|(_, x)| **x)
1548
            .map(|(a, _)| a)
1549
            .collect()
1550
    }
1551
1552
    /// Generates an array of length `len` with `valid_percent` non-null values
1553
    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1554
    where
1555
        StandardUniform: Distribution<T>,
1556
    {
1557
        let mut rng = rng();
1558
        (0..len)
1559
            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1560
            .collect()
1561
    }
1562
1563
    /// Generates an array of length `len` with `valid_percent` non-null values
1564
    fn gen_strings(
1565
        len: usize,
1566
        valid_percent: f64,
1567
        str_len_range: std::ops::Range<usize>,
1568
    ) -> Vec<Option<String>> {
1569
        let mut rng = rng();
1570
        (0..len)
1571
            .map(|_| {
1572
                rng.random_bool(valid_percent).then(|| {
1573
                    let len = rng.random_range(str_len_range.clone());
1574
                    (0..len)
1575
                        .map(|_| char::from(rng.sample(Alphanumeric)))
1576
                        .collect()
1577
                })
1578
            })
1579
            .collect()
1580
    }
1581
1582
    /// Returns an iterator that calls `Option::as_deref` on each item
1583
    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1584
        src.iter().map(|x| x.as_deref())
1585
    }
1586
1587
    #[test]
1588
    #[cfg_attr(miri, ignore)]
1589
    fn fuzz_filter() {
1590
        let mut rng = rng();
1591
1592
        for i in 0..100 {
1593
            let filter_percent = match i {
1594
                0..=4 => 1.,
1595
                5..=10 => 0.,
1596
                _ => rng.random_range(0.0..1.0),
1597
            };
1598
1599
            let valid_percent = rng.random_range(0.0..1.0);
1600
1601
            let array_len = rng.random_range(32..256);
1602
            let array_offset = rng.random_range(0..10);
1603
1604
            // Construct a predicate
1605
            let filter_offset = rng.random_range(0..10);
1606
            let filter_truncate = rng.random_range(0..10);
1607
            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1608
                .take(array_len + filter_offset - filter_truncate)
1609
                .collect();
1610
1611
            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1612
1613
            // Offset predicate
1614
            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1615
            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1616
            let bools = &bools[filter_offset..];
1617
1618
            // Test i32
1619
            let values = gen_primitive(array_len + array_offset, valid_percent);
1620
            let src = Int32Array::from_iter(values.iter().cloned());
1621
1622
            let src = src.slice(array_offset, array_len);
1623
            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1624
            let values = &values[array_offset..];
1625
1626
            let filtered = filter(src, predicate).unwrap();
1627
            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1628
            let actual: Vec<_> = array.iter().collect();
1629
1630
            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1631
1632
            // Test string
1633
            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1634
            let src = StringArray::from_iter(as_deref(&strings));
1635
1636
            let src = src.slice(array_offset, array_len);
1637
            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1638
1639
            let filtered = filter(src, predicate).unwrap();
1640
            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1641
            let actual: Vec<_> = array.iter().collect();
1642
1643
            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1644
            assert_eq!(actual, expected_strings);
1645
1646
            // Test string dictionary
1647
            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1648
1649
            let src = src.slice(array_offset, array_len);
1650
            let src = src
1651
                .as_any()
1652
                .downcast_ref::<DictionaryArray<Int32Type>>()
1653
                .unwrap();
1654
1655
            let filtered = filter(src, predicate).unwrap();
1656
1657
            let array = filtered
1658
                .as_any()
1659
                .downcast_ref::<DictionaryArray<Int32Type>>()
1660
                .unwrap();
1661
1662
            let values = array
1663
                .values()
1664
                .as_any()
1665
                .downcast_ref::<StringArray>()
1666
                .unwrap();
1667
1668
            let actual: Vec<_> = array
1669
                .keys()
1670
                .iter()
1671
                .map(|key| key.map(|key| values.value(key as usize)))
1672
                .collect();
1673
1674
            assert_eq!(actual, expected_strings);
1675
        }
1676
    }
1677
1678
    #[test]
1679
    fn test_filter_map() {
1680
        let mut builder =
1681
            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1682
        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1683
        builder.keys().append_value("key1");
1684
        builder.values().append_value(1);
1685
        builder.append(true).unwrap();
1686
        builder.keys().append_value("key2");
1687
        builder.keys().append_value("key3");
1688
        builder.values().append_value(2);
1689
        builder.values().append_value(3);
1690
        builder.append(true).unwrap();
1691
        builder.append(false).unwrap();
1692
        builder.keys().append_value("key1");
1693
        builder.values().append_value(1);
1694
        builder.append(true).unwrap();
1695
        let maparray = Arc::new(builder.finish()) as ArrayRef;
1696
1697
        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1698
            .into_iter()
1699
            .collect::<BooleanArray>();
1700
        let got = filter(&maparray, &indices).unwrap();
1701
1702
        let mut builder =
1703
            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1704
        builder.keys().append_value("key1");
1705
        builder.values().append_value(1);
1706
        builder.append(true).unwrap();
1707
        builder.keys().append_value("key1");
1708
        builder.values().append_value(1);
1709
        builder.append(true).unwrap();
1710
        let expected = Arc::new(builder.finish()) as ArrayRef;
1711
1712
        assert_eq!(&expected, &got);
1713
    }
1714
1715
    #[test]
1716
    fn test_filter_fixed_size_list_arrays() {
1717
        let value_data = ArrayData::builder(DataType::Int32)
1718
            .len(9)
1719
            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8]))
1720
            .build()
1721
            .unwrap();
1722
        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 3, false);
1723
        let list_data = ArrayData::builder(list_data_type)
1724
            .len(3)
1725
            .add_child_data(value_data)
1726
            .build()
1727
            .unwrap();
1728
        let array = FixedSizeListArray::from(list_data);
1729
1730
        let filter_array = BooleanArray::from(vec![true, false, false]);
1731
1732
        let c = filter(&array, &filter_array).unwrap();
1733
        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1734
1735
        assert_eq!(filtered.len(), 1);
1736
1737
        let list = filtered.value(0);
1738
        assert_eq!(
1739
            &[0, 1, 2],
1740
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1741
        );
1742
1743
        let filter_array = BooleanArray::from(vec![true, false, true]);
1744
1745
        let c = filter(&array, &filter_array).unwrap();
1746
        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1747
1748
        assert_eq!(filtered.len(), 2);
1749
1750
        let list = filtered.value(0);
1751
        assert_eq!(
1752
            &[0, 1, 2],
1753
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1754
        );
1755
        let list = filtered.value(1);
1756
        assert_eq!(
1757
            &[6, 7, 8],
1758
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1759
        );
1760
    }
1761
1762
    #[test]
1763
    fn test_filter_fixed_size_list_arrays_with_null() {
1764
        let value_data = ArrayData::builder(DataType::Int32)
1765
            .len(10)
1766
            .add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
1767
            .build()
1768
            .unwrap();
1769
1770
        // Set null buts for the nested array:
1771
        //  [[0, 1], null, null, [6, 7], [8, 9]]
1772
        // 01011001 00000001
1773
        let mut null_bits: [u8; 1] = [0; 1];
1774
        bit_util::set_bit(&mut null_bits, 0);
1775
        bit_util::set_bit(&mut null_bits, 3);
1776
        bit_util::set_bit(&mut null_bits, 4);
1777
1778
        let list_data_type = DataType::new_fixed_size_list(DataType::Int32, 2, false);
1779
        let list_data = ArrayData::builder(list_data_type)
1780
            .len(5)
1781
            .add_child_data(value_data)
1782
            .null_bit_buffer(Some(Buffer::from(null_bits)))
1783
            .build()
1784
            .unwrap();
1785
        let array = FixedSizeListArray::from(list_data);
1786
1787
        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1788
1789
        let c = filter(&array, &filter_array).unwrap();
1790
        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1791
1792
        assert_eq!(filtered.len(), 3);
1793
1794
        let list = filtered.value(0);
1795
        assert_eq!(
1796
            &[0, 1],
1797
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1798
        );
1799
        assert!(filtered.is_null(1));
1800
        let list = filtered.value(2);
1801
        assert_eq!(
1802
            &[6, 7],
1803
            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1804
        );
1805
    }
1806
1807
    fn test_filter_union_array(array: UnionArray) {
1808
        let filter_array = BooleanArray::from(vec![true, false, false]);
1809
        let c = filter(&array, &filter_array).unwrap();
1810
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1811
1812
        let mut builder = UnionBuilder::new_dense();
1813
        builder.append::<Int32Type>("A", 1).unwrap();
1814
        let expected_array = builder.build().unwrap();
1815
1816
        compare_union_arrays(filtered, &expected_array);
1817
1818
        let filter_array = BooleanArray::from(vec![true, false, true]);
1819
        let c = filter(&array, &filter_array).unwrap();
1820
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1821
1822
        let mut builder = UnionBuilder::new_dense();
1823
        builder.append::<Int32Type>("A", 1).unwrap();
1824
        builder.append::<Int32Type>("A", 34).unwrap();
1825
        let expected_array = builder.build().unwrap();
1826
1827
        compare_union_arrays(filtered, &expected_array);
1828
1829
        let filter_array = BooleanArray::from(vec![true, true, false]);
1830
        let c = filter(&array, &filter_array).unwrap();
1831
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1832
1833
        let mut builder = UnionBuilder::new_dense();
1834
        builder.append::<Int32Type>("A", 1).unwrap();
1835
        builder.append::<Float64Type>("B", 3.2).unwrap();
1836
        let expected_array = builder.build().unwrap();
1837
1838
        compare_union_arrays(filtered, &expected_array);
1839
    }
1840
1841
    #[test]
1842
    fn test_filter_union_array_dense() {
1843
        let mut builder = UnionBuilder::new_dense();
1844
        builder.append::<Int32Type>("A", 1).unwrap();
1845
        builder.append::<Float64Type>("B", 3.2).unwrap();
1846
        builder.append::<Int32Type>("A", 34).unwrap();
1847
        let array = builder.build().unwrap();
1848
1849
        test_filter_union_array(array);
1850
    }
1851
1852
    #[test]
1853
    fn test_filter_run_union_array_dense() {
1854
        let mut builder = UnionBuilder::new_dense();
1855
        builder.append::<Int32Type>("A", 1).unwrap();
1856
        builder.append::<Int32Type>("A", 3).unwrap();
1857
        builder.append::<Int32Type>("A", 34).unwrap();
1858
        let array = builder.build().unwrap();
1859
1860
        let filter_array = BooleanArray::from(vec![true, true, false]);
1861
        let c = filter(&array, &filter_array).unwrap();
1862
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1863
1864
        let mut builder = UnionBuilder::new_dense();
1865
        builder.append::<Int32Type>("A", 1).unwrap();
1866
        builder.append::<Int32Type>("A", 3).unwrap();
1867
        let expected = builder.build().unwrap();
1868
1869
        assert_eq!(filtered.to_data(), expected.to_data());
1870
    }
1871
1872
    #[test]
1873
    fn test_filter_union_array_dense_with_nulls() {
1874
        let mut builder = UnionBuilder::new_dense();
1875
        builder.append::<Int32Type>("A", 1).unwrap();
1876
        builder.append::<Float64Type>("B", 3.2).unwrap();
1877
        builder.append_null::<Float64Type>("B").unwrap();
1878
        builder.append::<Int32Type>("A", 34).unwrap();
1879
        let array = builder.build().unwrap();
1880
1881
        let filter_array = BooleanArray::from(vec![true, true, false, false]);
1882
        let c = filter(&array, &filter_array).unwrap();
1883
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1884
1885
        let mut builder = UnionBuilder::new_dense();
1886
        builder.append::<Int32Type>("A", 1).unwrap();
1887
        builder.append::<Float64Type>("B", 3.2).unwrap();
1888
        let expected_array = builder.build().unwrap();
1889
1890
        compare_union_arrays(filtered, &expected_array);
1891
1892
        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1893
        let c = filter(&array, &filter_array).unwrap();
1894
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1895
1896
        let mut builder = UnionBuilder::new_dense();
1897
        builder.append::<Int32Type>("A", 1).unwrap();
1898
        builder.append_null::<Float64Type>("B").unwrap();
1899
        let expected_array = builder.build().unwrap();
1900
1901
        compare_union_arrays(filtered, &expected_array);
1902
    }
1903
1904
    #[test]
1905
    fn test_filter_union_array_sparse() {
1906
        let mut builder = UnionBuilder::new_sparse();
1907
        builder.append::<Int32Type>("A", 1).unwrap();
1908
        builder.append::<Float64Type>("B", 3.2).unwrap();
1909
        builder.append::<Int32Type>("A", 34).unwrap();
1910
        let array = builder.build().unwrap();
1911
1912
        test_filter_union_array(array);
1913
    }
1914
1915
    #[test]
1916
    fn test_filter_union_array_sparse_with_nulls() {
1917
        let mut builder = UnionBuilder::new_sparse();
1918
        builder.append::<Int32Type>("A", 1).unwrap();
1919
        builder.append::<Float64Type>("B", 3.2).unwrap();
1920
        builder.append_null::<Float64Type>("B").unwrap();
1921
        builder.append::<Int32Type>("A", 34).unwrap();
1922
        let array = builder.build().unwrap();
1923
1924
        let filter_array = BooleanArray::from(vec![true, false, true, false]);
1925
        let c = filter(&array, &filter_array).unwrap();
1926
        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1927
1928
        let mut builder = UnionBuilder::new_sparse();
1929
        builder.append::<Int32Type>("A", 1).unwrap();
1930
        builder.append_null::<Float64Type>("B").unwrap();
1931
        let expected_array = builder.build().unwrap();
1932
1933
        compare_union_arrays(filtered, &expected_array);
1934
    }
1935
1936
    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
1937
        assert_eq!(union1.len(), union2.len());
1938
1939
        for i in 0..union1.len() {
1940
            let type_id = union1.type_id(i);
1941
1942
            let slot1 = union1.value(i);
1943
            let slot2 = union2.value(i);
1944
1945
            assert_eq!(slot1.is_null(0), slot2.is_null(0));
1946
1947
            if !slot1.is_null(0) && !slot2.is_null(0) {
1948
                match type_id {
1949
                    0 => {
1950
                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
1951
                        assert_eq!(slot1.len(), 1);
1952
                        let value1 = slot1.value(0);
1953
1954
                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
1955
                        assert_eq!(slot2.len(), 1);
1956
                        let value2 = slot2.value(0);
1957
                        assert_eq!(value1, value2);
1958
                    }
1959
                    1 => {
1960
                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
1961
                        assert_eq!(slot1.len(), 1);
1962
                        let value1 = slot1.value(0);
1963
1964
                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
1965
                        assert_eq!(slot2.len(), 1);
1966
                        let value2 = slot2.value(0);
1967
                        assert_eq!(value1, value2);
1968
                    }
1969
                    _ => unreachable!(),
1970
                }
1971
            }
1972
        }
1973
    }
1974
1975
    #[test]
1976
    fn test_filter_struct() {
1977
        let predicate = BooleanArray::from(vec![true, false, true, false]);
1978
1979
        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
1980
        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
1981
1982
        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1983
        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
1984
1985
        let null_mask = NullBuffer::from(vec![true, false, false, true]);
1986
        let null_mask_filtered = NullBuffer::from(vec![true, false]);
1987
1988
        let a_field = Field::new("a", DataType::Utf8, false);
1989
        let b_field = Field::new("b", DataType::Int32, false);
1990
1991
        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
1992
        let expected =
1993
            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
1994
1995
        let result = filter(&array, &predicate).unwrap();
1996
1997
        assert_eq!(result.to_data(), expected.to_data());
1998
1999
        let array = StructArray::new(
2000
            vec![a_field.clone()].into(),
2001
            vec![a.clone()],
2002
            Some(null_mask.clone()),
2003
        );
2004
        let expected = StructArray::new(
2005
            vec![a_field.clone()].into(),
2006
            vec![a_filtered.clone()],
2007
            Some(null_mask_filtered.clone()),
2008
        );
2009
2010
        let result = filter(&array, &predicate).unwrap();
2011
2012
        assert_eq!(result.to_data(), expected.to_data());
2013
2014
        let array = StructArray::new(
2015
            vec![a_field.clone(), b_field.clone()].into(),
2016
            vec![a.clone(), b.clone()],
2017
            None,
2018
        );
2019
        let expected = StructArray::new(
2020
            vec![a_field.clone(), b_field.clone()].into(),
2021
            vec![a_filtered.clone(), b_filtered.clone()],
2022
            None,
2023
        );
2024
2025
        let result = filter(&array, &predicate).unwrap();
2026
2027
        assert_eq!(result.to_data(), expected.to_data());
2028
2029
        let array = StructArray::new(
2030
            vec![a_field.clone(), b_field.clone()].into(),
2031
            vec![a.clone(), b.clone()],
2032
            Some(null_mask.clone()),
2033
        );
2034
2035
        let expected = StructArray::new(
2036
            vec![a_field.clone(), b_field.clone()].into(),
2037
            vec![a_filtered.clone(), b_filtered.clone()],
2038
            Some(null_mask_filtered.clone()),
2039
        );
2040
2041
        let result = filter(&array, &predicate).unwrap();
2042
2043
        assert_eq!(result.to_data(), expected.to_data());
2044
    }
2045
2046
    #[test]
2047
    fn test_filter_empty_struct() {
2048
        /*
2049
            "a": {
2050
                "b": int64,
2051
                "c": {}
2052
            },
2053
        */
2054
        let fields = arrow_schema::Field::new(
2055
            "a",
2056
            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2057
                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2058
                arrow_schema::Field::new(
2059
                    "c",
2060
                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2061
                    true,
2062
                ),
2063
            ])),
2064
            true,
2065
        );
2066
2067
        /* Test record
2068
            {"a":{"c": {}}}
2069
            {"a":{"c": {}}}
2070
            {"a":{"c": {}}}
2071
        */
2072
2073
        // Create the record batch with the nested struct array
2074
        let schema = Arc::new(Schema::new(vec![fields]));
2075
2076
        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2077
        let c = Arc::new(StructArray::new_empty_fields(
2078
            3,
2079
            Some(NullBuffer::from(vec![true, true, true])),
2080
        ));
2081
        let a = StructArray::new(
2082
            vec![
2083
                Field::new("b", DataType::Int64, true),
2084
                Field::new("c", DataType::Struct(Fields::empty()), true),
2085
            ]
2086
            .into(),
2087
            vec![b.clone(), c.clone()],
2088
            Some(NullBuffer::from(vec![true, true, true])),
2089
        );
2090
        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2091
        println!("{record_batch:?}");
2092
2093
        // Apply the filter
2094
        let predicate = BooleanArray::from(vec![true, false, true]);
2095
        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2096
2097
        // The filtered batch should have 2 rows (the 1st and 3rd)
2098
        assert_eq!(filtered_batch.num_rows(), 2);
2099
    }
2100
}