/Users/andrewlamb/Software/arrow-rs/arrow-select/src/union_extract.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines union_extract kernel for [UnionArray] |
19 | | |
20 | | use crate::take::take; |
21 | | use arrow_array::{ |
22 | | make_array, new_empty_array, new_null_array, Array, ArrayRef, BooleanArray, Int32Array, Scalar, |
23 | | UnionArray, |
24 | | }; |
25 | | use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer}; |
26 | | use arrow_data::layout; |
27 | | use arrow_schema::{ArrowError, DataType, UnionFields}; |
28 | | use std::cmp::Ordering; |
29 | | use std::sync::Arc; |
30 | | |
31 | | /// Returns the value of the target field when selected, or NULL otherwise. |
32 | | /// ```text |
33 | | /// ┌─────────────────┐ ┌─────────────────┐ |
34 | | /// │ A=1 │ │ 1 │ |
35 | | /// ├─────────────────┤ ├─────────────────┤ |
36 | | /// │ A=NULL │ │ NULL │ |
37 | | /// ├─────────────────┤ union_extract(values, 'A') ├─────────────────┤ |
38 | | /// │ B='t' │ ────────────────────────────▶ │ NULL │ |
39 | | /// ├─────────────────┤ ├─────────────────┤ |
40 | | /// │ A=3 │ │ 3 │ |
41 | | /// ├─────────────────┤ ├─────────────────┤ |
42 | | /// │ B=NULL │ │ NULL │ |
43 | | /// └─────────────────┘ └─────────────────┘ |
44 | | /// union array result |
45 | | /// ``` |
46 | | /// # Errors |
47 | | /// |
48 | | /// Returns error if target field is not found |
49 | | /// |
50 | | /// # Examples |
51 | | /// ``` |
52 | | /// # use std::sync::Arc; |
53 | | /// # use arrow_schema::{DataType, Field, UnionFields}; |
54 | | /// # use arrow_array::{UnionArray, StringArray, Int32Array}; |
55 | | /// # use arrow_select::union_extract::union_extract; |
56 | | /// let fields = UnionFields::new( |
57 | | /// [1, 3], |
58 | | /// [ |
59 | | /// Field::new("A", DataType::Int32, true), |
60 | | /// Field::new("B", DataType::Utf8, true) |
61 | | /// ] |
62 | | /// ); |
63 | | /// |
64 | | /// let union = UnionArray::try_new( |
65 | | /// fields, |
66 | | /// vec![1, 1, 3, 1, 3].into(), |
67 | | /// None, |
68 | | /// vec![ |
69 | | /// Arc::new(Int32Array::from(vec![Some(1), None, None, Some(3), Some(0)])), |
70 | | /// Arc::new(StringArray::from(vec![None, None, Some("t"), Some("."), None])) |
71 | | /// ] |
72 | | /// ).unwrap(); |
73 | | /// |
74 | | /// // Extract field A |
75 | | /// let extracted = union_extract(&union, "A").unwrap(); |
76 | | /// |
77 | | /// assert_eq!(*extracted, Int32Array::from(vec![Some(1), None, None, Some(3), None])); |
78 | | /// ``` |
79 | | pub fn union_extract(union_array: &UnionArray, target: &str) -> Result<ArrayRef, ArrowError> { |
80 | | let fields = match union_array.data_type() { |
81 | | DataType::Union(fields, _) => fields, |
82 | | _ => unreachable!(), |
83 | | }; |
84 | | |
85 | | let (target_type_id, _) = fields |
86 | | .iter() |
87 | 0 | .find(|field| field.1.name() == target) |
88 | 0 | .ok_or_else(|| { |
89 | 0 | ArrowError::InvalidArgumentError(format!("field {target} not found on union")) |
90 | 0 | })?; |
91 | | |
92 | | match union_array.offsets() { |
93 | | Some(_) => extract_dense(union_array, fields, target_type_id), |
94 | | None => extract_sparse(union_array, fields, target_type_id), |
95 | | } |
96 | | } |
97 | | |
98 | | fn extract_sparse( |
99 | | union_array: &UnionArray, |
100 | | fields: &UnionFields, |
101 | | target_type_id: i8, |
102 | | ) -> Result<ArrayRef, ArrowError> { |
103 | | let target = union_array.child(target_type_id); |
104 | | |
105 | | if fields.len() == 1 // case 1.1: if there is a single field, all type ids are the same, and since union doesn't have a null mask, the result array is exactly the same as it only child |
106 | | || union_array.is_empty() // case 1.2: sparse union length and childrens length must match, if the union is empty, so is any children |
107 | | || target.null_count() == target.len() || target.data_type().is_null() |
108 | | // case 1.3: if all values of the target children are null, regardless of selected type ids, the result will also be completely null |
109 | | { |
110 | | Ok(Arc::clone(target)) |
111 | | } else { |
112 | | match eq_scalar(union_array.type_ids(), target_type_id) { |
113 | | // case 2: all type ids equals our target, and since unions doesn't have a null mask, the result array is exactly the same as our target |
114 | | BoolValue::Scalar(true) => Ok(Arc::clone(target)), |
115 | | // case 3: none type_id matches our target, the result is a null array |
116 | | BoolValue::Scalar(false) => { |
117 | | if layout(target.data_type()).can_contain_null_mask { |
118 | | // case 3.1: target array can contain a null mask |
119 | | //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above |
120 | | let data = unsafe { |
121 | | target |
122 | | .into_data() |
123 | | .into_builder() |
124 | | .nulls(Some(NullBuffer::new_null(target.len()))) |
125 | | .build_unchecked() |
126 | | }; |
127 | | |
128 | | Ok(make_array(data)) |
129 | | } else { |
130 | | // case 3.2: target can't contain a null mask |
131 | | Ok(new_null_array(target.data_type(), target.len())) |
132 | | } |
133 | | } |
134 | | // case 4: some but not all type_id matches our target |
135 | | BoolValue::Buffer(selected) => { |
136 | | if layout(target.data_type()).can_contain_null_mask { |
137 | | // case 4.1: target array can contain a null mask |
138 | 0 | let nulls = match target.nulls().filter(|n| n.null_count() > 0) { |
139 | | // case 4.1.1: our target child has nulls and types other than our target are selected, union the masks |
140 | | // the case where n.null_count() == n.len() is cheaply handled at case 1.3 |
141 | | Some(nulls) => &selected & nulls.inner(), |
142 | | // case 4.1.2: target child has no nulls, but types other than our target are selected, use the selected mask as a null mask |
143 | | None => selected, |
144 | | }; |
145 | | |
146 | | //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above |
147 | | let data = unsafe { |
148 | | assert_eq!(nulls.len(), target.len()); |
149 | | |
150 | | target |
151 | | .into_data() |
152 | | .into_builder() |
153 | | .nulls(Some(nulls.into())) |
154 | | .build_unchecked() |
155 | | }; |
156 | | |
157 | | Ok(make_array(data)) |
158 | | } else { |
159 | | // case 4.2: target can't containt a null mask, zip the values that match with a null value |
160 | | Ok(crate::zip::zip( |
161 | | &BooleanArray::new(selected, None), |
162 | | target, |
163 | | &Scalar::new(new_null_array(target.data_type(), 1)), |
164 | | )?) |
165 | | } |
166 | | } |
167 | | } |
168 | | } |
169 | | } |
170 | | |
171 | | fn extract_dense( |
172 | | union_array: &UnionArray, |
173 | | fields: &UnionFields, |
174 | | target_type_id: i8, |
175 | | ) -> Result<ArrayRef, ArrowError> { |
176 | | let target = union_array.child(target_type_id); |
177 | | let offsets = union_array.offsets().unwrap(); |
178 | | |
179 | | if union_array.is_empty() { |
180 | | // case 1: the union is empty |
181 | | if target.is_empty() { |
182 | | // case 1.1: the target is also empty, do a cheap Arc::clone instead of allocating a new empty array |
183 | | Ok(Arc::clone(target)) |
184 | | } else { |
185 | | // case 1.2: the target is not empty, allocate a new empty array |
186 | | Ok(new_empty_array(target.data_type())) |
187 | | } |
188 | | } else if target.is_empty() { |
189 | | // case 2: the union is not empty but the target is, which implies that none type_id points to it. The result is a null array |
190 | | Ok(new_null_array(target.data_type(), union_array.len())) |
191 | | } else if target.null_count() == target.len() || target.data_type().is_null() { |
192 | | // case 3: since all values on our target are null, regardless of selected type ids and offsets, the result is a null array |
193 | | match target.len().cmp(&union_array.len()) { |
194 | | // case 3.1: since the target is smaller than the union, allocate a new correclty sized null array |
195 | | Ordering::Less => Ok(new_null_array(target.data_type(), union_array.len())), |
196 | | // case 3.2: target equals the union len, return it direcly |
197 | | Ordering::Equal => Ok(Arc::clone(target)), |
198 | | // case 3.3: target len is bigger than the union len, slice it |
199 | | Ordering::Greater => Ok(target.slice(0, union_array.len())), |
200 | | } |
201 | | } else if fields.len() == 1 // case A: since there's a single field, our target, every type id must matches our target |
202 | | || fields |
203 | | .iter() |
204 | 0 | .filter(|(field_type_id, _)| *field_type_id != target_type_id) |
205 | 0 | .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty()) |
206 | | // case B: since siblings are empty, every type id must matches our target |
207 | | { |
208 | | // case 4: every type id matches our target |
209 | | Ok(extract_dense_all_selected(union_array, target, offsets)?) |
210 | | } else { |
211 | | match eq_scalar(union_array.type_ids(), target_type_id) { |
212 | | // case 4C: all type ids matches our target. |
213 | | // Non empty sibling without any selected value may happen after slicing the parent union, |
214 | | // since only type_ids and offsets are sliced, not the children |
215 | | BoolValue::Scalar(true) => { |
216 | | Ok(extract_dense_all_selected(union_array, target, offsets)?) |
217 | | } |
218 | | BoolValue::Scalar(false) => { |
219 | | // case 5: none type_id matches our target, so the result array will be completely null |
220 | | // Non empty target without any selected value may happen after slicing the parent union, |
221 | | // since only type_ids and offsets are sliced, not the children |
222 | | match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) { |
223 | | (Ordering::Less, _) // case 5.1A: our target is smaller than the parent union, allocate a new correclty sized null array |
224 | | | (_, false) => { // case 5.1B: target array can't contain a null mask |
225 | | Ok(new_null_array(target.data_type(), union_array.len())) |
226 | | } |
227 | | // case 5.2: target and parent union lengths are equal, and the target can contain a null mask, let's set it to a all-null null-buffer |
228 | | (Ordering::Equal, true) => { |
229 | | //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above |
230 | | let data = unsafe { |
231 | | target |
232 | | .into_data() |
233 | | .into_builder() |
234 | | .nulls(Some(NullBuffer::new_null(union_array.len()))) |
235 | | .build_unchecked() |
236 | | }; |
237 | | |
238 | | Ok(make_array(data)) |
239 | | } |
240 | | // case 5.3: target is bigger than it's parent union and can contain a null mask, let's slice it, and set it's nulls to a all-null null-buffer |
241 | | (Ordering::Greater, true) => { |
242 | | //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above |
243 | | let data = unsafe { |
244 | | target |
245 | | .into_data() |
246 | | .slice(0, union_array.len()) |
247 | | .into_builder() |
248 | | .nulls(Some(NullBuffer::new_null(union_array.len()))) |
249 | | .build_unchecked() |
250 | | }; |
251 | | |
252 | | Ok(make_array(data)) |
253 | | } |
254 | | } |
255 | | } |
256 | | BoolValue::Buffer(selected) => { |
257 | | //case 6: some type_ids matches our target, but not all. For selected values, take the value pointed by the offset. For unselected, use a valid null |
258 | | Ok(take( |
259 | | target, |
260 | | &Int32Array::new(offsets.clone(), Some(selected.into())), |
261 | | None, |
262 | | )?) |
263 | | } |
264 | | } |
265 | | } |
266 | | } |
267 | | |
268 | | fn extract_dense_all_selected( |
269 | | union_array: &UnionArray, |
270 | | target: &Arc<dyn Array>, |
271 | | offsets: &ScalarBuffer<i32>, |
272 | | ) -> Result<ArrayRef, ArrowError> { |
273 | | let sequential = |
274 | | target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets); |
275 | | |
276 | | if sequential && target.len() == union_array.len() { |
277 | | // case 1: all offsets are sequential and both lengths match, return the array directly |
278 | | Ok(Arc::clone(target)) |
279 | | } else if sequential && target.len() > union_array.len() { |
280 | | // case 2: All offsets are sequential, but our target is bigger than our union, slice it, starting at the first offset |
281 | | Ok(target.slice(offsets[0] as usize, union_array.len())) |
282 | | } else { |
283 | | // case 3: Since offsets are not sequential, take them from the child to a new sequential and correcly sized array |
284 | | let indices = Int32Array::try_new(offsets.clone(), None)?; |
285 | | |
286 | | Ok(take(target, &indices, None)?) |
287 | | } |
288 | | } |
289 | | |
290 | | const EQ_SCALAR_CHUNK_SIZE: usize = 512; |
291 | | |
292 | | /// The result of checking which type_ids matches the target type_id |
293 | | #[derive(Debug, PartialEq)] |
294 | | enum BoolValue { |
295 | | /// If true, all type_ids matches the target type_id |
296 | | /// If false, none type_ids matches the target type_id |
297 | | Scalar(bool), |
298 | | /// A mask represeting which type_ids matches the target type_id |
299 | | Buffer(BooleanBuffer), |
300 | | } |
301 | | |
302 | | fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue { |
303 | | eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target) |
304 | | } |
305 | | |
306 | 0 | fn count_first_run(chunk_size: usize, type_ids: &[i8], mut f: impl FnMut(i8) -> bool) -> usize { |
307 | 0 | type_ids |
308 | 0 | .chunks(chunk_size) |
309 | 0 | .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v))) |
310 | 0 | .map(|chunk| chunk.len()) |
311 | 0 | .sum() |
312 | 0 | } |
313 | | |
314 | | // This is like MutableBuffer::collect_bool(type_ids.len(), |i| type_ids[i] == target) with fast paths for all true or all false values. |
315 | | fn eq_scalar_inner(chunk_size: usize, type_ids: &[i8], target: i8) -> BoolValue { |
316 | 0 | let true_bits = count_first_run(chunk_size, type_ids, |v| v == target); |
317 | | |
318 | | let (set_bits, val) = if true_bits == type_ids.len() { |
319 | | return BoolValue::Scalar(true); |
320 | | } else if true_bits == 0 { |
321 | 0 | let false_bits = count_first_run(chunk_size, type_ids, |v| v != target); |
322 | | |
323 | | if false_bits == type_ids.len() { |
324 | | return BoolValue::Scalar(false); |
325 | | } else { |
326 | | (false_bits, false) |
327 | | } |
328 | | } else { |
329 | | (true_bits, true) |
330 | | }; |
331 | | |
332 | | // restrict to chunk boundaries |
333 | | let set_bits = set_bits - set_bits % 64; |
334 | | |
335 | | let mut buffer = |
336 | | MutableBuffer::new(bit_util::ceil(type_ids.len(), 8)).with_bitset(set_bits / 8, val); |
337 | | |
338 | 0 | buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| { |
339 | 0 | chunk |
340 | 0 | .iter() |
341 | 0 | .copied() |
342 | 0 | .enumerate() |
343 | 0 | .fold(0, |packed, (bit_idx, v)| { |
344 | 0 | packed | (((v == target) as u64) << bit_idx) |
345 | 0 | }) |
346 | 0 | })); |
347 | | |
348 | | BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len())) |
349 | | } |
350 | | |
351 | | const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64; |
352 | | |
353 | | fn is_sequential(offsets: &[i32]) -> bool { |
354 | | is_sequential_generic::<IS_SEQUENTIAL_CHUNK_SIZE>(offsets) |
355 | | } |
356 | | |
357 | 0 | fn is_sequential_generic<const N: usize>(offsets: &[i32]) -> bool { |
358 | 0 | if offsets.is_empty() { |
359 | 0 | return true; |
360 | 0 | } |
361 | | |
362 | | // fast check this common combination: |
363 | | // 1: sequential nulls are represented as a single null value on the values array, pointed by the same offset multiple times |
364 | | // 2: valid values offsets increase one by one. |
365 | | // example for an union with a single field A with type_id 0: |
366 | | // union = A=7 A=NULL A=NULL A=5 A=9 |
367 | | // a values = 7 NULL 5 9 |
368 | | // offsets = 0 1 1 2 3 |
369 | | // type_ids = 0 0 0 0 0 |
370 | | // this also checks if the last chunk/remainder is sequential relative to the first offset |
371 | 0 | if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] { |
372 | 0 | return false; |
373 | 0 | } |
374 | | |
375 | 0 | let chunks = offsets.chunks_exact(N); |
376 | | |
377 | 0 | let remainder = chunks.remainder(); |
378 | | |
379 | 0 | chunks.enumerate().all(|(i, chunk)| { |
380 | 0 | let chunk_array = <&[i32; N]>::try_from(chunk).unwrap(); |
381 | | |
382 | | //checks if values within chunk are sequential |
383 | 0 | chunk_array |
384 | 0 | .iter() |
385 | 0 | .copied() |
386 | 0 | .enumerate() |
387 | 0 | .fold(true, |acc, (i, offset)| { |
388 | 0 | acc & (offset == chunk_array[0] + i as i32) |
389 | 0 | }) |
390 | 0 | && offsets[0] + (i * N) as i32 == chunk_array[0] //checks if chunk is sequential relative to the first offset |
391 | 0 | }) && remainder |
392 | 0 | .iter() |
393 | 0 | .copied() |
394 | 0 | .enumerate() |
395 | 0 | .fold(true, |acc, (i, offset)| { |
396 | 0 | acc & (offset == remainder[0] + i as i32) |
397 | 0 | }) //if the remainder is sequential relative to the first offset is checked at the start of the function |
398 | 0 | } |
399 | | |
400 | | #[cfg(test)] |
401 | | mod tests { |
402 | | use super::{eq_scalar_inner, is_sequential_generic, union_extract, BoolValue}; |
403 | | use arrow_array::{new_null_array, Array, Int32Array, NullArray, StringArray, UnionArray}; |
404 | | use arrow_buffer::{BooleanBuffer, ScalarBuffer}; |
405 | | use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode}; |
406 | | use std::sync::Arc; |
407 | | |
408 | | #[test] |
409 | | fn test_eq_scalar() { |
410 | | //multiple all equal chunks, so it's loop and sum logic it's tested |
411 | | //multiple chunks after, so it's loop logic it's tested |
412 | | const ARRAY_LEN: usize = 64 * 4; |
413 | | |
414 | | //so out of 64 boundaries chunks can be generated and checked for |
415 | | const EQ_SCALAR_CHUNK_SIZE: usize = 3; |
416 | | |
417 | | fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue { |
418 | | eq_scalar_inner(EQ_SCALAR_CHUNK_SIZE, type_ids, target) |
419 | | } |
420 | | |
421 | | fn cross_check(left: &[i8], right: i8) -> BooleanBuffer { |
422 | | BooleanBuffer::collect_bool(left.len(), |i| left[i] == right) |
423 | | } |
424 | | |
425 | | assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true)); |
426 | | |
427 | | assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true)); |
428 | | assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false)); |
429 | | |
430 | | let mut values = [1; ARRAY_LEN]; |
431 | | |
432 | | assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true)); |
433 | | assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false)); |
434 | | |
435 | | //every subslice should return the same value |
436 | | for i in 1..ARRAY_LEN { |
437 | | assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true)); |
438 | | assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false)); |
439 | | } |
440 | | |
441 | | // test that a single change anywhere is checked for |
442 | | for i in 0..ARRAY_LEN { |
443 | | values[i] = 2; |
444 | | |
445 | | assert_eq!( |
446 | | eq_scalar(&values, 1), |
447 | | BoolValue::Buffer(cross_check(&values, 1)) |
448 | | ); |
449 | | assert_eq!( |
450 | | eq_scalar(&values, 2), |
451 | | BoolValue::Buffer(cross_check(&values, 2)) |
452 | | ); |
453 | | |
454 | | values[i] = 1; |
455 | | } |
456 | | } |
457 | | |
458 | | #[test] |
459 | | fn test_is_sequential() { |
460 | | /* |
461 | | the smallest value that satisfies: |
462 | | >1 so the fold logic of a exact chunk executes |
463 | | >2 so a >1 non-exact remainder can exist, and it's fold logic executes |
464 | | */ |
465 | | const CHUNK_SIZE: usize = 3; |
466 | | //we test arrays of size up to 8 = 2 * CHUNK_SIZE + 2: |
467 | | //multiple(2) exact chunks, so the AND logic between them executes |
468 | | //a >1(2) remainder, so: |
469 | | // the AND logic between all exact chunks and the remainder executes |
470 | | // the remainder fold logic executes |
471 | | |
472 | | fn is_sequential(v: &[i32]) -> bool { |
473 | | is_sequential_generic::<CHUNK_SIZE>(v) |
474 | | } |
475 | | |
476 | | assert!(is_sequential(&[])); //empty |
477 | | assert!(is_sequential(&[1])); //single |
478 | | |
479 | | assert!(is_sequential(&[1, 2])); |
480 | | assert!(is_sequential(&[1, 2, 3])); |
481 | | assert!(is_sequential(&[1, 2, 3, 4])); |
482 | | assert!(is_sequential(&[1, 2, 3, 4, 5])); |
483 | | assert!(is_sequential(&[1, 2, 3, 4, 5, 6])); |
484 | | assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7])); |
485 | | assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8])); |
486 | | |
487 | | assert!(!is_sequential(&[8, 7])); |
488 | | assert!(!is_sequential(&[8, 7, 6])); |
489 | | assert!(!is_sequential(&[8, 7, 6, 5])); |
490 | | assert!(!is_sequential(&[8, 7, 6, 5, 4])); |
491 | | assert!(!is_sequential(&[8, 7, 6, 5, 4, 3])); |
492 | | assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2])); |
493 | | assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1])); |
494 | | |
495 | | assert!(!is_sequential(&[0, 2])); |
496 | | assert!(!is_sequential(&[1, 0])); |
497 | | |
498 | | assert!(!is_sequential(&[0, 2, 3])); |
499 | | assert!(!is_sequential(&[1, 0, 3])); |
500 | | assert!(!is_sequential(&[1, 2, 0])); |
501 | | |
502 | | assert!(!is_sequential(&[0, 2, 3, 4])); |
503 | | assert!(!is_sequential(&[1, 0, 3, 4])); |
504 | | assert!(!is_sequential(&[1, 2, 0, 4])); |
505 | | assert!(!is_sequential(&[1, 2, 3, 0])); |
506 | | |
507 | | assert!(!is_sequential(&[0, 2, 3, 4, 5])); |
508 | | assert!(!is_sequential(&[1, 0, 3, 4, 5])); |
509 | | assert!(!is_sequential(&[1, 2, 0, 4, 5])); |
510 | | assert!(!is_sequential(&[1, 2, 3, 0, 5])); |
511 | | assert!(!is_sequential(&[1, 2, 3, 4, 0])); |
512 | | |
513 | | assert!(!is_sequential(&[0, 2, 3, 4, 5, 6])); |
514 | | assert!(!is_sequential(&[1, 0, 3, 4, 5, 6])); |
515 | | assert!(!is_sequential(&[1, 2, 0, 4, 5, 6])); |
516 | | assert!(!is_sequential(&[1, 2, 3, 0, 5, 6])); |
517 | | assert!(!is_sequential(&[1, 2, 3, 4, 0, 6])); |
518 | | assert!(!is_sequential(&[1, 2, 3, 4, 5, 0])); |
519 | | |
520 | | assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7])); |
521 | | assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7])); |
522 | | assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7])); |
523 | | assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7])); |
524 | | assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7])); |
525 | | assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7])); |
526 | | assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0])); |
527 | | |
528 | | assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8])); |
529 | | assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8])); |
530 | | assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8])); |
531 | | assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8])); |
532 | | assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8])); |
533 | | assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8])); |
534 | | assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8])); |
535 | | assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0])); |
536 | | |
537 | | // checks increments at the chunk boundary |
538 | | assert!(!is_sequential(&[1, 2, 3, 5])); |
539 | | assert!(!is_sequential(&[1, 2, 3, 5, 6])); |
540 | | assert!(!is_sequential(&[1, 2, 3, 5, 6, 7])); |
541 | | assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8])); |
542 | | assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 8, 9])); |
543 | | } |
544 | | |
545 | | fn str1() -> UnionFields { |
546 | | UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, true)]) |
547 | | } |
548 | | |
549 | | fn str1_int3() -> UnionFields { |
550 | | UnionFields::new( |
551 | | vec![1, 3], |
552 | | vec![ |
553 | | Field::new("str", DataType::Utf8, true), |
554 | | Field::new("int", DataType::Int32, true), |
555 | | ], |
556 | | ) |
557 | | } |
558 | | |
559 | | #[test] |
560 | | fn sparse_1_1_single_field() { |
561 | | let union = UnionArray::try_new( |
562 | | //single field |
563 | | str1(), |
564 | | ScalarBuffer::from(vec![1, 1]), // non empty, every type id must match |
565 | | None, //sparse |
566 | | vec![ |
567 | | Arc::new(StringArray::from(vec!["a", "b"])), // not null |
568 | | ], |
569 | | ) |
570 | | .unwrap(); |
571 | | |
572 | | let expected = StringArray::from(vec!["a", "b"]); |
573 | | let extracted = union_extract(&union, "str").unwrap(); |
574 | | |
575 | | assert_eq!(extracted.into_data(), expected.into_data()); |
576 | | } |
577 | | |
578 | | #[test] |
579 | | fn sparse_1_2_empty() { |
580 | | let union = UnionArray::try_new( |
581 | | // multiple fields |
582 | | str1_int3(), |
583 | | ScalarBuffer::from(vec![]), //empty union |
584 | | None, // sparse |
585 | | vec![ |
586 | | Arc::new(StringArray::new_null(0)), |
587 | | Arc::new(Int32Array::new_null(0)), |
588 | | ], |
589 | | ) |
590 | | .unwrap(); |
591 | | |
592 | | let expected = StringArray::new_null(0); |
593 | | let extracted = union_extract(&union, "str").unwrap(); //target type is not Null |
594 | | |
595 | | assert_eq!(extracted.into_data(), expected.into_data()); |
596 | | } |
597 | | |
598 | | #[test] |
599 | | fn sparse_1_3a_null_target() { |
600 | | let union = UnionArray::try_new( |
601 | | // multiple fields |
602 | | UnionFields::new( |
603 | | vec![1, 3], |
604 | | vec![ |
605 | | Field::new("str", DataType::Utf8, true), |
606 | | Field::new("null", DataType::Null, true), // target type is Null |
607 | | ], |
608 | | ), |
609 | | ScalarBuffer::from(vec![1]), //not empty |
610 | | None, // sparse |
611 | | vec![ |
612 | | Arc::new(StringArray::new_null(1)), |
613 | | Arc::new(NullArray::new(1)), // null data type |
614 | | ], |
615 | | ) |
616 | | .unwrap(); |
617 | | |
618 | | let expected = NullArray::new(1); |
619 | | let extracted = union_extract(&union, "null").unwrap(); |
620 | | |
621 | | assert_eq!(extracted.into_data(), expected.into_data()); |
622 | | } |
623 | | |
624 | | #[test] |
625 | | fn sparse_1_3b_null_target() { |
626 | | let union = UnionArray::try_new( |
627 | | // multiple fields |
628 | | str1_int3(), |
629 | | ScalarBuffer::from(vec![1]), //not empty |
630 | | None, // sparse |
631 | | vec![ |
632 | | Arc::new(StringArray::new_null(1)), //all null |
633 | | Arc::new(Int32Array::new_null(1)), |
634 | | ], |
635 | | ) |
636 | | .unwrap(); |
637 | | |
638 | | let expected = StringArray::new_null(1); |
639 | | let extracted = union_extract(&union, "str").unwrap(); //target type is not Null |
640 | | |
641 | | assert_eq!(extracted.into_data(), expected.into_data()); |
642 | | } |
643 | | |
644 | | #[test] |
645 | | fn sparse_2_all_types_match() { |
646 | | let union = UnionArray::try_new( |
647 | | //multiple fields |
648 | | str1_int3(), |
649 | | ScalarBuffer::from(vec![3, 3]), // all types match |
650 | | None, //sparse |
651 | | vec![ |
652 | | Arc::new(StringArray::new_null(2)), |
653 | | Arc::new(Int32Array::from(vec![1, 4])), // not null |
654 | | ], |
655 | | ) |
656 | | .unwrap(); |
657 | | |
658 | | let expected = Int32Array::from(vec![1, 4]); |
659 | | let extracted = union_extract(&union, "int").unwrap(); |
660 | | |
661 | | assert_eq!(extracted.into_data(), expected.into_data()); |
662 | | } |
663 | | |
664 | | #[test] |
665 | | fn sparse_3_1_none_match_target_can_contain_null_mask() { |
666 | | let union = UnionArray::try_new( |
667 | | //multiple fields |
668 | | str1_int3(), |
669 | | ScalarBuffer::from(vec![1, 1, 1, 1]), // none match |
670 | | None, // sparse |
671 | | vec![ |
672 | | Arc::new(StringArray::new_null(4)), |
673 | | Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target is not null |
674 | | ], |
675 | | ) |
676 | | .unwrap(); |
677 | | |
678 | | let expected = Int32Array::new_null(4); |
679 | | let extracted = union_extract(&union, "int").unwrap(); |
680 | | |
681 | | assert_eq!(extracted.into_data(), expected.into_data()); |
682 | | } |
683 | | |
684 | | fn str1_union3(union3_datatype: DataType) -> UnionFields { |
685 | | UnionFields::new( |
686 | | vec![1, 3], |
687 | | vec![ |
688 | | Field::new("str", DataType::Utf8, true), |
689 | | Field::new("union", union3_datatype, true), |
690 | | ], |
691 | | ) |
692 | | } |
693 | | |
694 | | #[test] |
695 | | fn sparse_3_2_none_match_cant_contain_null_mask_union_target() { |
696 | | let target_fields = str1(); |
697 | | let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse); |
698 | | |
699 | | let union = UnionArray::try_new( |
700 | | //multiple fields |
701 | | str1_union3(target_type.clone()), |
702 | | ScalarBuffer::from(vec![1, 1]), // none match |
703 | | None, //sparse |
704 | | vec![ |
705 | | Arc::new(StringArray::new_null(2)), |
706 | | //target is not null |
707 | | Arc::new( |
708 | | UnionArray::try_new( |
709 | | target_fields.clone(), |
710 | | ScalarBuffer::from(vec![1, 1]), |
711 | | None, |
712 | | vec![Arc::new(StringArray::from(vec!["a", "b"]))], |
713 | | ) |
714 | | .unwrap(), |
715 | | ), |
716 | | ], |
717 | | ) |
718 | | .unwrap(); |
719 | | |
720 | | let expected = new_null_array(&target_type, 2); |
721 | | let extracted = union_extract(&union, "union").unwrap(); |
722 | | |
723 | | assert_eq!(extracted.into_data(), expected.into_data()); |
724 | | } |
725 | | |
726 | | #[test] |
727 | | fn sparse_4_1_1_target_with_nulls() { |
728 | | let union = UnionArray::try_new( |
729 | | //multiple fields |
730 | | str1_int3(), |
731 | | ScalarBuffer::from(vec![3, 3, 1, 1]), // multiple selected types |
732 | | None, // sparse |
733 | | vec![ |
734 | | Arc::new(StringArray::new_null(4)), |
735 | | Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target with nulls |
736 | | ], |
737 | | ) |
738 | | .unwrap(); |
739 | | |
740 | | let expected = Int32Array::from(vec![None, Some(4), None, None]); |
741 | | let extracted = union_extract(&union, "int").unwrap(); |
742 | | |
743 | | assert_eq!(extracted.into_data(), expected.into_data()); |
744 | | } |
745 | | |
746 | | #[test] |
747 | | fn sparse_4_1_2_target_without_nulls() { |
748 | | let union = UnionArray::try_new( |
749 | | //multiple fields |
750 | | str1_int3(), |
751 | | ScalarBuffer::from(vec![1, 3, 3]), // multiple selected types |
752 | | None, // sparse |
753 | | vec![ |
754 | | Arc::new(StringArray::new_null(3)), |
755 | | Arc::new(Int32Array::from(vec![2, 4, 8])), // target without nulls |
756 | | ], |
757 | | ) |
758 | | .unwrap(); |
759 | | |
760 | | let expected = Int32Array::from(vec![None, Some(4), Some(8)]); |
761 | | let extracted = union_extract(&union, "int").unwrap(); |
762 | | |
763 | | assert_eq!(extracted.into_data(), expected.into_data()); |
764 | | } |
765 | | |
766 | | #[test] |
767 | | fn sparse_4_2_some_match_target_cant_contain_null_mask() { |
768 | | let target_fields = str1(); |
769 | | let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse); |
770 | | |
771 | | let union = UnionArray::try_new( |
772 | | //multiple fields |
773 | | str1_union3(target_type), |
774 | | ScalarBuffer::from(vec![3, 1]), // some types match, but not all |
775 | | None, //sparse |
776 | | vec![ |
777 | | Arc::new(StringArray::new_null(2)), |
778 | | Arc::new( |
779 | | UnionArray::try_new( |
780 | | target_fields.clone(), |
781 | | ScalarBuffer::from(vec![1, 1]), |
782 | | None, |
783 | | vec![Arc::new(StringArray::from(vec!["a", "b"]))], |
784 | | ) |
785 | | .unwrap(), |
786 | | ), |
787 | | ], |
788 | | ) |
789 | | .unwrap(); |
790 | | |
791 | | let expected = UnionArray::try_new( |
792 | | target_fields, |
793 | | ScalarBuffer::from(vec![1, 1]), |
794 | | None, |
795 | | vec![Arc::new(StringArray::from(vec![Some("a"), None]))], |
796 | | ) |
797 | | .unwrap(); |
798 | | let extracted = union_extract(&union, "union").unwrap(); |
799 | | |
800 | | assert_eq!(extracted.into_data(), expected.into_data()); |
801 | | } |
802 | | |
803 | | #[test] |
804 | | fn dense_1_1_both_empty() { |
805 | | let union = UnionArray::try_new( |
806 | | str1_int3(), |
807 | | ScalarBuffer::from(vec![]), //empty union |
808 | | Some(ScalarBuffer::from(vec![])), // dense |
809 | | vec![ |
810 | | Arc::new(StringArray::new_null(0)), //empty target |
811 | | Arc::new(Int32Array::new_null(0)), |
812 | | ], |
813 | | ) |
814 | | .unwrap(); |
815 | | |
816 | | let expected = StringArray::new_null(0); |
817 | | let extracted = union_extract(&union, "str").unwrap(); |
818 | | |
819 | | assert_eq!(extracted.into_data(), expected.into_data()); |
820 | | } |
821 | | |
822 | | #[test] |
823 | | fn dense_1_2_empty_union_target_non_empty() { |
824 | | let union = UnionArray::try_new( |
825 | | str1_int3(), |
826 | | ScalarBuffer::from(vec![]), //empty union |
827 | | Some(ScalarBuffer::from(vec![])), // dense |
828 | | vec![ |
829 | | Arc::new(StringArray::new_null(1)), //non empty target |
830 | | Arc::new(Int32Array::new_null(0)), |
831 | | ], |
832 | | ) |
833 | | .unwrap(); |
834 | | |
835 | | let expected = StringArray::new_null(0); |
836 | | let extracted = union_extract(&union, "str").unwrap(); |
837 | | |
838 | | assert_eq!(extracted.into_data(), expected.into_data()); |
839 | | } |
840 | | |
841 | | #[test] |
842 | | fn dense_2_non_empty_union_target_empty() { |
843 | | let union = UnionArray::try_new( |
844 | | str1_int3(), |
845 | | ScalarBuffer::from(vec![3, 3]), //non empty union |
846 | | Some(ScalarBuffer::from(vec![0, 1])), // dense |
847 | | vec![ |
848 | | Arc::new(StringArray::new_null(0)), //empty target |
849 | | Arc::new(Int32Array::new_null(2)), |
850 | | ], |
851 | | ) |
852 | | .unwrap(); |
853 | | |
854 | | let expected = StringArray::new_null(2); |
855 | | let extracted = union_extract(&union, "str").unwrap(); |
856 | | |
857 | | assert_eq!(extracted.into_data(), expected.into_data()); |
858 | | } |
859 | | |
860 | | #[test] |
861 | | fn dense_3_1_null_target_smaller_len() { |
862 | | let union = UnionArray::try_new( |
863 | | str1_int3(), |
864 | | ScalarBuffer::from(vec![3, 3]), //non empty union |
865 | | Some(ScalarBuffer::from(vec![0, 0])), //dense |
866 | | vec![ |
867 | | Arc::new(StringArray::new_null(1)), //smaller target |
868 | | Arc::new(Int32Array::new_null(2)), |
869 | | ], |
870 | | ) |
871 | | .unwrap(); |
872 | | |
873 | | let expected = StringArray::new_null(2); |
874 | | let extracted = union_extract(&union, "str").unwrap(); |
875 | | |
876 | | assert_eq!(extracted.into_data(), expected.into_data()); |
877 | | } |
878 | | |
879 | | #[test] |
880 | | fn dense_3_2_null_target_equal_len() { |
881 | | let union = UnionArray::try_new( |
882 | | str1_int3(), |
883 | | ScalarBuffer::from(vec![3, 3]), //non empty union |
884 | | Some(ScalarBuffer::from(vec![0, 0])), //dense |
885 | | vec![ |
886 | | Arc::new(StringArray::new_null(2)), //equal len |
887 | | Arc::new(Int32Array::new_null(2)), |
888 | | ], |
889 | | ) |
890 | | .unwrap(); |
891 | | |
892 | | let expected = StringArray::new_null(2); |
893 | | let extracted = union_extract(&union, "str").unwrap(); |
894 | | |
895 | | assert_eq!(extracted.into_data(), expected.into_data()); |
896 | | } |
897 | | |
898 | | #[test] |
899 | | fn dense_3_3_null_target_bigger_len() { |
900 | | let union = UnionArray::try_new( |
901 | | str1_int3(), |
902 | | ScalarBuffer::from(vec![3, 3]), //non empty union |
903 | | Some(ScalarBuffer::from(vec![0, 0])), //dense |
904 | | vec![ |
905 | | Arc::new(StringArray::new_null(3)), //bigger len |
906 | | Arc::new(Int32Array::new_null(3)), |
907 | | ], |
908 | | ) |
909 | | .unwrap(); |
910 | | |
911 | | let expected = StringArray::new_null(2); |
912 | | let extracted = union_extract(&union, "str").unwrap(); |
913 | | |
914 | | assert_eq!(extracted.into_data(), expected.into_data()); |
915 | | } |
916 | | |
917 | | #[test] |
918 | | fn dense_4_1a_single_type_sequential_offsets_equal_len() { |
919 | | let union = UnionArray::try_new( |
920 | | // single field |
921 | | str1(), |
922 | | ScalarBuffer::from(vec![1, 1]), //non empty union |
923 | | Some(ScalarBuffer::from(vec![0, 1])), //sequential |
924 | | vec![ |
925 | | Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len, non null |
926 | | ], |
927 | | ) |
928 | | .unwrap(); |
929 | | |
930 | | let expected = StringArray::from(vec!["a1", "b2"]); |
931 | | let extracted = union_extract(&union, "str").unwrap(); |
932 | | |
933 | | assert_eq!(extracted.into_data(), expected.into_data()); |
934 | | } |
935 | | |
936 | | #[test] |
937 | | fn dense_4_2a_single_type_sequential_offsets_bigger() { |
938 | | let union = UnionArray::try_new( |
939 | | // single field |
940 | | str1(), |
941 | | ScalarBuffer::from(vec![1, 1]), //non empty union |
942 | | Some(ScalarBuffer::from(vec![0, 1])), //sequential |
943 | | vec![ |
944 | | Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), //equal len, non null |
945 | | ], |
946 | | ) |
947 | | .unwrap(); |
948 | | |
949 | | let expected = StringArray::from(vec!["a1", "b2"]); |
950 | | let extracted = union_extract(&union, "str").unwrap(); |
951 | | |
952 | | assert_eq!(extracted.into_data(), expected.into_data()); |
953 | | } |
954 | | |
955 | | #[test] |
956 | | fn dense_4_3a_single_type_non_sequential() { |
957 | | let union = UnionArray::try_new( |
958 | | // single field |
959 | | str1(), |
960 | | ScalarBuffer::from(vec![1, 1]), //non empty union |
961 | | Some(ScalarBuffer::from(vec![0, 2])), //non sequential |
962 | | vec![ |
963 | | Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), //equal len, non null |
964 | | ], |
965 | | ) |
966 | | .unwrap(); |
967 | | |
968 | | let expected = StringArray::from(vec!["a1", "c3"]); |
969 | | let extracted = union_extract(&union, "str").unwrap(); |
970 | | |
971 | | assert_eq!(extracted.into_data(), expected.into_data()); |
972 | | } |
973 | | |
974 | | #[test] |
975 | | fn dense_4_1b_empty_siblings_sequential_equal_len() { |
976 | | let union = UnionArray::try_new( |
977 | | // multiple fields |
978 | | str1_int3(), |
979 | | ScalarBuffer::from(vec![1, 1]), //non empty union |
980 | | Some(ScalarBuffer::from(vec![0, 1])), //sequential |
981 | | vec![ |
982 | | Arc::new(StringArray::from(vec!["a", "b"])), //equal len, non null |
983 | | Arc::new(Int32Array::new_null(0)), //empty sibling |
984 | | ], |
985 | | ) |
986 | | .unwrap(); |
987 | | |
988 | | let expected = StringArray::from(vec!["a", "b"]); |
989 | | let extracted = union_extract(&union, "str").unwrap(); |
990 | | |
991 | | assert_eq!(extracted.into_data(), expected.into_data()); |
992 | | } |
993 | | |
994 | | #[test] |
995 | | fn dense_4_2b_empty_siblings_sequential_bigger_len() { |
996 | | let union = UnionArray::try_new( |
997 | | // multiple fields |
998 | | str1_int3(), |
999 | | ScalarBuffer::from(vec![1, 1]), //non empty union |
1000 | | Some(ScalarBuffer::from(vec![0, 1])), //sequential |
1001 | | vec![ |
1002 | | Arc::new(StringArray::from(vec!["a", "b", "c"])), //bigger len, non null |
1003 | | Arc::new(Int32Array::new_null(0)), //empty sibling |
1004 | | ], |
1005 | | ) |
1006 | | .unwrap(); |
1007 | | |
1008 | | let expected = StringArray::from(vec!["a", "b"]); |
1009 | | let extracted = union_extract(&union, "str").unwrap(); |
1010 | | |
1011 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1012 | | } |
1013 | | |
1014 | | #[test] |
1015 | | fn dense_4_3b_empty_sibling_non_sequential() { |
1016 | | let union = UnionArray::try_new( |
1017 | | // multiple fields |
1018 | | str1_int3(), |
1019 | | ScalarBuffer::from(vec![1, 1]), //non empty union |
1020 | | Some(ScalarBuffer::from(vec![0, 2])), //non sequential |
1021 | | vec![ |
1022 | | Arc::new(StringArray::from(vec!["a", "b", "c"])), //non null |
1023 | | Arc::new(Int32Array::new_null(0)), //empty sibling |
1024 | | ], |
1025 | | ) |
1026 | | .unwrap(); |
1027 | | |
1028 | | let expected = StringArray::from(vec!["a", "c"]); |
1029 | | let extracted = union_extract(&union, "str").unwrap(); |
1030 | | |
1031 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1032 | | } |
1033 | | |
1034 | | #[test] |
1035 | | fn dense_4_1c_all_types_match_sequential_equal_len() { |
1036 | | let union = UnionArray::try_new( |
1037 | | // multiple fields |
1038 | | str1_int3(), |
1039 | | ScalarBuffer::from(vec![1, 1]), //all types match |
1040 | | Some(ScalarBuffer::from(vec![0, 1])), //sequential |
1041 | | vec![ |
1042 | | Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len |
1043 | | Arc::new(Int32Array::new_null(2)), //non empty sibling |
1044 | | ], |
1045 | | ) |
1046 | | .unwrap(); |
1047 | | |
1048 | | let expected = StringArray::from(vec!["a1", "b2"]); |
1049 | | let extracted = union_extract(&union, "str").unwrap(); |
1050 | | |
1051 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1052 | | } |
1053 | | |
1054 | | #[test] |
1055 | | fn dense_4_2c_all_types_match_sequential_bigger_len() { |
1056 | | let union = UnionArray::try_new( |
1057 | | // multiple fields |
1058 | | str1_int3(), |
1059 | | ScalarBuffer::from(vec![1, 1]), //all types match |
1060 | | Some(ScalarBuffer::from(vec![0, 1])), //sequential |
1061 | | vec![ |
1062 | | Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), //bigger len |
1063 | | Arc::new(Int32Array::new_null(2)), //non empty sibling |
1064 | | ], |
1065 | | ) |
1066 | | .unwrap(); |
1067 | | |
1068 | | let expected = StringArray::from(vec!["a1", "b2"]); |
1069 | | let extracted = union_extract(&union, "str").unwrap(); |
1070 | | |
1071 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1072 | | } |
1073 | | |
1074 | | #[test] |
1075 | | fn dense_4_3c_all_types_match_non_sequential() { |
1076 | | let union = UnionArray::try_new( |
1077 | | // multiple fields |
1078 | | str1_int3(), |
1079 | | ScalarBuffer::from(vec![1, 1]), //all types match |
1080 | | Some(ScalarBuffer::from(vec![0, 2])), //non sequential |
1081 | | vec![ |
1082 | | Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), |
1083 | | Arc::new(Int32Array::new_null(2)), //non empty sibling |
1084 | | ], |
1085 | | ) |
1086 | | .unwrap(); |
1087 | | |
1088 | | let expected = StringArray::from(vec!["a1", "b3"]); |
1089 | | let extracted = union_extract(&union, "str").unwrap(); |
1090 | | |
1091 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1092 | | } |
1093 | | |
1094 | | #[test] |
1095 | | fn dense_5_1a_none_match_less_len() { |
1096 | | let union = UnionArray::try_new( |
1097 | | // multiple fields |
1098 | | str1_int3(), |
1099 | | ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches |
1100 | | Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense |
1101 | | vec![ |
1102 | | Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len |
1103 | | Arc::new(Int32Array::from(vec![1, 2])), |
1104 | | ], |
1105 | | ) |
1106 | | .unwrap(); |
1107 | | |
1108 | | let expected = StringArray::new_null(5); |
1109 | | let extracted = union_extract(&union, "str").unwrap(); |
1110 | | |
1111 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1112 | | } |
1113 | | |
1114 | | #[test] |
1115 | | fn dense_5_1b_cant_contain_null_mask() { |
1116 | | let target_fields = str1(); |
1117 | | let target_type = DataType::Union(target_fields.clone(), UnionMode::Sparse); |
1118 | | |
1119 | | let union = UnionArray::try_new( |
1120 | | // multiple fields |
1121 | | str1_union3(target_type.clone()), |
1122 | | ScalarBuffer::from(vec![1, 1, 1, 1, 1]), //none matches |
1123 | | Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense |
1124 | | vec![ |
1125 | | Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len |
1126 | | Arc::new( |
1127 | | UnionArray::try_new( |
1128 | | target_fields.clone(), |
1129 | | ScalarBuffer::from(vec![1]), |
1130 | | None, |
1131 | | vec![Arc::new(StringArray::from(vec!["a"]))], |
1132 | | ) |
1133 | | .unwrap(), |
1134 | | ), // non empty |
1135 | | ], |
1136 | | ) |
1137 | | .unwrap(); |
1138 | | |
1139 | | let expected = new_null_array(&target_type, 5); |
1140 | | let extracted = union_extract(&union, "union").unwrap(); |
1141 | | |
1142 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1143 | | } |
1144 | | |
1145 | | #[test] |
1146 | | fn dense_5_2_none_match_equal_len() { |
1147 | | let union = UnionArray::try_new( |
1148 | | // multiple fields |
1149 | | str1_int3(), |
1150 | | ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches |
1151 | | Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense |
1152 | | vec![ |
1153 | | Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), // equal len |
1154 | | Arc::new(Int32Array::from(vec![1, 2])), |
1155 | | ], |
1156 | | ) |
1157 | | .unwrap(); |
1158 | | |
1159 | | let expected = StringArray::new_null(5); |
1160 | | let extracted = union_extract(&union, "str").unwrap(); |
1161 | | |
1162 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1163 | | } |
1164 | | |
1165 | | #[test] |
1166 | | fn dense_5_3_none_match_greater_len() { |
1167 | | let union = UnionArray::try_new( |
1168 | | // multiple fields |
1169 | | str1_int3(), |
1170 | | ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches |
1171 | | Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense |
1172 | | vec![ |
1173 | | Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), // greater len |
1174 | | Arc::new(Int32Array::from(vec![1, 2])), //non null |
1175 | | ], |
1176 | | ) |
1177 | | .unwrap(); |
1178 | | |
1179 | | let expected = StringArray::new_null(5); |
1180 | | let extracted = union_extract(&union, "str").unwrap(); |
1181 | | |
1182 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1183 | | } |
1184 | | |
1185 | | #[test] |
1186 | | fn dense_6_some_matches() { |
1187 | | let union = UnionArray::try_new( |
1188 | | // multiple fields |
1189 | | str1_int3(), |
1190 | | ScalarBuffer::from(vec![3, 3, 1, 1, 1]), //some matches |
1191 | | Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), // dense |
1192 | | vec![ |
1193 | | Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // non null |
1194 | | Arc::new(Int32Array::from(vec![1, 2])), |
1195 | | ], |
1196 | | ) |
1197 | | .unwrap(); |
1198 | | |
1199 | | let expected = Int32Array::from(vec![Some(1), Some(2), None, None, None]); |
1200 | | let extracted = union_extract(&union, "int").unwrap(); |
1201 | | |
1202 | | assert_eq!(extracted.into_data(), expected.into_data()); |
1203 | | } |
1204 | | |
1205 | | #[test] |
1206 | | fn empty_sparse_union() { |
1207 | | let union = UnionArray::try_new( |
1208 | | UnionFields::empty(), |
1209 | | ScalarBuffer::from(vec![]), |
1210 | | None, |
1211 | | vec![], |
1212 | | ) |
1213 | | .unwrap(); |
1214 | | |
1215 | | assert_eq!( |
1216 | | union_extract(&union, "a").unwrap_err().to_string(), |
1217 | | ArrowError::InvalidArgumentError("field a not found on union".into()).to_string() |
1218 | | ); |
1219 | | } |
1220 | | |
1221 | | #[test] |
1222 | | fn empty_dense_union() { |
1223 | | let union = UnionArray::try_new( |
1224 | | UnionFields::empty(), |
1225 | | ScalarBuffer::from(vec![]), |
1226 | | Some(ScalarBuffer::from(vec![])), |
1227 | | vec![], |
1228 | | ) |
1229 | | .unwrap(); |
1230 | | |
1231 | | assert_eq!( |
1232 | | union_extract(&union, "a").unwrap_err().to_string(), |
1233 | | ArrowError::InvalidArgumentError("field a not found on union".into()).to_string() |
1234 | | ); |
1235 | | } |
1236 | | } |