Coverage Report

Created: 2025-11-17 14:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-select/src/merge.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! [`merge`] and [`merge_n`]: Combine values from two or more arrays
19
20
use crate::filter::{SlicesIterator, prep_null_mask_filter};
21
use crate::zip::zip;
22
use arrow_array::{Array, ArrayRef, BooleanArray, Datum, make_array, new_empty_array};
23
use arrow_data::ArrayData;
24
use arrow_data::transform::MutableArrayData;
25
use arrow_schema::ArrowError;
26
27
/// An index for the [merge_n] function.
28
///
29
/// This trait allows the indices argument for [merge_n] to be stored using a more
30
/// compact representation than `usize` when the input arrays are small.
31
/// If the number of input arrays is less than 256 for instance, the indices can be stored as `u8`.
32
///
33
/// Implementation must ensure that all values which return `None` from [MergeIndex::index] are
34
/// considered equal by the [PartialEq] and [Eq] implementations.
35
pub trait MergeIndex: PartialEq + Eq + Copy {
36
    /// Returns the index value as an `Option<usize>`.
37
    ///
38
    /// `None` values returned by this function indicate holes in the index array and will result
39
    /// in null values in the array created by [merge].
40
    fn index(&self) -> Option<usize>;
41
}
42
43
impl MergeIndex for usize {
44
0
    fn index(&self) -> Option<usize> {
45
0
        Some(*self)
46
0
    }
47
}
48
49
impl MergeIndex for Option<usize> {
50
0
    fn index(&self) -> Option<usize> {
51
0
        *self
52
0
    }
53
}
54
55
/// Merges elements by index from a list of [`Array`], creating a new [`Array`] from
56
/// those values.
57
///
58
/// Each element in `indices` is the index of an array in `values`. The `indices` array is processed
59
/// sequentially. The first occurrence of index value `n` will be mapped to the first
60
/// value of the array at index `n`. The second occurrence to the second value, and so on.
61
/// An index value where `MergeIndex::index` returns `None` is interpreted as a null value.
62
///
63
/// # Implementation notes
64
///
65
/// This algorithm is similar in nature to both [zip] and
66
/// [interleave](crate::interleave::interleave), but there are some important differences.
67
///
68
/// In contrast to [zip], this function supports multiple input arrays. Instead of
69
/// a boolean selection vector, an index array is to take values from the input arrays, and a special
70
/// marker values can be used to indicate null values.
71
///
72
/// In contrast to [interleave](crate::interleave::interleave), this function does not use pairs of
73
/// indices. The values in `indices` serve the same purpose as the first value in the pairs passed
74
/// to `interleave`.
75
/// The index in the array is implicit and is derived from the number of times a particular array
76
/// index occurs.
77
/// The more constrained indexing mechanism used by this algorithm makes it easier to copy values
78
/// in contiguous slices. In the example below, the two subsequent elements from array `2` can be
79
/// copied in a single operation from the source array instead of copying them one by one.
80
/// Long spans of null values are also especially cheap because they do not need to be represented
81
/// in an input array.
82
///
83
/// # Panics
84
///
85
/// This function does not check that the number of occurrences of any particular array index matches
86
/// the length of the corresponding input array. If an array contains more values than required, the
87
/// spurious values will be ignored. If an array contains fewer values than necessary, this function
88
/// will panic.
89
///
90
/// # Example
91
///
92
/// ```text
93
/// ┌───────────┐  ┌─────────┐                             ┌─────────┐
94
/// │┌─────────┐│  │   None  │                             │   NULL  │
95
/// ││    A    ││  ├─────────┤                             ├─────────┤
96
/// │└─────────┘│  │    1    │                             │    B    │
97
/// │┌─────────┐│  ├─────────┤                             ├─────────┤
98
/// ││    B    ││  │    0    │    merge(values, indices)   │    A    │
99
/// │└─────────┘│  ├─────────┤  ─────────────────────────▶ ├─────────┤
100
/// │┌─────────┐│  │   None  │                             │   NULL  │
101
/// ││    C    ││  ├─────────┤                             ├─────────┤
102
/// │├─────────┤│  │    2    │                             │    C    │
103
/// ││    D    ││  ├─────────┤                             ├─────────┤
104
/// │└─────────┘│  │    2    │                             │    D    │
105
/// └───────────┘  └─────────┘                             └─────────┘
106
///    values        indices                                  result
107
///
108
/// ```
109
3
pub fn merge_n(values: &[&dyn Array], indices: &[impl MergeIndex]) -> Result<ArrayRef, ArrowError> {
110
3
    if values.is_empty() {
111
1
        return Err(ArrowError::InvalidArgumentError(
112
1
            "merge_n requires at least one value array".to_string(),
113
1
        ));
114
2
    }
115
116
2
    let data_type = values[0].data_type();
117
118
4
    for array in 
values2
.
iter2
().
skip2
(1) {
119
4
        if array.data_type() != data_type {
120
0
            return Err(ArrowError::InvalidArgumentError(format!(
121
0
                "It is not possible to merge arrays of different data types ({} and {})",
122
0
                data_type,
123
0
                array.data_type()
124
0
            )));
125
4
        }
126
    }
127
128
2
    if indices.is_empty() {
129
1
        return Ok(new_empty_array(data_type));
130
1
    }
131
132
    #[cfg(debug_assertions)]
133
9
    for 
ix8
in indices {
134
8
        if let Some(
index6
) = ix.index() {
135
6
            assert!(
136
6
                index < values.len(),
137
0
                "Index out of bounds: {} >= {}",
138
                index,
139
0
                values.len()
140
            );
141
2
        }
142
    }
143
144
3
    let 
data1
:
Vec<ArrayData>1
=
values1
.
iter1
().
map1
(|a| a.to_data()).
collect1
();
145
1
    let data_refs = data.iter().collect();
146
147
1
    let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
148
149
    // This loop extends the mutable array by taking slices from the partial results.
150
    //
151
    // take_offsets keeps track of how many values have been taken from each array.
152
1
    let mut take_offsets = vec![0; values.len() + 1];
153
1
    let mut start_row_ix = 0;
154
    loop {
155
6
        let array_ix = indices[start_row_ix];
156
157
        // Determine the length of the slice to take.
158
6
        let mut end_row_ix = start_row_ix + 1;
159
8
        while end_row_ix < indices.len() && 
indices[end_row_ix] == array_ix7
{
160
2
            end_row_ix += 1;
161
2
        }
162
6
        let slice_length = end_row_ix - start_row_ix;
163
164
        // Extend mutable with either nulls or with values from the array.
165
6
        match array_ix.index() {
166
2
            None => mutable.extend_nulls(slice_length),
167
4
            Some(index) => {
168
4
                let start_offset = take_offsets[index];
169
4
                let end_offset = start_offset + slice_length;
170
4
                mutable.extend(index, start_offset, end_offset);
171
4
                take_offsets[index] = end_offset;
172
4
            }
173
        }
174
175
6
        if end_row_ix == indices.len() {
176
1
            break;
177
5
        } else {
178
5
            // Set the start_row_ix for the next slice.
179
5
            start_row_ix = end_row_ix;
180
5
        }
181
    }
182
183
1
    Ok(make_array(mutable.freeze()))
184
3
}
185
186
/// Merges two arrays in the order specified by a boolean mask.
187
///
188
/// This algorithm is a variant of [zip] that does not require the truthy and
189
/// falsy arrays to have the same length.
190
///
191
/// When truthy of falsy are [Scalar](arrow_array::Scalar), the single
192
/// scalar value is repeated whenever the mask array contains true or false respectively.
193
///
194
/// # Example
195
///
196
/// ```text
197
///  truthy
198
/// ┌─────────┐  mask
199
/// │    A    │  ┌─────────┐                             ┌─────────┐
200
/// ├─────────┤  │  true   │                             │    A    │
201
/// │    C    │  ├─────────┤                             ├─────────┤
202
/// ├─────────┤  │  true   │                             │    C    │
203
/// │   NULL  │  ├─────────┤                             ├─────────┤
204
/// ├─────────┤  │  false  │  merge(mask, truthy, falsy) │    B    │
205
/// │    D    │  ├─────────┤  ─────────────────────────▶ ├─────────┤
206
/// └─────────┘  │  true   │                             │   NULL  │
207
///  falsy       ├─────────┤                             ├─────────┤
208
/// ┌─────────┐  │  false  │                             │    E    │
209
/// │    B    │  ├─────────┤                             ├─────────┤
210
/// ├─────────┤  │  true   │                             │    D    │
211
/// │    E    │  └─────────┘                             └─────────┘
212
/// └─────────┘
213
/// ```
214
2
pub fn merge(
215
2
    mask: &BooleanArray,
216
2
    truthy: &dyn Datum,
217
2
    falsy: &dyn Datum,
218
2
) -> Result<ArrayRef, ArrowError> {
219
2
    let (truthy_array, truthy_is_scalar) = truthy.get();
220
2
    let (falsy_array, falsy_is_scalar) = falsy.get();
221
222
2
    if truthy_is_scalar && 
falsy_is_scalar0
{
223
        // When both truthy and falsy are scalars, we can use `zip` since the result is the same
224
        // and zip has optimized code for scalars.
225
0
        return zip(mask, truthy, falsy);
226
2
    }
227
228
2
    if truthy_array.data_type() != falsy_array.data_type() {
229
0
        return Err(ArrowError::InvalidArgumentError(
230
0
            "arguments need to have the same data type".into(),
231
0
        ));
232
2
    }
233
234
2
    if truthy_is_scalar && 
truthy_array.len() != 10
{
235
0
        return Err(ArrowError::InvalidArgumentError(
236
0
            "scalar arrays must have 1 element".into(),
237
0
        ));
238
2
    }
239
2
    if falsy_is_scalar && 
falsy_array.len() != 10
{
240
0
        return Err(ArrowError::InvalidArgumentError(
241
0
            "scalar arrays must have 1 element".into(),
242
0
        ));
243
2
    }
244
245
2
    let falsy = falsy_array.to_data();
246
2
    let truthy = truthy_array.to_data();
247
248
2
    let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, mask.len());
249
250
    // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to
251
    // fill with falsy values
252
253
    // keep track of how much is filled
254
2
    let mut filled = 0;
255
2
    let mut falsy_offset = 0;
256
2
    let mut truthy_offset = 0;
257
258
    // Ensure nulls are treated as false
259
2
    let mask_buffer = match mask.null_count() {
260
2
        0 => mask.values().clone(),
261
0
        _ => prep_null_mask_filter(mask).into_parts().0,
262
    };
263
264
3
    
SlicesIterator::from2
(
&mask_buffer2
).
for_each2
(|(start, end)| {
265
        // the gap needs to be filled with falsy values
266
3
        if start > filled {
267
2
            if falsy_is_scalar {
268
0
                for _ in filled..start {
269
0
                    // Copy the first item from the 'falsy' array into the output buffer.
270
0
                    mutable.extend(1, 0, 1);
271
0
                }
272
2
            } else {
273
2
                let falsy_length = start - filled;
274
2
                let falsy_end = falsy_offset + falsy_length;
275
2
                mutable.extend(1, falsy_offset, falsy_end);
276
2
                falsy_offset = falsy_end;
277
2
            }
278
1
        }
279
        // fill with truthy values
280
3
        if truthy_is_scalar {
281
0
            for _ in start..end {
282
0
                // Copy the first item from the 'truthy' array into the output buffer.
283
0
                mutable.extend(0, 0, 1);
284
0
            }
285
3
        } else {
286
3
            let truthy_length = end - start;
287
3
            let truthy_end = truthy_offset + truthy_length;
288
3
            mutable.extend(0, truthy_offset, truthy_end);
289
3
            truthy_offset = truthy_end;
290
3
        }
291
3
        filled = end;
292
3
    });
293
    // the remaining part is falsy
294
2
    if filled < mask.len() {
295
0
        if falsy_is_scalar {
296
0
            for _ in filled..mask.len() {
297
0
                // Copy the first item from the 'falsy' array into the output buffer.
298
0
                mutable.extend(1, 0, 1);
299
0
            }
300
0
        } else {
301
0
            let falsy_length = mask.len() - filled;
302
0
            let falsy_end = falsy_offset + falsy_length;
303
0
            mutable.extend(1, falsy_offset, falsy_end);
304
0
        }
305
2
    }
306
307
2
    let data = mutable.freeze();
308
2
    Ok(make_array(data))
309
2
}
310
311
#[cfg(test)]
312
mod tests {
313
    use crate::merge::{MergeIndex, merge, merge_n};
314
    use arrow_array::cast::AsArray;
315
    use arrow_array::{Array, BooleanArray, StringArray};
316
    use arrow_schema::ArrowError::InvalidArgumentError;
317
318
    #[derive(PartialEq, Eq, Copy, Clone)]
319
    struct CompactMergeIndex {
320
        index: u8,
321
    }
322
323
    impl MergeIndex for CompactMergeIndex {
324
14
        fn index(&self) -> Option<usize> {
325
14
            if self.index == u8::MAX {
326
4
                None
327
            } else {
328
10
                Some(self.index as usize)
329
            }
330
14
        }
331
    }
332
333
    #[test]
334
1
    fn test_merge() {
335
1
        let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]);
336
1
        let a2 = StringArray::from(vec![Some("C"), Some("D")]);
337
338
1
        let indices = BooleanArray::from(vec![true, false, true, false, true, true]);
339
340
1
        let merged = merge(&indices, &a1, &a2).unwrap();
341
1
        let merged = merged.as_string::<i32>();
342
343
1
        assert_eq!(merged.len(), indices.len());
344
1
        assert!(merged.is_valid(0));
345
1
        assert_eq!(merged.value(0), "A");
346
1
        assert!(merged.is_valid(1));
347
1
        assert_eq!(merged.value(1), "C");
348
1
        assert!(merged.is_valid(2));
349
1
        assert_eq!(merged.value(2), "B");
350
1
        assert!(merged.is_valid(3));
351
1
        assert_eq!(merged.value(3), "D");
352
1
        assert!(merged.is_valid(4));
353
1
        assert_eq!(merged.value(4), "E");
354
1
        assert!(!merged.is_valid(5));
355
1
    }
356
    #[test]
357
1
    fn test_merge_empty_mask() {
358
1
        let a1 = StringArray::from(vec![Some("A")]);
359
1
        let a2 = StringArray::from(vec![Some("B")]);
360
1
        let mask: Vec<bool> = vec![];
361
1
        let mask = BooleanArray::from(mask);
362
1
        let result = merge(&mask, &a1, &a2).unwrap();
363
1
        assert_eq!(result.len(), 0);
364
1
    }
365
366
    #[test]
367
1
    fn test_merge_n() {
368
1
        let a1 = StringArray::from(vec![Some("A")]);
369
1
        let a2 = StringArray::from(vec![Some("B"), None, None]);
370
1
        let a3 = StringArray::from(vec![Some("C"), Some("D")]);
371
372
1
        let indices = vec![
373
1
            CompactMergeIndex { index: u8::MAX },
374
1
            CompactMergeIndex { index: 1 },
375
1
            CompactMergeIndex { index: 0 },
376
1
            CompactMergeIndex { index: u8::MAX },
377
1
            CompactMergeIndex { index: 2 },
378
1
            CompactMergeIndex { index: 2 },
379
1
            CompactMergeIndex { index: 1 },
380
1
            CompactMergeIndex { index: 1 },
381
        ];
382
383
1
        let arrays = [a1, a2, a3];
384
3
        let 
array_refs1
=
arrays1
.
iter1
().
map1
(|a| a as &dyn Array).
collect1
::<Vec<_>>();
385
1
        let merged = merge_n(&array_refs, &indices).unwrap();
386
1
        let merged = merged.as_string::<i32>();
387
388
1
        assert_eq!(merged.len(), indices.len());
389
1
        assert!(!merged.is_valid(0));
390
1
        assert!(merged.is_valid(1));
391
1
        assert_eq!(merged.value(1), "B");
392
1
        assert!(merged.is_valid(2));
393
1
        assert_eq!(merged.value(2), "A");
394
1
        assert!(!merged.is_valid(3));
395
1
        assert!(merged.is_valid(4));
396
1
        assert_eq!(merged.value(4), "C");
397
1
        assert!(merged.is_valid(5));
398
1
        assert_eq!(merged.value(5), "D");
399
1
        assert!(!merged.is_valid(6));
400
1
        assert!(!merged.is_valid(7));
401
1
    }
402
403
    #[test]
404
1
    fn test_merge_n_empty_indices() {
405
1
        let a1 = StringArray::from(vec![Some("A")]);
406
1
        let a2 = StringArray::from(vec![Some("B"), None, None]);
407
1
        let a3 = StringArray::from(vec![Some("C"), Some("D")]);
408
409
1
        let indices: Vec<CompactMergeIndex> = vec![];
410
411
1
        let arrays = [a1, a2, a3];
412
3
        let 
array_refs1
=
arrays1
.
iter1
().
map1
(|a| a as &dyn Array).
collect1
::<Vec<_>>();
413
1
        let merged = merge_n(&array_refs, &indices).unwrap();
414
415
1
        assert_eq!(merged.len(), indices.len());
416
1
    }
417
418
    #[test]
419
1
    fn test_merge_n_empty_values() {
420
1
        let indices: Vec<CompactMergeIndex> = vec![];
421
422
1
        let arrays: Vec<&dyn Array> = vec![];
423
1
        let merged = merge_n(&arrays, &indices);
424
425
1
        assert!(
matches!0
(merged, Err(InvalidArgumentError { .. })));
426
1
    }
427
}