Coverage Report

Created: 2025-11-17 14:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-select/src/take.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines take kernel for [Array]
19
20
use std::sync::Arc;
21
22
use arrow_array::builder::{BufferBuilder, UInt32Builder};
23
use arrow_array::cast::AsArray;
24
use arrow_array::types::*;
25
use arrow_array::*;
26
use arrow_buffer::{
27
    ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
28
    bit_util,
29
};
30
use arrow_data::ArrayDataBuilder;
31
use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
32
33
use num_traits::{One, Zero};
34
35
/// Take elements by index from [Array], creating a new [Array] from those indexes.
36
///
37
/// ```text
38
/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
39
/// │        A        │      │    0    │                              │        A        │
40
/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
41
/// │        D        │      │    2    │                              │        B        │
42
/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
43
/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
44
/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
45
/// │        C        │      │    1    │                              │        D        │
46
/// ├─────────────────┤      └─────────┘                              └─────────────────┘
47
/// │        E        │
48
/// └─────────────────┘
49
///    values array          indices array                              result
50
/// ```
51
///
52
/// For selecting values by index from multiple arrays see [`crate::interleave`]
53
///
54
/// Note that this kernel, similar to other kernels in this crate,
55
/// will avoid allocating where not necessary. Consequently
56
/// the returned array may share buffers with the inputs
57
///
58
/// # Errors
59
/// This function errors whenever:
60
/// * An index cannot be casted to `usize` (typically 32 bit architectures)
61
/// * An index is out of bounds and `options` is set to check bounds.
62
///
63
/// # Safety
64
///
65
/// When `options` is not set to check bounds, taking indexes after `len` will panic.
66
///
67
/// # See also
68
/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
69
///   the results into a single array.
70
///
71
/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
72
///
73
/// # Examples
74
/// ```
75
/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
76
/// # use arrow_select::take::take;
77
/// let values = StringArray::from(vec!["zero", "one", "two"]);
78
///
79
/// // Take items at index 2, and 1:
80
/// let indices = UInt32Array::from(vec![2, 1]);
81
/// let taken = take(&values, &indices, None).unwrap();
82
/// let taken = taken.as_string::<i32>();
83
///
84
/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
85
/// ```
86
80
pub fn take(
87
80
    values: &dyn Array,
88
80
    indices: &dyn Array,
89
80
    options: Option<TakeOptions>,
90
80
) -> Result<ArrayRef, ArrowError> {
91
80
    let options = options.unwrap_or_default();
92
80
    downcast_integer_array!(
93
        indices => {
94
0
            if options.check_bounds {
95
0
                check_bounds(values.len(), indices)?;
96
0
            }
97
0
            let indices = indices.to_indices();
98
0
            take_impl(values, &indices)
99
        },
100
0
        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
101
    )
102
80
}
103
104
/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
105
/// [`Vec<ArrayRef>`] from those indices.
106
///
107
/// ```text
108
/// ┌────────┬────────┐
109
/// │        │        │           ┌────────┐                                ┌────────┬────────┐
110
/// │   A    │   1    │           │        │                                │        │        │
111
/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
112
/// │        │        │           ├────────┤                                ├────────┼────────┤
113
/// │   D    │   4    │           │        │                                │        │        │
114
/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
115
/// │        │        │           ├────────┤                                ├────────┼────────┤
116
/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
117
/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
118
/// │        │        │           ├────────┤                                ├────────┼────────┤
119
/// │   C    │   3    │           │        │                                │        │        │
120
/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
121
/// │        │        │           └────────┘                                └────────┼────────┘
122
/// │   E    │   5    │
123
/// └────────┴────────┘
124
///    values arrays             indices array                                      result
125
/// ```
126
///
127
/// # Errors
128
/// This function errors whenever:
129
/// * An index cannot be casted to `usize` (typically 32 bit architectures)
130
/// * An index is out of bounds and `options` is set to check bounds.
131
///
132
/// # Safety
133
///
134
/// When `options` is not set to check bounds, taking indexes after `len` will panic.
135
///
136
/// # Examples
137
/// ```
138
/// # use std::sync::Arc;
139
/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
140
/// # use arrow_select::take::{take, take_arrays};
141
/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
142
/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
143
///
144
/// // Take items at index 2, and 1:
145
/// let indices = UInt32Array::from(vec![2, 1]);
146
/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
147
/// let taken_string = taken_arrays[0].as_string::<i32>();
148
/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
149
/// let taken_values = taken_arrays[1].as_primitive();
150
/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
151
/// ```
152
0
pub fn take_arrays(
153
0
    arrays: &[ArrayRef],
154
0
    indices: &dyn Array,
155
0
    options: Option<TakeOptions>,
156
0
) -> Result<Vec<ArrayRef>, ArrowError> {
157
0
    arrays
158
0
        .iter()
159
0
        .map(|array| take(array.as_ref(), indices, options.clone()))
160
0
        .collect()
161
0
}
162
163
/// Verifies that the non-null values of `indices` are all `< len`
164
2
fn check_bounds<T: ArrowPrimitiveType>(
165
2
    len: usize,
166
2
    indices: &PrimitiveArray<T>,
167
2
) -> Result<(), ArrowError> {
168
2
    if indices.null_count() > 0 {
169
6
        
indices2
.iter().flatten().
try_for_each2
(|index| {
170
6
            let ix = index
171
6
                .to_usize()
172
6
                .ok_or_else(|| ArrowError::ComputeError(
"Cast to usize failed"0
.
to_string0
()))
?0
;
173
6
            if ix >= len {
174
2
                return Err(ArrowError::ComputeError(format!(
175
2
                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
176
2
                )));
177
4
            }
178
4
            Ok(())
179
6
        })
180
    } else {
181
0
        indices.values().iter().try_for_each(|index| {
182
0
            let ix = index
183
0
                .to_usize()
184
0
                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
185
0
            if ix >= len {
186
0
                return Err(ArrowError::ComputeError(format!(
187
0
                    "Array index out of bounds, cannot get item at index {ix} from {len} entries"
188
0
                )));
189
0
            }
190
0
            Ok(())
191
0
        })
192
    }
193
2
}
194
195
#[inline(never)]
196
104
fn take_impl<IndexType: ArrowPrimitiveType>(
197
104
    values: &dyn Array,
198
104
    indices: &PrimitiveArray<IndexType>,
199
104
) -> Result<ArrayRef, ArrowError> {
200
104
    downcast_primitive_array! {
201
4
        values => Ok(Arc::new(take_primitive(values, indices)
?0
)),
202
        DataType::Boolean => {
203
7
            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
204
7
            Ok(Arc::new(take_boolean(values, indices)))
205
        }
206
        DataType::Utf8 => {
207
10
            Ok(
Arc::new9
(take_bytes(values.as_string::<i32>(), indices)
?1
))
208
        }
209
        DataType::LargeUtf8 => {
210
1
            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)
?0
))
211
        }
212
        DataType::Utf8View => {
213
1
            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)
?0
))
214
        }
215
        DataType::List(_) => {
216
4
            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)
?0
))
217
        }
218
        DataType::LargeList(_) => {
219
3
            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)
?0
))
220
        }
221
        DataType::ListView(_) => {
222
4
            Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)
?0
))
223
        }
224
        DataType::LargeListView(_) => {
225
4
            Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)
?0
))
226
        }
227
1
        DataType::FixedSizeList(_, length) => {
228
1
            let values = values
229
1
                .as_any()
230
1
                .downcast_ref::<FixedSizeListArray>()
231
1
                .unwrap();
232
1
            Ok(Arc::new(take_fixed_size_list(
233
1
                values,
234
1
                indices,
235
1
                *length as u32,
236
0
            )?))
237
        }
238
        DataType::Map(_, _) => {
239
1
            let list_arr = ListArray::from(values.as_map().clone());
240
1
            let list_data = take_list::<_, Int32Type>(&list_arr, indices)
?0
;
241
1
            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
242
1
            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
243
        }
244
5
        DataType::Struct(fields) => {
245
5
            let array: &StructArray = values.as_struct();
246
5
            let arrays  = array
247
5
                .columns()
248
5
                .iter()
249
8
                .
map5
(|a| take_impl(a.as_ref(), indices))
250
5
                .collect::<Result<Vec<ArrayRef>, _>>()
?0
;
251
5
            let fields: Vec<(FieldRef, ArrayRef)> =
252
5
                fields.iter().cloned().zip(arrays).collect();
253
254
            // Create the null bit buffer.
255
5
            let is_valid: Buffer = indices
256
5
                .iter()
257
25
                .
map5
(|index| {
258
25
                    if let Some(
index23
) = index {
259
23
                        array.is_valid(index.to_usize().unwrap())
260
                    } else {
261
2
                        false
262
                    }
263
25
                })
264
5
                .collect();
265
266
5
            if fields.is_empty() {
267
1
                let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
268
1
                Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
269
            } else {
270
4
                Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
271
            }
272
        }
273
1
        DataType::Dictionary(_, _) => downcast_dictionary_array! {
274
0
            values => Ok(Arc::new(take_dict(values, indices)?)),
275
0
            t => unimplemented!("Take not supported for dictionary type {:?}", t)
276
        }
277
0
        DataType::RunEndEncoded(_, _) => downcast_run_array! {
278
0
            values => Ok(Arc::new(take_run(values, indices)?)),
279
0
            t => unimplemented!("Take not supported for run type {:?}", t)
280
        }
281
        DataType::Binary => {
282
0
            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
283
        }
284
        DataType::LargeBinary => {
285
0
            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
286
        }
287
        DataType::BinaryView => {
288
1
            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)
?0
))
289
        }
290
0
        DataType::FixedSizeBinary(size) => {
291
0
            let values = values
292
0
                .as_any()
293
0
                .downcast_ref::<FixedSizeBinaryArray>()
294
0
                .unwrap();
295
0
            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
296
        }
297
        DataType::Null => {
298
            // Take applied to a null array produces a null array.
299
2
            if values.len() >= indices.len() {
300
                // If the existing null array is as big as the indices, we can use a slice of it
301
                // to avoid allocating a new null array.
302
1
                Ok(values.slice(0, indices.len()))
303
            } else {
304
                // If the existing null array isn't big enough, create a new one.
305
1
                Ok(new_null_array(&DataType::Null, indices.len()))
306
            }
307
        }
308
1
        DataType::Union(fields, UnionMode::Sparse) => {
309
1
            let mut children = Vec::with_capacity(fields.len());
310
1
            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
311
1
            let type_ids = take_native(values.type_ids(), indices);
312
2
            for (type_id, _field) in 
fields1
.
iter1
() {
313
2
                let values = values.child(type_id);
314
2
                let values = take_impl(values, indices)
?0
;
315
2
                children.push(values);
316
            }
317
1
            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)
?0
;
318
1
            Ok(Arc::new(array))
319
        }
320
3
        DataType::Union(fields, UnionMode::Dense) => {
321
3
            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
322
323
3
            let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)
?0
;
324
3
            let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)
?0
;
325
326
3
            let children = fields.iter()
327
5
                .
map3
(|(field_type_id, _)| {
328
23
                    let 
mask5
=
BooleanArray::from_unary5
(
&type_ids5
, |value_type_id| value_type_id == field_type_id);
329
330
5
                    let indices = crate::filter::filter(&offsets, &mask)
?0
;
331
332
5
                    let values = values.child(field_type_id);
333
334
5
                    take_impl(values, indices.as_primitive::<Int32Type>())
335
5
                })
336
3
                .collect::<Result<_, _>>()
?0
;
337
338
3
            let mut child_offsets = [0; 128];
339
340
3
            let offsets = type_ids.values()
341
3
                .iter()
342
13
                .
map3
(|&i| {
343
13
                    let offset = child_offsets[i as usize];
344
345
13
                    child_offsets[i as usize] += 1;
346
347
13
                    offset
348
13
                })
349
3
                .collect();
350
351
3
            let (_, type_ids, _) = type_ids.into_parts();
352
353
3
            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)
?0
;
354
355
3
            Ok(Arc::new(array))
356
        }
357
0
        t => unimplemented!("Take not supported for data type {:?}", t)
358
    }
359
104
}
360
361
/// Options that define how `take` should behave
362
#[derive(Clone, Debug, Default)]
363
pub struct TakeOptions {
364
    /// Perform bounds check before taking indices from values.
365
    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
366
    /// If not enabled, and indices exceed bounds, the kernel will panic.
367
    pub check_bounds: bool,
368
}
369
370
#[inline(always)]
371
0
fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize, ArrowError> {
372
0
    index
373
0
        .to_usize()
374
0
        .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
375
0
}
376
377
/// `take` implementation for all primitive arrays
378
///
379
/// This checks if an `indices` slot is populated, and gets the value from `values`
380
///  as the populated index.
381
/// If the `indices` slot is null, a null value is returned.
382
/// For example, given:
383
///     values:  [1, 2, 3, null, 5]
384
///     indices: [0, null, 4, 3]
385
/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
386
56
fn take_primitive<T, I>(
387
56
    values: &PrimitiveArray<T>,
388
56
    indices: &PrimitiveArray<I>,
389
56
) -> Result<PrimitiveArray<T>, ArrowError>
390
56
where
391
56
    T: ArrowPrimitiveType,
392
56
    I: ArrowPrimitiveType,
393
{
394
56
    let values_buf = take_native(values.values(), indices);
395
56
    let nulls = take_nulls(values.nulls(), indices);
396
56
    Ok(PrimitiveArray::try_new(values_buf, nulls)
?0
.with_data_type(values.data_type().clone()))
397
56
}
398
399
#[inline(never)]
400
83
fn take_nulls<I: ArrowPrimitiveType>(
401
83
    values: Option<&NullBuffer>,
402
83
    indices: &PrimitiveArray<I>,
403
83
) -> Option<NullBuffer> {
404
83
    match values.filter(|n| 
n60
.
null_count60
() > 0) {
405
60
        Some(n) => {
406
60
            let buffer = take_bits(n.inner(), indices);
407
60
            Some(NullBuffer::new(buffer)).filter(|n| n.null_count() > 0)
408
        }
409
23
        None => indices.nulls().cloned(),
410
    }
411
83
}
412
413
#[inline(never)]
414
81
fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
415
81
    values: &[T],
416
81
    indices: &PrimitiveArray<I>,
417
81
) -> ScalarBuffer<T> {
418
81
    match indices.nulls().filter(|n| 
n49
.
null_count49
() > 0) {
419
49
        Some(n) => indices
420
49
            .values()
421
49
            .iter()
422
49
            .enumerate()
423
235
            .
map49
(|(idx, index)| match values.get(index.as_usize()) {
424
232
                Some(v) => *v,
425
3
                None => match n.is_null(idx) {
426
3
                    true => T::default(),
427
0
                    false => panic!("Out-of-bounds index {index:?}"),
428
                },
429
235
            })
430
49
            .collect(),
431
32
        None => indices
432
32
            .values()
433
32
            .iter()
434
171
            .
map32
(|index| values[index.as_usize()])
435
32
            .collect(),
436
    }
437
81
}
438
439
#[inline(never)]
440
67
fn take_bits<I: ArrowPrimitiveType>(
441
67
    values: &BooleanBuffer,
442
67
    indices: &PrimitiveArray<I>,
443
67
) -> BooleanBuffer {
444
67
    let len = indices.len();
445
446
67
    match indices.nulls().filter(|n| 
n49
.
null_count49
() > 0) {
447
49
        Some(nulls) => {
448
49
            let mut output_buffer = MutableBuffer::new_null(len);
449
49
            let output_slice = output_buffer.as_slice_mut();
450
182
            
nulls49
.
valid_indices49
().
for_each49
(|idx| {
451
182
                if values.value(indices.value(idx).as_usize()) {
452
125
                    bit_util::set_bit(output_slice, idx);
453
125
                
}57
454
182
            });
455
49
            BooleanBuffer::new(output_buffer.into(), 0, len)
456
        }
457
        None => {
458
120
            
BooleanBuffer::collect_bool18
(
len18
, |idx: usize| {
459
                // SAFETY: idx<indices.len()
460
120
                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
461
120
            })
462
        }
463
    }
464
67
}
465
466
/// `take` implementation for boolean arrays
467
7
fn take_boolean<IndexType: ArrowPrimitiveType>(
468
7
    values: &BooleanArray,
469
7
    indices: &PrimitiveArray<IndexType>,
470
7
) -> BooleanArray {
471
7
    let val_buf = take_bits(values.values(), indices);
472
7
    let null_buf = take_nulls(values.nulls(), indices);
473
7
    BooleanArray::new(val_buf, null_buf)
474
7
}
475
476
/// `take` implementation for string arrays
477
11
fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
478
11
    array: &GenericByteArray<T>,
479
11
    indices: &PrimitiveArray<IndexType>,
480
11
) -> Result<GenericByteArray<T>, ArrowError> {
481
11
    let mut offsets = Vec::with_capacity(indices.len() + 1);
482
11
    offsets.push(T::Offset::default());
483
484
11
    let input_offsets = array.value_offsets();
485
11
    let mut capacity = 0;
486
11
    let nulls = take_nulls(array.nulls(), indices);
487
488
11
    let (
offsets10
,
values10
) = if array.null_count() == 0 &&
indices.null_count() == 05
{
489
5
        offsets.reserve(indices.len());
490
82.5M
        for index in 
indices5
.
values5
() {
491
82.5M
            let index = index.as_usize();
492
82.5M
            capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
493
82.5M
            offsets.
push82.5M
(
494
82.5M
                T::Offset::from_usize(capacity)
495
82.5M
                    .ok_or_else(|| ArrowError::OffsetOverflowError(
capacity1
))
?1
,
496
            );
497
        }
498
4
        let mut values = Vec::with_capacity(capacity);
499
500
9
        for index in 
indices4
.
values4
() {
501
9
            values.extend_from_slice(array.value(index.as_usize()).as_ref());
502
9
        }
503
4
        (offsets, values)
504
6
    } else if indices.null_count() == 0 {
505
2
        offsets.reserve(indices.len());
506
8
        for index in 
indices2
.
values2
() {
507
8
            let index = index.as_usize();
508
8
            if array.is_valid(index) {
509
5
                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
510
5
            
}3
511
8
            offsets.push(
512
8
                T::Offset::from_usize(capacity)
513
8
                    .ok_or_else(|| ArrowError::OffsetOverflowError(
capacity0
))
?0
,
514
            );
515
        }
516
2
        let mut values = Vec::with_capacity(capacity);
517
518
8
        for index in 
indices2
.
values2
() {
519
8
            let index = index.as_usize();
520
8
            if array.is_valid(index) {
521
5
                values.extend_from_slice(array.value(index).as_ref());
522
5
            
}3
523
        }
524
2
        (offsets, values)
525
4
    } else if array.null_count() == 0 {
526
0
        offsets.reserve(indices.len());
527
0
        for (i, index) in indices.values().iter().enumerate() {
528
0
            let index = index.as_usize();
529
0
            if indices.is_valid(i) {
530
0
                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
531
0
            }
532
0
            offsets.push(
533
0
                T::Offset::from_usize(capacity)
534
0
                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
535
            );
536
        }
537
0
        let mut values = Vec::with_capacity(capacity);
538
539
0
        for (i, index) in indices.values().iter().enumerate() {
540
0
            if indices.is_valid(i) {
541
0
                values.extend_from_slice(array.value(index.as_usize()).as_ref());
542
0
            }
543
        }
544
0
        (offsets, values)
545
    } else {
546
4
        let nulls = nulls.as_ref().unwrap();
547
4
        offsets.reserve(indices.len());
548
18
        for (i, index) in 
indices.values().iter()4
.
enumerate4
() {
549
18
            let index = index.as_usize();
550
18
            if nulls.is_valid(i) {
551
9
                capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize();
552
9
            }
553
18
            offsets.push(
554
18
                T::Offset::from_usize(capacity)
555
18
                    .ok_or_else(|| ArrowError::OffsetOverflowError(
capacity0
))
?0
,
556
            );
557
        }
558
4
        let mut values = Vec::with_capacity(capacity);
559
560
18
        for (i, index) in 
indices.values().iter()4
.
enumerate4
() {
561
            // check index is valid before using index. The value in
562
            // NULL index slots may not be within bounds of array
563
18
            let index = index.as_usize();
564
18
            if nulls.is_valid(i) {
565
9
                values.extend_from_slice(array.value(index).as_ref());
566
9
            }
567
        }
568
4
        (offsets, values)
569
    };
570
571
10
    T::Offset::from_usize(values.len())
572
10
        .ok_or_else(|| ArrowError::OffsetOverflowError(
values0
.
len0
()))
?0
;
573
574
10
    let array = unsafe {
575
10
        let offsets = OffsetBuffer::new_unchecked(offsets.into());
576
10
        GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
577
    };
578
579
10
    Ok(array)
580
11
}
581
582
/// `take` implementation for byte view arrays
583
2
fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
584
2
    array: &GenericByteViewArray<T>,
585
2
    indices: &PrimitiveArray<IndexType>,
586
2
) -> Result<GenericByteViewArray<T>, ArrowError> {
587
2
    let new_views = take_native(array.views(), indices);
588
2
    let new_nulls = take_nulls(array.nulls(), indices);
589
    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
590
2
    Ok(unsafe {
591
2
        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
592
2
    })
593
2
}
594
595
/// `take` implementation for list arrays
596
///
597
/// Calculates the index and indexed offset for the inner array,
598
/// applying `take` on the inner array, then reconstructing a list array
599
/// with the indexed offsets
600
8
fn take_list<IndexType, OffsetType>(
601
8
    values: &GenericListArray<OffsetType::Native>,
602
8
    indices: &PrimitiveArray<IndexType>,
603
8
) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
604
8
where
605
8
    IndexType: ArrowPrimitiveType,
606
8
    OffsetType: ArrowPrimitiveType,
607
8
    OffsetType::Native: OffsetSizeTrait,
608
8
    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
609
{
610
    // TODO: Some optimizations can be done here such as if it is
611
    // taking the whole list or a contiguous sublist
612
8
    let (list_indices, offsets, null_buf) =
613
8
        take_value_indices_from_list::<IndexType, OffsetType>(values, indices)
?0
;
614
615
8
    let taken = take_impl::<OffsetType>(values.values().as_ref(), &list_indices)
?0
;
616
8
    let value_offsets = Buffer::from_vec(offsets);
617
    // create a new list with taken data and computed null information
618
8
    let list_data = ArrayDataBuilder::new(values.data_type().clone())
619
8
        .len(indices.len())
620
8
        .null_bit_buffer(Some(null_buf.into()))
621
8
        .offset(0)
622
8
        .add_child_data(taken.into_data())
623
8
        .add_buffer(value_offsets);
624
625
8
    let list_data = unsafe { list_data.build_unchecked() };
626
627
8
    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
628
8
}
629
630
8
fn take_list_view<IndexType, OffsetType>(
631
8
    values: &GenericListViewArray<OffsetType::Native>,
632
8
    indices: &PrimitiveArray<IndexType>,
633
8
) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
634
8
where
635
8
    IndexType: ArrowPrimitiveType,
636
8
    OffsetType: ArrowPrimitiveType,
637
8
    OffsetType::Native: OffsetSizeTrait,
638
{
639
8
    let taken_offsets = take_native(values.offsets(), indices);
640
8
    let taken_sizes = take_native(values.sizes(), indices);
641
8
    let nulls = take_nulls(values.nulls(), indices);
642
643
8
    let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
644
8
        .len(indices.len())
645
8
        .nulls(nulls)
646
8
        .buffers(vec![taken_offsets.into(), taken_sizes.into()])
647
8
        .child_data(vec![values.values().to_data()]);
648
649
    // SAFETY: all buffers and child nodes for ListView added in constructor
650
8
    let list_view_data = unsafe { list_view_data.build_unchecked() };
651
652
8
    Ok(GenericListViewArray::<OffsetType::Native>::from(
653
8
        list_view_data,
654
8
    ))
655
8
}
656
657
/// `take` implementation for `FixedSizeListArray`
658
///
659
/// Calculates the index and indexed offset for the inner array,
660
/// applying `take` on the inner array, then reconstructing a list array
661
/// with the indexed offsets
662
4
fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
663
4
    values: &FixedSizeListArray,
664
4
    indices: &PrimitiveArray<IndexType>,
665
4
    length: <UInt32Type as ArrowPrimitiveType>::Native,
666
4
) -> Result<FixedSizeListArray, ArrowError> {
667
4
    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)
?0
;
668
4
    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)
?0
;
669
670
    // determine null count and null buffer, which are a function of `values` and `indices`
671
4
    let num_bytes = bit_util::ceil(indices.len(), 8);
672
4
    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
673
4
    let null_slice = null_buf.as_slice_mut();
674
675
13
    for i in 0..
indices4
.
len4
() {
676
13
        let index = indices
677
13
            .value(i)
678
13
            .to_usize()
679
13
            .ok_or_else(|| ArrowError::ComputeError(
"Cast to usize failed"0
.
to_string0
()))
?0
;
680
13
        if !indices.is_valid(i) || 
values12
.
is_null12
(
index12
) {
681
3
            bit_util::unset_bit(null_slice, i);
682
10
        }
683
    }
684
685
4
    let list_data = ArrayDataBuilder::new(values.data_type().clone())
686
4
        .len(indices.len())
687
4
        .null_bit_buffer(Some(null_buf.into()))
688
4
        .offset(0)
689
4
        .add_child_data(taken.into_data());
690
691
4
    let list_data = unsafe { list_data.build_unchecked() };
692
693
4
    Ok(FixedSizeListArray::from(list_data))
694
4
}
695
696
0
fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
697
0
    values: &FixedSizeBinaryArray,
698
0
    indices: &PrimitiveArray<IndexType>,
699
0
    size: i32,
700
0
) -> Result<FixedSizeBinaryArray, ArrowError> {
701
0
    let nulls = values.nulls();
702
0
    let array_iter = indices
703
0
        .values()
704
0
        .iter()
705
0
        .map(|idx| {
706
0
            let idx = maybe_usize::<IndexType::Native>(*idx)?;
707
0
            if nulls.map(|n| n.is_valid(idx)).unwrap_or(true) {
708
0
                Ok(Some(values.value(idx)))
709
            } else {
710
0
                Ok(None)
711
            }
712
0
        })
713
0
        .collect::<Result<Vec<_>, ArrowError>>()?
714
0
        .into_iter();
715
716
0
    FixedSizeBinaryArray::try_from_sparse_iter_with_size(array_iter, size)
717
0
}
718
719
/// `take` implementation for dictionary arrays
720
///
721
/// applies `take` to the keys of the dictionary array and returns a new dictionary array
722
/// with the same dictionary values and reordered keys
723
1
fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
724
1
    values: &DictionaryArray<T>,
725
1
    indices: &PrimitiveArray<I>,
726
1
) -> Result<DictionaryArray<T>, ArrowError> {
727
1
    let new_keys = take_primitive(values.keys(), indices)
?0
;
728
1
    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
729
1
}
730
731
/// `take` implementation for run arrays
732
///
733
/// Finds physical indices for the given logical indices and builds output run array
734
/// by taking values in the input run_array.values at the physical indices.
735
/// The output run array will be run encoded on the physical indices and not on output values.
736
/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
737
/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
738
/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
739
1
fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
740
1
    run_array: &RunArray<T>,
741
1
    logical_indices: &PrimitiveArray<I>,
742
1
) -> Result<RunArray<T>, ArrowError> {
743
    // get physical indices for the input logical indices
744
1
    let physical_indices = run_array.get_physical_indices(logical_indices.values())
?0
;
745
746
    // Run encode the physical indices into new_run_ends_builder
747
    // Keep track of the physical indices to take in take_value_indices
748
    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
749
1
    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
750
1
    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
751
1
    let mut new_physical_len = 1;
752
6
    for ix in 1..
physical_indices1
.
len1
() {
753
6
        if physical_indices[ix] != physical_indices[ix - 1] {
754
4
            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
755
4
            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
756
4
            new_physical_len += 1;
757
4
        
}2
758
    }
759
1
    take_value_indices
760
1
        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
761
1
    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
762
1
    let new_run_ends = unsafe {
763
        // Safety:
764
        // The function builds a valid run_ends array and hence need not be validated.
765
1
        ArrayDataBuilder::new(T::DATA_TYPE)
766
1
            .len(new_physical_len)
767
1
            .null_count(0)
768
1
            .add_buffer(new_run_ends_builder.finish())
769
1
            .build_unchecked()
770
    };
771
772
1
    let take_value_indices: PrimitiveArray<I> = unsafe {
773
        // Safety:
774
        // The function builds a valid take_value_indices array and hence need not be validated.
775
1
        ArrayDataBuilder::new(I::DATA_TYPE)
776
1
            .len(new_physical_len)
777
1
            .null_count(0)
778
1
            .add_buffer(take_value_indices.finish())
779
1
            .build_unchecked()
780
1
            .into()
781
    };
782
783
1
    let new_values = take(run_array.values(), &take_value_indices, None)
?0
;
784
785
1
    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
786
1
        .len(physical_indices.len())
787
1
        .add_child_data(new_run_ends)
788
1
        .add_child_data(new_values.into_data());
789
1
    let array_data = unsafe {
790
        // Safety:
791
        //  This function builds a valid run array and hence can skip validation.
792
1
        builder.build_unchecked()
793
    };
794
1
    Ok(array_data.into())
795
1
}
796
797
/// Takes/filters a list array's inner data using the offsets of the list array.
798
///
799
/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
800
/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
801
/// elements)
802
#[allow(clippy::type_complexity)]
803
10
fn take_value_indices_from_list<IndexType, OffsetType>(
804
10
    list: &GenericListArray<OffsetType::Native>,
805
10
    indices: &PrimitiveArray<IndexType>,
806
10
) -> Result<
807
10
    (
808
10
        PrimitiveArray<OffsetType>,
809
10
        Vec<OffsetType::Native>,
810
10
        MutableBuffer,
811
10
    ),
812
10
    ArrowError,
813
10
>
814
10
where
815
10
    IndexType: ArrowPrimitiveType,
816
10
    OffsetType: ArrowPrimitiveType,
817
10
    OffsetType::Native: OffsetSizeTrait + std::ops::Add + Zero + One,
818
10
    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
819
{
820
    // TODO: benchmark this function, there might be a faster unsafe alternative
821
10
    let offsets: &[OffsetType::Native] = list.value_offsets();
822
823
10
    let mut new_offsets = Vec::with_capacity(indices.len());
824
10
    let mut values = Vec::new();
825
10
    let mut current_offset = OffsetType::Native::zero();
826
    // add first offset
827
10
    new_offsets.push(OffsetType::Native::zero());
828
829
    // Initialize null buffer
830
10
    let num_bytes = bit_util::ceil(indices.len(), 8);
831
10
    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
832
10
    let null_slice = null_buf.as_slice_mut();
833
834
    // compute the value indices, and set offsets accordingly
835
36
    for i in 0..
indices10
.
len10
() {
836
36
        if indices.is_valid(i) {
837
30
            let 
ix29
= indices
838
30
                .value(i)
839
30
                .to_usize()
840
30
                .ok_or_else(|| ArrowError::ComputeError(
"Cast to usize failed"0
.
to_string0
()))
?1
;
841
29
            let start = offsets[ix];
842
29
            let end = offsets[ix + 1];
843
29
            current_offset += end - start;
844
29
            new_offsets.push(current_offset);
845
846
29
            let mut curr = start;
847
848
            // if start == end, this slot is empty
849
96
            while curr < end {
850
67
                values.push(curr);
851
67
                curr += One::one();
852
67
            }
853
29
            if !list.is_valid(ix) {
854
2
                bit_util::unset_bit(null_slice, i);
855
27
            }
856
6
        } else {
857
6
            bit_util::unset_bit(null_slice, i);
858
6
            new_offsets.push(current_offset);
859
6
        }
860
    }
861
862
9
    Ok((
863
9
        PrimitiveArray::<OffsetType>::from(values),
864
9
        new_offsets,
865
9
        null_buf,
866
9
    ))
867
10
}
868
869
/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
870
6
fn take_value_indices_from_fixed_size_list<IndexType>(
871
6
    list: &FixedSizeListArray,
872
6
    indices: &PrimitiveArray<IndexType>,
873
6
    length: <UInt32Type as ArrowPrimitiveType>::Native,
874
6
) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
875
6
where
876
6
    IndexType: ArrowPrimitiveType,
877
{
878
6
    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
879
880
21
    for i in 0..
indices6
.
len6
() {
881
21
        if indices.is_valid(i) {
882
20
            let index = indices
883
20
                .value(i)
884
20
                .to_usize()
885
20
                .ok_or_else(|| ArrowError::ComputeError(
"Cast to usize failed"0
.
to_string0
()))
?0
;
886
20
            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
887
888
            // Safety: Range always has known length.
889
20
            unsafe {
890
20
                values.append_trusted_len_iter(start..start + length);
891
20
            }
892
1
        } else {
893
1
            values.append_nulls(length as usize);
894
1
        }
895
    }
896
897
6
    Ok(values.finish())
898
6
}
899
900
/// To avoid generating take implementations for every index type, instead we
901
/// only generate for UInt32 and UInt64 and coerce inputs to these types
902
trait ToIndices {
903
    type T: ArrowPrimitiveType;
904
905
    fn to_indices(&self) -> PrimitiveArray<Self::T>;
906
}
907
908
macro_rules! to_indices_reinterpret {
909
    ($t:ty, $o:ty) => {
910
        impl ToIndices for PrimitiveArray<$t> {
911
            type T = $o;
912
913
17
            fn to_indices(&self) -> PrimitiveArray<$o> {
914
17
                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
915
17
                PrimitiveArray::new(cast, self.nulls().cloned())
916
17
            }
917
        }
918
    };
919
}
920
921
macro_rules! to_indices_identity {
922
    ($t:ty) => {
923
        impl ToIndices for PrimitiveArray<$t> {
924
            type T = $t;
925
926
58
            fn to_indices(&self) -> PrimitiveArray<$t> {
927
58
                self.clone()
928
58
            }
929
        }
930
    };
931
}
932
933
macro_rules! to_indices_widening {
934
    ($t:ty, $o:ty) => {
935
        impl ToIndices for PrimitiveArray<$t> {
936
            type T = UInt32Type;
937
938
3
            fn to_indices(&self) -> PrimitiveArray<$o> {
939
15
                let 
cast3
=
self.values().iter()3
.
copied3
().
map3
(|x| x as _).
collect3
();
940
3
                PrimitiveArray::new(cast, self.nulls().cloned())
941
3
            }
942
        }
943
    };
944
}
945
946
to_indices_widening!(UInt8Type, UInt32Type);
947
to_indices_widening!(Int8Type, UInt32Type);
948
949
to_indices_widening!(UInt16Type, UInt32Type);
950
to_indices_widening!(Int16Type, UInt32Type);
951
952
to_indices_identity!(UInt32Type);
953
to_indices_reinterpret!(Int32Type, UInt32Type);
954
955
to_indices_identity!(UInt64Type);
956
to_indices_reinterpret!(Int64Type, UInt64Type);
957
958
/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
959
///
960
/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
961
///
962
/// # Example
963
/// ```
964
/// # use std::sync::Arc;
965
/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
966
/// # use arrow_schema::{DataType, Field, Schema};
967
/// # use arrow_select::take::take_record_batch;
968
///
969
/// let schema = Arc::new(Schema::new(vec![
970
///     Field::new("a", DataType::Int32, true),
971
///     Field::new("b", DataType::Utf8, true),
972
/// ]));
973
/// let batch = RecordBatch::try_new(
974
///     schema.clone(),
975
///     vec![
976
///         Arc::new(Int32Array::from_iter_values(0..20)),
977
///         Arc::new(StringArray::from_iter_values(
978
///             (0..20).map(|i| format!("str-{}", i)),
979
///         )),
980
///     ],
981
/// )
982
/// .unwrap();
983
///
984
/// let indices = UInt32Array::from(vec![1, 5, 10]);
985
/// let taken = take_record_batch(&batch, &indices).unwrap();
986
///
987
/// let expected = RecordBatch::try_new(
988
///     schema,
989
///     vec![
990
///         Arc::new(Int32Array::from(vec![1, 5, 10])),
991
///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
992
///     ],
993
/// )
994
/// .unwrap();
995
/// assert_eq!(taken, expected);
996
/// ```
997
0
pub fn take_record_batch(
998
0
    record_batch: &RecordBatch,
999
0
    indices: &dyn Array,
1000
0
) -> Result<RecordBatch, ArrowError> {
1001
0
    let columns = record_batch
1002
0
        .columns()
1003
0
        .iter()
1004
0
        .map(|c| take(c, indices, None))
1005
0
        .collect::<Result<Vec<_>, _>>()?;
1006
0
    RecordBatch::try_new(record_batch.schema(), columns)
1007
0
}
1008
1009
#[cfg(test)]
1010
mod tests {
1011
    use super::*;
1012
    use arrow_array::builder::*;
1013
    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1014
    use arrow_data::ArrayData;
1015
    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1016
    use num_traits::ToPrimitive;
1017
1018
2
    fn test_take_decimal_arrays(
1019
2
        data: Vec<Option<i128>>,
1020
2
        index: &UInt32Array,
1021
2
        options: Option<TakeOptions>,
1022
2
        expected_data: Vec<Option<i128>>,
1023
2
        precision: &u8,
1024
2
        scale: &i8,
1025
2
    ) -> Result<(), ArrowError> {
1026
2
        let output = data
1027
2
            .into_iter()
1028
2
            .collect::<Decimal128Array>()
1029
2
            .with_precision_and_scale(*precision, *scale)
1030
2
            .unwrap();
1031
1032
2
        let expected = expected_data
1033
2
            .into_iter()
1034
2
            .collect::<Decimal128Array>()
1035
2
            .with_precision_and_scale(*precision, *scale)
1036
2
            .unwrap();
1037
1038
2
        let expected = Arc::new(expected) as ArrayRef;
1039
2
        let output = take(&output, index, options).unwrap();
1040
2
        assert_eq!(&output, &expected);
1041
2
        Ok(())
1042
2
    }
1043
1044
4
    fn test_take_boolean_arrays(
1045
4
        data: Vec<Option<bool>>,
1046
4
        index: &UInt32Array,
1047
4
        options: Option<TakeOptions>,
1048
4
        expected_data: Vec<Option<bool>>,
1049
4
    ) {
1050
4
        let output = BooleanArray::from(data);
1051
4
        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1052
4
        let output = take(&output, index, options).unwrap();
1053
4
        assert_eq!(&output, &expected)
1054
4
    }
1055
1056
23
    fn test_take_primitive_arrays<T>(
1057
23
        data: Vec<Option<T::Native>>,
1058
23
        index: &UInt32Array,
1059
23
        options: Option<TakeOptions>,
1060
23
        expected_data: Vec<Option<T::Native>>,
1061
23
    ) -> Result<(), ArrowError>
1062
23
    where
1063
23
        T: ArrowPrimitiveType,
1064
23
        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1065
    {
1066
23
        let output = PrimitiveArray::<T>::from(data);
1067
23
        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1068
23
        let 
output22
= take(&output, index, options)
?1
;
1069
22
        assert_eq!(&output, &expected);
1070
21
        Ok(())
1071
22
    }
1072
1073
1
    fn test_take_primitive_arrays_non_null<T>(
1074
1
        data: Vec<T::Native>,
1075
1
        index: &UInt32Array,
1076
1
        options: Option<TakeOptions>,
1077
1
        expected_data: Vec<Option<T::Native>>,
1078
1
    ) -> Result<(), ArrowError>
1079
1
    where
1080
1
        T: ArrowPrimitiveType,
1081
1
        PrimitiveArray<T>: From<Vec<T::Native>>,
1082
1
        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1083
    {
1084
1
        let output = PrimitiveArray::<T>::from(data);
1085
1
        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1086
1
        let output = take(&output, index, options)
?0
;
1087
1
        assert_eq!(&output, &expected);
1088
1
        Ok(())
1089
1
    }
1090
1091
8
    fn test_take_impl_primitive_arrays<T, I>(
1092
8
        data: Vec<Option<T::Native>>,
1093
8
        index: &PrimitiveArray<I>,
1094
8
        options: Option<TakeOptions>,
1095
8
        expected_data: Vec<Option<T::Native>>,
1096
8
    ) where
1097
8
        T: ArrowPrimitiveType,
1098
8
        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1099
8
        I: ArrowPrimitiveType,
1100
    {
1101
8
        let output = PrimitiveArray::<T>::from(data);
1102
8
        let expected = PrimitiveArray::<T>::from(expected_data);
1103
8
        let output = take(&output, index, options).unwrap();
1104
8
        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1105
8
        assert_eq!(output, &expected)
1106
8
    }
1107
1108
    // create a simple struct for testing purposes
1109
5
    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1110
5
        let mut struct_builder = StructBuilder::new(
1111
5
            Fields::from(vec![
1112
5
                Field::new("a", DataType::Boolean, true),
1113
5
                Field::new("b", DataType::Int32, true),
1114
            ]),
1115
5
            vec![
1116
5
                Box::new(BooleanBuilder::with_capacity(values.len())),
1117
5
                Box::new(Int32Builder::with_capacity(values.len())),
1118
            ],
1119
        );
1120
1121
32
        for 
value27
in values {
1122
27
            struct_builder
1123
27
                .field_builder::<BooleanBuilder>(0)
1124
27
                .unwrap()
1125
27
                .append_option(value.and_then(|v| v.0));
1126
27
            struct_builder
1127
27
                .field_builder::<Int32Builder>(1)
1128
27
                .unwrap()
1129
27
                .append_option(value.and_then(|v| v.1));
1130
27
            struct_builder.append(value.is_some());
1131
        }
1132
5
        struct_builder.finish()
1133
5
    }
1134
1135
    #[test]
1136
1
    fn test_take_decimal128_non_null_indices() {
1137
1
        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1138
1
        let precision: u8 = 10;
1139
1
        let scale: i8 = 5;
1140
1
        test_take_decimal_arrays(
1141
1
            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1142
1
            &index,
1143
1
            None,
1144
1
            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1145
1
            &precision,
1146
1
            &scale,
1147
        )
1148
1
        .unwrap();
1149
1
    }
1150
1151
    #[test]
1152
1
    fn test_take_decimal128() {
1153
1
        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1154
1
        let precision: u8 = 10;
1155
1
        let scale: i8 = 5;
1156
1
        test_take_decimal_arrays(
1157
1
            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1158
1
            &index,
1159
1
            None,
1160
1
            vec![Some(3), None, Some(1), Some(3), Some(2)],
1161
1
            &precision,
1162
1
            &scale,
1163
        )
1164
1
        .unwrap();
1165
1
    }
1166
1167
    #[test]
1168
1
    fn test_take_primitive_non_null_indices() {
1169
1
        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1170
1
        test_take_primitive_arrays::<Int8Type>(
1171
1
            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1172
1
            &index,
1173
1
            None,
1174
1
            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1175
        )
1176
1
        .unwrap();
1177
1
    }
1178
1179
    #[test]
1180
1
    fn test_take_primitive_non_null_values() {
1181
1
        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1182
1
        test_take_primitive_arrays::<Int8Type>(
1183
1
            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1184
1
            &index,
1185
1
            None,
1186
1
            vec![Some(3), None, Some(1), Some(3), Some(2)],
1187
        )
1188
1
        .unwrap();
1189
1
    }
1190
1191
    #[test]
1192
1
    fn test_take_primitive_non_null() {
1193
1
        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1194
1
        test_take_primitive_arrays::<Int8Type>(
1195
1
            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1196
1
            &index,
1197
1
            None,
1198
1
            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1199
        )
1200
1
        .unwrap();
1201
1
    }
1202
1203
    #[test]
1204
1
    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1205
1
        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1206
1
        let index = index.slice(2, 4);
1207
1
        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1208
1209
1
        assert_eq!(
1210
            index,
1211
1
            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1212
        );
1213
1214
1
        test_take_primitive_arrays_non_null::<Int64Type>(
1215
1
            vec![0, 10, 20, 30, 40, 50],
1216
1
            index,
1217
1
            None,
1218
1
            vec![Some(20), Some(30), None, None],
1219
        )
1220
1
        .unwrap();
1221
1
    }
1222
1223
    #[test]
1224
1
    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1225
1
        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1226
1
        let index = index.slice(2, 4);
1227
1
        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1228
1229
1
        assert_eq!(
1230
            index,
1231
1
            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1232
        );
1233
1234
1
        test_take_primitive_arrays::<Int64Type>(
1235
1
            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1236
1
            index,
1237
1
            None,
1238
1
            vec![Some(20), Some(30), None, None],
1239
        )
1240
1
        .unwrap();
1241
1
    }
1242
1243
    #[test]
1244
1
    fn test_take_primitive() {
1245
1
        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1246
1247
        // int8
1248
1
        test_take_primitive_arrays::<Int8Type>(
1249
1
            vec![Some(0), None, Some(2), Some(3), None],
1250
1
            &index,
1251
1
            None,
1252
1
            vec![Some(3), None, None, Some(3), Some(2)],
1253
        )
1254
1
        .unwrap();
1255
1256
        // int16
1257
1
        test_take_primitive_arrays::<Int16Type>(
1258
1
            vec![Some(0), None, Some(2), Some(3), None],
1259
1
            &index,
1260
1
            None,
1261
1
            vec![Some(3), None, None, Some(3), Some(2)],
1262
        )
1263
1
        .unwrap();
1264
1265
        // int32
1266
1
        test_take_primitive_arrays::<Int32Type>(
1267
1
            vec![Some(0), None, Some(2), Some(3), None],
1268
1
            &index,
1269
1
            None,
1270
1
            vec![Some(3), None, None, Some(3), Some(2)],
1271
        )
1272
1
        .unwrap();
1273
1274
        // int64
1275
1
        test_take_primitive_arrays::<Int64Type>(
1276
1
            vec![Some(0), None, Some(2), Some(3), None],
1277
1
            &index,
1278
1
            None,
1279
1
            vec![Some(3), None, None, Some(3), Some(2)],
1280
        )
1281
1
        .unwrap();
1282
1283
        // uint8
1284
1
        test_take_primitive_arrays::<UInt8Type>(
1285
1
            vec![Some(0), None, Some(2), Some(3), None],
1286
1
            &index,
1287
1
            None,
1288
1
            vec![Some(3), None, None, Some(3), Some(2)],
1289
        )
1290
1
        .unwrap();
1291
1292
        // uint16
1293
1
        test_take_primitive_arrays::<UInt16Type>(
1294
1
            vec![Some(0), None, Some(2), Some(3), None],
1295
1
            &index,
1296
1
            None,
1297
1
            vec![Some(3), None, None, Some(3), Some(2)],
1298
        )
1299
1
        .unwrap();
1300
1301
        // uint32
1302
1
        test_take_primitive_arrays::<UInt32Type>(
1303
1
            vec![Some(0), None, Some(2), Some(3), None],
1304
1
            &index,
1305
1
            None,
1306
1
            vec![Some(3), None, None, Some(3), Some(2)],
1307
        )
1308
1
        .unwrap();
1309
1310
        // int64
1311
1
        test_take_primitive_arrays::<Int64Type>(
1312
1
            vec![Some(0), None, Some(2), Some(-15), None],
1313
1
            &index,
1314
1
            None,
1315
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1316
        )
1317
1
        .unwrap();
1318
1319
        // interval_year_month
1320
1
        test_take_primitive_arrays::<IntervalYearMonthType>(
1321
1
            vec![Some(0), None, Some(2), Some(-15), None],
1322
1
            &index,
1323
1
            None,
1324
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1325
        )
1326
1
        .unwrap();
1327
1328
        // interval_day_time
1329
1
        let v1 = IntervalDayTime::new(0, 0);
1330
1
        let v2 = IntervalDayTime::new(2, 0);
1331
1
        let v3 = IntervalDayTime::new(-15, 0);
1332
1
        test_take_primitive_arrays::<IntervalDayTimeType>(
1333
1
            vec![Some(v1), None, Some(v2), Some(v3), None],
1334
1
            &index,
1335
1
            None,
1336
1
            vec![Some(v3), None, None, Some(v3), Some(v2)],
1337
        )
1338
1
        .unwrap();
1339
1340
        // interval_month_day_nano
1341
1
        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1342
1
        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1343
1
        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1344
1
        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1345
1
            vec![Some(v1), None, Some(v2), Some(v3), None],
1346
1
            &index,
1347
1
            None,
1348
1
            vec![Some(v3), None, None, Some(v3), Some(v2)],
1349
        )
1350
1
        .unwrap();
1351
1352
        // duration_second
1353
1
        test_take_primitive_arrays::<DurationSecondType>(
1354
1
            vec![Some(0), None, Some(2), Some(-15), None],
1355
1
            &index,
1356
1
            None,
1357
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1358
        )
1359
1
        .unwrap();
1360
1361
        // duration_millisecond
1362
1
        test_take_primitive_arrays::<DurationMillisecondType>(
1363
1
            vec![Some(0), None, Some(2), Some(-15), None],
1364
1
            &index,
1365
1
            None,
1366
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1367
        )
1368
1
        .unwrap();
1369
1370
        // duration_microsecond
1371
1
        test_take_primitive_arrays::<DurationMicrosecondType>(
1372
1
            vec![Some(0), None, Some(2), Some(-15), None],
1373
1
            &index,
1374
1
            None,
1375
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1376
        )
1377
1
        .unwrap();
1378
1379
        // duration_nanosecond
1380
1
        test_take_primitive_arrays::<DurationNanosecondType>(
1381
1
            vec![Some(0), None, Some(2), Some(-15), None],
1382
1
            &index,
1383
1
            None,
1384
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1385
        )
1386
1
        .unwrap();
1387
1388
        // float32
1389
1
        test_take_primitive_arrays::<Float32Type>(
1390
1
            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1391
1
            &index,
1392
1
            None,
1393
1
            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1394
        )
1395
1
        .unwrap();
1396
1397
        // float64
1398
1
        test_take_primitive_arrays::<Float64Type>(
1399
1
            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1400
1
            &index,
1401
1
            None,
1402
1
            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1403
        )
1404
1
        .unwrap();
1405
1
    }
1406
1407
    #[test]
1408
1
    fn test_take_preserve_timezone() {
1409
1
        let index = Int64Array::from(vec![Some(0), None]);
1410
1411
1
        let input = TimestampNanosecondArray::from(vec![
1412
            1_639_715_368_000_000_000,
1413
            1_639_715_368_000_000_000,
1414
        ])
1415
1
        .with_timezone("UTC".to_string());
1416
1
        let result = take(&input, &index, None).unwrap();
1417
1
        match result.data_type() {
1418
1
            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1419
1
                assert_eq!(tz.clone(), Some("UTC".into()))
1420
            }
1421
0
            _ => panic!(),
1422
        }
1423
1
    }
1424
1425
    #[test]
1426
1
    fn test_take_impl_primitive_with_int64_indices() {
1427
1
        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1428
1429
        // int16
1430
1
        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1431
1
            vec![Some(0), None, Some(2), Some(3), None],
1432
1
            &index,
1433
1
            None,
1434
1
            vec![Some(3), None, None, Some(3), Some(2)],
1435
        );
1436
1437
        // int64
1438
1
        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1439
1
            vec![Some(0), None, Some(2), Some(-15), None],
1440
1
            &index,
1441
1
            None,
1442
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1443
        );
1444
1445
        // uint64
1446
1
        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1447
1
            vec![Some(0), None, Some(2), Some(3), None],
1448
1
            &index,
1449
1
            None,
1450
1
            vec![Some(3), None, None, Some(3), Some(2)],
1451
        );
1452
1453
        // duration_millisecond
1454
1
        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1455
1
            vec![Some(0), None, Some(2), Some(-15), None],
1456
1
            &index,
1457
1
            None,
1458
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1459
        );
1460
1461
        // float32
1462
1
        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1463
1
            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1464
1
            &index,
1465
1
            None,
1466
1
            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1467
        );
1468
1
    }
1469
1470
    #[test]
1471
1
    fn test_take_impl_primitive_with_uint8_indices() {
1472
1
        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1473
1474
        // int16
1475
1
        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1476
1
            vec![Some(0), None, Some(2), Some(3), None],
1477
1
            &index,
1478
1
            None,
1479
1
            vec![Some(3), None, None, Some(3), Some(2)],
1480
        );
1481
1482
        // duration_millisecond
1483
1
        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1484
1
            vec![Some(0), None, Some(2), Some(-15), None],
1485
1
            &index,
1486
1
            None,
1487
1
            vec![Some(-15), None, None, Some(-15), Some(2)],
1488
        );
1489
1490
        // float32
1491
1
        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1492
1
            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1493
1
            &index,
1494
1
            None,
1495
1
            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1496
        );
1497
1
    }
1498
1499
    #[test]
1500
1
    fn test_take_bool() {
1501
1
        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1502
        // boolean
1503
1
        test_take_boolean_arrays(
1504
1
            vec![Some(false), None, Some(true), Some(false), None],
1505
1
            &index,
1506
1
            None,
1507
1
            vec![Some(false), None, None, Some(false), Some(true)],
1508
        );
1509
1
    }
1510
1511
    #[test]
1512
1
    fn test_take_bool_nullable_index() {
1513
        // indices where the masked invalid elements would be out of bounds
1514
1
        let index_data = ArrayData::try_new(
1515
1
            DataType::UInt32,
1516
            6,
1517
1
            Some(Buffer::from_iter(vec![
1518
1
                false, true, false, true, false, true,
1519
1
            ])),
1520
            0,
1521
1
            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1522
1
            vec![],
1523
        )
1524
1
        .unwrap();
1525
1
        let index = UInt32Array::from(index_data);
1526
1
        test_take_boolean_arrays(
1527
1
            vec![Some(true), None, Some(false)],
1528
1
            &index,
1529
1
            None,
1530
1
            vec![None, Some(true), None, None, None, Some(false)],
1531
        );
1532
1
    }
1533
1534
    #[test]
1535
1
    fn test_take_bool_nullable_index_nonnull_values() {
1536
        // indices where the masked invalid elements would be out of bounds
1537
1
        let index_data = ArrayData::try_new(
1538
1
            DataType::UInt32,
1539
            6,
1540
1
            Some(Buffer::from_iter(vec![
1541
1
                false, true, false, true, false, true,
1542
1
            ])),
1543
            0,
1544
1
            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1545
1
            vec![],
1546
        )
1547
1
        .unwrap();
1548
1
        let index = UInt32Array::from(index_data);
1549
1
        test_take_boolean_arrays(
1550
1
            vec![Some(true), Some(true), Some(false)],
1551
1
            &index,
1552
1
            None,
1553
1
            vec![None, Some(true), None, Some(true), None, Some(false)],
1554
        );
1555
1
    }
1556
1557
    #[test]
1558
1
    fn test_take_bool_with_offset() {
1559
1
        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1560
1
        let index = index.slice(2, 4);
1561
1
        let index = index
1562
1
            .as_any()
1563
1
            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1564
1
            .unwrap();
1565
1566
        // boolean
1567
1
        test_take_boolean_arrays(
1568
1
            vec![Some(false), None, Some(true), Some(false), None],
1569
1
            index,
1570
1
            None,
1571
1
            vec![None, Some(false), Some(true), None],
1572
        );
1573
1
    }
1574
1575
2
    fn _test_take_string<'a, K>()
1576
2
    where
1577
2
        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1578
    {
1579
2
        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1580
1581
2
        let array = K::from(vec![
1582
2
            Some("one"),
1583
2
            None,
1584
2
            Some("three"),
1585
2
            Some("four"),
1586
2
            Some("five"),
1587
        ]);
1588
2
        let actual = take(&array, &index, None).unwrap();
1589
2
        assert_eq!(actual.len(), index.len());
1590
1591
2
        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1592
1593
2
        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1594
1595
2
        assert_eq!(actual, &expected);
1596
2
    }
1597
1598
    #[test]
1599
1
    fn test_take_string() {
1600
1
        _test_take_string::<StringArray>()
1601
1
    }
1602
1603
    #[test]
1604
1
    fn test_take_large_string() {
1605
1
        _test_take_string::<LargeStringArray>()
1606
1
    }
1607
1608
    #[test]
1609
1
    fn test_take_slice_string() {
1610
1
        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1611
1
        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1612
1
        let indices_slice = indices.slice(1, 4);
1613
1
        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1614
1
        let result = take(&strings, &indices_slice, None).unwrap();
1615
1
        assert_eq!(result.as_ref(), &expected);
1616
1
    }
1617
1618
2
    fn _test_byte_view<T>()
1619
2
    where
1620
2
        T: ByteViewType,
1621
2
        str: AsRef<T::Native>,
1622
2
        T::Native: PartialEq,
1623
    {
1624
2
        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1625
2
        let array = {
1626
            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1627
2
            let mut builder = GenericByteViewBuilder::<T>::new();
1628
2
            builder.append_value("hello");
1629
2
            builder.append_value("world");
1630
2
            builder.append_null();
1631
2
            builder.append_value("large payload over 12 bytes");
1632
2
            builder.append_value("lulu");
1633
2
            builder.finish()
1634
        };
1635
1636
2
        let actual = take(&array, &index, None).unwrap();
1637
1638
2
        assert_eq!(actual.len(), index.len());
1639
1640
2
        let expected = {
1641
            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1642
2
            let mut builder = GenericByteViewBuilder::<T>::new();
1643
2
            builder.append_value("large payload over 12 bytes");
1644
2
            builder.append_null();
1645
2
            builder.append_value("world");
1646
2
            builder.append_value("large payload over 12 bytes");
1647
2
            builder.append_value("lulu");
1648
2
            builder.append_null();
1649
2
            builder.finish()
1650
        };
1651
1652
2
        assert_eq!(actual.as_ref(), &expected);
1653
2
    }
1654
1655
    #[test]
1656
1
    fn test_take_string_view() {
1657
1
        _test_byte_view::<StringViewType>()
1658
1
    }
1659
1660
    #[test]
1661
1
    fn test_take_binary_view() {
1662
1
        _test_byte_view::<BinaryViewType>()
1663
1
    }
1664
1665
    macro_rules! test_take_list {
1666
        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1667
            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1668
            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1669
            // Construct offsets
1670
            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1671
            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1672
            // Construct a list array from the above two
1673
            let list_data_type =
1674
                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1675
            let list_data = ArrayData::builder(list_data_type.clone())
1676
                .len(4)
1677
                .add_buffer(value_offsets)
1678
                .add_child_data(value_data)
1679
                .build()
1680
                .unwrap();
1681
            let list_array = $list_array_type::from(list_data);
1682
1683
            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1684
            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1685
1686
            let a = take(&list_array, &index, None).unwrap();
1687
            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1688
1689
            // construct a value array with expected results:
1690
            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1691
            let expected_data = Int32Array::from(vec![
1692
                Some(2),
1693
                Some(3),
1694
                Some(-1),
1695
                Some(-2),
1696
                Some(-1),
1697
                Some(0),
1698
                Some(0),
1699
                Some(0),
1700
            ])
1701
            .into_data();
1702
            // construct offsets
1703
            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1704
            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1705
            // construct list array from the two
1706
            let expected_list_data = ArrayData::builder(list_data_type)
1707
                .len(5)
1708
                // null buffer remains the same as only the indices have nulls
1709
                .nulls(index.nulls().cloned())
1710
                .add_buffer(expected_offsets)
1711
                .add_child_data(expected_data)
1712
                .build()
1713
                .unwrap();
1714
            let expected_list_array = $list_array_type::from(expected_list_data);
1715
1716
            assert_eq!(a, &expected_list_array);
1717
        }};
1718
    }
1719
1720
    macro_rules! test_take_list_with_value_nulls {
1721
        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1722
            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1723
            let value_data = Int32Array::from(vec![
1724
                Some(0),
1725
                None,
1726
                Some(0),
1727
                Some(-1),
1728
                Some(-2),
1729
                Some(3),
1730
                None,
1731
                Some(5),
1732
                None,
1733
            ])
1734
            .into_data();
1735
            // Construct offsets
1736
            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1737
            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1738
            // Construct a list array from the above two
1739
            let list_data_type =
1740
                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1741
            let list_data = ArrayData::builder(list_data_type.clone())
1742
                .len(4)
1743
                .add_buffer(value_offsets)
1744
                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1745
                .add_child_data(value_data)
1746
                .build()
1747
                .unwrap();
1748
            let list_array = $list_array_type::from(list_data);
1749
1750
            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1751
            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1752
1753
            let a = take(&list_array, &index, None).unwrap();
1754
            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1755
1756
            // construct a value array with expected results:
1757
            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1758
            let expected_data = Int32Array::from(vec![
1759
                None,
1760
                Some(-1),
1761
                Some(-2),
1762
                Some(3),
1763
                Some(5),
1764
                None,
1765
                Some(0),
1766
                None,
1767
                Some(0),
1768
            ])
1769
            .into_data();
1770
            // construct offsets
1771
            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1772
            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1773
            // construct list array from the two
1774
            let expected_list_data = ArrayData::builder(list_data_type)
1775
                .len(5)
1776
                // null buffer remains the same as only the indices have nulls
1777
                .nulls(index.nulls().cloned())
1778
                .add_buffer(expected_offsets)
1779
                .add_child_data(expected_data)
1780
                .build()
1781
                .unwrap();
1782
            let expected_list_array = $list_array_type::from(expected_list_data);
1783
1784
            assert_eq!(a, &expected_list_array);
1785
        }};
1786
    }
1787
1788
    macro_rules! test_take_list_with_nulls {
1789
        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1790
            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1791
            let value_data = Int32Array::from(vec![
1792
                Some(0),
1793
                None,
1794
                Some(0),
1795
                Some(-1),
1796
                Some(-2),
1797
                Some(3),
1798
                Some(5),
1799
                None,
1800
            ])
1801
            .into_data();
1802
            // Construct offsets
1803
            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1804
            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1805
            // Construct a list array from the above two
1806
            let list_data_type =
1807
                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1808
            let list_data = ArrayData::builder(list_data_type.clone())
1809
                .len(4)
1810
                .add_buffer(value_offsets)
1811
                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1812
                .add_child_data(value_data)
1813
                .build()
1814
                .unwrap();
1815
            let list_array = $list_array_type::from(list_data);
1816
1817
            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1818
            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1819
1820
            let a = take(&list_array, &index, None).unwrap();
1821
            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1822
1823
            // construct a value array with expected results:
1824
            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1825
            let expected_data = Int32Array::from(vec![
1826
                Some(-1),
1827
                Some(-2),
1828
                Some(3),
1829
                Some(5),
1830
                None,
1831
                Some(0),
1832
                None,
1833
                Some(0),
1834
            ])
1835
            .into_data();
1836
            // construct offsets
1837
            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
1838
            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1839
            // construct list array from the two
1840
            let mut null_bits: [u8; 1] = [0; 1];
1841
            bit_util::set_bit(&mut null_bits, 2);
1842
            bit_util::set_bit(&mut null_bits, 3);
1843
            bit_util::set_bit(&mut null_bits, 4);
1844
            let expected_list_data = ArrayData::builder(list_data_type)
1845
                .len(5)
1846
                // null buffer must be recalculated as both values and indices have nulls
1847
                .null_bit_buffer(Some(Buffer::from(null_bits)))
1848
                .add_buffer(expected_offsets)
1849
                .add_child_data(expected_data)
1850
                .build()
1851
                .unwrap();
1852
            let expected_list_array = $list_array_type::from(expected_list_data);
1853
1854
            assert_eq!(a, &expected_list_array);
1855
        }};
1856
    }
1857
1858
8
    fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
1859
8
        values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1860
8
        take_indices: Vec<Option<usize>>,
1861
8
        expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
1862
8
        mapper: F,
1863
8
    ) where
1864
8
        F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
1865
    {
1866
8
        let mut list_view_array =
1867
8
            GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1868
1869
38
        for 
value30
in values {
1870
30
            list_view_array.append_option(value);
1871
30
        }
1872
8
        let list_view_array = list_view_array.finish();
1873
8
        let list_view_array = mapper(list_view_array);
1874
1875
8
        let mut indices = UInt64Builder::new();
1876
40
        for 
idx32
in take_indices {
1877
32
            indices.append_option(idx.map(|i| 
i22
.
to_u6422
().
unwrap22
()));
1878
        }
1879
8
        let indices = indices.finish();
1880
1881
8
        let taken = take(&list_view_array, &indices, None)
1882
8
            .unwrap()
1883
8
            .as_list_view()
1884
8
            .clone();
1885
1886
8
        let mut expected_array =
1887
8
            GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
1888
40
        for 
value32
in expected {
1889
32
            expected_array.append_option(value);
1890
32
        }
1891
8
        let expected_array = expected_array.finish();
1892
1893
8
        assert_eq!(taken, expected_array);
1894
8
    }
1895
1896
    macro_rules! list_view_test_case {
1897
        (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
1898
            test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
1899
            test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
1900
        }};
1901
        (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
1902
            test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
1903
            test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
1904
        }};
1905
    }
1906
1907
3
    fn do_take_fixed_size_list_test<T>(
1908
3
        length: <Int32Type as ArrowPrimitiveType>::Native,
1909
3
        input_data: Vec<Option<Vec<Option<T::Native>>>>,
1910
3
        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
1911
3
        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
1912
3
    ) where
1913
3
        T: ArrowPrimitiveType,
1914
3
        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1915
    {
1916
3
        let indices = UInt32Array::from(indices);
1917
1918
3
        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
1919
1920
3
        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
1921
1922
3
        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
1923
1924
3
        assert_eq!(&output, &expected)
1925
3
    }
1926
1927
    #[test]
1928
1
    fn test_take_list() {
1929
1
        test_take_list!(i32, List, ListArray);
1930
1
    }
1931
1932
    #[test]
1933
1
    fn test_take_large_list() {
1934
1
        test_take_list!(i64, LargeList, LargeListArray);
1935
1
    }
1936
1937
    #[test]
1938
1
    fn test_take_list_with_value_nulls() {
1939
1
        test_take_list_with_value_nulls!(i32, List, ListArray);
1940
1
    }
1941
1942
    #[test]
1943
1
    fn test_take_large_list_with_value_nulls() {
1944
1
        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
1945
1
    }
1946
1947
    #[test]
1948
1
    fn test_test_take_list_with_nulls() {
1949
1
        test_take_list_with_nulls!(i32, List, ListArray);
1950
1
    }
1951
1952
    #[test]
1953
1
    fn test_test_take_large_list_with_nulls() {
1954
1
        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
1955
1
    }
1956
1957
    #[test]
1958
1
    fn test_test_take_list_view_reversed() {
1959
        // Take reversed indices
1960
1
        list_view_test_case! {
1961
1
            values: vec![
1962
1
                Some(vec![Some(1), None, Some(3)]),
1963
1
                None,
1964
1
                Some(vec![Some(7), Some(8), None]),
1965
            ],
1966
1
            indices: vec![Some(2), Some(1), Some(0)],
1967
1
            expected: vec![
1968
1
                Some(vec![Some(7), Some(8), None]),
1969
1
                None,
1970
1
                Some(vec![Some(1), None, Some(3)]),
1971
            ]
1972
        }
1973
1
    }
1974
1975
    #[test]
1976
1
    fn test_take_list_view_null_indices() {
1977
        // Take with null indices
1978
1
        list_view_test_case! {
1979
1
            values: vec![
1980
1
                Some(vec![Some(1), None, Some(3)]),
1981
1
                None,
1982
1
                Some(vec![Some(7), Some(8), None]),
1983
            ],
1984
1
            indices: vec![None, Some(0), None],
1985
1
            expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
1986
        }
1987
1
    }
1988
1989
    #[test]
1990
1
    fn test_take_list_view_null_values() {
1991
        // Take at null values
1992
1
        list_view_test_case! {
1993
1
            values: vec![
1994
1
                Some(vec![Some(1), None, Some(3)]),
1995
1
                None,
1996
1
                Some(vec![Some(7), Some(8), None]),
1997
            ],
1998
1
            indices: vec![Some(1), Some(1), Some(1), None, None],
1999
1
            expected: vec![None; 5]
2000
        }
2001
1
    }
2002
2003
    #[test]
2004
1
    fn test_take_list_view_sliced() {
2005
        // Take null indices/values, with slicing.
2006
1
        list_view_test_case! {
2007
1
            values: vec![
2008
1
                Some(vec![Some(1)]),
2009
1
                None,
2010
1
                None,
2011
1
                Some(vec![Some(2), Some(3)]),
2012
1
                Some(vec![Some(4), Some(5)]),
2013
1
                None,
2014
            ],
2015
2
            transform: |l| l.slice(2, 4),
2016
1
            indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2017
1
            expected: vec![
2018
1
                None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2019
            ]
2020
        }
2021
1
    }
2022
2023
    #[test]
2024
1
    fn test_take_fixed_size_list() {
2025
1
        do_take_fixed_size_list_test::<Int32Type>(
2026
            3,
2027
1
            vec![
2028
1
                Some(vec![None, Some(1), Some(2)]),
2029
1
                Some(vec![Some(3), Some(4), None]),
2030
1
                Some(vec![Some(6), Some(7), Some(8)]),
2031
            ],
2032
1
            vec![2, 1, 0],
2033
1
            vec![
2034
1
                Some(vec![Some(6), Some(7), Some(8)]),
2035
1
                Some(vec![Some(3), Some(4), None]),
2036
1
                Some(vec![None, Some(1), Some(2)]),
2037
            ],
2038
        );
2039
2040
1
        do_take_fixed_size_list_test::<UInt8Type>(
2041
            1,
2042
1
            vec![
2043
1
                Some(vec![Some(1)]),
2044
1
                Some(vec![Some(2)]),
2045
1
                Some(vec![Some(3)]),
2046
1
                Some(vec![Some(4)]),
2047
1
                Some(vec![Some(5)]),
2048
1
                Some(vec![Some(6)]),
2049
1
                Some(vec![Some(7)]),
2050
1
                Some(vec![Some(8)]),
2051
            ],
2052
1
            vec![2, 7, 0],
2053
1
            vec![
2054
1
                Some(vec![Some(3)]),
2055
1
                Some(vec![Some(8)]),
2056
1
                Some(vec![Some(1)]),
2057
            ],
2058
        );
2059
2060
1
        do_take_fixed_size_list_test::<UInt64Type>(
2061
            3,
2062
1
            vec![
2063
1
                Some(vec![Some(10), Some(11), Some(12)]),
2064
1
                Some(vec![Some(13), Some(14), Some(15)]),
2065
1
                None,
2066
1
                Some(vec![Some(16), Some(17), Some(18)]),
2067
            ],
2068
1
            vec![3, 2, 1, 2, 0],
2069
1
            vec![
2070
1
                Some(vec![Some(16), Some(17), Some(18)]),
2071
1
                None,
2072
1
                Some(vec![Some(13), Some(14), Some(15)]),
2073
1
                None,
2074
1
                Some(vec![Some(10), Some(11), Some(12)]),
2075
            ],
2076
        );
2077
1
    }
2078
2079
    #[test]
2080
    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2081
1
    fn test_take_list_out_of_bounds() {
2082
        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
2083
1
        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2084
        // Construct offsets
2085
1
        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2086
        // Construct a list array from the above two
2087
1
        let list_data_type =
2088
1
            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2089
1
        let list_data = ArrayData::builder(list_data_type)
2090
1
            .len(3)
2091
1
            .add_buffer(value_offsets)
2092
1
            .add_child_data(value_data)
2093
1
            .build()
2094
1
            .unwrap();
2095
1
        let list_array = ListArray::from(list_data);
2096
2097
1
        let index = UInt32Array::from(vec![1000]);
2098
2099
        // A panic is expected here since we have not supplied the check_bounds
2100
        // option.
2101
1
        take(&list_array, &index, None).unwrap();
2102
1
    }
2103
2104
    #[test]
2105
1
    fn test_take_map() {
2106
1
        let values = Int32Array::from(vec![1, 2, 3, 4]);
2107
1
        let array =
2108
1
            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2109
1
                .unwrap();
2110
2111
1
        let index = UInt32Array::from(vec![0]);
2112
2113
1
        let result = take(&array, &index, None).unwrap();
2114
1
        let expected: ArrayRef = Arc::new(
2115
1
            MapArray::new_from_strings(
2116
1
                vec!["a", "b", "c"].into_iter(),
2117
1
                &values.slice(0, 3),
2118
1
                &[0, 3],
2119
1
            )
2120
1
            .unwrap(),
2121
1
        );
2122
1
        assert_eq!(&expected, &result);
2123
1
    }
2124
2125
    #[test]
2126
1
    fn test_take_struct() {
2127
1
        let array = create_test_struct(vec![
2128
1
            Some((Some(true), Some(42))),
2129
1
            Some((Some(false), Some(28))),
2130
1
            Some((Some(false), Some(19))),
2131
1
            Some((Some(true), Some(31))),
2132
1
            None,
2133
        ]);
2134
2135
1
        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2136
1
        let actual = take(&array, &index, None).unwrap();
2137
1
        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2138
1
        assert_eq!(index.len(), actual.len());
2139
1
        assert_eq!(1, actual.null_count());
2140
2141
1
        let expected = create_test_struct(vec![
2142
1
            Some((Some(true), Some(42))),
2143
1
            Some((Some(true), Some(31))),
2144
1
            Some((Some(false), Some(28))),
2145
1
            Some((Some(true), Some(42))),
2146
1
            Some((Some(false), Some(19))),
2147
1
            None,
2148
        ]);
2149
2150
1
        assert_eq!(&expected, actual);
2151
2152
1
        let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2153
1
        let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2154
1
        let index = UInt32Array::from(vec![0, 2, 1, 4]);
2155
1
        let actual = take(&empty_struct_arr, &index, None).unwrap();
2156
2157
1
        let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2158
1
        let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2159
1
        assert_eq!(&expected_struct_arr, actual.as_struct());
2160
1
    }
2161
2162
    #[test]
2163
1
    fn test_take_struct_with_null_indices() {
2164
1
        let array = create_test_struct(vec![
2165
1
            Some((Some(true), Some(42))),
2166
1
            Some((Some(false), Some(28))),
2167
1
            Some((Some(false), Some(19))),
2168
1
            Some((Some(true), Some(31))),
2169
1
            None,
2170
        ]);
2171
2172
1
        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2173
1
        let actual = take(&array, &index, None).unwrap();
2174
1
        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2175
1
        assert_eq!(index.len(), actual.len());
2176
1
        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
2177
2178
1
        let expected = create_test_struct(vec![
2179
1
            None,
2180
1
            Some((Some(true), Some(31))),
2181
1
            Some((Some(false), Some(28))),
2182
1
            None,
2183
1
            Some((Some(true), Some(42))),
2184
1
            None,
2185
        ]);
2186
2187
1
        assert_eq!(&expected, actual);
2188
1
    }
2189
2190
    #[test]
2191
1
    fn test_take_out_of_bounds() {
2192
1
        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2193
1
        let take_opt = TakeOptions { check_bounds: true };
2194
2195
        // int64
2196
1
        let result = test_take_primitive_arrays::<Int64Type>(
2197
1
            vec![Some(0), None, Some(2), Some(3), None],
2198
1
            &index,
2199
1
            Some(take_opt),
2200
1
            vec![None],
2201
        );
2202
1
        assert!(result.is_err());
2203
1
    }
2204
2205
    #[test]
2206
    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2207
1
    fn test_take_out_of_bounds_panic() {
2208
1
        let index = UInt32Array::from(vec![Some(1000)]);
2209
2210
1
        test_take_primitive_arrays::<Int64Type>(
2211
1
            vec![Some(0), Some(1), Some(2), Some(3)],
2212
1
            &index,
2213
1
            None,
2214
1
            vec![None],
2215
        )
2216
1
        .unwrap();
2217
1
    }
2218
2219
    #[test]
2220
1
    fn test_null_array_smaller_than_indices() {
2221
1
        let values = NullArray::new(2);
2222
1
        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2223
2224
1
        let result = take(&values, &indices, None).unwrap();
2225
1
        let expected: ArrayRef = Arc::new(NullArray::new(3));
2226
1
        assert_eq!(&result, &expected);
2227
1
    }
2228
2229
    #[test]
2230
1
    fn test_null_array_larger_than_indices() {
2231
1
        let values = NullArray::new(5);
2232
1
        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2233
2234
1
        let result = take(&values, &indices, None).unwrap();
2235
1
        let expected: ArrayRef = Arc::new(NullArray::new(3));
2236
1
        assert_eq!(&result, &expected);
2237
1
    }
2238
2239
    #[test]
2240
1
    fn test_null_array_indices_out_of_bounds() {
2241
1
        let values = NullArray::new(5);
2242
1
        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2243
2244
1
        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2245
1
        assert_eq!(
2246
1
            result.unwrap_err().to_string(),
2247
            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2248
        );
2249
1
    }
2250
2251
    #[test]
2252
1
    fn test_take_dict() {
2253
1
        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2254
2255
1
        dict_builder.append("foo").unwrap();
2256
1
        dict_builder.append("bar").unwrap();
2257
1
        dict_builder.append("").unwrap();
2258
1
        dict_builder.append_null();
2259
1
        dict_builder.append("foo").unwrap();
2260
1
        dict_builder.append("bar").unwrap();
2261
1
        dict_builder.append("bar").unwrap();
2262
1
        dict_builder.append("foo").unwrap();
2263
2264
1
        let array = dict_builder.finish();
2265
1
        let dict_values = array.values().clone();
2266
1
        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2267
2268
1
        let indices = UInt32Array::from(vec![
2269
1
            Some(0), // first "foo"
2270
1
            Some(7), // last "foo"
2271
1
            None,    // null index should return null
2272
1
            Some(5), // second "bar"
2273
1
            Some(6), // another "bar"
2274
1
            Some(2), // empty string
2275
1
            Some(3), // input is null at this index
2276
        ]);
2277
2278
1
        let result = take(&array, &indices, None).unwrap();
2279
1
        let result = result
2280
1
            .as_any()
2281
1
            .downcast_ref::<DictionaryArray<Int16Type>>()
2282
1
            .unwrap();
2283
2284
1
        let result_values: StringArray = result.values().to_data().into();
2285
2286
        // dictionary values should stay the same
2287
1
        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2288
1
        assert_eq!(&expected_values, dict_values);
2289
1
        assert_eq!(&expected_values, &result_values);
2290
2291
1
        let expected_keys = Int16Array::from(vec![
2292
1
            Some(0),
2293
1
            Some(0),
2294
1
            None,
2295
1
            Some(1),
2296
1
            Some(1),
2297
1
            Some(2),
2298
1
            None,
2299
        ]);
2300
1
        assert_eq!(result.keys(), &expected_keys);
2301
1
    }
2302
2303
2
    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2304
2
    where
2305
2
        S: OffsetSizeTrait + 'static,
2306
2
        T: ArrowPrimitiveType,
2307
2
        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2308
    {
2309
2
        GenericListArray::from_iter_primitive::<T, _, _>(
2310
2
            data.iter()
2311
20
                .
map2
(|x|
x6
.
as_ref6
().
map6
(|x|
x.iter()6
.
map6
(|x| Some(*x)))),
2312
        )
2313
2
    }
2314
2315
    #[test]
2316
1
    fn test_take_value_index_from_list() {
2317
1
        let list = build_generic_list::<i32, Int32Type>(vec![
2318
1
            Some(vec![0, 1]),
2319
1
            Some(vec![2, 3, 4]),
2320
1
            Some(vec![5, 6, 7, 8, 9]),
2321
        ]);
2322
1
        let indices = UInt32Array::from(vec![2, 0]);
2323
2324
1
        let (indexed, offsets, null_buf) = take_value_indices_from_list(&list, &indices).unwrap();
2325
2326
1
        assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2327
1
        assert_eq!(offsets, vec![0, 5, 7]);
2328
1
        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2329
1
    }
2330
2331
    #[test]
2332
1
    fn test_take_value_index_from_large_list() {
2333
1
        let list = build_generic_list::<i64, Int32Type>(vec![
2334
1
            Some(vec![0, 1]),
2335
1
            Some(vec![2, 3, 4]),
2336
1
            Some(vec![5, 6, 7, 8, 9]),
2337
        ]);
2338
1
        let indices = UInt32Array::from(vec![2, 0]);
2339
2340
1
        let (indexed, offsets, null_buf) =
2341
1
            take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
2342
2343
1
        assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
2344
1
        assert_eq!(offsets, vec![0, 5, 7]);
2345
1
        assert_eq!(null_buf.as_slice(), &[0b11111111]);
2346
1
    }
2347
2348
    #[test]
2349
1
    fn test_take_runs() {
2350
1
        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2351
2352
1
        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2353
1
        builder.extend(logical_array.into_iter().map(Some));
2354
1
        let run_array = builder.finish();
2355
2356
1
        let take_indices: PrimitiveArray<Int32Type> =
2357
1
            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2358
2359
1
        let take_out = take_run(&run_array, &take_indices).unwrap();
2360
2361
1
        assert_eq!(take_out.len(), 7);
2362
1
        assert_eq!(take_out.run_ends().len(), 7);
2363
1
        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2364
2365
1
        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2366
1
        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2367
1
    }
2368
2369
    #[test]
2370
1
    fn test_take_value_index_from_fixed_list() {
2371
1
        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2372
1
            vec![
2373
1
                Some(vec![Some(1), Some(2), None]),
2374
1
                Some(vec![Some(4), None, Some(6)]),
2375
1
                None,
2376
1
                Some(vec![None, Some(8), Some(9)]),
2377
            ],
2378
            3,
2379
        );
2380
2381
1
        let indices = UInt32Array::from(vec![2, 1, 0]);
2382
1
        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2383
2384
1
        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2385
2386
1
        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2387
1
        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2388
2389
1
        assert_eq!(
2390
            indexed,
2391
1
            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2392
        );
2393
1
    }
2394
2395
    #[test]
2396
1
    fn test_take_null_indices() {
2397
        // Build indices with values that are out of bounds, but masked by null mask
2398
1
        let indices = Int32Array::new(
2399
1
            vec![1, 2, 400, 400].into(),
2400
1
            Some(NullBuffer::from(vec![true, true, false, false])),
2401
        );
2402
1
        let values = Int32Array::from(vec![1, 23, 4, 5]);
2403
1
        let r = take(&values, &indices, None).unwrap();
2404
1
        let values = r
2405
1
            .as_primitive::<Int32Type>()
2406
1
            .into_iter()
2407
1
            .collect::<Vec<_>>();
2408
1
        assert_eq!(&values, &[Some(23), Some(4), None, None])
2409
1
    }
2410
2411
    #[test]
2412
1
    fn test_take_fixed_size_list_null_indices() {
2413
1
        let indices = Int32Array::from_iter([Some(0), None]);
2414
1
        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2415
1
        let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2416
1
        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2417
2418
1
        let r = take(&values, &indices, None).unwrap();
2419
1
        let values = r
2420
1
            .as_fixed_size_list()
2421
1
            .values()
2422
1
            .as_primitive::<Int32Type>()
2423
1
            .into_iter()
2424
1
            .collect::<Vec<_>>();
2425
1
        assert_eq!(values, &[Some(0), Some(1), None, None])
2426
1
    }
2427
2428
    #[test]
2429
1
    fn test_take_bytes_null_indices() {
2430
1
        let indices = Int32Array::new(
2431
1
            vec![0, 1, 400, 400].into(),
2432
1
            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2433
        );
2434
1
        let values = StringArray::from(vec![Some("foo"), None]);
2435
1
        let r = take(&values, &indices, None).unwrap();
2436
1
        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2437
1
        assert_eq!(&values, &[Some("foo"), None, None, None])
2438
1
    }
2439
2440
    #[test]
2441
1
    fn test_take_union_sparse() {
2442
1
        let structs = create_test_struct(vec![
2443
1
            Some((Some(true), Some(42))),
2444
1
            Some((Some(false), Some(28))),
2445
1
            Some((Some(false), Some(19))),
2446
1
            Some((Some(true), Some(31))),
2447
1
            None,
2448
        ]);
2449
1
        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2450
1
        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2451
2452
1
        let union_fields = [
2453
1
            (
2454
1
                0,
2455
1
                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2456
1
            ),
2457
1
            (
2458
1
                1,
2459
1
                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2460
1
            ),
2461
1
        ]
2462
1
        .into_iter()
2463
1
        .collect();
2464
1
        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2465
1
        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2466
2467
1
        let indices = vec![0, 3, 1, 0, 2, 4];
2468
1
        let index = UInt32Array::from(indices.clone());
2469
1
        let actual = take(&array, &index, None).unwrap();
2470
1
        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2471
1
        let strings = actual.child(1);
2472
1
        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2473
2474
1
        let actual = strings.iter().collect::<Vec<_>>();
2475
1
        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2476
1
        assert_eq!(expected, actual);
2477
1
    }
2478
2479
    #[test]
2480
1
    fn test_take_union_dense() {
2481
1
        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2482
1
        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2483
1
        let ints = vec![10, 20, 30, 40];
2484
1
        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2485
2486
1
        let indices = vec![0, 3, 1, 0, 2, 4];
2487
2488
1
        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2489
1
        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2490
1
        let taken_ints = vec![10, 20, 10, 30];
2491
1
        let taken_strings = vec![Some("a"), None];
2492
2493
1
        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2494
1
        let offsets = <ScalarBuffer<i32>>::from(offsets);
2495
1
        let ints = UInt32Array::from(ints);
2496
1
        let strings = StringArray::from(strings);
2497
2498
1
        let union_fields = [
2499
1
            (
2500
1
                0,
2501
1
                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2502
1
            ),
2503
1
            (
2504
1
                1,
2505
1
                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2506
1
            ),
2507
1
        ]
2508
1
        .into_iter()
2509
1
        .collect();
2510
2511
1
        let array = UnionArray::try_new(
2512
1
            union_fields,
2513
1
            type_ids,
2514
1
            Some(offsets),
2515
1
            vec![Arc::new(ints), Arc::new(strings)],
2516
        )
2517
1
        .unwrap();
2518
2519
1
        let index = UInt32Array::from(indices);
2520
2521
1
        let actual = take(&array, &index, None).unwrap();
2522
1
        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2523
2524
1
        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2525
1
        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2526
1
        assert_eq!(
2527
1
            UInt32Array::from(actual.child(0).to_data()),
2528
1
            UInt32Array::from(taken_ints)
2529
        );
2530
1
        assert_eq!(
2531
1
            StringArray::from(actual.child(1).to_data()),
2532
1
            StringArray::from(taken_strings)
2533
        );
2534
1
    }
2535
2536
    #[test]
2537
1
    fn test_take_union_dense_using_builder() {
2538
1
        let mut builder = UnionBuilder::new_dense();
2539
2540
1
        builder.append::<Int32Type>("a", 1).unwrap();
2541
1
        builder.append::<Float64Type>("b", 3.0).unwrap();
2542
1
        builder.append::<Int32Type>("a", 4).unwrap();
2543
1
        builder.append::<Int32Type>("a", 5).unwrap();
2544
1
        builder.append::<Float64Type>("b", 2.0).unwrap();
2545
2546
1
        let union = builder.build().unwrap();
2547
2548
1
        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2549
2550
1
        let mut builder = UnionBuilder::new_dense();
2551
2552
1
        builder.append::<Int32Type>("a", 4).unwrap();
2553
1
        builder.append::<Int32Type>("a", 1).unwrap();
2554
1
        builder.append::<Float64Type>("b", 3.0).unwrap();
2555
1
        builder.append::<Int32Type>("a", 4).unwrap();
2556
2557
1
        let taken = builder.build().unwrap();
2558
2559
1
        assert_eq!(
2560
1
            taken.to_data(),
2561
1
            take(&union, &indices, None).unwrap().to_data()
2562
        );
2563
1
    }
2564
2565
    #[test]
2566
1
    fn test_take_union_dense_all_match_issue_6206() {
2567
1
        let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2568
1
        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2569
2570
1
        let array = UnionArray::try_new(
2571
1
            fields,
2572
1
            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2573
1
            Some(ScalarBuffer::from_iter(0_i32..5)),
2574
1
            vec![ints],
2575
        )
2576
1
        .unwrap();
2577
2578
1
        let indicies = Int64Array::from(vec![0, 2, 4]);
2579
1
        let array = take(&array, &indicies, None).unwrap();
2580
1
        assert_eq!(array.len(), 3);
2581
1
    }
2582
2583
    #[test]
2584
1
    fn test_take_bytes_offset_overflow() {
2585
1
        let indices = Int32Array::from(vec![0; (i32::MAX >> 4) as usize]);
2586
1
        let text = ('a'..='z').collect::<String>();
2587
1
        let values = StringArray::from(vec![Some(text.clone())]);
2588
1
        assert!(
matches!0
(
2589
1
            take(&values, &indices, None),
2590
            Err(ArrowError::OffsetOverflowError(_))
2591
        ));
2592
1
    }
2593
}