Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-array/src/array/union_array.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
#![allow(clippy::enum_clike_unportable_variant)]
18
19
use crate::{make_array, Array, ArrayRef};
20
use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks};
21
use arrow_buffer::buffer::NullBuffer;
22
use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
23
use arrow_data::{ArrayData, ArrayDataBuilder};
24
use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode};
25
/// Contains the `UnionArray` type.
26
///
27
use std::any::Any;
28
use std::collections::HashSet;
29
use std::sync::Arc;
30
31
/// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout)
32
///
33
/// Each slot in a [UnionArray] can have a value chosen from a number
34
/// of types.  Each of the possible types are named like the fields of
35
/// a [`StructArray`](crate::StructArray).  A `UnionArray` can
36
/// have two possible memory layouts, "dense" or "sparse".  For more
37
/// information on please see the
38
/// [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout).
39
///
40
/// [UnionBuilder](crate::builder::UnionBuilder) can be used to
41
/// create [UnionArray]'s of primitive types. `UnionArray`'s of nested
42
/// types are also supported but not via `UnionBuilder`, see the tests
43
/// for examples.
44
///
45
/// # Examples
46
/// ## Create a dense UnionArray `[1, 3.2, 34]`
47
/// ```
48
/// use arrow_buffer::ScalarBuffer;
49
/// use arrow_schema::*;
50
/// use std::sync::Arc;
51
/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray};
52
///
53
/// let int_array = Int32Array::from(vec![1, 34]);
54
/// let float_array = Float64Array::from(vec![3.2]);
55
/// let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
56
/// let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
57
///
58
/// let union_fields = [
59
///     (0, Arc::new(Field::new("A", DataType::Int32, false))),
60
///     (1, Arc::new(Field::new("B", DataType::Float64, false))),
61
/// ].into_iter().collect::<UnionFields>();
62
///
63
/// let children = vec![
64
///     Arc::new(int_array) as Arc<dyn Array>,
65
///     Arc::new(float_array),
66
/// ];
67
///
68
/// let array = UnionArray::try_new(
69
///     union_fields,
70
///     type_ids,
71
///     Some(offsets),
72
///     children,
73
/// ).unwrap();
74
///
75
/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
76
/// assert_eq!(1, value);
77
///
78
/// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0);
79
/// assert!(3.2 - value < f64::EPSILON);
80
///
81
/// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
82
/// assert_eq!(34, value);
83
/// ```
84
///
85
/// ## Create a sparse UnionArray `[1, 3.2, 34]`
86
/// ```
87
/// use arrow_buffer::ScalarBuffer;
88
/// use arrow_schema::*;
89
/// use std::sync::Arc;
90
/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray};
91
///
92
/// let int_array = Int32Array::from(vec![Some(1), None, Some(34)]);
93
/// let float_array = Float64Array::from(vec![None, Some(3.2), None]);
94
/// let type_ids = [0_i8, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
95
///
96
/// let union_fields = [
97
///     (0, Arc::new(Field::new("A", DataType::Int32, false))),
98
///     (1, Arc::new(Field::new("B", DataType::Float64, false))),
99
/// ].into_iter().collect::<UnionFields>();
100
///
101
/// let children = vec![
102
///     Arc::new(int_array) as Arc<dyn Array>,
103
///     Arc::new(float_array),
104
/// ];
105
///
106
/// let array = UnionArray::try_new(
107
///     union_fields,
108
///     type_ids,
109
///     None,
110
///     children,
111
/// ).unwrap();
112
///
113
/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
114
/// assert_eq!(1, value);
115
///
116
/// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0);
117
/// assert!(3.2 - value < f64::EPSILON);
118
///
119
/// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
120
/// assert_eq!(34, value);
121
/// ```
122
#[derive(Clone)]
123
pub struct UnionArray {
124
    data_type: DataType,
125
    type_ids: ScalarBuffer<i8>,
126
    offsets: Option<ScalarBuffer<i32>>,
127
    fields: Vec<Option<ArrayRef>>,
128
}
129
130
impl UnionArray {
131
    /// Creates a new `UnionArray`.
132
    ///
133
    /// Accepts type ids, child arrays and optionally offsets (for dense unions) to create
134
    /// a new `UnionArray`.  This method makes no attempt to validate the data provided by the
135
    /// caller and assumes that each of the components are correct and consistent with each other.
136
    /// See `try_new` for an alternative that validates the data provided.
137
    ///
138
    /// # Safety
139
    ///
140
    /// The `type_ids` values should be positive and must match one of the type ids of the fields provided in `fields`.
141
    /// These values are used to index into the `children` arrays.
142
    ///
143
    /// The `offsets` is provided in the case of a dense union, sparse unions should use `None`.
144
    /// If provided the `offsets` values should be positive and must be less than the length of the
145
    /// corresponding array.
146
    ///
147
    /// In both cases above we use signed integer types to maintain compatibility with other
148
    /// Arrow implementations.
149
0
    pub unsafe fn new_unchecked(
150
0
        fields: UnionFields,
151
0
        type_ids: ScalarBuffer<i8>,
152
0
        offsets: Option<ScalarBuffer<i32>>,
153
0
        children: Vec<ArrayRef>,
154
0
    ) -> Self {
155
0
        let mode = if offsets.is_some() {
156
0
            UnionMode::Dense
157
        } else {
158
0
            UnionMode::Sparse
159
        };
160
161
0
        let len = type_ids.len();
162
0
        let builder = ArrayData::builder(DataType::Union(fields, mode))
163
0
            .add_buffer(type_ids.into_inner())
164
0
            .child_data(children.into_iter().map(Array::into_data).collect())
165
0
            .len(len);
166
167
0
        let data = match offsets {
168
0
            Some(offsets) => builder.add_buffer(offsets.into_inner()).build_unchecked(),
169
0
            None => builder.build_unchecked(),
170
        };
171
0
        Self::from(data)
172
0
    }
173
174
    /// Attempts to create a new `UnionArray`, validating the inputs provided.
175
    ///
176
    /// The order of child arrays child array order must match the fields order
177
0
    pub fn try_new(
178
0
        fields: UnionFields,
179
0
        type_ids: ScalarBuffer<i8>,
180
0
        offsets: Option<ScalarBuffer<i32>>,
181
0
        children: Vec<ArrayRef>,
182
0
    ) -> Result<Self, ArrowError> {
183
        // There must be a child array for every field.
184
0
        if fields.len() != children.len() {
185
0
            return Err(ArrowError::InvalidArgumentError(
186
0
                "Union fields length must match child arrays length".to_string(),
187
0
            ));
188
0
        }
189
190
0
        if let Some(offsets) = &offsets {
191
            // There must be an offset value for every type id value.
192
0
            if offsets.len() != type_ids.len() {
193
0
                return Err(ArrowError::InvalidArgumentError(
194
0
                    "Type Ids and Offsets lengths must match".to_string(),
195
0
                ));
196
0
            }
197
        } else {
198
            // Sparse union child arrays must be equal in length to the length of the union
199
0
            for child in &children {
200
0
                if child.len() != type_ids.len() {
201
0
                    return Err(ArrowError::InvalidArgumentError(
202
0
                        "Sparse union child arrays must be equal in length to the length of the union".to_string(),
203
0
                    ));
204
0
                }
205
            }
206
        }
207
208
        // Create mapping from type id to array lengths.
209
0
        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
210
0
        let mut array_lens = vec![i32::MIN; max_id + 1];
211
0
        for (cd, (field_id, _)) in children.iter().zip(fields.iter()) {
212
0
            array_lens[field_id as usize] = cd.len() as i32;
213
0
        }
214
215
        // Type id values must match one of the fields.
216
0
        for id in &type_ids {
217
0
            match array_lens.get(*id as usize) {
218
0
                Some(x) if *x != i32::MIN => {}
219
                _ => {
220
0
                    return Err(ArrowError::InvalidArgumentError(
221
0
                        "Type Ids values must match one of the field type ids".to_owned(),
222
0
                    ))
223
                }
224
            }
225
        }
226
227
        // Check the value offsets are in bounds.
228
0
        if let Some(offsets) = &offsets {
229
0
            let mut iter = type_ids.iter().zip(offsets.iter());
230
0
            if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize])
231
            {
232
0
                return Err(ArrowError::InvalidArgumentError(
233
0
                    "Offsets must be positive and within the length of the Array".to_owned(),
234
0
                ));
235
0
            }
236
0
        }
237
238
        // Safety:
239
        // - Arguments validated above.
240
0
        let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) };
241
0
        Ok(union_array)
242
0
    }
243
244
    /// Accesses the child array for `type_id`.
245
    ///
246
    /// # Panics
247
    ///
248
    /// Panics if the `type_id` provided is not present in the array's DataType
249
    /// in the `Union`.
250
0
    pub fn child(&self, type_id: i8) -> &ArrayRef {
251
0
        assert!((type_id as usize) < self.fields.len());
252
0
        let boxed = &self.fields[type_id as usize];
253
0
        boxed.as_ref().expect("invalid type id")
254
0
    }
255
256
    /// Returns the `type_id` for the array slot at `index`.
257
    ///
258
    /// # Panics
259
    ///
260
    /// Panics if `index` is greater than or equal to the number of child arrays
261
0
    pub fn type_id(&self, index: usize) -> i8 {
262
0
        assert!(index < self.type_ids.len());
263
0
        self.type_ids[index]
264
0
    }
265
266
    /// Returns the `type_ids` buffer for this array
267
0
    pub fn type_ids(&self) -> &ScalarBuffer<i8> {
268
0
        &self.type_ids
269
0
    }
270
271
    /// Returns the `offsets` buffer if this is a dense array
272
0
    pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
273
0
        self.offsets.as_ref()
274
0
    }
275
276
    /// Returns the offset into the underlying values array for the array slot at `index`.
277
    ///
278
    /// # Panics
279
    ///
280
    /// Panics if `index` is greater than or equal the length of the array.
281
0
    pub fn value_offset(&self, index: usize) -> usize {
282
0
        assert!(index < self.len());
283
0
        match &self.offsets {
284
0
            Some(offsets) => offsets[index] as usize,
285
0
            None => self.offset() + index,
286
        }
287
0
    }
288
289
    /// Returns the array's value at index `i`.
290
    ///
291
    /// Note: This method does not check for nulls and the value is arbitrary
292
    /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index.
293
    ///
294
    /// # Panics
295
    /// Panics if index `i` is out of bounds
296
0
    pub fn value(&self, i: usize) -> ArrayRef {
297
0
        let type_id = self.type_id(i);
298
0
        let value_offset = self.value_offset(i);
299
0
        let child = self.child(type_id);
300
0
        child.slice(value_offset, 1)
301
0
    }
302
303
    /// Returns the names of the types in the union.
304
0
    pub fn type_names(&self) -> Vec<&str> {
305
0
        match self.data_type() {
306
0
            DataType::Union(fields, _) => fields
307
0
                .iter()
308
0
                .map(|(_, f)| f.name().as_str())
309
0
                .collect::<Vec<&str>>(),
310
0
            _ => unreachable!("Union array's data type is not a union!"),
311
        }
312
0
    }
313
314
    /// Returns whether the `UnionArray` is dense (or sparse if `false`).
315
0
    fn is_dense(&self) -> bool {
316
0
        match self.data_type() {
317
0
            DataType::Union(_, mode) => mode == &UnionMode::Dense,
318
0
            _ => unreachable!("Union array's data type is not a union!"),
319
        }
320
0
    }
321
322
    /// Returns a zero-copy slice of this array with the indicated offset and length.
323
0
    pub fn slice(&self, offset: usize, length: usize) -> Self {
324
0
        let (offsets, fields) = match self.offsets.as_ref() {
325
            // If dense union, slice offsets
326
0
            Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()),
327
            // Otherwise need to slice sparse children
328
            None => {
329
0
                let fields = self
330
0
                    .fields
331
0
                    .iter()
332
0
                    .map(|x| x.as_ref().map(|x| x.slice(offset, length)))
333
0
                    .collect();
334
0
                (None, fields)
335
            }
336
        };
337
338
0
        Self {
339
0
            data_type: self.data_type.clone(),
340
0
            type_ids: self.type_ids.slice(offset, length),
341
0
            offsets,
342
0
            fields,
343
0
        }
344
0
    }
345
346
    /// Deconstruct this array into its constituent parts
347
    ///
348
    /// # Example
349
    ///
350
    /// ```
351
    /// # use arrow_array::array::UnionArray;
352
    /// # use arrow_array::types::Int32Type;
353
    /// # use arrow_array::builder::UnionBuilder;
354
    /// # use arrow_buffer::ScalarBuffer;
355
    /// # fn main() -> Result<(), arrow_schema::ArrowError> {
356
    /// let mut builder = UnionBuilder::new_dense();
357
    /// builder.append::<Int32Type>("a", 1).unwrap();
358
    /// let union_array = builder.build()?;
359
    ///
360
    /// // Deconstruct into parts
361
    /// let (union_fields, type_ids, offsets, children) = union_array.into_parts();
362
    ///
363
    /// // Reconstruct from parts
364
    /// let union_array = UnionArray::try_new(
365
    ///     union_fields,
366
    ///     type_ids,
367
    ///     offsets,
368
    ///     children,
369
    /// );
370
    /// # Ok(())
371
    /// # }
372
    /// ```
373
    #[allow(clippy::type_complexity)]
374
0
    pub fn into_parts(
375
0
        self,
376
0
    ) -> (
377
0
        UnionFields,
378
0
        ScalarBuffer<i8>,
379
0
        Option<ScalarBuffer<i32>>,
380
0
        Vec<ArrayRef>,
381
0
    ) {
382
        let Self {
383
0
            data_type,
384
0
            type_ids,
385
0
            offsets,
386
0
            mut fields,
387
0
        } = self;
388
0
        match data_type {
389
0
            DataType::Union(union_fields, _) => {
390
0
                let children = union_fields
391
0
                    .iter()
392
0
                    .map(|(type_id, _)| fields[type_id as usize].take().unwrap())
393
0
                    .collect();
394
0
                (union_fields, type_ids, offsets, children)
395
            }
396
0
            _ => unreachable!(),
397
        }
398
0
    }
399
400
    /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields without nulls
401
0
    fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
402
        // Example logic for a union with 5 fields, a, b & c with nulls, d & e without nulls:
403
        // let [a_nulls, b_nulls, c_nulls] = nulls;
404
        // let [is_a, is_b, is_c] = masks;
405
        // let is_d_or_e = !(is_a | is_b | is_c)
406
        // let union_chunk_nulls = is_d_or_e  | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
407
0
        let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| {
408
0
            (
409
0
                with_nulls_selected | is_field,
410
0
                union_nulls | (is_field & field_nulls),
411
0
            )
412
0
        };
413
414
0
        self.mask_sparse_helper(
415
0
            nulls,
416
0
            |type_ids_chunk_array, nulls_masks_iters| {
417
0
                let (with_nulls_selected, union_nulls) = nulls_masks_iters
418
0
                    .iter_mut()
419
0
                    .map(|(field_type_id, field_nulls)| {
420
0
                        let field_nulls = field_nulls.next().unwrap();
421
0
                        let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
422
423
0
                        (is_field, field_nulls)
424
0
                    })
425
0
                    .fold((0, 0), fold);
426
427
                // In the example above, this is the is_d_or_e = !(is_a | is_b) part
428
0
                let without_nulls_selected = !with_nulls_selected;
429
430
                // if a field without nulls is selected, the value is always true(set bit)
431
                // otherwise, the true/set bits have been computed above
432
0
                without_nulls_selected | union_nulls
433
0
            },
434
0
            |type_ids_remainder, bit_chunks| {
435
0
                let (with_nulls_selected, union_nulls) = bit_chunks
436
0
                    .iter()
437
0
                    .map(|(field_type_id, field_bit_chunks)| {
438
0
                        let field_nulls = field_bit_chunks.remainder_bits();
439
0
                        let is_field = selection_mask(type_ids_remainder, *field_type_id);
440
441
0
                        (is_field, field_nulls)
442
0
                    })
443
0
                    .fold((0, 0), fold);
444
445
0
                let without_nulls_selected = !with_nulls_selected;
446
447
0
                without_nulls_selected | union_nulls
448
0
            },
449
        )
450
0
    }
451
452
    /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields fully null
453
0
    fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
454
0
        let fields = match self.data_type() {
455
0
            DataType::Union(fields, _) => fields,
456
0
            _ => unreachable!("Union array's data type is not a union!"),
457
        };
458
459
0
        let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>();
460
0
        let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>();
461
462
0
        let without_nulls_ids = type_ids
463
0
            .difference(&with_nulls)
464
0
            .copied()
465
0
            .collect::<Vec<_>>();
466
467
0
        nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len());
468
469
        // Example logic for a union with 6 fields, a, b & c with nulls, d & e without nulls, and f fully_null:
470
        // let [a_nulls, b_nulls, c_nulls] = nulls;
471
        // let [is_a, is_b, is_c, is_d, is_e] = masks;
472
        // let union_chunk_nulls = is_d | is_e | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
473
0
        self.mask_sparse_helper(
474
0
            nulls,
475
0
            |type_ids_chunk_array, nulls_masks_iters| {
476
0
                let union_nulls = nulls_masks_iters.iter_mut().fold(
477
                    0,
478
0
                    |union_nulls, (field_type_id, nulls_iter)| {
479
0
                        let field_nulls = nulls_iter.next().unwrap();
480
481
0
                        if field_nulls == 0 {
482
0
                            union_nulls
483
                        } else {
484
0
                            let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
485
486
0
                            union_nulls | (is_field & field_nulls)
487
                        }
488
0
                    },
489
                );
490
491
                // Given the example above, this is the is_d_or_e = (is_d | is_e) part
492
0
                let without_nulls_selected =
493
0
                    without_nulls_selected(type_ids_chunk_array, &without_nulls_ids);
494
495
                // if a field without nulls is selected, the value is always true(set bit)
496
                // otherwise, the true/set bits have been computed above
497
0
                union_nulls | without_nulls_selected
498
0
            },
499
0
            |type_ids_remainder, bit_chunks| {
500
0
                let union_nulls =
501
0
                    bit_chunks
502
0
                        .iter()
503
0
                        .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
504
0
                            let is_field = selection_mask(type_ids_remainder, *field_type_id);
505
0
                            let field_nulls = field_bit_chunks.remainder_bits();
506
507
0
                            union_nulls | is_field & field_nulls
508
0
                        });
509
510
0
                union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids)
511
0
            },
512
        )
513
0
    }
514
515
    /// Computes the logical nulls for a sparse union, optimized for when all fields contains nulls
516
0
    fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
517
        // Example logic for a union with 3 fields, a, b & c, all containing nulls:
518
        // let [a_nulls, b_nulls, c_nulls] = nulls;
519
        // We can skip the first field: it's selection mask is the negation of all others selection mask
520
        // let [is_b, is_c] = selection_masks;
521
        // let is_a = !(is_b | is_c)
522
        // let union_chunk_nulls = (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
523
0
        self.mask_sparse_helper(
524
0
            nulls,
525
0
            |type_ids_chunk_array, nulls_masks_iters| {
526
0
                let (is_not_first, union_nulls) = nulls_masks_iters[1..] // skip first
527
0
                    .iter_mut()
528
0
                    .fold(
529
0
                        (0, 0),
530
0
                        |(is_not_first, union_nulls), (field_type_id, nulls_iter)| {
531
0
                            let field_nulls = nulls_iter.next().unwrap();
532
0
                            let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
533
534
0
                            (
535
0
                                is_not_first | is_field,
536
0
                                union_nulls | (is_field & field_nulls),
537
0
                            )
538
0
                        },
539
                    );
540
541
0
                let is_first = !is_not_first;
542
0
                let first_nulls = nulls_masks_iters[0].1.next().unwrap();
543
544
0
                (is_first & first_nulls) | union_nulls
545
0
            },
546
0
            |type_ids_remainder, bit_chunks| {
547
0
                bit_chunks
548
0
                    .iter()
549
0
                    .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
550
0
                        let field_nulls = field_bit_chunks.remainder_bits();
551
                        // The same logic as above, except that since this runs at most once,
552
                        // it doesn't make difference to speed-up the first selection mask
553
0
                        let is_field = selection_mask(type_ids_remainder, *field_type_id);
554
555
0
                        union_nulls | (is_field & field_nulls)
556
0
                    })
557
0
            },
558
        )
559
0
    }
560
561
    /// Maps `nulls` to `BitChunk's` and then to `BitChunkIterator's`, then divides `self.type_ids` into exact chunks of 64 values,
562
    /// calling `mask_chunk` for every exact chunk, and `mask_remainder` for the remainder, if any, collecting the result in a `BooleanBuffer`
563
0
    fn mask_sparse_helper(
564
0
        &self,
565
0
        nulls: Vec<(i8, NullBuffer)>,
566
0
        mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64,
567
0
        mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64,
568
0
    ) -> BooleanBuffer {
569
0
        let bit_chunks = nulls
570
0
            .iter()
571
0
            .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks()))
572
0
            .collect::<Vec<_>>();
573
574
0
        let mut nulls_masks_iter = bit_chunks
575
0
            .iter()
576
0
            .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter()))
577
0
            .collect::<Vec<_>>();
578
579
0
        let chunks_exact = self.type_ids.chunks_exact(64);
580
0
        let remainder = chunks_exact.remainder();
581
582
0
        let chunks = chunks_exact.map(|type_ids_chunk| {
583
0
            let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap();
584
585
0
            mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter)
586
0
        });
587
588
        // SAFETY:
589
        // chunks is a ChunksExact iterator, which implements TrustedLen, and correctly reports its length
590
0
        let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) };
591
592
0
        if !remainder.is_empty() {
593
0
            buffer.push(mask_remainder(remainder, &bit_chunks));
594
0
        }
595
596
0
        BooleanBuffer::new(buffer.into(), 0, self.type_ids.len())
597
0
    }
598
599
    /// Computes the logical nulls for a sparse or dense union, by gathering individual bits from the null buffer of the selected field
600
0
    fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
601
0
        let one_null = NullBuffer::new_null(1);
602
0
        let one_valid = NullBuffer::new_valid(1);
603
604
        // Unsafe code below depend on it:
605
        // To remove one branch from the loop, if the a type_id is not utilized, or it's logical_nulls is None/all set,
606
        // we use a null buffer of len 1 and a index_mask of 0, or the true null buffer and usize::MAX otherwise.
607
        // We then unconditionally access the null buffer with index & index_mask,
608
        // which always return 0 for the 1-len buffer, or the true index unchanged otherwise
609
        // We also use a 256 array, so llvm knows that `type_id as u8 as usize` is always in bounds
610
0
        let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256];
611
612
0
        for (type_id, nulls) in &nulls {
613
0
            if nulls.null_count() == nulls.len() {
614
0
                // Similarly, if all values are null, use a 1-null null-buffer to reduce cache pressure a bit
615
0
                logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero);
616
0
            } else {
617
0
                logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max);
618
0
            }
619
        }
620
621
0
        match &self.offsets {
622
0
            Some(offsets) => {
623
0
                assert_eq!(self.type_ids.len(), offsets.len());
624
625
0
                BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe {
626
                    // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len()
627
0
                    let type_id = *self.type_ids.get_unchecked(i);
628
                    // SAFETY: We asserted that offsets len and self.type_ids len are equal
629
0
                    let offset = *offsets.get_unchecked(i);
630
631
0
                    let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize];
632
633
                    // SAFETY:
634
                    // If offset_mask is Max
635
                    // 1. Offset validity is checked at union creation
636
                    // 2. If the null buffer len equals it's array len is checked at array creation
637
                    // If offset_mask is Zero, the null buffer len is 1
638
0
                    nulls
639
0
                        .inner()
640
0
                        .value_unchecked(offset as usize & *offset_mask as usize)
641
0
                })
642
            }
643
            None => {
644
0
                BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe {
645
                    // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len()
646
0
                    let type_id = *self.type_ids.get_unchecked(index);
647
648
0
                    let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize];
649
650
                    // SAFETY:
651
                    // If index_mask is Max
652
                    // 1. On sparse union, every child len match it's parent, this is checked at union creation
653
                    // 2. If the null buffer len equals it's array len is checked at array creation
654
                    // If index_mask is Zero, the null buffer len is 1
655
0
                    nulls.inner().value_unchecked(index & *index_mask as usize)
656
0
                })
657
            }
658
        }
659
0
    }
660
661
    /// Returns a vector of tuples containing each field's type_id and its logical null buffer.
662
    /// Only fields with non-zero null counts are included.
663
0
    fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> {
664
0
        self.fields
665
0
            .iter()
666
0
            .enumerate()
667
0
            .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?)))
668
0
            .filter(|(_, nulls)| nulls.null_count() > 0)
669
0
            .collect()
670
0
    }
671
}
672
673
impl From<ArrayData> for UnionArray {
674
0
    fn from(data: ArrayData) -> Self {
675
0
        let (fields, mode) = match data.data_type() {
676
0
            DataType::Union(fields, mode) => (fields, *mode),
677
0
            d => panic!("UnionArray expected ArrayData with type Union got {d}"),
678
        };
679
0
        let (type_ids, offsets) = match mode {
680
0
            UnionMode::Sparse => (
681
0
                ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
682
0
                None,
683
0
            ),
684
0
            UnionMode::Dense => (
685
0
                ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
686
0
                Some(ScalarBuffer::new(
687
0
                    data.buffers()[1].clone(),
688
0
                    data.offset(),
689
0
                    data.len(),
690
0
                )),
691
0
            ),
692
        };
693
694
0
        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
695
0
        let mut boxed_fields = vec![None; max_id + 1];
696
0
        for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
697
0
            boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
698
0
        }
699
0
        Self {
700
0
            data_type: data.data_type().clone(),
701
0
            type_ids,
702
0
            offsets,
703
0
            fields: boxed_fields,
704
0
        }
705
0
    }
706
}
707
708
impl From<UnionArray> for ArrayData {
709
0
    fn from(array: UnionArray) -> Self {
710
0
        let len = array.len();
711
0
        let f = match &array.data_type {
712
0
            DataType::Union(f, _) => f,
713
0
            _ => unreachable!(),
714
        };
715
0
        let buffers = match array.offsets {
716
0
            Some(o) => vec![array.type_ids.into_inner(), o.into_inner()],
717
0
            None => vec![array.type_ids.into_inner()],
718
        };
719
720
0
        let child = f
721
0
            .iter()
722
0
            .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data())
723
0
            .collect();
724
725
0
        let builder = ArrayDataBuilder::new(array.data_type)
726
0
            .len(len)
727
0
            .buffers(buffers)
728
0
            .child_data(child);
729
0
        unsafe { builder.build_unchecked() }
730
0
    }
731
}
732
733
impl Array for UnionArray {
734
0
    fn as_any(&self) -> &dyn Any {
735
0
        self
736
0
    }
737
738
0
    fn to_data(&self) -> ArrayData {
739
0
        self.clone().into()
740
0
    }
741
742
0
    fn into_data(self) -> ArrayData {
743
0
        self.into()
744
0
    }
745
746
0
    fn data_type(&self) -> &DataType {
747
0
        &self.data_type
748
0
    }
749
750
0
    fn slice(&self, offset: usize, length: usize) -> ArrayRef {
751
0
        Arc::new(self.slice(offset, length))
752
0
    }
753
754
0
    fn len(&self) -> usize {
755
0
        self.type_ids.len()
756
0
    }
757
758
0
    fn is_empty(&self) -> bool {
759
0
        self.type_ids.is_empty()
760
0
    }
761
762
0
    fn shrink_to_fit(&mut self) {
763
0
        self.type_ids.shrink_to_fit();
764
0
        if let Some(offsets) = &mut self.offsets {
765
0
            offsets.shrink_to_fit();
766
0
        }
767
0
        for array in self.fields.iter_mut().flatten() {
768
0
            array.shrink_to_fit();
769
0
        }
770
0
        self.fields.shrink_to_fit();
771
0
    }
772
773
0
    fn offset(&self) -> usize {
774
0
        0
775
0
    }
776
777
0
    fn nulls(&self) -> Option<&NullBuffer> {
778
0
        None
779
0
    }
780
781
0
    fn logical_nulls(&self) -> Option<NullBuffer> {
782
0
        let fields = match self.data_type() {
783
0
            DataType::Union(fields, _) => fields,
784
0
            _ => unreachable!(),
785
        };
786
787
0
        if fields.len() <= 1 {
788
0
            return self.fields.iter().find_map(|field_opt| {
789
0
                field_opt
790
0
                    .as_ref()
791
0
                    .and_then(|field| field.logical_nulls())
792
0
                    .map(|logical_nulls| {
793
0
                        if self.is_dense() {
794
0
                            self.gather_nulls(vec![(0, logical_nulls)]).into()
795
                        } else {
796
0
                            logical_nulls
797
                        }
798
0
                    })
799
0
            });
800
0
        }
801
802
0
        let logical_nulls = self.fields_logical_nulls();
803
804
0
        if logical_nulls.is_empty() {
805
0
            return None;
806
0
        }
807
808
0
        let fully_null_count = logical_nulls
809
0
            .iter()
810
0
            .filter(|(_, nulls)| nulls.null_count() == nulls.len())
811
0
            .count();
812
813
0
        if fully_null_count == fields.len() {
814
0
            if let Some((_, exactly_sized)) = logical_nulls
815
0
                .iter()
816
0
                .find(|(_, nulls)| nulls.len() == self.len())
817
            {
818
0
                return Some(exactly_sized.clone());
819
0
            }
820
821
0
            if let Some((_, bigger)) = logical_nulls
822
0
                .iter()
823
0
                .find(|(_, nulls)| nulls.len() > self.len())
824
            {
825
0
                return Some(bigger.slice(0, self.len()));
826
0
            }
827
828
0
            return Some(NullBuffer::new_null(self.len()));
829
0
        }
830
831
0
        let boolean_buffer = match &self.offsets {
832
0
            Some(_) => self.gather_nulls(logical_nulls),
833
            None => {
834
                // Choose the fastest way to compute the logical nulls
835
                // Gather computes one null per iteration, while the others work on 64 nulls chunks,
836
                // but must also compute selection masks, which is expensive,
837
                // so it's cost is the number of selection masks computed per chunk
838
                // Since computing the selection mask gets auto-vectorized, it's performance depends on which simd feature is enabled
839
                // For gather, the cost is the threshold where masking becomes slower than gather, which is determined with benchmarks
840
                // TODO: bench on avx512f(feature is still unstable)
841
0
                let gather_relative_cost = if cfg!(target_feature = "avx2") {
842
0
                    10
843
0
                } else if cfg!(target_feature = "sse4.1") {
844
0
                    3
845
0
                } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
846
                    // x86 baseline includes sse2
847
0
                    2
848
                } else {
849
                    // TODO: bench on non x86
850
                    // Always use gather on non benchmarked archs because even though it may slower on some cases,
851
                    // it's performance depends only on the union length, without being affected by the number of fields
852
0
                    0
853
                };
854
855
0
                let strategies = [
856
0
                    (SparseStrategy::Gather, gather_relative_cost, true),
857
0
                    (
858
0
                        SparseStrategy::MaskAllFieldsWithNullsSkipOne,
859
0
                        fields.len() - 1,
860
0
                        fields.len() == logical_nulls.len(),
861
0
                    ),
862
0
                    (
863
0
                        SparseStrategy::MaskSkipWithoutNulls,
864
0
                        logical_nulls.len(),
865
0
                        true,
866
0
                    ),
867
0
                    (
868
0
                        SparseStrategy::MaskSkipFullyNull,
869
0
                        fields.len() - fully_null_count,
870
0
                        true,
871
0
                    ),
872
0
                ];
873
874
0
                let (strategy, _, _) = strategies
875
0
                    .iter()
876
0
                    .filter(|(_, _, applicable)| *applicable)
877
0
                    .min_by_key(|(_, cost, _)| cost)
878
0
                    .unwrap();
879
880
0
                match strategy {
881
0
                    SparseStrategy::Gather => self.gather_nulls(logical_nulls),
882
                    SparseStrategy::MaskAllFieldsWithNullsSkipOne => {
883
0
                        self.mask_sparse_all_with_nulls_skip_one(logical_nulls)
884
                    }
885
                    SparseStrategy::MaskSkipWithoutNulls => {
886
0
                        self.mask_sparse_skip_without_nulls(logical_nulls)
887
                    }
888
                    SparseStrategy::MaskSkipFullyNull => {
889
0
                        self.mask_sparse_skip_fully_null(logical_nulls)
890
                    }
891
                }
892
            }
893
        };
894
895
0
        let null_buffer = NullBuffer::from(boolean_buffer);
896
897
0
        if null_buffer.null_count() > 0 {
898
0
            Some(null_buffer)
899
        } else {
900
0
            None
901
        }
902
0
    }
903
904
0
    fn is_nullable(&self) -> bool {
905
0
        self.fields
906
0
            .iter()
907
0
            .flatten()
908
0
            .any(|field| field.is_nullable())
909
0
    }
910
911
0
    fn get_buffer_memory_size(&self) -> usize {
912
0
        let mut sum = self.type_ids.inner().capacity();
913
0
        if let Some(o) = self.offsets.as_ref() {
914
0
            sum += o.inner().capacity()
915
0
        }
916
0
        self.fields
917
0
            .iter()
918
0
            .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size()))
919
0
            .sum::<usize>()
920
0
            + sum
921
0
    }
922
923
0
    fn get_array_memory_size(&self) -> usize {
924
0
        let mut sum = self.type_ids.inner().capacity();
925
0
        if let Some(o) = self.offsets.as_ref() {
926
0
            sum += o.inner().capacity()
927
0
        }
928
0
        std::mem::size_of::<Self>()
929
0
            + self
930
0
                .fields
931
0
                .iter()
932
0
                .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size()))
933
0
                .sum::<usize>()
934
0
            + sum
935
0
    }
936
}
937
938
impl std::fmt::Debug for UnionArray {
939
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
940
0
        let header = if self.is_dense() {
941
0
            "UnionArray(Dense)\n["
942
        } else {
943
0
            "UnionArray(Sparse)\n["
944
        };
945
0
        writeln!(f, "{header}")?;
946
947
0
        writeln!(f, "-- type id buffer:")?;
948
0
        writeln!(f, "{:?}", self.type_ids)?;
949
950
0
        if let Some(offsets) = &self.offsets {
951
0
            writeln!(f, "-- offsets buffer:")?;
952
0
            writeln!(f, "{offsets:?}")?;
953
0
        }
954
955
0
        let fields = match self.data_type() {
956
0
            DataType::Union(fields, _) => fields,
957
0
            _ => unreachable!(),
958
        };
959
960
0
        for (type_id, field) in fields.iter() {
961
0
            let child = self.child(type_id);
962
0
            writeln!(
963
0
                f,
964
0
                "-- child {}: \"{}\" ({:?})",
965
                type_id,
966
0
                field.name(),
967
0
                field.data_type()
968
0
            )?;
969
0
            std::fmt::Debug::fmt(child, f)?;
970
0
            writeln!(f)?;
971
        }
972
0
        writeln!(f, "]")
973
0
    }
974
}
975
976
/// How to compute the logical nulls of a sparse union. All strategies return the same result.
977
/// Those starting with Mask perform bitwise masking for each chunk of 64 values, including
978
/// computing expensive selection masks of fields: which fields masks must be computed is the
979
/// difference between them
980
enum SparseStrategy {
981
    /// Gather individual bits from the null buffer of the selected field
982
    Gather,
983
    /// All fields contains nulls, so we can skip the selection mask computation of one field by negating the others
984
    MaskAllFieldsWithNullsSkipOne,
985
    /// Skip the selection mask computation of the fields without nulls
986
    MaskSkipWithoutNulls,
987
    /// Skip the selection mask computation of the fully nulls fields
988
    MaskSkipFullyNull,
989
}
990
991
#[derive(Copy, Clone)]
992
#[repr(usize)]
993
enum Mask {
994
    Zero = 0,
995
    // false positive, see https://github.com/rust-lang/rust-clippy/issues/8043
996
    #[allow(clippy::enum_clike_unportable_variant)]
997
    Max = usize::MAX,
998
}
999
1000
0
fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 {
1001
0
    type_ids_chunk
1002
0
        .iter()
1003
0
        .copied()
1004
0
        .enumerate()
1005
0
        .fold(0, |packed, (bit_idx, v)| {
1006
0
            packed | (((v == type_id) as u64) << bit_idx)
1007
0
        })
1008
0
}
1009
1010
/// Returns a bitmask where bits indicate if any id from `without_nulls_ids` exist in `type_ids_chunk`.
1011
0
fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 {
1012
0
    without_nulls_ids
1013
0
        .iter()
1014
0
        .fold(0, |fully_valid_selected, field_type_id| {
1015
0
            fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id)
1016
0
        })
1017
0
}
1018
1019
#[cfg(test)]
1020
mod tests {
1021
    use super::*;
1022
    use std::collections::HashSet;
1023
1024
    use crate::array::Int8Type;
1025
    use crate::builder::UnionBuilder;
1026
    use crate::cast::AsArray;
1027
    use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
1028
    use crate::{Float64Array, Int32Array, Int64Array, StringArray};
1029
    use crate::{Int8Array, RecordBatch};
1030
    use arrow_buffer::Buffer;
1031
    use arrow_schema::{Field, Schema};
1032
1033
    #[test]
1034
    fn test_dense_i32() {
1035
        let mut builder = UnionBuilder::new_dense();
1036
        builder.append::<Int32Type>("a", 1).unwrap();
1037
        builder.append::<Int32Type>("b", 2).unwrap();
1038
        builder.append::<Int32Type>("c", 3).unwrap();
1039
        builder.append::<Int32Type>("a", 4).unwrap();
1040
        builder.append::<Int32Type>("c", 5).unwrap();
1041
        builder.append::<Int32Type>("a", 6).unwrap();
1042
        builder.append::<Int32Type>("b", 7).unwrap();
1043
        let union = builder.build().unwrap();
1044
1045
        let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1046
        let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
1047
        let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1048
1049
        // Check type ids
1050
        assert_eq!(*union.type_ids(), expected_type_ids);
1051
        for (i, id) in expected_type_ids.iter().enumerate() {
1052
            assert_eq!(id, &union.type_id(i));
1053
        }
1054
1055
        // Check offsets
1056
        assert_eq!(*union.offsets().unwrap(), expected_offsets);
1057
        for (i, id) in expected_offsets.iter().enumerate() {
1058
            assert_eq!(union.value_offset(i), *id as usize);
1059
        }
1060
1061
        // Check data
1062
        assert_eq!(
1063
            *union.child(0).as_primitive::<Int32Type>().values(),
1064
            [1_i32, 4, 6]
1065
        );
1066
        assert_eq!(
1067
            *union.child(1).as_primitive::<Int32Type>().values(),
1068
            [2_i32, 7]
1069
        );
1070
        assert_eq!(
1071
            *union.child(2).as_primitive::<Int32Type>().values(),
1072
            [3_i32, 5]
1073
        );
1074
1075
        assert_eq!(expected_array_values.len(), union.len());
1076
        for (i, expected_value) in expected_array_values.iter().enumerate() {
1077
            assert!(!union.is_null(i));
1078
            let slot = union.value(i);
1079
            let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1080
            assert_eq!(slot.len(), 1);
1081
            let value = slot.value(0);
1082
            assert_eq!(expected_value, &value);
1083
        }
1084
    }
1085
1086
    #[test]
1087
    fn slice_union_array_single_field() {
1088
        // Dense Union
1089
        // [1, null, 3, null, 4]
1090
        let union_array = {
1091
            let mut builder = UnionBuilder::new_dense();
1092
            builder.append::<Int32Type>("a", 1).unwrap();
1093
            builder.append_null::<Int32Type>("a").unwrap();
1094
            builder.append::<Int32Type>("a", 3).unwrap();
1095
            builder.append_null::<Int32Type>("a").unwrap();
1096
            builder.append::<Int32Type>("a", 4).unwrap();
1097
            builder.build().unwrap()
1098
        };
1099
1100
        // [null, 3, null]
1101
        let union_slice = union_array.slice(1, 3);
1102
        let logical_nulls = union_slice.logical_nulls().unwrap();
1103
1104
        assert_eq!(logical_nulls.len(), 3);
1105
        assert!(logical_nulls.is_null(0));
1106
        assert!(logical_nulls.is_valid(1));
1107
        assert!(logical_nulls.is_null(2));
1108
    }
1109
1110
    #[test]
1111
    #[cfg_attr(miri, ignore)]
1112
    fn test_dense_i32_large() {
1113
        let mut builder = UnionBuilder::new_dense();
1114
1115
        let expected_type_ids = vec![0_i8; 1024];
1116
        let expected_offsets: Vec<_> = (0..1024).collect();
1117
        let expected_array_values: Vec<_> = (1..=1024).collect();
1118
1119
        expected_array_values
1120
            .iter()
1121
            .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap());
1122
1123
        let union = builder.build().unwrap();
1124
1125
        // Check type ids
1126
        assert_eq!(*union.type_ids(), expected_type_ids);
1127
        for (i, id) in expected_type_ids.iter().enumerate() {
1128
            assert_eq!(id, &union.type_id(i));
1129
        }
1130
1131
        // Check offsets
1132
        assert_eq!(*union.offsets().unwrap(), expected_offsets);
1133
        for (i, id) in expected_offsets.iter().enumerate() {
1134
            assert_eq!(union.value_offset(i), *id as usize);
1135
        }
1136
1137
        for (i, expected_value) in expected_array_values.iter().enumerate() {
1138
            assert!(!union.is_null(i));
1139
            let slot = union.value(i);
1140
            let slot = slot.as_primitive::<Int32Type>();
1141
            assert_eq!(slot.len(), 1);
1142
            let value = slot.value(0);
1143
            assert_eq!(expected_value, &value);
1144
        }
1145
    }
1146
1147
    #[test]
1148
    fn test_dense_mixed() {
1149
        let mut builder = UnionBuilder::new_dense();
1150
        builder.append::<Int32Type>("a", 1).unwrap();
1151
        builder.append::<Int64Type>("c", 3).unwrap();
1152
        builder.append::<Int32Type>("a", 4).unwrap();
1153
        builder.append::<Int64Type>("c", 5).unwrap();
1154
        builder.append::<Int32Type>("a", 6).unwrap();
1155
        let union = builder.build().unwrap();
1156
1157
        assert_eq!(5, union.len());
1158
        for i in 0..union.len() {
1159
            let slot = union.value(i);
1160
            assert!(!union.is_null(i));
1161
            match i {
1162
                0 => {
1163
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1164
                    assert_eq!(slot.len(), 1);
1165
                    let value = slot.value(0);
1166
                    assert_eq!(1_i32, value);
1167
                }
1168
                1 => {
1169
                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1170
                    assert_eq!(slot.len(), 1);
1171
                    let value = slot.value(0);
1172
                    assert_eq!(3_i64, value);
1173
                }
1174
                2 => {
1175
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1176
                    assert_eq!(slot.len(), 1);
1177
                    let value = slot.value(0);
1178
                    assert_eq!(4_i32, value);
1179
                }
1180
                3 => {
1181
                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1182
                    assert_eq!(slot.len(), 1);
1183
                    let value = slot.value(0);
1184
                    assert_eq!(5_i64, value);
1185
                }
1186
                4 => {
1187
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1188
                    assert_eq!(slot.len(), 1);
1189
                    let value = slot.value(0);
1190
                    assert_eq!(6_i32, value);
1191
                }
1192
                _ => unreachable!(),
1193
            }
1194
        }
1195
    }
1196
1197
    #[test]
1198
    fn test_dense_mixed_with_nulls() {
1199
        let mut builder = UnionBuilder::new_dense();
1200
        builder.append::<Int32Type>("a", 1).unwrap();
1201
        builder.append::<Int64Type>("c", 3).unwrap();
1202
        builder.append::<Int32Type>("a", 10).unwrap();
1203
        builder.append_null::<Int32Type>("a").unwrap();
1204
        builder.append::<Int32Type>("a", 6).unwrap();
1205
        let union = builder.build().unwrap();
1206
1207
        assert_eq!(5, union.len());
1208
        for i in 0..union.len() {
1209
            let slot = union.value(i);
1210
            match i {
1211
                0 => {
1212
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1213
                    assert!(!slot.is_null(0));
1214
                    assert_eq!(slot.len(), 1);
1215
                    let value = slot.value(0);
1216
                    assert_eq!(1_i32, value);
1217
                }
1218
                1 => {
1219
                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1220
                    assert!(!slot.is_null(0));
1221
                    assert_eq!(slot.len(), 1);
1222
                    let value = slot.value(0);
1223
                    assert_eq!(3_i64, value);
1224
                }
1225
                2 => {
1226
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1227
                    assert!(!slot.is_null(0));
1228
                    assert_eq!(slot.len(), 1);
1229
                    let value = slot.value(0);
1230
                    assert_eq!(10_i32, value);
1231
                }
1232
                3 => assert!(slot.is_null(0)),
1233
                4 => {
1234
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1235
                    assert!(!slot.is_null(0));
1236
                    assert_eq!(slot.len(), 1);
1237
                    let value = slot.value(0);
1238
                    assert_eq!(6_i32, value);
1239
                }
1240
                _ => unreachable!(),
1241
            }
1242
        }
1243
    }
1244
1245
    #[test]
1246
    fn test_dense_mixed_with_nulls_and_offset() {
1247
        let mut builder = UnionBuilder::new_dense();
1248
        builder.append::<Int32Type>("a", 1).unwrap();
1249
        builder.append::<Int64Type>("c", 3).unwrap();
1250
        builder.append::<Int32Type>("a", 10).unwrap();
1251
        builder.append_null::<Int32Type>("a").unwrap();
1252
        builder.append::<Int32Type>("a", 6).unwrap();
1253
        let union = builder.build().unwrap();
1254
1255
        let slice = union.slice(2, 3);
1256
        let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1257
1258
        assert_eq!(3, new_union.len());
1259
        for i in 0..new_union.len() {
1260
            let slot = new_union.value(i);
1261
            match i {
1262
                0 => {
1263
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1264
                    assert!(!slot.is_null(0));
1265
                    assert_eq!(slot.len(), 1);
1266
                    let value = slot.value(0);
1267
                    assert_eq!(10_i32, value);
1268
                }
1269
                1 => assert!(slot.is_null(0)),
1270
                2 => {
1271
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1272
                    assert!(!slot.is_null(0));
1273
                    assert_eq!(slot.len(), 1);
1274
                    let value = slot.value(0);
1275
                    assert_eq!(6_i32, value);
1276
                }
1277
                _ => unreachable!(),
1278
            }
1279
        }
1280
    }
1281
1282
    #[test]
1283
    fn test_dense_mixed_with_str() {
1284
        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1285
        let int_array = Int32Array::from(vec![5, 6]);
1286
        let float_array = Float64Array::from(vec![10.0]);
1287
1288
        let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1289
        let offsets = [0, 0, 1, 0, 2, 1]
1290
            .into_iter()
1291
            .collect::<ScalarBuffer<i32>>();
1292
1293
        let fields = [
1294
            (0, Arc::new(Field::new("A", DataType::Utf8, false))),
1295
            (1, Arc::new(Field::new("B", DataType::Int32, false))),
1296
            (2, Arc::new(Field::new("C", DataType::Float64, false))),
1297
        ]
1298
        .into_iter()
1299
        .collect::<UnionFields>();
1300
        let children = [
1301
            Arc::new(string_array) as Arc<dyn Array>,
1302
            Arc::new(int_array),
1303
            Arc::new(float_array),
1304
        ]
1305
        .into_iter()
1306
        .collect();
1307
        let array =
1308
            UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap();
1309
1310
        // Check type ids
1311
        assert_eq!(*array.type_ids(), type_ids);
1312
        for (i, id) in type_ids.iter().enumerate() {
1313
            assert_eq!(id, &array.type_id(i));
1314
        }
1315
1316
        // Check offsets
1317
        assert_eq!(*array.offsets().unwrap(), offsets);
1318
        for (i, id) in offsets.iter().enumerate() {
1319
            assert_eq!(*id as usize, array.value_offset(i));
1320
        }
1321
1322
        // Check values
1323
        assert_eq!(6, array.len());
1324
1325
        let slot = array.value(0);
1326
        let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1327
        assert_eq!(5, value);
1328
1329
        let slot = array.value(1);
1330
        let value = slot
1331
            .as_any()
1332
            .downcast_ref::<StringArray>()
1333
            .unwrap()
1334
            .value(0);
1335
        assert_eq!("foo", value);
1336
1337
        let slot = array.value(2);
1338
        let value = slot
1339
            .as_any()
1340
            .downcast_ref::<StringArray>()
1341
            .unwrap()
1342
            .value(0);
1343
        assert_eq!("bar", value);
1344
1345
        let slot = array.value(3);
1346
        let value = slot
1347
            .as_any()
1348
            .downcast_ref::<Float64Array>()
1349
            .unwrap()
1350
            .value(0);
1351
        assert_eq!(10.0, value);
1352
1353
        let slot = array.value(4);
1354
        let value = slot
1355
            .as_any()
1356
            .downcast_ref::<StringArray>()
1357
            .unwrap()
1358
            .value(0);
1359
        assert_eq!("baz", value);
1360
1361
        let slot = array.value(5);
1362
        let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1363
        assert_eq!(6, value);
1364
    }
1365
1366
    #[test]
1367
    fn test_sparse_i32() {
1368
        let mut builder = UnionBuilder::new_sparse();
1369
        builder.append::<Int32Type>("a", 1).unwrap();
1370
        builder.append::<Int32Type>("b", 2).unwrap();
1371
        builder.append::<Int32Type>("c", 3).unwrap();
1372
        builder.append::<Int32Type>("a", 4).unwrap();
1373
        builder.append::<Int32Type>("c", 5).unwrap();
1374
        builder.append::<Int32Type>("a", 6).unwrap();
1375
        builder.append::<Int32Type>("b", 7).unwrap();
1376
        let union = builder.build().unwrap();
1377
1378
        let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1379
        let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1380
1381
        // Check type ids
1382
        assert_eq!(*union.type_ids(), expected_type_ids);
1383
        for (i, id) in expected_type_ids.iter().enumerate() {
1384
            assert_eq!(id, &union.type_id(i));
1385
        }
1386
1387
        // Check offsets, sparse union should only have a single buffer
1388
        assert!(union.offsets().is_none());
1389
1390
        // Check data
1391
        assert_eq!(
1392
            *union.child(0).as_primitive::<Int32Type>().values(),
1393
            [1_i32, 0, 0, 4, 0, 6, 0],
1394
        );
1395
        assert_eq!(
1396
            *union.child(1).as_primitive::<Int32Type>().values(),
1397
            [0_i32, 2_i32, 0, 0, 0, 0, 7]
1398
        );
1399
        assert_eq!(
1400
            *union.child(2).as_primitive::<Int32Type>().values(),
1401
            [0_i32, 0, 3_i32, 0, 5, 0, 0]
1402
        );
1403
1404
        assert_eq!(expected_array_values.len(), union.len());
1405
        for (i, expected_value) in expected_array_values.iter().enumerate() {
1406
            assert!(!union.is_null(i));
1407
            let slot = union.value(i);
1408
            let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1409
            assert_eq!(slot.len(), 1);
1410
            let value = slot.value(0);
1411
            assert_eq!(expected_value, &value);
1412
        }
1413
    }
1414
1415
    #[test]
1416
    fn test_sparse_mixed() {
1417
        let mut builder = UnionBuilder::new_sparse();
1418
        builder.append::<Int32Type>("a", 1).unwrap();
1419
        builder.append::<Float64Type>("c", 3.0).unwrap();
1420
        builder.append::<Int32Type>("a", 4).unwrap();
1421
        builder.append::<Float64Type>("c", 5.0).unwrap();
1422
        builder.append::<Int32Type>("a", 6).unwrap();
1423
        let union = builder.build().unwrap();
1424
1425
        let expected_type_ids = vec![0_i8, 1, 0, 1, 0];
1426
1427
        // Check type ids
1428
        assert_eq!(*union.type_ids(), expected_type_ids);
1429
        for (i, id) in expected_type_ids.iter().enumerate() {
1430
            assert_eq!(id, &union.type_id(i));
1431
        }
1432
1433
        // Check offsets, sparse union should only have a single buffer, i.e. no offsets
1434
        assert!(union.offsets().is_none());
1435
1436
        for i in 0..union.len() {
1437
            let slot = union.value(i);
1438
            assert!(!union.is_null(i));
1439
            match i {
1440
                0 => {
1441
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1442
                    assert_eq!(slot.len(), 1);
1443
                    let value = slot.value(0);
1444
                    assert_eq!(1_i32, value);
1445
                }
1446
                1 => {
1447
                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1448
                    assert_eq!(slot.len(), 1);
1449
                    let value = slot.value(0);
1450
                    assert_eq!(value, 3_f64);
1451
                }
1452
                2 => {
1453
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1454
                    assert_eq!(slot.len(), 1);
1455
                    let value = slot.value(0);
1456
                    assert_eq!(4_i32, value);
1457
                }
1458
                3 => {
1459
                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1460
                    assert_eq!(slot.len(), 1);
1461
                    let value = slot.value(0);
1462
                    assert_eq!(5_f64, value);
1463
                }
1464
                4 => {
1465
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1466
                    assert_eq!(slot.len(), 1);
1467
                    let value = slot.value(0);
1468
                    assert_eq!(6_i32, value);
1469
                }
1470
                _ => unreachable!(),
1471
            }
1472
        }
1473
    }
1474
1475
    #[test]
1476
    fn test_sparse_mixed_with_nulls() {
1477
        let mut builder = UnionBuilder::new_sparse();
1478
        builder.append::<Int32Type>("a", 1).unwrap();
1479
        builder.append_null::<Int32Type>("a").unwrap();
1480
        builder.append::<Float64Type>("c", 3.0).unwrap();
1481
        builder.append::<Int32Type>("a", 4).unwrap();
1482
        let union = builder.build().unwrap();
1483
1484
        let expected_type_ids = vec![0_i8, 0, 1, 0];
1485
1486
        // Check type ids
1487
        assert_eq!(*union.type_ids(), expected_type_ids);
1488
        for (i, id) in expected_type_ids.iter().enumerate() {
1489
            assert_eq!(id, &union.type_id(i));
1490
        }
1491
1492
        // Check offsets, sparse union should only have a single buffer, i.e. no offsets
1493
        assert!(union.offsets().is_none());
1494
1495
        for i in 0..union.len() {
1496
            let slot = union.value(i);
1497
            match i {
1498
                0 => {
1499
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1500
                    assert!(!slot.is_null(0));
1501
                    assert_eq!(slot.len(), 1);
1502
                    let value = slot.value(0);
1503
                    assert_eq!(1_i32, value);
1504
                }
1505
                1 => assert!(slot.is_null(0)),
1506
                2 => {
1507
                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1508
                    assert!(!slot.is_null(0));
1509
                    assert_eq!(slot.len(), 1);
1510
                    let value = slot.value(0);
1511
                    assert_eq!(value, 3_f64);
1512
                }
1513
                3 => {
1514
                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1515
                    assert!(!slot.is_null(0));
1516
                    assert_eq!(slot.len(), 1);
1517
                    let value = slot.value(0);
1518
                    assert_eq!(4_i32, value);
1519
                }
1520
                _ => unreachable!(),
1521
            }
1522
        }
1523
    }
1524
1525
    #[test]
1526
    fn test_sparse_mixed_with_nulls_and_offset() {
1527
        let mut builder = UnionBuilder::new_sparse();
1528
        builder.append::<Int32Type>("a", 1).unwrap();
1529
        builder.append_null::<Int32Type>("a").unwrap();
1530
        builder.append::<Float64Type>("c", 3.0).unwrap();
1531
        builder.append_null::<Float64Type>("c").unwrap();
1532
        builder.append::<Int32Type>("a", 4).unwrap();
1533
        let union = builder.build().unwrap();
1534
1535
        let slice = union.slice(1, 4);
1536
        let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1537
1538
        assert_eq!(4, new_union.len());
1539
        for i in 0..new_union.len() {
1540
            let slot = new_union.value(i);
1541
            match i {
1542
                0 => assert!(slot.is_null(0)),
1543
                1 => {
1544
                    let slot = slot.as_primitive::<Float64Type>();
1545
                    assert!(!slot.is_null(0));
1546
                    assert_eq!(slot.len(), 1);
1547
                    let value = slot.value(0);
1548
                    assert_eq!(value, 3_f64);
1549
                }
1550
                2 => assert!(slot.is_null(0)),
1551
                3 => {
1552
                    let slot = slot.as_primitive::<Int32Type>();
1553
                    assert!(!slot.is_null(0));
1554
                    assert_eq!(slot.len(), 1);
1555
                    let value = slot.value(0);
1556
                    assert_eq!(4_i32, value);
1557
                }
1558
                _ => unreachable!(),
1559
            }
1560
        }
1561
    }
1562
1563
    fn test_union_validity(union_array: &UnionArray) {
1564
        assert_eq!(union_array.null_count(), 0);
1565
1566
        for i in 0..union_array.len() {
1567
            assert!(!union_array.is_null(i));
1568
            assert!(union_array.is_valid(i));
1569
        }
1570
    }
1571
1572
    #[test]
1573
    fn test_union_array_validity() {
1574
        let mut builder = UnionBuilder::new_sparse();
1575
        builder.append::<Int32Type>("a", 1).unwrap();
1576
        builder.append_null::<Int32Type>("a").unwrap();
1577
        builder.append::<Float64Type>("c", 3.0).unwrap();
1578
        builder.append_null::<Float64Type>("c").unwrap();
1579
        builder.append::<Int32Type>("a", 4).unwrap();
1580
        let union = builder.build().unwrap();
1581
1582
        test_union_validity(&union);
1583
1584
        let mut builder = UnionBuilder::new_dense();
1585
        builder.append::<Int32Type>("a", 1).unwrap();
1586
        builder.append_null::<Int32Type>("a").unwrap();
1587
        builder.append::<Float64Type>("c", 3.0).unwrap();
1588
        builder.append_null::<Float64Type>("c").unwrap();
1589
        builder.append::<Int32Type>("a", 4).unwrap();
1590
        let union = builder.build().unwrap();
1591
1592
        test_union_validity(&union);
1593
    }
1594
1595
    #[test]
1596
    fn test_type_check() {
1597
        let mut builder = UnionBuilder::new_sparse();
1598
        builder.append::<Float32Type>("a", 1.0).unwrap();
1599
        let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
1600
        assert!(
1601
            err.contains(
1602
                "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"
1603
            ),
1604
            "{}",
1605
            err
1606
        );
1607
    }
1608
1609
    #[test]
1610
    fn slice_union_array() {
1611
        // [1, null, 3.0, null, 4]
1612
        fn create_union(mut builder: UnionBuilder) -> UnionArray {
1613
            builder.append::<Int32Type>("a", 1).unwrap();
1614
            builder.append_null::<Int32Type>("a").unwrap();
1615
            builder.append::<Float64Type>("c", 3.0).unwrap();
1616
            builder.append_null::<Float64Type>("c").unwrap();
1617
            builder.append::<Int32Type>("a", 4).unwrap();
1618
            builder.build().unwrap()
1619
        }
1620
1621
        fn create_batch(union: UnionArray) -> RecordBatch {
1622
            let schema = Schema::new(vec![Field::new(
1623
                "struct_array",
1624
                union.data_type().clone(),
1625
                true,
1626
            )]);
1627
1628
            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap()
1629
        }
1630
1631
        fn test_slice_union(record_batch_slice: RecordBatch) {
1632
            let union_slice = record_batch_slice
1633
                .column(0)
1634
                .as_any()
1635
                .downcast_ref::<UnionArray>()
1636
                .unwrap();
1637
1638
            assert_eq!(union_slice.type_id(0), 0);
1639
            assert_eq!(union_slice.type_id(1), 1);
1640
            assert_eq!(union_slice.type_id(2), 1);
1641
1642
            let slot = union_slice.value(0);
1643
            let array = slot.as_primitive::<Int32Type>();
1644
            assert_eq!(array.len(), 1);
1645
            assert!(array.is_null(0));
1646
1647
            let slot = union_slice.value(1);
1648
            let array = slot.as_primitive::<Float64Type>();
1649
            assert_eq!(array.len(), 1);
1650
            assert!(array.is_valid(0));
1651
            assert_eq!(array.value(0), 3.0);
1652
1653
            let slot = union_slice.value(2);
1654
            let array = slot.as_primitive::<Float64Type>();
1655
            assert_eq!(array.len(), 1);
1656
            assert!(array.is_null(0));
1657
        }
1658
1659
        // Sparse Union
1660
        let builder = UnionBuilder::new_sparse();
1661
        let record_batch = create_batch(create_union(builder));
1662
        // [null, 3.0, null]
1663
        let record_batch_slice = record_batch.slice(1, 3);
1664
        test_slice_union(record_batch_slice);
1665
1666
        // Dense Union
1667
        let builder = UnionBuilder::new_dense();
1668
        let record_batch = create_batch(create_union(builder));
1669
        // [null, 3.0, null]
1670
        let record_batch_slice = record_batch.slice(1, 3);
1671
        test_slice_union(record_batch_slice);
1672
    }
1673
1674
    #[test]
1675
    fn test_custom_type_ids() {
1676
        let data_type = DataType::Union(
1677
            UnionFields::new(
1678
                vec![8, 4, 9],
1679
                vec![
1680
                    Field::new("strings", DataType::Utf8, false),
1681
                    Field::new("integers", DataType::Int32, false),
1682
                    Field::new("floats", DataType::Float64, false),
1683
                ],
1684
            ),
1685
            UnionMode::Dense,
1686
        );
1687
1688
        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1689
        let int_array = Int32Array::from(vec![5, 6, 4]);
1690
        let float_array = Float64Array::from(vec![10.0]);
1691
1692
        let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1693
        let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1694
1695
        let data = ArrayData::builder(data_type)
1696
            .len(7)
1697
            .buffers(vec![type_ids, value_offsets])
1698
            .child_data(vec![
1699
                string_array.into_data(),
1700
                int_array.into_data(),
1701
                float_array.into_data(),
1702
            ])
1703
            .build()
1704
            .unwrap();
1705
1706
        let array = UnionArray::from(data);
1707
1708
        let v = array.value(0);
1709
        assert_eq!(v.data_type(), &DataType::Int32);
1710
        assert_eq!(v.len(), 1);
1711
        assert_eq!(v.as_primitive::<Int32Type>().value(0), 5);
1712
1713
        let v = array.value(1);
1714
        assert_eq!(v.data_type(), &DataType::Utf8);
1715
        assert_eq!(v.len(), 1);
1716
        assert_eq!(v.as_string::<i32>().value(0), "foo");
1717
1718
        let v = array.value(2);
1719
        assert_eq!(v.data_type(), &DataType::Int32);
1720
        assert_eq!(v.len(), 1);
1721
        assert_eq!(v.as_primitive::<Int32Type>().value(0), 6);
1722
1723
        let v = array.value(3);
1724
        assert_eq!(v.data_type(), &DataType::Utf8);
1725
        assert_eq!(v.len(), 1);
1726
        assert_eq!(v.as_string::<i32>().value(0), "bar");
1727
1728
        let v = array.value(4);
1729
        assert_eq!(v.data_type(), &DataType::Float64);
1730
        assert_eq!(v.len(), 1);
1731
        assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0);
1732
1733
        let v = array.value(5);
1734
        assert_eq!(v.data_type(), &DataType::Int32);
1735
        assert_eq!(v.len(), 1);
1736
        assert_eq!(v.as_primitive::<Int32Type>().value(0), 4);
1737
1738
        let v = array.value(6);
1739
        assert_eq!(v.data_type(), &DataType::Utf8);
1740
        assert_eq!(v.len(), 1);
1741
        assert_eq!(v.as_string::<i32>().value(0), "baz");
1742
    }
1743
1744
    #[test]
1745
    fn into_parts() {
1746
        let mut builder = UnionBuilder::new_dense();
1747
        builder.append::<Int32Type>("a", 1).unwrap();
1748
        builder.append::<Int8Type>("b", 2).unwrap();
1749
        builder.append::<Int32Type>("a", 3).unwrap();
1750
        let dense_union = builder.build().unwrap();
1751
1752
        let field = [
1753
            &Arc::new(Field::new("a", DataType::Int32, false)),
1754
            &Arc::new(Field::new("b", DataType::Int8, false)),
1755
        ];
1756
        let (union_fields, type_ids, offsets, children) = dense_union.into_parts();
1757
        assert_eq!(
1758
            union_fields
1759
                .iter()
1760
                .map(|(_, field)| field)
1761
                .collect::<Vec<_>>(),
1762
            field
1763
        );
1764
        assert_eq!(type_ids, [0, 1, 0]);
1765
        assert!(offsets.is_some());
1766
        assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
1767
1768
        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1769
        assert!(result.is_ok());
1770
        assert_eq!(result.unwrap().len(), 3);
1771
1772
        let mut builder = UnionBuilder::new_sparse();
1773
        builder.append::<Int32Type>("a", 1).unwrap();
1774
        builder.append::<Int8Type>("b", 2).unwrap();
1775
        builder.append::<Int32Type>("a", 3).unwrap();
1776
        let sparse_union = builder.build().unwrap();
1777
1778
        let (union_fields, type_ids, offsets, children) = sparse_union.into_parts();
1779
        assert_eq!(type_ids, [0, 1, 0]);
1780
        assert!(offsets.is_none());
1781
1782
        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1783
        assert!(result.is_ok());
1784
        assert_eq!(result.unwrap().len(), 3);
1785
    }
1786
1787
    #[test]
1788
    fn into_parts_custom_type_ids() {
1789
        let set_field_type_ids: [i8; 3] = [8, 4, 9];
1790
        let data_type = DataType::Union(
1791
            UnionFields::new(
1792
                set_field_type_ids,
1793
                [
1794
                    Field::new("strings", DataType::Utf8, false),
1795
                    Field::new("integers", DataType::Int32, false),
1796
                    Field::new("floats", DataType::Float64, false),
1797
                ],
1798
            ),
1799
            UnionMode::Dense,
1800
        );
1801
        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1802
        let int_array = Int32Array::from(vec![5, 6, 4]);
1803
        let float_array = Float64Array::from(vec![10.0]);
1804
        let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1805
        let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1806
        let data = ArrayData::builder(data_type)
1807
            .len(7)
1808
            .buffers(vec![type_ids, value_offsets])
1809
            .child_data(vec![
1810
                string_array.into_data(),
1811
                int_array.into_data(),
1812
                float_array.into_data(),
1813
            ])
1814
            .build()
1815
            .unwrap();
1816
        let array = UnionArray::from(data);
1817
1818
        let (union_fields, type_ids, offsets, children) = array.into_parts();
1819
        assert_eq!(
1820
            type_ids.iter().collect::<HashSet<_>>(),
1821
            set_field_type_ids.iter().collect::<HashSet<_>>()
1822
        );
1823
        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1824
        assert!(result.is_ok());
1825
        let array = result.unwrap();
1826
        assert_eq!(array.len(), 7);
1827
    }
1828
1829
    #[test]
1830
    fn test_invalid() {
1831
        let fields = UnionFields::new(
1832
            [3, 2],
1833
            [
1834
                Field::new("a", DataType::Utf8, false),
1835
                Field::new("b", DataType::Utf8, false),
1836
            ],
1837
        );
1838
        let children = vec![
1839
            Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1840
            Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
1841
        ];
1842
1843
        let type_ids = vec![3, 3, 2].into();
1844
        let err =
1845
            UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1846
        assert_eq!(
1847
            err.to_string(),
1848
            "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union"
1849
        );
1850
1851
        let type_ids = vec![1, 2].into();
1852
        let err =
1853
            UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1854
        assert_eq!(
1855
            err.to_string(),
1856
            "Invalid argument error: Type Ids values must match one of the field type ids"
1857
        );
1858
1859
        let type_ids = vec![7, 2].into();
1860
        let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err();
1861
        assert_eq!(
1862
            err.to_string(),
1863
            "Invalid argument error: Type Ids values must match one of the field type ids"
1864
        );
1865
1866
        let children = vec![
1867
            Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1868
            Arc::new(StringArray::from_iter_values(["c"])) as _,
1869
        ];
1870
        let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]);
1871
        let offsets = Some(vec![0, 1, 0].into());
1872
        UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap();
1873
1874
        let offsets = Some(vec![0, 1, 1].into());
1875
        let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone())
1876
            .unwrap_err();
1877
1878
        assert_eq!(
1879
            err.to_string(),
1880
            "Invalid argument error: Offsets must be positive and within the length of the Array"
1881
        );
1882
1883
        let offsets = Some(vec![0, 1].into());
1884
        let err =
1885
            UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err();
1886
1887
        assert_eq!(
1888
            err.to_string(),
1889
            "Invalid argument error: Type Ids and Offsets lengths must match"
1890
        );
1891
1892
        let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err();
1893
1894
        assert_eq!(
1895
            err.to_string(),
1896
            "Invalid argument error: Union fields length must match child arrays length"
1897
        );
1898
    }
1899
1900
    #[test]
1901
    fn test_logical_nulls_fast_paths() {
1902
        // fields.len() <= 1
1903
        let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap();
1904
1905
        assert_eq!(array.logical_nulls(), None);
1906
1907
        let fields = UnionFields::new(
1908
            [1, 3],
1909
            [
1910
                Field::new("a", DataType::Int8, false), // non nullable
1911
                Field::new("b", DataType::Int8, false), // non nullable
1912
            ],
1913
        );
1914
        let array = UnionArray::try_new(
1915
            fields,
1916
            vec![1].into(),
1917
            None,
1918
            vec![
1919
                Arc::new(Int8Array::from_value(5, 1)),
1920
                Arc::new(Int8Array::from_value(5, 1)),
1921
            ],
1922
        )
1923
        .unwrap();
1924
1925
        assert_eq!(array.logical_nulls(), None);
1926
1927
        let nullable_fields = UnionFields::new(
1928
            [1, 3],
1929
            [
1930
                Field::new("a", DataType::Int8, true), // nullable but without nulls
1931
                Field::new("b", DataType::Int8, true), // nullable but without nulls
1932
            ],
1933
        );
1934
        let array = UnionArray::try_new(
1935
            nullable_fields.clone(),
1936
            vec![1, 1].into(),
1937
            None,
1938
            vec![
1939
                Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls
1940
                Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls
1941
            ],
1942
        )
1943
        .unwrap();
1944
1945
        assert_eq!(array.logical_nulls(), None);
1946
1947
        let array = UnionArray::try_new(
1948
            nullable_fields.clone(),
1949
            vec![1, 1].into(),
1950
            None,
1951
            vec![
1952
                // every children is completly null
1953
                Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent
1954
                Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent
1955
            ],
1956
        )
1957
        .unwrap();
1958
1959
        assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1960
1961
        let array = UnionArray::try_new(
1962
            nullable_fields.clone(),
1963
            vec![1, 1].into(),
1964
            Some(vec![0, 1].into()),
1965
            vec![
1966
                // every children is completly null
1967
                Arc::new(Int8Array::new_null(3)), // bigger that parent
1968
                Arc::new(Int8Array::new_null(3)), // bigger that parent
1969
            ],
1970
        )
1971
        .unwrap();
1972
1973
        assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1974
    }
1975
1976
    #[test]
1977
    fn test_dense_union_logical_nulls_gather() {
1978
        // union of [{A=1}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}]
1979
        let int_array = Int32Array::from(vec![1, 2]);
1980
        let float_array = Float64Array::from(vec![Some(3.2), None]);
1981
        let str_array = StringArray::new_null(1);
1982
        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
1983
        let offsets = [0, 1, 0, 1, 0, 0]
1984
            .into_iter()
1985
            .collect::<ScalarBuffer<i32>>();
1986
1987
        let children = vec![
1988
            Arc::new(int_array) as Arc<dyn Array>,
1989
            Arc::new(float_array),
1990
            Arc::new(str_array),
1991
        ];
1992
1993
        let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
1994
1995
        let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]);
1996
1997
        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
1998
        assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
1999
    }
2000
2001
    #[test]
2002
    fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() {
2003
        let fields: UnionFields = [
2004
            (1, Arc::new(Field::new("A", DataType::Int32, true))),
2005
            (3, Arc::new(Field::new("B", DataType::Float64, true))),
2006
        ]
2007
        .into_iter()
2008
        .collect();
2009
2010
        // union of [{A=}, {A=}, {B=3.2}, {B=}]
2011
        let int_array = Int32Array::new_null(4);
2012
        let float_array = Float64Array::from(vec![None, None, Some(3.2), None]);
2013
        let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>();
2014
2015
        let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2016
2017
        let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap();
2018
2019
        let expected = BooleanBuffer::from(vec![false, false, true, false]);
2020
2021
        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2022
        assert_eq!(
2023
            expected,
2024
            array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2025
        );
2026
2027
        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2028
        let len = 2 * 64 + 32;
2029
2030
        let int_array = Int32Array::new_null(len);
2031
        let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len));
2032
        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len));
2033
2034
        let array = UnionArray::try_new(
2035
            fields,
2036
            type_ids,
2037
            None,
2038
            vec![Arc::new(int_array), Arc::new(float_array)],
2039
        )
2040
        .unwrap();
2041
2042
        let expected =
2043
            BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len));
2044
2045
        assert_eq!(array.len(), len);
2046
        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2047
        assert_eq!(
2048
            expected,
2049
            array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2050
        );
2051
    }
2052
2053
    #[test]
2054
    fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() {
2055
        // union of [{A=2}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}]
2056
        let int_array = Int32Array::from_value(2, 6);
2057
        let float_array = Float64Array::from_value(4.2, 6);
2058
        let str_array = StringArray::new_null(6);
2059
        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2060
2061
        let children = vec![
2062
            Arc::new(int_array) as Arc<dyn Array>,
2063
            Arc::new(float_array),
2064
            Arc::new(str_array),
2065
        ];
2066
2067
        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2068
2069
        let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]);
2070
2071
        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2072
        assert_eq!(
2073
            expected,
2074
            array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2075
        );
2076
2077
        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2078
        let len = 2 * 64 + 32;
2079
2080
        let int_array = Int32Array::from_value(2, len);
2081
        let float_array = Float64Array::from_value(4.2, len);
2082
        let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len));
2083
        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2084
2085
        let children = vec![
2086
            Arc::new(int_array) as Arc<dyn Array>,
2087
            Arc::new(float_array),
2088
            Arc::new(str_array),
2089
        ];
2090
2091
        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2092
2093
        let expected = BooleanBuffer::from_iter(
2094
            [true, true, true, true, false, true]
2095
                .into_iter()
2096
                .cycle()
2097
                .take(len),
2098
        );
2099
2100
        assert_eq!(array.len(), len);
2101
        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2102
        assert_eq!(
2103
            expected,
2104
            array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2105
        );
2106
    }
2107
2108
    #[test]
2109
    fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() {
2110
        // union of [{A=}, {A=}, {B=4.2}, {B=4.2}, {C=}, {C=}]
2111
        let int_array = Int32Array::new_null(6);
2112
        let float_array = Float64Array::from_value(4.2, 6);
2113
        let str_array = StringArray::new_null(6);
2114
        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2115
2116
        let children = vec![
2117
            Arc::new(int_array) as Arc<dyn Array>,
2118
            Arc::new(float_array),
2119
            Arc::new(str_array),
2120
        ];
2121
2122
        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2123
2124
        let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]);
2125
2126
        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2127
        assert_eq!(
2128
            expected,
2129
            array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2130
        );
2131
2132
        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2133
        let len = 2 * 64 + 32;
2134
2135
        let int_array = Int32Array::new_null(len);
2136
        let float_array = Float64Array::from_value(4.2, len);
2137
        let str_array = StringArray::new_null(len);
2138
        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2139
2140
        let children = vec![
2141
            Arc::new(int_array) as Arc<dyn Array>,
2142
            Arc::new(float_array),
2143
            Arc::new(str_array),
2144
        ];
2145
2146
        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2147
2148
        let expected = BooleanBuffer::from_iter(
2149
            [false, false, true, true, false, false]
2150
                .into_iter()
2151
                .cycle()
2152
                .take(len),
2153
        );
2154
2155
        assert_eq!(array.len(), len);
2156
        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2157
        assert_eq!(
2158
            expected,
2159
            array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2160
        );
2161
    }
2162
2163
    #[test]
2164
    fn test_sparse_union_logical_nulls_gather() {
2165
        let n_fields = 50;
2166
2167
        let non_null = Int32Array::from_value(2, 4);
2168
        let mixed = Int32Array::from(vec![None, None, Some(1), None]);
2169
        let fully_null = Int32Array::new_null(4);
2170
2171
        let array = UnionArray::try_new(
2172
            (1..)
2173
                .step_by(2)
2174
                .map(|i| {
2175
                    (
2176
                        i,
2177
                        Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)),
2178
                    )
2179
                })
2180
                .take(n_fields)
2181
                .collect(),
2182
            vec![1, 3, 3, 5].into(),
2183
            None,
2184
            [
2185
                Arc::new(non_null) as ArrayRef,
2186
                Arc::new(mixed),
2187
                Arc::new(fully_null),
2188
            ]
2189
            .into_iter()
2190
            .cycle()
2191
            .take(n_fields)
2192
            .collect(),
2193
        )
2194
        .unwrap();
2195
2196
        let expected = BooleanBuffer::from(vec![true, false, true, false]);
2197
2198
        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2199
        assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2200
    }
2201
2202
    fn union_fields() -> UnionFields {
2203
        [
2204
            (1, Arc::new(Field::new("A", DataType::Int32, true))),
2205
            (3, Arc::new(Field::new("B", DataType::Float64, true))),
2206
            (4, Arc::new(Field::new("C", DataType::Utf8, true))),
2207
        ]
2208
        .into_iter()
2209
        .collect()
2210
    }
2211
2212
    #[test]
2213
    fn test_is_nullable() {
2214
        assert!(!create_union_array(false, false).is_nullable());
2215
        assert!(create_union_array(true, false).is_nullable());
2216
        assert!(create_union_array(false, true).is_nullable());
2217
        assert!(create_union_array(true, true).is_nullable());
2218
    }
2219
2220
    /// Create a union array with a float and integer field
2221
    ///
2222
    /// If the `int_nullable` is true, the integer field will have nulls
2223
    /// If the `float_nullable` is true, the float field will have nulls
2224
    ///
2225
    /// Note the `Field` definitions are always declared to be nullable
2226
    fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray {
2227
        let int_array = if int_nullable {
2228
            Int32Array::from(vec![Some(1), None, Some(3)])
2229
        } else {
2230
            Int32Array::from(vec![1, 2, 3])
2231
        };
2232
        let float_array = if float_nullable {
2233
            Float64Array::from(vec![Some(3.2), None, Some(4.2)])
2234
        } else {
2235
            Float64Array::from(vec![3.2, 4.2, 5.2])
2236
        };
2237
        let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
2238
        let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>();
2239
        let union_fields = [
2240
            (0, Arc::new(Field::new("A", DataType::Int32, true))),
2241
            (1, Arc::new(Field::new("B", DataType::Float64, true))),
2242
        ]
2243
        .into_iter()
2244
        .collect::<UnionFields>();
2245
2246
        let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2247
2248
        UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap()
2249
    }
2250
}