/Users/andrewlamb/Software/arrow-rs/arrow-select/src/merge.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`merge`] and [`merge_n`]: Combine values from two or more arrays |
19 | | |
20 | | use crate::filter::{SlicesIterator, prep_null_mask_filter}; |
21 | | use crate::zip::zip; |
22 | | use arrow_array::{Array, ArrayRef, BooleanArray, Datum, make_array, new_empty_array}; |
23 | | use arrow_data::ArrayData; |
24 | | use arrow_data::transform::MutableArrayData; |
25 | | use arrow_schema::ArrowError; |
26 | | |
27 | | /// An index for the [merge_n] function. |
28 | | /// |
29 | | /// This trait allows the indices argument for [merge_n] to be stored using a more |
30 | | /// compact representation than `usize` when the input arrays are small. |
31 | | /// If the number of input arrays is less than 256 for instance, the indices can be stored as `u8`. |
32 | | /// |
33 | | /// Implementation must ensure that all values which return `None` from [MergeIndex::index] are |
34 | | /// considered equal by the [PartialEq] and [Eq] implementations. |
35 | | pub trait MergeIndex: PartialEq + Eq + Copy { |
36 | | /// Returns the index value as an `Option<usize>`. |
37 | | /// |
38 | | /// `None` values returned by this function indicate holes in the index array and will result |
39 | | /// in null values in the array created by [merge]. |
40 | | fn index(&self) -> Option<usize>; |
41 | | } |
42 | | |
43 | | impl MergeIndex for usize { |
44 | 0 | fn index(&self) -> Option<usize> { |
45 | 0 | Some(*self) |
46 | 0 | } |
47 | | } |
48 | | |
49 | | impl MergeIndex for Option<usize> { |
50 | 0 | fn index(&self) -> Option<usize> { |
51 | 0 | *self |
52 | 0 | } |
53 | | } |
54 | | |
55 | | /// Merges elements by index from a list of [`Array`], creating a new [`Array`] from |
56 | | /// those values. |
57 | | /// |
58 | | /// Each element in `indices` is the index of an array in `values`. The `indices` array is processed |
59 | | /// sequentially. The first occurrence of index value `n` will be mapped to the first |
60 | | /// value of the array at index `n`. The second occurrence to the second value, and so on. |
61 | | /// An index value where `MergeIndex::index` returns `None` is interpreted as a null value. |
62 | | /// |
63 | | /// # Implementation notes |
64 | | /// |
65 | | /// This algorithm is similar in nature to both [zip] and |
66 | | /// [interleave](crate::interleave::interleave), but there are some important differences. |
67 | | /// |
68 | | /// In contrast to [zip], this function supports multiple input arrays. Instead of |
69 | | /// a boolean selection vector, an index array is to take values from the input arrays, and a special |
70 | | /// marker values can be used to indicate null values. |
71 | | /// |
72 | | /// In contrast to [interleave](crate::interleave::interleave), this function does not use pairs of |
73 | | /// indices. The values in `indices` serve the same purpose as the first value in the pairs passed |
74 | | /// to `interleave`. |
75 | | /// The index in the array is implicit and is derived from the number of times a particular array |
76 | | /// index occurs. |
77 | | /// The more constrained indexing mechanism used by this algorithm makes it easier to copy values |
78 | | /// in contiguous slices. In the example below, the two subsequent elements from array `2` can be |
79 | | /// copied in a single operation from the source array instead of copying them one by one. |
80 | | /// Long spans of null values are also especially cheap because they do not need to be represented |
81 | | /// in an input array. |
82 | | /// |
83 | | /// # Panics |
84 | | /// |
85 | | /// This function does not check that the number of occurrences of any particular array index matches |
86 | | /// the length of the corresponding input array. If an array contains more values than required, the |
87 | | /// spurious values will be ignored. If an array contains fewer values than necessary, this function |
88 | | /// will panic. |
89 | | /// |
90 | | /// # Example |
91 | | /// |
92 | | /// ```text |
93 | | /// ┌───────────┐ ┌─────────┐ ┌─────────┐ |
94 | | /// │┌─────────┐│ │ None │ │ NULL │ |
95 | | /// ││ A ││ ├─────────┤ ├─────────┤ |
96 | | /// │└─────────┘│ │ 1 │ │ B │ |
97 | | /// │┌─────────┐│ ├─────────┤ ├─────────┤ |
98 | | /// ││ B ││ │ 0 │ merge(values, indices) │ A │ |
99 | | /// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤ |
100 | | /// │┌─────────┐│ │ None │ │ NULL │ |
101 | | /// ││ C ││ ├─────────┤ ├─────────┤ |
102 | | /// │├─────────┤│ │ 2 │ │ C │ |
103 | | /// ││ D ││ ├─────────┤ ├─────────┤ |
104 | | /// │└─────────┘│ │ 2 │ │ D │ |
105 | | /// └───────────┘ └─────────┘ └─────────┘ |
106 | | /// values indices result |
107 | | /// |
108 | | /// ``` |
109 | 3 | pub fn merge_n(values: &[&dyn Array], indices: &[impl MergeIndex]) -> Result<ArrayRef, ArrowError> { |
110 | 3 | if values.is_empty() { |
111 | 1 | return Err(ArrowError::InvalidArgumentError( |
112 | 1 | "merge_n requires at least one value array".to_string(), |
113 | 1 | )); |
114 | 2 | } |
115 | | |
116 | 2 | let data_type = values[0].data_type(); |
117 | | |
118 | 4 | for array in values2 .iter2 ().skip2 (1) { |
119 | 4 | if array.data_type() != data_type { |
120 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
121 | 0 | "It is not possible to merge arrays of different data types ({} and {})", |
122 | 0 | data_type, |
123 | 0 | array.data_type() |
124 | 0 | ))); |
125 | 4 | } |
126 | | } |
127 | | |
128 | 2 | if indices.is_empty() { |
129 | 1 | return Ok(new_empty_array(data_type)); |
130 | 1 | } |
131 | | |
132 | | #[cfg(debug_assertions)] |
133 | 9 | for ix8 in indices { |
134 | 8 | if let Some(index6 ) = ix.index() { |
135 | 6 | assert!( |
136 | 6 | index < values.len(), |
137 | 0 | "Index out of bounds: {} >= {}", |
138 | | index, |
139 | 0 | values.len() |
140 | | ); |
141 | 2 | } |
142 | | } |
143 | | |
144 | 3 | let data1 : Vec<ArrayData>1 = values1 .iter1 ().map1 (|a| a.to_data()).collect1 (); |
145 | 1 | let data_refs = data.iter().collect(); |
146 | | |
147 | 1 | let mut mutable = MutableArrayData::new(data_refs, true, indices.len()); |
148 | | |
149 | | // This loop extends the mutable array by taking slices from the partial results. |
150 | | // |
151 | | // take_offsets keeps track of how many values have been taken from each array. |
152 | 1 | let mut take_offsets = vec![0; values.len() + 1]; |
153 | 1 | let mut start_row_ix = 0; |
154 | | loop { |
155 | 6 | let array_ix = indices[start_row_ix]; |
156 | | |
157 | | // Determine the length of the slice to take. |
158 | 6 | let mut end_row_ix = start_row_ix + 1; |
159 | 8 | while end_row_ix < indices.len() && indices[end_row_ix] == array_ix7 { |
160 | 2 | end_row_ix += 1; |
161 | 2 | } |
162 | 6 | let slice_length = end_row_ix - start_row_ix; |
163 | | |
164 | | // Extend mutable with either nulls or with values from the array. |
165 | 6 | match array_ix.index() { |
166 | 2 | None => mutable.extend_nulls(slice_length), |
167 | 4 | Some(index) => { |
168 | 4 | let start_offset = take_offsets[index]; |
169 | 4 | let end_offset = start_offset + slice_length; |
170 | 4 | mutable.extend(index, start_offset, end_offset); |
171 | 4 | take_offsets[index] = end_offset; |
172 | 4 | } |
173 | | } |
174 | | |
175 | 6 | if end_row_ix == indices.len() { |
176 | 1 | break; |
177 | 5 | } else { |
178 | 5 | // Set the start_row_ix for the next slice. |
179 | 5 | start_row_ix = end_row_ix; |
180 | 5 | } |
181 | | } |
182 | | |
183 | 1 | Ok(make_array(mutable.freeze())) |
184 | 3 | } |
185 | | |
186 | | /// Merges two arrays in the order specified by a boolean mask. |
187 | | /// |
188 | | /// This algorithm is a variant of [zip] that does not require the truthy and |
189 | | /// falsy arrays to have the same length. |
190 | | /// |
191 | | /// When truthy of falsy are [Scalar](arrow_array::Scalar), the single |
192 | | /// scalar value is repeated whenever the mask array contains true or false respectively. |
193 | | /// |
194 | | /// # Example |
195 | | /// |
196 | | /// ```text |
197 | | /// truthy |
198 | | /// ┌─────────┐ mask |
199 | | /// │ A │ ┌─────────┐ ┌─────────┐ |
200 | | /// ├─────────┤ │ true │ │ A │ |
201 | | /// │ C │ ├─────────┤ ├─────────┤ |
202 | | /// ├─────────┤ │ true │ │ C │ |
203 | | /// │ NULL │ ├─────────┤ ├─────────┤ |
204 | | /// ├─────────┤ │ false │ merge(mask, truthy, falsy) │ B │ |
205 | | /// │ D │ ├─────────┤ ─────────────────────────▶ ├─────────┤ |
206 | | /// └─────────┘ │ true │ │ NULL │ |
207 | | /// falsy ├─────────┤ ├─────────┤ |
208 | | /// ┌─────────┐ │ false │ │ E │ |
209 | | /// │ B │ ├─────────┤ ├─────────┤ |
210 | | /// ├─────────┤ │ true │ │ D │ |
211 | | /// │ E │ └─────────┘ └─────────┘ |
212 | | /// └─────────┘ |
213 | | /// ``` |
214 | 2 | pub fn merge( |
215 | 2 | mask: &BooleanArray, |
216 | 2 | truthy: &dyn Datum, |
217 | 2 | falsy: &dyn Datum, |
218 | 2 | ) -> Result<ArrayRef, ArrowError> { |
219 | 2 | let (truthy_array, truthy_is_scalar) = truthy.get(); |
220 | 2 | let (falsy_array, falsy_is_scalar) = falsy.get(); |
221 | | |
222 | 2 | if truthy_is_scalar && falsy_is_scalar0 { |
223 | | // When both truthy and falsy are scalars, we can use `zip` since the result is the same |
224 | | // and zip has optimized code for scalars. |
225 | 0 | return zip(mask, truthy, falsy); |
226 | 2 | } |
227 | | |
228 | 2 | if truthy_array.data_type() != falsy_array.data_type() { |
229 | 0 | return Err(ArrowError::InvalidArgumentError( |
230 | 0 | "arguments need to have the same data type".into(), |
231 | 0 | )); |
232 | 2 | } |
233 | | |
234 | 2 | if truthy_is_scalar && truthy_array.len() != 10 { |
235 | 0 | return Err(ArrowError::InvalidArgumentError( |
236 | 0 | "scalar arrays must have 1 element".into(), |
237 | 0 | )); |
238 | 2 | } |
239 | 2 | if falsy_is_scalar && falsy_array.len() != 10 { |
240 | 0 | return Err(ArrowError::InvalidArgumentError( |
241 | 0 | "scalar arrays must have 1 element".into(), |
242 | 0 | )); |
243 | 2 | } |
244 | | |
245 | 2 | let falsy = falsy_array.to_data(); |
246 | 2 | let truthy = truthy_array.to_data(); |
247 | | |
248 | 2 | let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, mask.len()); |
249 | | |
250 | | // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to |
251 | | // fill with falsy values |
252 | | |
253 | | // keep track of how much is filled |
254 | 2 | let mut filled = 0; |
255 | 2 | let mut falsy_offset = 0; |
256 | 2 | let mut truthy_offset = 0; |
257 | | |
258 | | // Ensure nulls are treated as false |
259 | 2 | let mask_buffer = match mask.null_count() { |
260 | 2 | 0 => mask.values().clone(), |
261 | 0 | _ => prep_null_mask_filter(mask).into_parts().0, |
262 | | }; |
263 | | |
264 | 3 | SlicesIterator::from2 (&mask_buffer2 ).for_each2 (|(start, end)| { |
265 | | // the gap needs to be filled with falsy values |
266 | 3 | if start > filled { |
267 | 2 | if falsy_is_scalar { |
268 | 0 | for _ in filled..start { |
269 | 0 | // Copy the first item from the 'falsy' array into the output buffer. |
270 | 0 | mutable.extend(1, 0, 1); |
271 | 0 | } |
272 | 2 | } else { |
273 | 2 | let falsy_length = start - filled; |
274 | 2 | let falsy_end = falsy_offset + falsy_length; |
275 | 2 | mutable.extend(1, falsy_offset, falsy_end); |
276 | 2 | falsy_offset = falsy_end; |
277 | 2 | } |
278 | 1 | } |
279 | | // fill with truthy values |
280 | 3 | if truthy_is_scalar { |
281 | 0 | for _ in start..end { |
282 | 0 | // Copy the first item from the 'truthy' array into the output buffer. |
283 | 0 | mutable.extend(0, 0, 1); |
284 | 0 | } |
285 | 3 | } else { |
286 | 3 | let truthy_length = end - start; |
287 | 3 | let truthy_end = truthy_offset + truthy_length; |
288 | 3 | mutable.extend(0, truthy_offset, truthy_end); |
289 | 3 | truthy_offset = truthy_end; |
290 | 3 | } |
291 | 3 | filled = end; |
292 | 3 | }); |
293 | | // the remaining part is falsy |
294 | 2 | if filled < mask.len() { |
295 | 0 | if falsy_is_scalar { |
296 | 0 | for _ in filled..mask.len() { |
297 | 0 | // Copy the first item from the 'falsy' array into the output buffer. |
298 | 0 | mutable.extend(1, 0, 1); |
299 | 0 | } |
300 | 0 | } else { |
301 | 0 | let falsy_length = mask.len() - filled; |
302 | 0 | let falsy_end = falsy_offset + falsy_length; |
303 | 0 | mutable.extend(1, falsy_offset, falsy_end); |
304 | 0 | } |
305 | 2 | } |
306 | | |
307 | 2 | let data = mutable.freeze(); |
308 | 2 | Ok(make_array(data)) |
309 | 2 | } |
310 | | |
311 | | #[cfg(test)] |
312 | | mod tests { |
313 | | use crate::merge::{MergeIndex, merge, merge_n}; |
314 | | use arrow_array::cast::AsArray; |
315 | | use arrow_array::{Array, BooleanArray, StringArray}; |
316 | | use arrow_schema::ArrowError::InvalidArgumentError; |
317 | | |
318 | | #[derive(PartialEq, Eq, Copy, Clone)] |
319 | | struct CompactMergeIndex { |
320 | | index: u8, |
321 | | } |
322 | | |
323 | | impl MergeIndex for CompactMergeIndex { |
324 | 14 | fn index(&self) -> Option<usize> { |
325 | 14 | if self.index == u8::MAX { |
326 | 4 | None |
327 | | } else { |
328 | 10 | Some(self.index as usize) |
329 | | } |
330 | 14 | } |
331 | | } |
332 | | |
333 | | #[test] |
334 | 1 | fn test_merge() { |
335 | 1 | let a1 = StringArray::from(vec![Some("A"), Some("B"), Some("E"), None]); |
336 | 1 | let a2 = StringArray::from(vec![Some("C"), Some("D")]); |
337 | | |
338 | 1 | let indices = BooleanArray::from(vec![true, false, true, false, true, true]); |
339 | | |
340 | 1 | let merged = merge(&indices, &a1, &a2).unwrap(); |
341 | 1 | let merged = merged.as_string::<i32>(); |
342 | | |
343 | 1 | assert_eq!(merged.len(), indices.len()); |
344 | 1 | assert!(merged.is_valid(0)); |
345 | 1 | assert_eq!(merged.value(0), "A"); |
346 | 1 | assert!(merged.is_valid(1)); |
347 | 1 | assert_eq!(merged.value(1), "C"); |
348 | 1 | assert!(merged.is_valid(2)); |
349 | 1 | assert_eq!(merged.value(2), "B"); |
350 | 1 | assert!(merged.is_valid(3)); |
351 | 1 | assert_eq!(merged.value(3), "D"); |
352 | 1 | assert!(merged.is_valid(4)); |
353 | 1 | assert_eq!(merged.value(4), "E"); |
354 | 1 | assert!(!merged.is_valid(5)); |
355 | 1 | } |
356 | | #[test] |
357 | 1 | fn test_merge_empty_mask() { |
358 | 1 | let a1 = StringArray::from(vec![Some("A")]); |
359 | 1 | let a2 = StringArray::from(vec![Some("B")]); |
360 | 1 | let mask: Vec<bool> = vec![]; |
361 | 1 | let mask = BooleanArray::from(mask); |
362 | 1 | let result = merge(&mask, &a1, &a2).unwrap(); |
363 | 1 | assert_eq!(result.len(), 0); |
364 | 1 | } |
365 | | |
366 | | #[test] |
367 | 1 | fn test_merge_n() { |
368 | 1 | let a1 = StringArray::from(vec![Some("A")]); |
369 | 1 | let a2 = StringArray::from(vec![Some("B"), None, None]); |
370 | 1 | let a3 = StringArray::from(vec![Some("C"), Some("D")]); |
371 | | |
372 | 1 | let indices = vec![ |
373 | 1 | CompactMergeIndex { index: u8::MAX }, |
374 | 1 | CompactMergeIndex { index: 1 }, |
375 | 1 | CompactMergeIndex { index: 0 }, |
376 | 1 | CompactMergeIndex { index: u8::MAX }, |
377 | 1 | CompactMergeIndex { index: 2 }, |
378 | 1 | CompactMergeIndex { index: 2 }, |
379 | 1 | CompactMergeIndex { index: 1 }, |
380 | 1 | CompactMergeIndex { index: 1 }, |
381 | | ]; |
382 | | |
383 | 1 | let arrays = [a1, a2, a3]; |
384 | 3 | let array_refs1 = arrays1 .iter1 ().map1 (|a| a as &dyn Array).collect1 ::<Vec<_>>(); |
385 | 1 | let merged = merge_n(&array_refs, &indices).unwrap(); |
386 | 1 | let merged = merged.as_string::<i32>(); |
387 | | |
388 | 1 | assert_eq!(merged.len(), indices.len()); |
389 | 1 | assert!(!merged.is_valid(0)); |
390 | 1 | assert!(merged.is_valid(1)); |
391 | 1 | assert_eq!(merged.value(1), "B"); |
392 | 1 | assert!(merged.is_valid(2)); |
393 | 1 | assert_eq!(merged.value(2), "A"); |
394 | 1 | assert!(!merged.is_valid(3)); |
395 | 1 | assert!(merged.is_valid(4)); |
396 | 1 | assert_eq!(merged.value(4), "C"); |
397 | 1 | assert!(merged.is_valid(5)); |
398 | 1 | assert_eq!(merged.value(5), "D"); |
399 | 1 | assert!(!merged.is_valid(6)); |
400 | 1 | assert!(!merged.is_valid(7)); |
401 | 1 | } |
402 | | |
403 | | #[test] |
404 | 1 | fn test_merge_n_empty_indices() { |
405 | 1 | let a1 = StringArray::from(vec![Some("A")]); |
406 | 1 | let a2 = StringArray::from(vec![Some("B"), None, None]); |
407 | 1 | let a3 = StringArray::from(vec![Some("C"), Some("D")]); |
408 | | |
409 | 1 | let indices: Vec<CompactMergeIndex> = vec![]; |
410 | | |
411 | 1 | let arrays = [a1, a2, a3]; |
412 | 3 | let array_refs1 = arrays1 .iter1 ().map1 (|a| a as &dyn Array).collect1 ::<Vec<_>>(); |
413 | 1 | let merged = merge_n(&array_refs, &indices).unwrap(); |
414 | | |
415 | 1 | assert_eq!(merged.len(), indices.len()); |
416 | 1 | } |
417 | | |
418 | | #[test] |
419 | 1 | fn test_merge_n_empty_values() { |
420 | 1 | let indices: Vec<CompactMergeIndex> = vec![]; |
421 | | |
422 | 1 | let arrays: Vec<&dyn Array> = vec![]; |
423 | 1 | let merged = merge_n(&arrays, &indices); |
424 | | |
425 | 1 | assert!(matches!0 (merged, Err(InvalidArgumentError { .. }))); |
426 | 1 | } |
427 | | } |