/Users/andrewlamb/Software/arrow-rs/arrow-array/src/array/union_array.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | #![allow(clippy::enum_clike_unportable_variant)] |
18 | | |
19 | | use crate::{make_array, Array, ArrayRef}; |
20 | | use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks}; |
21 | | use arrow_buffer::buffer::NullBuffer; |
22 | | use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer}; |
23 | | use arrow_data::{ArrayData, ArrayDataBuilder}; |
24 | | use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode}; |
25 | | /// Contains the `UnionArray` type. |
26 | | /// |
27 | | use std::any::Any; |
28 | | use std::collections::HashSet; |
29 | | use std::sync::Arc; |
30 | | |
31 | | /// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout) |
32 | | /// |
33 | | /// Each slot in a [UnionArray] can have a value chosen from a number |
34 | | /// of types. Each of the possible types are named like the fields of |
35 | | /// a [`StructArray`](crate::StructArray). A `UnionArray` can |
36 | | /// have two possible memory layouts, "dense" or "sparse". For more |
37 | | /// information on please see the |
38 | | /// [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout). |
39 | | /// |
40 | | /// [UnionBuilder](crate::builder::UnionBuilder) can be used to |
41 | | /// create [UnionArray]'s of primitive types. `UnionArray`'s of nested |
42 | | /// types are also supported but not via `UnionBuilder`, see the tests |
43 | | /// for examples. |
44 | | /// |
45 | | /// # Examples |
46 | | /// ## Create a dense UnionArray `[1, 3.2, 34]` |
47 | | /// ``` |
48 | | /// use arrow_buffer::ScalarBuffer; |
49 | | /// use arrow_schema::*; |
50 | | /// use std::sync::Arc; |
51 | | /// use arrow_array::{Array, Int32Array, Float64Array, UnionArray}; |
52 | | /// |
53 | | /// let int_array = Int32Array::from(vec![1, 34]); |
54 | | /// let float_array = Float64Array::from(vec![3.2]); |
55 | | /// let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); |
56 | | /// let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>(); |
57 | | /// |
58 | | /// let union_fields = [ |
59 | | /// (0, Arc::new(Field::new("A", DataType::Int32, false))), |
60 | | /// (1, Arc::new(Field::new("B", DataType::Float64, false))), |
61 | | /// ].into_iter().collect::<UnionFields>(); |
62 | | /// |
63 | | /// let children = vec![ |
64 | | /// Arc::new(int_array) as Arc<dyn Array>, |
65 | | /// Arc::new(float_array), |
66 | | /// ]; |
67 | | /// |
68 | | /// let array = UnionArray::try_new( |
69 | | /// union_fields, |
70 | | /// type_ids, |
71 | | /// Some(offsets), |
72 | | /// children, |
73 | | /// ).unwrap(); |
74 | | /// |
75 | | /// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0); |
76 | | /// assert_eq!(1, value); |
77 | | /// |
78 | | /// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0); |
79 | | /// assert!(3.2 - value < f64::EPSILON); |
80 | | /// |
81 | | /// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0); |
82 | | /// assert_eq!(34, value); |
83 | | /// ``` |
84 | | /// |
85 | | /// ## Create a sparse UnionArray `[1, 3.2, 34]` |
86 | | /// ``` |
87 | | /// use arrow_buffer::ScalarBuffer; |
88 | | /// use arrow_schema::*; |
89 | | /// use std::sync::Arc; |
90 | | /// use arrow_array::{Array, Int32Array, Float64Array, UnionArray}; |
91 | | /// |
92 | | /// let int_array = Int32Array::from(vec![Some(1), None, Some(34)]); |
93 | | /// let float_array = Float64Array::from(vec![None, Some(3.2), None]); |
94 | | /// let type_ids = [0_i8, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); |
95 | | /// |
96 | | /// let union_fields = [ |
97 | | /// (0, Arc::new(Field::new("A", DataType::Int32, false))), |
98 | | /// (1, Arc::new(Field::new("B", DataType::Float64, false))), |
99 | | /// ].into_iter().collect::<UnionFields>(); |
100 | | /// |
101 | | /// let children = vec![ |
102 | | /// Arc::new(int_array) as Arc<dyn Array>, |
103 | | /// Arc::new(float_array), |
104 | | /// ]; |
105 | | /// |
106 | | /// let array = UnionArray::try_new( |
107 | | /// union_fields, |
108 | | /// type_ids, |
109 | | /// None, |
110 | | /// children, |
111 | | /// ).unwrap(); |
112 | | /// |
113 | | /// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0); |
114 | | /// assert_eq!(1, value); |
115 | | /// |
116 | | /// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0); |
117 | | /// assert!(3.2 - value < f64::EPSILON); |
118 | | /// |
119 | | /// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0); |
120 | | /// assert_eq!(34, value); |
121 | | /// ``` |
122 | | #[derive(Clone)] |
123 | | pub struct UnionArray { |
124 | | data_type: DataType, |
125 | | type_ids: ScalarBuffer<i8>, |
126 | | offsets: Option<ScalarBuffer<i32>>, |
127 | | fields: Vec<Option<ArrayRef>>, |
128 | | } |
129 | | |
130 | | impl UnionArray { |
131 | | /// Creates a new `UnionArray`. |
132 | | /// |
133 | | /// Accepts type ids, child arrays and optionally offsets (for dense unions) to create |
134 | | /// a new `UnionArray`. This method makes no attempt to validate the data provided by the |
135 | | /// caller and assumes that each of the components are correct and consistent with each other. |
136 | | /// See `try_new` for an alternative that validates the data provided. |
137 | | /// |
138 | | /// # Safety |
139 | | /// |
140 | | /// The `type_ids` values should be positive and must match one of the type ids of the fields provided in `fields`. |
141 | | /// These values are used to index into the `children` arrays. |
142 | | /// |
143 | | /// The `offsets` is provided in the case of a dense union, sparse unions should use `None`. |
144 | | /// If provided the `offsets` values should be positive and must be less than the length of the |
145 | | /// corresponding array. |
146 | | /// |
147 | | /// In both cases above we use signed integer types to maintain compatibility with other |
148 | | /// Arrow implementations. |
149 | 0 | pub unsafe fn new_unchecked( |
150 | 0 | fields: UnionFields, |
151 | 0 | type_ids: ScalarBuffer<i8>, |
152 | 0 | offsets: Option<ScalarBuffer<i32>>, |
153 | 0 | children: Vec<ArrayRef>, |
154 | 0 | ) -> Self { |
155 | 0 | let mode = if offsets.is_some() { |
156 | 0 | UnionMode::Dense |
157 | | } else { |
158 | 0 | UnionMode::Sparse |
159 | | }; |
160 | | |
161 | 0 | let len = type_ids.len(); |
162 | 0 | let builder = ArrayData::builder(DataType::Union(fields, mode)) |
163 | 0 | .add_buffer(type_ids.into_inner()) |
164 | 0 | .child_data(children.into_iter().map(Array::into_data).collect()) |
165 | 0 | .len(len); |
166 | | |
167 | 0 | let data = match offsets { |
168 | 0 | Some(offsets) => builder.add_buffer(offsets.into_inner()).build_unchecked(), |
169 | 0 | None => builder.build_unchecked(), |
170 | | }; |
171 | 0 | Self::from(data) |
172 | 0 | } |
173 | | |
174 | | /// Attempts to create a new `UnionArray`, validating the inputs provided. |
175 | | /// |
176 | | /// The order of child arrays child array order must match the fields order |
177 | 0 | pub fn try_new( |
178 | 0 | fields: UnionFields, |
179 | 0 | type_ids: ScalarBuffer<i8>, |
180 | 0 | offsets: Option<ScalarBuffer<i32>>, |
181 | 0 | children: Vec<ArrayRef>, |
182 | 0 | ) -> Result<Self, ArrowError> { |
183 | | // There must be a child array for every field. |
184 | 0 | if fields.len() != children.len() { |
185 | 0 | return Err(ArrowError::InvalidArgumentError( |
186 | 0 | "Union fields length must match child arrays length".to_string(), |
187 | 0 | )); |
188 | 0 | } |
189 | | |
190 | 0 | if let Some(offsets) = &offsets { |
191 | | // There must be an offset value for every type id value. |
192 | 0 | if offsets.len() != type_ids.len() { |
193 | 0 | return Err(ArrowError::InvalidArgumentError( |
194 | 0 | "Type Ids and Offsets lengths must match".to_string(), |
195 | 0 | )); |
196 | 0 | } |
197 | | } else { |
198 | | // Sparse union child arrays must be equal in length to the length of the union |
199 | 0 | for child in &children { |
200 | 0 | if child.len() != type_ids.len() { |
201 | 0 | return Err(ArrowError::InvalidArgumentError( |
202 | 0 | "Sparse union child arrays must be equal in length to the length of the union".to_string(), |
203 | 0 | )); |
204 | 0 | } |
205 | | } |
206 | | } |
207 | | |
208 | | // Create mapping from type id to array lengths. |
209 | 0 | let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize; |
210 | 0 | let mut array_lens = vec![i32::MIN; max_id + 1]; |
211 | 0 | for (cd, (field_id, _)) in children.iter().zip(fields.iter()) { |
212 | 0 | array_lens[field_id as usize] = cd.len() as i32; |
213 | 0 | } |
214 | | |
215 | | // Type id values must match one of the fields. |
216 | 0 | for id in &type_ids { |
217 | 0 | match array_lens.get(*id as usize) { |
218 | 0 | Some(x) if *x != i32::MIN => {} |
219 | | _ => { |
220 | 0 | return Err(ArrowError::InvalidArgumentError( |
221 | 0 | "Type Ids values must match one of the field type ids".to_owned(), |
222 | 0 | )) |
223 | | } |
224 | | } |
225 | | } |
226 | | |
227 | | // Check the value offsets are in bounds. |
228 | 0 | if let Some(offsets) = &offsets { |
229 | 0 | let mut iter = type_ids.iter().zip(offsets.iter()); |
230 | 0 | if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize]) |
231 | | { |
232 | 0 | return Err(ArrowError::InvalidArgumentError( |
233 | 0 | "Offsets must be positive and within the length of the Array".to_owned(), |
234 | 0 | )); |
235 | 0 | } |
236 | 0 | } |
237 | | |
238 | | // Safety: |
239 | | // - Arguments validated above. |
240 | 0 | let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) }; |
241 | 0 | Ok(union_array) |
242 | 0 | } |
243 | | |
244 | | /// Accesses the child array for `type_id`. |
245 | | /// |
246 | | /// # Panics |
247 | | /// |
248 | | /// Panics if the `type_id` provided is not present in the array's DataType |
249 | | /// in the `Union`. |
250 | 0 | pub fn child(&self, type_id: i8) -> &ArrayRef { |
251 | 0 | assert!((type_id as usize) < self.fields.len()); |
252 | 0 | let boxed = &self.fields[type_id as usize]; |
253 | 0 | boxed.as_ref().expect("invalid type id") |
254 | 0 | } |
255 | | |
256 | | /// Returns the `type_id` for the array slot at `index`. |
257 | | /// |
258 | | /// # Panics |
259 | | /// |
260 | | /// Panics if `index` is greater than or equal to the number of child arrays |
261 | 0 | pub fn type_id(&self, index: usize) -> i8 { |
262 | 0 | assert!(index < self.type_ids.len()); |
263 | 0 | self.type_ids[index] |
264 | 0 | } |
265 | | |
266 | | /// Returns the `type_ids` buffer for this array |
267 | 0 | pub fn type_ids(&self) -> &ScalarBuffer<i8> { |
268 | 0 | &self.type_ids |
269 | 0 | } |
270 | | |
271 | | /// Returns the `offsets` buffer if this is a dense array |
272 | 0 | pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> { |
273 | 0 | self.offsets.as_ref() |
274 | 0 | } |
275 | | |
276 | | /// Returns the offset into the underlying values array for the array slot at `index`. |
277 | | /// |
278 | | /// # Panics |
279 | | /// |
280 | | /// Panics if `index` is greater than or equal the length of the array. |
281 | 0 | pub fn value_offset(&self, index: usize) -> usize { |
282 | 0 | assert!(index < self.len()); |
283 | 0 | match &self.offsets { |
284 | 0 | Some(offsets) => offsets[index] as usize, |
285 | 0 | None => self.offset() + index, |
286 | | } |
287 | 0 | } |
288 | | |
289 | | /// Returns the array's value at index `i`. |
290 | | /// |
291 | | /// Note: This method does not check for nulls and the value is arbitrary |
292 | | /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. |
293 | | /// |
294 | | /// # Panics |
295 | | /// Panics if index `i` is out of bounds |
296 | 0 | pub fn value(&self, i: usize) -> ArrayRef { |
297 | 0 | let type_id = self.type_id(i); |
298 | 0 | let value_offset = self.value_offset(i); |
299 | 0 | let child = self.child(type_id); |
300 | 0 | child.slice(value_offset, 1) |
301 | 0 | } |
302 | | |
303 | | /// Returns the names of the types in the union. |
304 | 0 | pub fn type_names(&self) -> Vec<&str> { |
305 | 0 | match self.data_type() { |
306 | 0 | DataType::Union(fields, _) => fields |
307 | 0 | .iter() |
308 | 0 | .map(|(_, f)| f.name().as_str()) |
309 | 0 | .collect::<Vec<&str>>(), |
310 | 0 | _ => unreachable!("Union array's data type is not a union!"), |
311 | | } |
312 | 0 | } |
313 | | |
314 | | /// Returns whether the `UnionArray` is dense (or sparse if `false`). |
315 | 0 | fn is_dense(&self) -> bool { |
316 | 0 | match self.data_type() { |
317 | 0 | DataType::Union(_, mode) => mode == &UnionMode::Dense, |
318 | 0 | _ => unreachable!("Union array's data type is not a union!"), |
319 | | } |
320 | 0 | } |
321 | | |
322 | | /// Returns a zero-copy slice of this array with the indicated offset and length. |
323 | 0 | pub fn slice(&self, offset: usize, length: usize) -> Self { |
324 | 0 | let (offsets, fields) = match self.offsets.as_ref() { |
325 | | // If dense union, slice offsets |
326 | 0 | Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()), |
327 | | // Otherwise need to slice sparse children |
328 | | None => { |
329 | 0 | let fields = self |
330 | 0 | .fields |
331 | 0 | .iter() |
332 | 0 | .map(|x| x.as_ref().map(|x| x.slice(offset, length))) |
333 | 0 | .collect(); |
334 | 0 | (None, fields) |
335 | | } |
336 | | }; |
337 | | |
338 | 0 | Self { |
339 | 0 | data_type: self.data_type.clone(), |
340 | 0 | type_ids: self.type_ids.slice(offset, length), |
341 | 0 | offsets, |
342 | 0 | fields, |
343 | 0 | } |
344 | 0 | } |
345 | | |
346 | | /// Deconstruct this array into its constituent parts |
347 | | /// |
348 | | /// # Example |
349 | | /// |
350 | | /// ``` |
351 | | /// # use arrow_array::array::UnionArray; |
352 | | /// # use arrow_array::types::Int32Type; |
353 | | /// # use arrow_array::builder::UnionBuilder; |
354 | | /// # use arrow_buffer::ScalarBuffer; |
355 | | /// # fn main() -> Result<(), arrow_schema::ArrowError> { |
356 | | /// let mut builder = UnionBuilder::new_dense(); |
357 | | /// builder.append::<Int32Type>("a", 1).unwrap(); |
358 | | /// let union_array = builder.build()?; |
359 | | /// |
360 | | /// // Deconstruct into parts |
361 | | /// let (union_fields, type_ids, offsets, children) = union_array.into_parts(); |
362 | | /// |
363 | | /// // Reconstruct from parts |
364 | | /// let union_array = UnionArray::try_new( |
365 | | /// union_fields, |
366 | | /// type_ids, |
367 | | /// offsets, |
368 | | /// children, |
369 | | /// ); |
370 | | /// # Ok(()) |
371 | | /// # } |
372 | | /// ``` |
373 | | #[allow(clippy::type_complexity)] |
374 | 0 | pub fn into_parts( |
375 | 0 | self, |
376 | 0 | ) -> ( |
377 | 0 | UnionFields, |
378 | 0 | ScalarBuffer<i8>, |
379 | 0 | Option<ScalarBuffer<i32>>, |
380 | 0 | Vec<ArrayRef>, |
381 | 0 | ) { |
382 | | let Self { |
383 | 0 | data_type, |
384 | 0 | type_ids, |
385 | 0 | offsets, |
386 | 0 | mut fields, |
387 | 0 | } = self; |
388 | 0 | match data_type { |
389 | 0 | DataType::Union(union_fields, _) => { |
390 | 0 | let children = union_fields |
391 | 0 | .iter() |
392 | 0 | .map(|(type_id, _)| fields[type_id as usize].take().unwrap()) |
393 | 0 | .collect(); |
394 | 0 | (union_fields, type_ids, offsets, children) |
395 | | } |
396 | 0 | _ => unreachable!(), |
397 | | } |
398 | 0 | } |
399 | | |
400 | | /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields without nulls |
401 | 0 | fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer { |
402 | | // Example logic for a union with 5 fields, a, b & c with nulls, d & e without nulls: |
403 | | // let [a_nulls, b_nulls, c_nulls] = nulls; |
404 | | // let [is_a, is_b, is_c] = masks; |
405 | | // let is_d_or_e = !(is_a | is_b | is_c) |
406 | | // let union_chunk_nulls = is_d_or_e | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls) |
407 | 0 | let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| { |
408 | 0 | ( |
409 | 0 | with_nulls_selected | is_field, |
410 | 0 | union_nulls | (is_field & field_nulls), |
411 | 0 | ) |
412 | 0 | }; |
413 | | |
414 | 0 | self.mask_sparse_helper( |
415 | 0 | nulls, |
416 | 0 | |type_ids_chunk_array, nulls_masks_iters| { |
417 | 0 | let (with_nulls_selected, union_nulls) = nulls_masks_iters |
418 | 0 | .iter_mut() |
419 | 0 | .map(|(field_type_id, field_nulls)| { |
420 | 0 | let field_nulls = field_nulls.next().unwrap(); |
421 | 0 | let is_field = selection_mask(type_ids_chunk_array, *field_type_id); |
422 | | |
423 | 0 | (is_field, field_nulls) |
424 | 0 | }) |
425 | 0 | .fold((0, 0), fold); |
426 | | |
427 | | // In the example above, this is the is_d_or_e = !(is_a | is_b) part |
428 | 0 | let without_nulls_selected = !with_nulls_selected; |
429 | | |
430 | | // if a field without nulls is selected, the value is always true(set bit) |
431 | | // otherwise, the true/set bits have been computed above |
432 | 0 | without_nulls_selected | union_nulls |
433 | 0 | }, |
434 | 0 | |type_ids_remainder, bit_chunks| { |
435 | 0 | let (with_nulls_selected, union_nulls) = bit_chunks |
436 | 0 | .iter() |
437 | 0 | .map(|(field_type_id, field_bit_chunks)| { |
438 | 0 | let field_nulls = field_bit_chunks.remainder_bits(); |
439 | 0 | let is_field = selection_mask(type_ids_remainder, *field_type_id); |
440 | | |
441 | 0 | (is_field, field_nulls) |
442 | 0 | }) |
443 | 0 | .fold((0, 0), fold); |
444 | | |
445 | 0 | let without_nulls_selected = !with_nulls_selected; |
446 | | |
447 | 0 | without_nulls_selected | union_nulls |
448 | 0 | }, |
449 | | ) |
450 | 0 | } |
451 | | |
452 | | /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields fully null |
453 | 0 | fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer { |
454 | 0 | let fields = match self.data_type() { |
455 | 0 | DataType::Union(fields, _) => fields, |
456 | 0 | _ => unreachable!("Union array's data type is not a union!"), |
457 | | }; |
458 | | |
459 | 0 | let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>(); |
460 | 0 | let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>(); |
461 | | |
462 | 0 | let without_nulls_ids = type_ids |
463 | 0 | .difference(&with_nulls) |
464 | 0 | .copied() |
465 | 0 | .collect::<Vec<_>>(); |
466 | | |
467 | 0 | nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len()); |
468 | | |
469 | | // Example logic for a union with 6 fields, a, b & c with nulls, d & e without nulls, and f fully_null: |
470 | | // let [a_nulls, b_nulls, c_nulls] = nulls; |
471 | | // let [is_a, is_b, is_c, is_d, is_e] = masks; |
472 | | // let union_chunk_nulls = is_d | is_e | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls) |
473 | 0 | self.mask_sparse_helper( |
474 | 0 | nulls, |
475 | 0 | |type_ids_chunk_array, nulls_masks_iters| { |
476 | 0 | let union_nulls = nulls_masks_iters.iter_mut().fold( |
477 | | 0, |
478 | 0 | |union_nulls, (field_type_id, nulls_iter)| { |
479 | 0 | let field_nulls = nulls_iter.next().unwrap(); |
480 | | |
481 | 0 | if field_nulls == 0 { |
482 | 0 | union_nulls |
483 | | } else { |
484 | 0 | let is_field = selection_mask(type_ids_chunk_array, *field_type_id); |
485 | | |
486 | 0 | union_nulls | (is_field & field_nulls) |
487 | | } |
488 | 0 | }, |
489 | | ); |
490 | | |
491 | | // Given the example above, this is the is_d_or_e = (is_d | is_e) part |
492 | 0 | let without_nulls_selected = |
493 | 0 | without_nulls_selected(type_ids_chunk_array, &without_nulls_ids); |
494 | | |
495 | | // if a field without nulls is selected, the value is always true(set bit) |
496 | | // otherwise, the true/set bits have been computed above |
497 | 0 | union_nulls | without_nulls_selected |
498 | 0 | }, |
499 | 0 | |type_ids_remainder, bit_chunks| { |
500 | 0 | let union_nulls = |
501 | 0 | bit_chunks |
502 | 0 | .iter() |
503 | 0 | .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| { |
504 | 0 | let is_field = selection_mask(type_ids_remainder, *field_type_id); |
505 | 0 | let field_nulls = field_bit_chunks.remainder_bits(); |
506 | | |
507 | 0 | union_nulls | is_field & field_nulls |
508 | 0 | }); |
509 | | |
510 | 0 | union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids) |
511 | 0 | }, |
512 | | ) |
513 | 0 | } |
514 | | |
515 | | /// Computes the logical nulls for a sparse union, optimized for when all fields contains nulls |
516 | 0 | fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer { |
517 | | // Example logic for a union with 3 fields, a, b & c, all containing nulls: |
518 | | // let [a_nulls, b_nulls, c_nulls] = nulls; |
519 | | // We can skip the first field: it's selection mask is the negation of all others selection mask |
520 | | // let [is_b, is_c] = selection_masks; |
521 | | // let is_a = !(is_b | is_c) |
522 | | // let union_chunk_nulls = (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls) |
523 | 0 | self.mask_sparse_helper( |
524 | 0 | nulls, |
525 | 0 | |type_ids_chunk_array, nulls_masks_iters| { |
526 | 0 | let (is_not_first, union_nulls) = nulls_masks_iters[1..] // skip first |
527 | 0 | .iter_mut() |
528 | 0 | .fold( |
529 | 0 | (0, 0), |
530 | 0 | |(is_not_first, union_nulls), (field_type_id, nulls_iter)| { |
531 | 0 | let field_nulls = nulls_iter.next().unwrap(); |
532 | 0 | let is_field = selection_mask(type_ids_chunk_array, *field_type_id); |
533 | | |
534 | 0 | ( |
535 | 0 | is_not_first | is_field, |
536 | 0 | union_nulls | (is_field & field_nulls), |
537 | 0 | ) |
538 | 0 | }, |
539 | | ); |
540 | | |
541 | 0 | let is_first = !is_not_first; |
542 | 0 | let first_nulls = nulls_masks_iters[0].1.next().unwrap(); |
543 | | |
544 | 0 | (is_first & first_nulls) | union_nulls |
545 | 0 | }, |
546 | 0 | |type_ids_remainder, bit_chunks| { |
547 | 0 | bit_chunks |
548 | 0 | .iter() |
549 | 0 | .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| { |
550 | 0 | let field_nulls = field_bit_chunks.remainder_bits(); |
551 | | // The same logic as above, except that since this runs at most once, |
552 | | // it doesn't make difference to speed-up the first selection mask |
553 | 0 | let is_field = selection_mask(type_ids_remainder, *field_type_id); |
554 | | |
555 | 0 | union_nulls | (is_field & field_nulls) |
556 | 0 | }) |
557 | 0 | }, |
558 | | ) |
559 | 0 | } |
560 | | |
561 | | /// Maps `nulls` to `BitChunk's` and then to `BitChunkIterator's`, then divides `self.type_ids` into exact chunks of 64 values, |
562 | | /// calling `mask_chunk` for every exact chunk, and `mask_remainder` for the remainder, if any, collecting the result in a `BooleanBuffer` |
563 | 0 | fn mask_sparse_helper( |
564 | 0 | &self, |
565 | 0 | nulls: Vec<(i8, NullBuffer)>, |
566 | 0 | mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64, |
567 | 0 | mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64, |
568 | 0 | ) -> BooleanBuffer { |
569 | 0 | let bit_chunks = nulls |
570 | 0 | .iter() |
571 | 0 | .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks())) |
572 | 0 | .collect::<Vec<_>>(); |
573 | | |
574 | 0 | let mut nulls_masks_iter = bit_chunks |
575 | 0 | .iter() |
576 | 0 | .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter())) |
577 | 0 | .collect::<Vec<_>>(); |
578 | | |
579 | 0 | let chunks_exact = self.type_ids.chunks_exact(64); |
580 | 0 | let remainder = chunks_exact.remainder(); |
581 | | |
582 | 0 | let chunks = chunks_exact.map(|type_ids_chunk| { |
583 | 0 | let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap(); |
584 | | |
585 | 0 | mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter) |
586 | 0 | }); |
587 | | |
588 | | // SAFETY: |
589 | | // chunks is a ChunksExact iterator, which implements TrustedLen, and correctly reports its length |
590 | 0 | let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) }; |
591 | | |
592 | 0 | if !remainder.is_empty() { |
593 | 0 | buffer.push(mask_remainder(remainder, &bit_chunks)); |
594 | 0 | } |
595 | | |
596 | 0 | BooleanBuffer::new(buffer.into(), 0, self.type_ids.len()) |
597 | 0 | } |
598 | | |
599 | | /// Computes the logical nulls for a sparse or dense union, by gathering individual bits from the null buffer of the selected field |
600 | 0 | fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer { |
601 | 0 | let one_null = NullBuffer::new_null(1); |
602 | 0 | let one_valid = NullBuffer::new_valid(1); |
603 | | |
604 | | // Unsafe code below depend on it: |
605 | | // To remove one branch from the loop, if the a type_id is not utilized, or it's logical_nulls is None/all set, |
606 | | // we use a null buffer of len 1 and a index_mask of 0, or the true null buffer and usize::MAX otherwise. |
607 | | // We then unconditionally access the null buffer with index & index_mask, |
608 | | // which always return 0 for the 1-len buffer, or the true index unchanged otherwise |
609 | | // We also use a 256 array, so llvm knows that `type_id as u8 as usize` is always in bounds |
610 | 0 | let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256]; |
611 | | |
612 | 0 | for (type_id, nulls) in &nulls { |
613 | 0 | if nulls.null_count() == nulls.len() { |
614 | 0 | // Similarly, if all values are null, use a 1-null null-buffer to reduce cache pressure a bit |
615 | 0 | logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero); |
616 | 0 | } else { |
617 | 0 | logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max); |
618 | 0 | } |
619 | | } |
620 | | |
621 | 0 | match &self.offsets { |
622 | 0 | Some(offsets) => { |
623 | 0 | assert_eq!(self.type_ids.len(), offsets.len()); |
624 | | |
625 | 0 | BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe { |
626 | | // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len() |
627 | 0 | let type_id = *self.type_ids.get_unchecked(i); |
628 | | // SAFETY: We asserted that offsets len and self.type_ids len are equal |
629 | 0 | let offset = *offsets.get_unchecked(i); |
630 | | |
631 | 0 | let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize]; |
632 | | |
633 | | // SAFETY: |
634 | | // If offset_mask is Max |
635 | | // 1. Offset validity is checked at union creation |
636 | | // 2. If the null buffer len equals it's array len is checked at array creation |
637 | | // If offset_mask is Zero, the null buffer len is 1 |
638 | 0 | nulls |
639 | 0 | .inner() |
640 | 0 | .value_unchecked(offset as usize & *offset_mask as usize) |
641 | 0 | }) |
642 | | } |
643 | | None => { |
644 | 0 | BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe { |
645 | | // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len() |
646 | 0 | let type_id = *self.type_ids.get_unchecked(index); |
647 | | |
648 | 0 | let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize]; |
649 | | |
650 | | // SAFETY: |
651 | | // If index_mask is Max |
652 | | // 1. On sparse union, every child len match it's parent, this is checked at union creation |
653 | | // 2. If the null buffer len equals it's array len is checked at array creation |
654 | | // If index_mask is Zero, the null buffer len is 1 |
655 | 0 | nulls.inner().value_unchecked(index & *index_mask as usize) |
656 | 0 | }) |
657 | | } |
658 | | } |
659 | 0 | } |
660 | | |
661 | | /// Returns a vector of tuples containing each field's type_id and its logical null buffer. |
662 | | /// Only fields with non-zero null counts are included. |
663 | 0 | fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> { |
664 | 0 | self.fields |
665 | 0 | .iter() |
666 | 0 | .enumerate() |
667 | 0 | .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?))) |
668 | 0 | .filter(|(_, nulls)| nulls.null_count() > 0) |
669 | 0 | .collect() |
670 | 0 | } |
671 | | } |
672 | | |
673 | | impl From<ArrayData> for UnionArray { |
674 | 0 | fn from(data: ArrayData) -> Self { |
675 | 0 | let (fields, mode) = match data.data_type() { |
676 | 0 | DataType::Union(fields, mode) => (fields, *mode), |
677 | 0 | d => panic!("UnionArray expected ArrayData with type Union got {d}"), |
678 | | }; |
679 | 0 | let (type_ids, offsets) = match mode { |
680 | 0 | UnionMode::Sparse => ( |
681 | 0 | ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()), |
682 | 0 | None, |
683 | 0 | ), |
684 | 0 | UnionMode::Dense => ( |
685 | 0 | ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()), |
686 | 0 | Some(ScalarBuffer::new( |
687 | 0 | data.buffers()[1].clone(), |
688 | 0 | data.offset(), |
689 | 0 | data.len(), |
690 | 0 | )), |
691 | 0 | ), |
692 | | }; |
693 | | |
694 | 0 | let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize; |
695 | 0 | let mut boxed_fields = vec![None; max_id + 1]; |
696 | 0 | for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) { |
697 | 0 | boxed_fields[field_id as usize] = Some(make_array(cd.clone())); |
698 | 0 | } |
699 | 0 | Self { |
700 | 0 | data_type: data.data_type().clone(), |
701 | 0 | type_ids, |
702 | 0 | offsets, |
703 | 0 | fields: boxed_fields, |
704 | 0 | } |
705 | 0 | } |
706 | | } |
707 | | |
708 | | impl From<UnionArray> for ArrayData { |
709 | 0 | fn from(array: UnionArray) -> Self { |
710 | 0 | let len = array.len(); |
711 | 0 | let f = match &array.data_type { |
712 | 0 | DataType::Union(f, _) => f, |
713 | 0 | _ => unreachable!(), |
714 | | }; |
715 | 0 | let buffers = match array.offsets { |
716 | 0 | Some(o) => vec![array.type_ids.into_inner(), o.into_inner()], |
717 | 0 | None => vec![array.type_ids.into_inner()], |
718 | | }; |
719 | | |
720 | 0 | let child = f |
721 | 0 | .iter() |
722 | 0 | .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data()) |
723 | 0 | .collect(); |
724 | | |
725 | 0 | let builder = ArrayDataBuilder::new(array.data_type) |
726 | 0 | .len(len) |
727 | 0 | .buffers(buffers) |
728 | 0 | .child_data(child); |
729 | 0 | unsafe { builder.build_unchecked() } |
730 | 0 | } |
731 | | } |
732 | | |
733 | | impl Array for UnionArray { |
734 | 0 | fn as_any(&self) -> &dyn Any { |
735 | 0 | self |
736 | 0 | } |
737 | | |
738 | 0 | fn to_data(&self) -> ArrayData { |
739 | 0 | self.clone().into() |
740 | 0 | } |
741 | | |
742 | 0 | fn into_data(self) -> ArrayData { |
743 | 0 | self.into() |
744 | 0 | } |
745 | | |
746 | 0 | fn data_type(&self) -> &DataType { |
747 | 0 | &self.data_type |
748 | 0 | } |
749 | | |
750 | 0 | fn slice(&self, offset: usize, length: usize) -> ArrayRef { |
751 | 0 | Arc::new(self.slice(offset, length)) |
752 | 0 | } |
753 | | |
754 | 0 | fn len(&self) -> usize { |
755 | 0 | self.type_ids.len() |
756 | 0 | } |
757 | | |
758 | 0 | fn is_empty(&self) -> bool { |
759 | 0 | self.type_ids.is_empty() |
760 | 0 | } |
761 | | |
762 | 0 | fn shrink_to_fit(&mut self) { |
763 | 0 | self.type_ids.shrink_to_fit(); |
764 | 0 | if let Some(offsets) = &mut self.offsets { |
765 | 0 | offsets.shrink_to_fit(); |
766 | 0 | } |
767 | 0 | for array in self.fields.iter_mut().flatten() { |
768 | 0 | array.shrink_to_fit(); |
769 | 0 | } |
770 | 0 | self.fields.shrink_to_fit(); |
771 | 0 | } |
772 | | |
773 | 0 | fn offset(&self) -> usize { |
774 | 0 | 0 |
775 | 0 | } |
776 | | |
777 | 0 | fn nulls(&self) -> Option<&NullBuffer> { |
778 | 0 | None |
779 | 0 | } |
780 | | |
781 | 0 | fn logical_nulls(&self) -> Option<NullBuffer> { |
782 | 0 | let fields = match self.data_type() { |
783 | 0 | DataType::Union(fields, _) => fields, |
784 | 0 | _ => unreachable!(), |
785 | | }; |
786 | | |
787 | 0 | if fields.len() <= 1 { |
788 | 0 | return self.fields.iter().find_map(|field_opt| { |
789 | 0 | field_opt |
790 | 0 | .as_ref() |
791 | 0 | .and_then(|field| field.logical_nulls()) |
792 | 0 | .map(|logical_nulls| { |
793 | 0 | if self.is_dense() { |
794 | 0 | self.gather_nulls(vec![(0, logical_nulls)]).into() |
795 | | } else { |
796 | 0 | logical_nulls |
797 | | } |
798 | 0 | }) |
799 | 0 | }); |
800 | 0 | } |
801 | | |
802 | 0 | let logical_nulls = self.fields_logical_nulls(); |
803 | | |
804 | 0 | if logical_nulls.is_empty() { |
805 | 0 | return None; |
806 | 0 | } |
807 | | |
808 | 0 | let fully_null_count = logical_nulls |
809 | 0 | .iter() |
810 | 0 | .filter(|(_, nulls)| nulls.null_count() == nulls.len()) |
811 | 0 | .count(); |
812 | | |
813 | 0 | if fully_null_count == fields.len() { |
814 | 0 | if let Some((_, exactly_sized)) = logical_nulls |
815 | 0 | .iter() |
816 | 0 | .find(|(_, nulls)| nulls.len() == self.len()) |
817 | | { |
818 | 0 | return Some(exactly_sized.clone()); |
819 | 0 | } |
820 | | |
821 | 0 | if let Some((_, bigger)) = logical_nulls |
822 | 0 | .iter() |
823 | 0 | .find(|(_, nulls)| nulls.len() > self.len()) |
824 | | { |
825 | 0 | return Some(bigger.slice(0, self.len())); |
826 | 0 | } |
827 | | |
828 | 0 | return Some(NullBuffer::new_null(self.len())); |
829 | 0 | } |
830 | | |
831 | 0 | let boolean_buffer = match &self.offsets { |
832 | 0 | Some(_) => self.gather_nulls(logical_nulls), |
833 | | None => { |
834 | | // Choose the fastest way to compute the logical nulls |
835 | | // Gather computes one null per iteration, while the others work on 64 nulls chunks, |
836 | | // but must also compute selection masks, which is expensive, |
837 | | // so it's cost is the number of selection masks computed per chunk |
838 | | // Since computing the selection mask gets auto-vectorized, it's performance depends on which simd feature is enabled |
839 | | // For gather, the cost is the threshold where masking becomes slower than gather, which is determined with benchmarks |
840 | | // TODO: bench on avx512f(feature is still unstable) |
841 | 0 | let gather_relative_cost = if cfg!(target_feature = "avx2") { |
842 | 0 | 10 |
843 | 0 | } else if cfg!(target_feature = "sse4.1") { |
844 | 0 | 3 |
845 | 0 | } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") { |
846 | | // x86 baseline includes sse2 |
847 | 0 | 2 |
848 | | } else { |
849 | | // TODO: bench on non x86 |
850 | | // Always use gather on non benchmarked archs because even though it may slower on some cases, |
851 | | // it's performance depends only on the union length, without being affected by the number of fields |
852 | 0 | 0 |
853 | | }; |
854 | | |
855 | 0 | let strategies = [ |
856 | 0 | (SparseStrategy::Gather, gather_relative_cost, true), |
857 | 0 | ( |
858 | 0 | SparseStrategy::MaskAllFieldsWithNullsSkipOne, |
859 | 0 | fields.len() - 1, |
860 | 0 | fields.len() == logical_nulls.len(), |
861 | 0 | ), |
862 | 0 | ( |
863 | 0 | SparseStrategy::MaskSkipWithoutNulls, |
864 | 0 | logical_nulls.len(), |
865 | 0 | true, |
866 | 0 | ), |
867 | 0 | ( |
868 | 0 | SparseStrategy::MaskSkipFullyNull, |
869 | 0 | fields.len() - fully_null_count, |
870 | 0 | true, |
871 | 0 | ), |
872 | 0 | ]; |
873 | | |
874 | 0 | let (strategy, _, _) = strategies |
875 | 0 | .iter() |
876 | 0 | .filter(|(_, _, applicable)| *applicable) |
877 | 0 | .min_by_key(|(_, cost, _)| cost) |
878 | 0 | .unwrap(); |
879 | | |
880 | 0 | match strategy { |
881 | 0 | SparseStrategy::Gather => self.gather_nulls(logical_nulls), |
882 | | SparseStrategy::MaskAllFieldsWithNullsSkipOne => { |
883 | 0 | self.mask_sparse_all_with_nulls_skip_one(logical_nulls) |
884 | | } |
885 | | SparseStrategy::MaskSkipWithoutNulls => { |
886 | 0 | self.mask_sparse_skip_without_nulls(logical_nulls) |
887 | | } |
888 | | SparseStrategy::MaskSkipFullyNull => { |
889 | 0 | self.mask_sparse_skip_fully_null(logical_nulls) |
890 | | } |
891 | | } |
892 | | } |
893 | | }; |
894 | | |
895 | 0 | let null_buffer = NullBuffer::from(boolean_buffer); |
896 | | |
897 | 0 | if null_buffer.null_count() > 0 { |
898 | 0 | Some(null_buffer) |
899 | | } else { |
900 | 0 | None |
901 | | } |
902 | 0 | } |
903 | | |
904 | 0 | fn is_nullable(&self) -> bool { |
905 | 0 | self.fields |
906 | 0 | .iter() |
907 | 0 | .flatten() |
908 | 0 | .any(|field| field.is_nullable()) |
909 | 0 | } |
910 | | |
911 | 0 | fn get_buffer_memory_size(&self) -> usize { |
912 | 0 | let mut sum = self.type_ids.inner().capacity(); |
913 | 0 | if let Some(o) = self.offsets.as_ref() { |
914 | 0 | sum += o.inner().capacity() |
915 | 0 | } |
916 | 0 | self.fields |
917 | 0 | .iter() |
918 | 0 | .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size())) |
919 | 0 | .sum::<usize>() |
920 | 0 | + sum |
921 | 0 | } |
922 | | |
923 | 0 | fn get_array_memory_size(&self) -> usize { |
924 | 0 | let mut sum = self.type_ids.inner().capacity(); |
925 | 0 | if let Some(o) = self.offsets.as_ref() { |
926 | 0 | sum += o.inner().capacity() |
927 | 0 | } |
928 | 0 | std::mem::size_of::<Self>() |
929 | 0 | + self |
930 | 0 | .fields |
931 | 0 | .iter() |
932 | 0 | .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size())) |
933 | 0 | .sum::<usize>() |
934 | 0 | + sum |
935 | 0 | } |
936 | | } |
937 | | |
938 | | impl std::fmt::Debug for UnionArray { |
939 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
940 | 0 | let header = if self.is_dense() { |
941 | 0 | "UnionArray(Dense)\n[" |
942 | | } else { |
943 | 0 | "UnionArray(Sparse)\n[" |
944 | | }; |
945 | 0 | writeln!(f, "{header}")?; |
946 | | |
947 | 0 | writeln!(f, "-- type id buffer:")?; |
948 | 0 | writeln!(f, "{:?}", self.type_ids)?; |
949 | | |
950 | 0 | if let Some(offsets) = &self.offsets { |
951 | 0 | writeln!(f, "-- offsets buffer:")?; |
952 | 0 | writeln!(f, "{offsets:?}")?; |
953 | 0 | } |
954 | | |
955 | 0 | let fields = match self.data_type() { |
956 | 0 | DataType::Union(fields, _) => fields, |
957 | 0 | _ => unreachable!(), |
958 | | }; |
959 | | |
960 | 0 | for (type_id, field) in fields.iter() { |
961 | 0 | let child = self.child(type_id); |
962 | 0 | writeln!( |
963 | 0 | f, |
964 | 0 | "-- child {}: \"{}\" ({:?})", |
965 | | type_id, |
966 | 0 | field.name(), |
967 | 0 | field.data_type() |
968 | 0 | )?; |
969 | 0 | std::fmt::Debug::fmt(child, f)?; |
970 | 0 | writeln!(f)?; |
971 | | } |
972 | 0 | writeln!(f, "]") |
973 | 0 | } |
974 | | } |
975 | | |
976 | | /// How to compute the logical nulls of a sparse union. All strategies return the same result. |
977 | | /// Those starting with Mask perform bitwise masking for each chunk of 64 values, including |
978 | | /// computing expensive selection masks of fields: which fields masks must be computed is the |
979 | | /// difference between them |
980 | | enum SparseStrategy { |
981 | | /// Gather individual bits from the null buffer of the selected field |
982 | | Gather, |
983 | | /// All fields contains nulls, so we can skip the selection mask computation of one field by negating the others |
984 | | MaskAllFieldsWithNullsSkipOne, |
985 | | /// Skip the selection mask computation of the fields without nulls |
986 | | MaskSkipWithoutNulls, |
987 | | /// Skip the selection mask computation of the fully nulls fields |
988 | | MaskSkipFullyNull, |
989 | | } |
990 | | |
991 | | #[derive(Copy, Clone)] |
992 | | #[repr(usize)] |
993 | | enum Mask { |
994 | | Zero = 0, |
995 | | // false positive, see https://github.com/rust-lang/rust-clippy/issues/8043 |
996 | | #[allow(clippy::enum_clike_unportable_variant)] |
997 | | Max = usize::MAX, |
998 | | } |
999 | | |
1000 | 0 | fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 { |
1001 | 0 | type_ids_chunk |
1002 | 0 | .iter() |
1003 | 0 | .copied() |
1004 | 0 | .enumerate() |
1005 | 0 | .fold(0, |packed, (bit_idx, v)| { |
1006 | 0 | packed | (((v == type_id) as u64) << bit_idx) |
1007 | 0 | }) |
1008 | 0 | } |
1009 | | |
1010 | | /// Returns a bitmask where bits indicate if any id from `without_nulls_ids` exist in `type_ids_chunk`. |
1011 | 0 | fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 { |
1012 | 0 | without_nulls_ids |
1013 | 0 | .iter() |
1014 | 0 | .fold(0, |fully_valid_selected, field_type_id| { |
1015 | 0 | fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id) |
1016 | 0 | }) |
1017 | 0 | } |
1018 | | |
1019 | | #[cfg(test)] |
1020 | | mod tests { |
1021 | | use super::*; |
1022 | | use std::collections::HashSet; |
1023 | | |
1024 | | use crate::array::Int8Type; |
1025 | | use crate::builder::UnionBuilder; |
1026 | | use crate::cast::AsArray; |
1027 | | use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type}; |
1028 | | use crate::{Float64Array, Int32Array, Int64Array, StringArray}; |
1029 | | use crate::{Int8Array, RecordBatch}; |
1030 | | use arrow_buffer::Buffer; |
1031 | | use arrow_schema::{Field, Schema}; |
1032 | | |
1033 | | #[test] |
1034 | | fn test_dense_i32() { |
1035 | | let mut builder = UnionBuilder::new_dense(); |
1036 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1037 | | builder.append::<Int32Type>("b", 2).unwrap(); |
1038 | | builder.append::<Int32Type>("c", 3).unwrap(); |
1039 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1040 | | builder.append::<Int32Type>("c", 5).unwrap(); |
1041 | | builder.append::<Int32Type>("a", 6).unwrap(); |
1042 | | builder.append::<Int32Type>("b", 7).unwrap(); |
1043 | | let union = builder.build().unwrap(); |
1044 | | |
1045 | | let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1]; |
1046 | | let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1]; |
1047 | | let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7]; |
1048 | | |
1049 | | // Check type ids |
1050 | | assert_eq!(*union.type_ids(), expected_type_ids); |
1051 | | for (i, id) in expected_type_ids.iter().enumerate() { |
1052 | | assert_eq!(id, &union.type_id(i)); |
1053 | | } |
1054 | | |
1055 | | // Check offsets |
1056 | | assert_eq!(*union.offsets().unwrap(), expected_offsets); |
1057 | | for (i, id) in expected_offsets.iter().enumerate() { |
1058 | | assert_eq!(union.value_offset(i), *id as usize); |
1059 | | } |
1060 | | |
1061 | | // Check data |
1062 | | assert_eq!( |
1063 | | *union.child(0).as_primitive::<Int32Type>().values(), |
1064 | | [1_i32, 4, 6] |
1065 | | ); |
1066 | | assert_eq!( |
1067 | | *union.child(1).as_primitive::<Int32Type>().values(), |
1068 | | [2_i32, 7] |
1069 | | ); |
1070 | | assert_eq!( |
1071 | | *union.child(2).as_primitive::<Int32Type>().values(), |
1072 | | [3_i32, 5] |
1073 | | ); |
1074 | | |
1075 | | assert_eq!(expected_array_values.len(), union.len()); |
1076 | | for (i, expected_value) in expected_array_values.iter().enumerate() { |
1077 | | assert!(!union.is_null(i)); |
1078 | | let slot = union.value(i); |
1079 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1080 | | assert_eq!(slot.len(), 1); |
1081 | | let value = slot.value(0); |
1082 | | assert_eq!(expected_value, &value); |
1083 | | } |
1084 | | } |
1085 | | |
1086 | | #[test] |
1087 | | fn slice_union_array_single_field() { |
1088 | | // Dense Union |
1089 | | // [1, null, 3, null, 4] |
1090 | | let union_array = { |
1091 | | let mut builder = UnionBuilder::new_dense(); |
1092 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1093 | | builder.append_null::<Int32Type>("a").unwrap(); |
1094 | | builder.append::<Int32Type>("a", 3).unwrap(); |
1095 | | builder.append_null::<Int32Type>("a").unwrap(); |
1096 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1097 | | builder.build().unwrap() |
1098 | | }; |
1099 | | |
1100 | | // [null, 3, null] |
1101 | | let union_slice = union_array.slice(1, 3); |
1102 | | let logical_nulls = union_slice.logical_nulls().unwrap(); |
1103 | | |
1104 | | assert_eq!(logical_nulls.len(), 3); |
1105 | | assert!(logical_nulls.is_null(0)); |
1106 | | assert!(logical_nulls.is_valid(1)); |
1107 | | assert!(logical_nulls.is_null(2)); |
1108 | | } |
1109 | | |
1110 | | #[test] |
1111 | | #[cfg_attr(miri, ignore)] |
1112 | | fn test_dense_i32_large() { |
1113 | | let mut builder = UnionBuilder::new_dense(); |
1114 | | |
1115 | | let expected_type_ids = vec![0_i8; 1024]; |
1116 | | let expected_offsets: Vec<_> = (0..1024).collect(); |
1117 | | let expected_array_values: Vec<_> = (1..=1024).collect(); |
1118 | | |
1119 | | expected_array_values |
1120 | | .iter() |
1121 | | .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap()); |
1122 | | |
1123 | | let union = builder.build().unwrap(); |
1124 | | |
1125 | | // Check type ids |
1126 | | assert_eq!(*union.type_ids(), expected_type_ids); |
1127 | | for (i, id) in expected_type_ids.iter().enumerate() { |
1128 | | assert_eq!(id, &union.type_id(i)); |
1129 | | } |
1130 | | |
1131 | | // Check offsets |
1132 | | assert_eq!(*union.offsets().unwrap(), expected_offsets); |
1133 | | for (i, id) in expected_offsets.iter().enumerate() { |
1134 | | assert_eq!(union.value_offset(i), *id as usize); |
1135 | | } |
1136 | | |
1137 | | for (i, expected_value) in expected_array_values.iter().enumerate() { |
1138 | | assert!(!union.is_null(i)); |
1139 | | let slot = union.value(i); |
1140 | | let slot = slot.as_primitive::<Int32Type>(); |
1141 | | assert_eq!(slot.len(), 1); |
1142 | | let value = slot.value(0); |
1143 | | assert_eq!(expected_value, &value); |
1144 | | } |
1145 | | } |
1146 | | |
1147 | | #[test] |
1148 | | fn test_dense_mixed() { |
1149 | | let mut builder = UnionBuilder::new_dense(); |
1150 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1151 | | builder.append::<Int64Type>("c", 3).unwrap(); |
1152 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1153 | | builder.append::<Int64Type>("c", 5).unwrap(); |
1154 | | builder.append::<Int32Type>("a", 6).unwrap(); |
1155 | | let union = builder.build().unwrap(); |
1156 | | |
1157 | | assert_eq!(5, union.len()); |
1158 | | for i in 0..union.len() { |
1159 | | let slot = union.value(i); |
1160 | | assert!(!union.is_null(i)); |
1161 | | match i { |
1162 | | 0 => { |
1163 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1164 | | assert_eq!(slot.len(), 1); |
1165 | | let value = slot.value(0); |
1166 | | assert_eq!(1_i32, value); |
1167 | | } |
1168 | | 1 => { |
1169 | | let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap(); |
1170 | | assert_eq!(slot.len(), 1); |
1171 | | let value = slot.value(0); |
1172 | | assert_eq!(3_i64, value); |
1173 | | } |
1174 | | 2 => { |
1175 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1176 | | assert_eq!(slot.len(), 1); |
1177 | | let value = slot.value(0); |
1178 | | assert_eq!(4_i32, value); |
1179 | | } |
1180 | | 3 => { |
1181 | | let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap(); |
1182 | | assert_eq!(slot.len(), 1); |
1183 | | let value = slot.value(0); |
1184 | | assert_eq!(5_i64, value); |
1185 | | } |
1186 | | 4 => { |
1187 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1188 | | assert_eq!(slot.len(), 1); |
1189 | | let value = slot.value(0); |
1190 | | assert_eq!(6_i32, value); |
1191 | | } |
1192 | | _ => unreachable!(), |
1193 | | } |
1194 | | } |
1195 | | } |
1196 | | |
1197 | | #[test] |
1198 | | fn test_dense_mixed_with_nulls() { |
1199 | | let mut builder = UnionBuilder::new_dense(); |
1200 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1201 | | builder.append::<Int64Type>("c", 3).unwrap(); |
1202 | | builder.append::<Int32Type>("a", 10).unwrap(); |
1203 | | builder.append_null::<Int32Type>("a").unwrap(); |
1204 | | builder.append::<Int32Type>("a", 6).unwrap(); |
1205 | | let union = builder.build().unwrap(); |
1206 | | |
1207 | | assert_eq!(5, union.len()); |
1208 | | for i in 0..union.len() { |
1209 | | let slot = union.value(i); |
1210 | | match i { |
1211 | | 0 => { |
1212 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1213 | | assert!(!slot.is_null(0)); |
1214 | | assert_eq!(slot.len(), 1); |
1215 | | let value = slot.value(0); |
1216 | | assert_eq!(1_i32, value); |
1217 | | } |
1218 | | 1 => { |
1219 | | let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap(); |
1220 | | assert!(!slot.is_null(0)); |
1221 | | assert_eq!(slot.len(), 1); |
1222 | | let value = slot.value(0); |
1223 | | assert_eq!(3_i64, value); |
1224 | | } |
1225 | | 2 => { |
1226 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1227 | | assert!(!slot.is_null(0)); |
1228 | | assert_eq!(slot.len(), 1); |
1229 | | let value = slot.value(0); |
1230 | | assert_eq!(10_i32, value); |
1231 | | } |
1232 | | 3 => assert!(slot.is_null(0)), |
1233 | | 4 => { |
1234 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1235 | | assert!(!slot.is_null(0)); |
1236 | | assert_eq!(slot.len(), 1); |
1237 | | let value = slot.value(0); |
1238 | | assert_eq!(6_i32, value); |
1239 | | } |
1240 | | _ => unreachable!(), |
1241 | | } |
1242 | | } |
1243 | | } |
1244 | | |
1245 | | #[test] |
1246 | | fn test_dense_mixed_with_nulls_and_offset() { |
1247 | | let mut builder = UnionBuilder::new_dense(); |
1248 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1249 | | builder.append::<Int64Type>("c", 3).unwrap(); |
1250 | | builder.append::<Int32Type>("a", 10).unwrap(); |
1251 | | builder.append_null::<Int32Type>("a").unwrap(); |
1252 | | builder.append::<Int32Type>("a", 6).unwrap(); |
1253 | | let union = builder.build().unwrap(); |
1254 | | |
1255 | | let slice = union.slice(2, 3); |
1256 | | let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap(); |
1257 | | |
1258 | | assert_eq!(3, new_union.len()); |
1259 | | for i in 0..new_union.len() { |
1260 | | let slot = new_union.value(i); |
1261 | | match i { |
1262 | | 0 => { |
1263 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1264 | | assert!(!slot.is_null(0)); |
1265 | | assert_eq!(slot.len(), 1); |
1266 | | let value = slot.value(0); |
1267 | | assert_eq!(10_i32, value); |
1268 | | } |
1269 | | 1 => assert!(slot.is_null(0)), |
1270 | | 2 => { |
1271 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1272 | | assert!(!slot.is_null(0)); |
1273 | | assert_eq!(slot.len(), 1); |
1274 | | let value = slot.value(0); |
1275 | | assert_eq!(6_i32, value); |
1276 | | } |
1277 | | _ => unreachable!(), |
1278 | | } |
1279 | | } |
1280 | | } |
1281 | | |
1282 | | #[test] |
1283 | | fn test_dense_mixed_with_str() { |
1284 | | let string_array = StringArray::from(vec!["foo", "bar", "baz"]); |
1285 | | let int_array = Int32Array::from(vec![5, 6]); |
1286 | | let float_array = Float64Array::from(vec![10.0]); |
1287 | | |
1288 | | let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>(); |
1289 | | let offsets = [0, 0, 1, 0, 2, 1] |
1290 | | .into_iter() |
1291 | | .collect::<ScalarBuffer<i32>>(); |
1292 | | |
1293 | | let fields = [ |
1294 | | (0, Arc::new(Field::new("A", DataType::Utf8, false))), |
1295 | | (1, Arc::new(Field::new("B", DataType::Int32, false))), |
1296 | | (2, Arc::new(Field::new("C", DataType::Float64, false))), |
1297 | | ] |
1298 | | .into_iter() |
1299 | | .collect::<UnionFields>(); |
1300 | | let children = [ |
1301 | | Arc::new(string_array) as Arc<dyn Array>, |
1302 | | Arc::new(int_array), |
1303 | | Arc::new(float_array), |
1304 | | ] |
1305 | | .into_iter() |
1306 | | .collect(); |
1307 | | let array = |
1308 | | UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap(); |
1309 | | |
1310 | | // Check type ids |
1311 | | assert_eq!(*array.type_ids(), type_ids); |
1312 | | for (i, id) in type_ids.iter().enumerate() { |
1313 | | assert_eq!(id, &array.type_id(i)); |
1314 | | } |
1315 | | |
1316 | | // Check offsets |
1317 | | assert_eq!(*array.offsets().unwrap(), offsets); |
1318 | | for (i, id) in offsets.iter().enumerate() { |
1319 | | assert_eq!(*id as usize, array.value_offset(i)); |
1320 | | } |
1321 | | |
1322 | | // Check values |
1323 | | assert_eq!(6, array.len()); |
1324 | | |
1325 | | let slot = array.value(0); |
1326 | | let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0); |
1327 | | assert_eq!(5, value); |
1328 | | |
1329 | | let slot = array.value(1); |
1330 | | let value = slot |
1331 | | .as_any() |
1332 | | .downcast_ref::<StringArray>() |
1333 | | .unwrap() |
1334 | | .value(0); |
1335 | | assert_eq!("foo", value); |
1336 | | |
1337 | | let slot = array.value(2); |
1338 | | let value = slot |
1339 | | .as_any() |
1340 | | .downcast_ref::<StringArray>() |
1341 | | .unwrap() |
1342 | | .value(0); |
1343 | | assert_eq!("bar", value); |
1344 | | |
1345 | | let slot = array.value(3); |
1346 | | let value = slot |
1347 | | .as_any() |
1348 | | .downcast_ref::<Float64Array>() |
1349 | | .unwrap() |
1350 | | .value(0); |
1351 | | assert_eq!(10.0, value); |
1352 | | |
1353 | | let slot = array.value(4); |
1354 | | let value = slot |
1355 | | .as_any() |
1356 | | .downcast_ref::<StringArray>() |
1357 | | .unwrap() |
1358 | | .value(0); |
1359 | | assert_eq!("baz", value); |
1360 | | |
1361 | | let slot = array.value(5); |
1362 | | let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0); |
1363 | | assert_eq!(6, value); |
1364 | | } |
1365 | | |
1366 | | #[test] |
1367 | | fn test_sparse_i32() { |
1368 | | let mut builder = UnionBuilder::new_sparse(); |
1369 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1370 | | builder.append::<Int32Type>("b", 2).unwrap(); |
1371 | | builder.append::<Int32Type>("c", 3).unwrap(); |
1372 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1373 | | builder.append::<Int32Type>("c", 5).unwrap(); |
1374 | | builder.append::<Int32Type>("a", 6).unwrap(); |
1375 | | builder.append::<Int32Type>("b", 7).unwrap(); |
1376 | | let union = builder.build().unwrap(); |
1377 | | |
1378 | | let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1]; |
1379 | | let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7]; |
1380 | | |
1381 | | // Check type ids |
1382 | | assert_eq!(*union.type_ids(), expected_type_ids); |
1383 | | for (i, id) in expected_type_ids.iter().enumerate() { |
1384 | | assert_eq!(id, &union.type_id(i)); |
1385 | | } |
1386 | | |
1387 | | // Check offsets, sparse union should only have a single buffer |
1388 | | assert!(union.offsets().is_none()); |
1389 | | |
1390 | | // Check data |
1391 | | assert_eq!( |
1392 | | *union.child(0).as_primitive::<Int32Type>().values(), |
1393 | | [1_i32, 0, 0, 4, 0, 6, 0], |
1394 | | ); |
1395 | | assert_eq!( |
1396 | | *union.child(1).as_primitive::<Int32Type>().values(), |
1397 | | [0_i32, 2_i32, 0, 0, 0, 0, 7] |
1398 | | ); |
1399 | | assert_eq!( |
1400 | | *union.child(2).as_primitive::<Int32Type>().values(), |
1401 | | [0_i32, 0, 3_i32, 0, 5, 0, 0] |
1402 | | ); |
1403 | | |
1404 | | assert_eq!(expected_array_values.len(), union.len()); |
1405 | | for (i, expected_value) in expected_array_values.iter().enumerate() { |
1406 | | assert!(!union.is_null(i)); |
1407 | | let slot = union.value(i); |
1408 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1409 | | assert_eq!(slot.len(), 1); |
1410 | | let value = slot.value(0); |
1411 | | assert_eq!(expected_value, &value); |
1412 | | } |
1413 | | } |
1414 | | |
1415 | | #[test] |
1416 | | fn test_sparse_mixed() { |
1417 | | let mut builder = UnionBuilder::new_sparse(); |
1418 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1419 | | builder.append::<Float64Type>("c", 3.0).unwrap(); |
1420 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1421 | | builder.append::<Float64Type>("c", 5.0).unwrap(); |
1422 | | builder.append::<Int32Type>("a", 6).unwrap(); |
1423 | | let union = builder.build().unwrap(); |
1424 | | |
1425 | | let expected_type_ids = vec![0_i8, 1, 0, 1, 0]; |
1426 | | |
1427 | | // Check type ids |
1428 | | assert_eq!(*union.type_ids(), expected_type_ids); |
1429 | | for (i, id) in expected_type_ids.iter().enumerate() { |
1430 | | assert_eq!(id, &union.type_id(i)); |
1431 | | } |
1432 | | |
1433 | | // Check offsets, sparse union should only have a single buffer, i.e. no offsets |
1434 | | assert!(union.offsets().is_none()); |
1435 | | |
1436 | | for i in 0..union.len() { |
1437 | | let slot = union.value(i); |
1438 | | assert!(!union.is_null(i)); |
1439 | | match i { |
1440 | | 0 => { |
1441 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1442 | | assert_eq!(slot.len(), 1); |
1443 | | let value = slot.value(0); |
1444 | | assert_eq!(1_i32, value); |
1445 | | } |
1446 | | 1 => { |
1447 | | let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap(); |
1448 | | assert_eq!(slot.len(), 1); |
1449 | | let value = slot.value(0); |
1450 | | assert_eq!(value, 3_f64); |
1451 | | } |
1452 | | 2 => { |
1453 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1454 | | assert_eq!(slot.len(), 1); |
1455 | | let value = slot.value(0); |
1456 | | assert_eq!(4_i32, value); |
1457 | | } |
1458 | | 3 => { |
1459 | | let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap(); |
1460 | | assert_eq!(slot.len(), 1); |
1461 | | let value = slot.value(0); |
1462 | | assert_eq!(5_f64, value); |
1463 | | } |
1464 | | 4 => { |
1465 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1466 | | assert_eq!(slot.len(), 1); |
1467 | | let value = slot.value(0); |
1468 | | assert_eq!(6_i32, value); |
1469 | | } |
1470 | | _ => unreachable!(), |
1471 | | } |
1472 | | } |
1473 | | } |
1474 | | |
1475 | | #[test] |
1476 | | fn test_sparse_mixed_with_nulls() { |
1477 | | let mut builder = UnionBuilder::new_sparse(); |
1478 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1479 | | builder.append_null::<Int32Type>("a").unwrap(); |
1480 | | builder.append::<Float64Type>("c", 3.0).unwrap(); |
1481 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1482 | | let union = builder.build().unwrap(); |
1483 | | |
1484 | | let expected_type_ids = vec![0_i8, 0, 1, 0]; |
1485 | | |
1486 | | // Check type ids |
1487 | | assert_eq!(*union.type_ids(), expected_type_ids); |
1488 | | for (i, id) in expected_type_ids.iter().enumerate() { |
1489 | | assert_eq!(id, &union.type_id(i)); |
1490 | | } |
1491 | | |
1492 | | // Check offsets, sparse union should only have a single buffer, i.e. no offsets |
1493 | | assert!(union.offsets().is_none()); |
1494 | | |
1495 | | for i in 0..union.len() { |
1496 | | let slot = union.value(i); |
1497 | | match i { |
1498 | | 0 => { |
1499 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1500 | | assert!(!slot.is_null(0)); |
1501 | | assert_eq!(slot.len(), 1); |
1502 | | let value = slot.value(0); |
1503 | | assert_eq!(1_i32, value); |
1504 | | } |
1505 | | 1 => assert!(slot.is_null(0)), |
1506 | | 2 => { |
1507 | | let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap(); |
1508 | | assert!(!slot.is_null(0)); |
1509 | | assert_eq!(slot.len(), 1); |
1510 | | let value = slot.value(0); |
1511 | | assert_eq!(value, 3_f64); |
1512 | | } |
1513 | | 3 => { |
1514 | | let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap(); |
1515 | | assert!(!slot.is_null(0)); |
1516 | | assert_eq!(slot.len(), 1); |
1517 | | let value = slot.value(0); |
1518 | | assert_eq!(4_i32, value); |
1519 | | } |
1520 | | _ => unreachable!(), |
1521 | | } |
1522 | | } |
1523 | | } |
1524 | | |
1525 | | #[test] |
1526 | | fn test_sparse_mixed_with_nulls_and_offset() { |
1527 | | let mut builder = UnionBuilder::new_sparse(); |
1528 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1529 | | builder.append_null::<Int32Type>("a").unwrap(); |
1530 | | builder.append::<Float64Type>("c", 3.0).unwrap(); |
1531 | | builder.append_null::<Float64Type>("c").unwrap(); |
1532 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1533 | | let union = builder.build().unwrap(); |
1534 | | |
1535 | | let slice = union.slice(1, 4); |
1536 | | let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap(); |
1537 | | |
1538 | | assert_eq!(4, new_union.len()); |
1539 | | for i in 0..new_union.len() { |
1540 | | let slot = new_union.value(i); |
1541 | | match i { |
1542 | | 0 => assert!(slot.is_null(0)), |
1543 | | 1 => { |
1544 | | let slot = slot.as_primitive::<Float64Type>(); |
1545 | | assert!(!slot.is_null(0)); |
1546 | | assert_eq!(slot.len(), 1); |
1547 | | let value = slot.value(0); |
1548 | | assert_eq!(value, 3_f64); |
1549 | | } |
1550 | | 2 => assert!(slot.is_null(0)), |
1551 | | 3 => { |
1552 | | let slot = slot.as_primitive::<Int32Type>(); |
1553 | | assert!(!slot.is_null(0)); |
1554 | | assert_eq!(slot.len(), 1); |
1555 | | let value = slot.value(0); |
1556 | | assert_eq!(4_i32, value); |
1557 | | } |
1558 | | _ => unreachable!(), |
1559 | | } |
1560 | | } |
1561 | | } |
1562 | | |
1563 | | fn test_union_validity(union_array: &UnionArray) { |
1564 | | assert_eq!(union_array.null_count(), 0); |
1565 | | |
1566 | | for i in 0..union_array.len() { |
1567 | | assert!(!union_array.is_null(i)); |
1568 | | assert!(union_array.is_valid(i)); |
1569 | | } |
1570 | | } |
1571 | | |
1572 | | #[test] |
1573 | | fn test_union_array_validity() { |
1574 | | let mut builder = UnionBuilder::new_sparse(); |
1575 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1576 | | builder.append_null::<Int32Type>("a").unwrap(); |
1577 | | builder.append::<Float64Type>("c", 3.0).unwrap(); |
1578 | | builder.append_null::<Float64Type>("c").unwrap(); |
1579 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1580 | | let union = builder.build().unwrap(); |
1581 | | |
1582 | | test_union_validity(&union); |
1583 | | |
1584 | | let mut builder = UnionBuilder::new_dense(); |
1585 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1586 | | builder.append_null::<Int32Type>("a").unwrap(); |
1587 | | builder.append::<Float64Type>("c", 3.0).unwrap(); |
1588 | | builder.append_null::<Float64Type>("c").unwrap(); |
1589 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1590 | | let union = builder.build().unwrap(); |
1591 | | |
1592 | | test_union_validity(&union); |
1593 | | } |
1594 | | |
1595 | | #[test] |
1596 | | fn test_type_check() { |
1597 | | let mut builder = UnionBuilder::new_sparse(); |
1598 | | builder.append::<Float32Type>("a", 1.0).unwrap(); |
1599 | | let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string(); |
1600 | | assert!( |
1601 | | err.contains( |
1602 | | "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32" |
1603 | | ), |
1604 | | "{}", |
1605 | | err |
1606 | | ); |
1607 | | } |
1608 | | |
1609 | | #[test] |
1610 | | fn slice_union_array() { |
1611 | | // [1, null, 3.0, null, 4] |
1612 | | fn create_union(mut builder: UnionBuilder) -> UnionArray { |
1613 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1614 | | builder.append_null::<Int32Type>("a").unwrap(); |
1615 | | builder.append::<Float64Type>("c", 3.0).unwrap(); |
1616 | | builder.append_null::<Float64Type>("c").unwrap(); |
1617 | | builder.append::<Int32Type>("a", 4).unwrap(); |
1618 | | builder.build().unwrap() |
1619 | | } |
1620 | | |
1621 | | fn create_batch(union: UnionArray) -> RecordBatch { |
1622 | | let schema = Schema::new(vec![Field::new( |
1623 | | "struct_array", |
1624 | | union.data_type().clone(), |
1625 | | true, |
1626 | | )]); |
1627 | | |
1628 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap() |
1629 | | } |
1630 | | |
1631 | | fn test_slice_union(record_batch_slice: RecordBatch) { |
1632 | | let union_slice = record_batch_slice |
1633 | | .column(0) |
1634 | | .as_any() |
1635 | | .downcast_ref::<UnionArray>() |
1636 | | .unwrap(); |
1637 | | |
1638 | | assert_eq!(union_slice.type_id(0), 0); |
1639 | | assert_eq!(union_slice.type_id(1), 1); |
1640 | | assert_eq!(union_slice.type_id(2), 1); |
1641 | | |
1642 | | let slot = union_slice.value(0); |
1643 | | let array = slot.as_primitive::<Int32Type>(); |
1644 | | assert_eq!(array.len(), 1); |
1645 | | assert!(array.is_null(0)); |
1646 | | |
1647 | | let slot = union_slice.value(1); |
1648 | | let array = slot.as_primitive::<Float64Type>(); |
1649 | | assert_eq!(array.len(), 1); |
1650 | | assert!(array.is_valid(0)); |
1651 | | assert_eq!(array.value(0), 3.0); |
1652 | | |
1653 | | let slot = union_slice.value(2); |
1654 | | let array = slot.as_primitive::<Float64Type>(); |
1655 | | assert_eq!(array.len(), 1); |
1656 | | assert!(array.is_null(0)); |
1657 | | } |
1658 | | |
1659 | | // Sparse Union |
1660 | | let builder = UnionBuilder::new_sparse(); |
1661 | | let record_batch = create_batch(create_union(builder)); |
1662 | | // [null, 3.0, null] |
1663 | | let record_batch_slice = record_batch.slice(1, 3); |
1664 | | test_slice_union(record_batch_slice); |
1665 | | |
1666 | | // Dense Union |
1667 | | let builder = UnionBuilder::new_dense(); |
1668 | | let record_batch = create_batch(create_union(builder)); |
1669 | | // [null, 3.0, null] |
1670 | | let record_batch_slice = record_batch.slice(1, 3); |
1671 | | test_slice_union(record_batch_slice); |
1672 | | } |
1673 | | |
1674 | | #[test] |
1675 | | fn test_custom_type_ids() { |
1676 | | let data_type = DataType::Union( |
1677 | | UnionFields::new( |
1678 | | vec![8, 4, 9], |
1679 | | vec![ |
1680 | | Field::new("strings", DataType::Utf8, false), |
1681 | | Field::new("integers", DataType::Int32, false), |
1682 | | Field::new("floats", DataType::Float64, false), |
1683 | | ], |
1684 | | ), |
1685 | | UnionMode::Dense, |
1686 | | ); |
1687 | | |
1688 | | let string_array = StringArray::from(vec!["foo", "bar", "baz"]); |
1689 | | let int_array = Int32Array::from(vec![5, 6, 4]); |
1690 | | let float_array = Float64Array::from(vec![10.0]); |
1691 | | |
1692 | | let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]); |
1693 | | let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]); |
1694 | | |
1695 | | let data = ArrayData::builder(data_type) |
1696 | | .len(7) |
1697 | | .buffers(vec![type_ids, value_offsets]) |
1698 | | .child_data(vec![ |
1699 | | string_array.into_data(), |
1700 | | int_array.into_data(), |
1701 | | float_array.into_data(), |
1702 | | ]) |
1703 | | .build() |
1704 | | .unwrap(); |
1705 | | |
1706 | | let array = UnionArray::from(data); |
1707 | | |
1708 | | let v = array.value(0); |
1709 | | assert_eq!(v.data_type(), &DataType::Int32); |
1710 | | assert_eq!(v.len(), 1); |
1711 | | assert_eq!(v.as_primitive::<Int32Type>().value(0), 5); |
1712 | | |
1713 | | let v = array.value(1); |
1714 | | assert_eq!(v.data_type(), &DataType::Utf8); |
1715 | | assert_eq!(v.len(), 1); |
1716 | | assert_eq!(v.as_string::<i32>().value(0), "foo"); |
1717 | | |
1718 | | let v = array.value(2); |
1719 | | assert_eq!(v.data_type(), &DataType::Int32); |
1720 | | assert_eq!(v.len(), 1); |
1721 | | assert_eq!(v.as_primitive::<Int32Type>().value(0), 6); |
1722 | | |
1723 | | let v = array.value(3); |
1724 | | assert_eq!(v.data_type(), &DataType::Utf8); |
1725 | | assert_eq!(v.len(), 1); |
1726 | | assert_eq!(v.as_string::<i32>().value(0), "bar"); |
1727 | | |
1728 | | let v = array.value(4); |
1729 | | assert_eq!(v.data_type(), &DataType::Float64); |
1730 | | assert_eq!(v.len(), 1); |
1731 | | assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0); |
1732 | | |
1733 | | let v = array.value(5); |
1734 | | assert_eq!(v.data_type(), &DataType::Int32); |
1735 | | assert_eq!(v.len(), 1); |
1736 | | assert_eq!(v.as_primitive::<Int32Type>().value(0), 4); |
1737 | | |
1738 | | let v = array.value(6); |
1739 | | assert_eq!(v.data_type(), &DataType::Utf8); |
1740 | | assert_eq!(v.len(), 1); |
1741 | | assert_eq!(v.as_string::<i32>().value(0), "baz"); |
1742 | | } |
1743 | | |
1744 | | #[test] |
1745 | | fn into_parts() { |
1746 | | let mut builder = UnionBuilder::new_dense(); |
1747 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1748 | | builder.append::<Int8Type>("b", 2).unwrap(); |
1749 | | builder.append::<Int32Type>("a", 3).unwrap(); |
1750 | | let dense_union = builder.build().unwrap(); |
1751 | | |
1752 | | let field = [ |
1753 | | &Arc::new(Field::new("a", DataType::Int32, false)), |
1754 | | &Arc::new(Field::new("b", DataType::Int8, false)), |
1755 | | ]; |
1756 | | let (union_fields, type_ids, offsets, children) = dense_union.into_parts(); |
1757 | | assert_eq!( |
1758 | | union_fields |
1759 | | .iter() |
1760 | | .map(|(_, field)| field) |
1761 | | .collect::<Vec<_>>(), |
1762 | | field |
1763 | | ); |
1764 | | assert_eq!(type_ids, [0, 1, 0]); |
1765 | | assert!(offsets.is_some()); |
1766 | | assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]); |
1767 | | |
1768 | | let result = UnionArray::try_new(union_fields, type_ids, offsets, children); |
1769 | | assert!(result.is_ok()); |
1770 | | assert_eq!(result.unwrap().len(), 3); |
1771 | | |
1772 | | let mut builder = UnionBuilder::new_sparse(); |
1773 | | builder.append::<Int32Type>("a", 1).unwrap(); |
1774 | | builder.append::<Int8Type>("b", 2).unwrap(); |
1775 | | builder.append::<Int32Type>("a", 3).unwrap(); |
1776 | | let sparse_union = builder.build().unwrap(); |
1777 | | |
1778 | | let (union_fields, type_ids, offsets, children) = sparse_union.into_parts(); |
1779 | | assert_eq!(type_ids, [0, 1, 0]); |
1780 | | assert!(offsets.is_none()); |
1781 | | |
1782 | | let result = UnionArray::try_new(union_fields, type_ids, offsets, children); |
1783 | | assert!(result.is_ok()); |
1784 | | assert_eq!(result.unwrap().len(), 3); |
1785 | | } |
1786 | | |
1787 | | #[test] |
1788 | | fn into_parts_custom_type_ids() { |
1789 | | let set_field_type_ids: [i8; 3] = [8, 4, 9]; |
1790 | | let data_type = DataType::Union( |
1791 | | UnionFields::new( |
1792 | | set_field_type_ids, |
1793 | | [ |
1794 | | Field::new("strings", DataType::Utf8, false), |
1795 | | Field::new("integers", DataType::Int32, false), |
1796 | | Field::new("floats", DataType::Float64, false), |
1797 | | ], |
1798 | | ), |
1799 | | UnionMode::Dense, |
1800 | | ); |
1801 | | let string_array = StringArray::from(vec!["foo", "bar", "baz"]); |
1802 | | let int_array = Int32Array::from(vec![5, 6, 4]); |
1803 | | let float_array = Float64Array::from(vec![10.0]); |
1804 | | let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]); |
1805 | | let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]); |
1806 | | let data = ArrayData::builder(data_type) |
1807 | | .len(7) |
1808 | | .buffers(vec![type_ids, value_offsets]) |
1809 | | .child_data(vec![ |
1810 | | string_array.into_data(), |
1811 | | int_array.into_data(), |
1812 | | float_array.into_data(), |
1813 | | ]) |
1814 | | .build() |
1815 | | .unwrap(); |
1816 | | let array = UnionArray::from(data); |
1817 | | |
1818 | | let (union_fields, type_ids, offsets, children) = array.into_parts(); |
1819 | | assert_eq!( |
1820 | | type_ids.iter().collect::<HashSet<_>>(), |
1821 | | set_field_type_ids.iter().collect::<HashSet<_>>() |
1822 | | ); |
1823 | | let result = UnionArray::try_new(union_fields, type_ids, offsets, children); |
1824 | | assert!(result.is_ok()); |
1825 | | let array = result.unwrap(); |
1826 | | assert_eq!(array.len(), 7); |
1827 | | } |
1828 | | |
1829 | | #[test] |
1830 | | fn test_invalid() { |
1831 | | let fields = UnionFields::new( |
1832 | | [3, 2], |
1833 | | [ |
1834 | | Field::new("a", DataType::Utf8, false), |
1835 | | Field::new("b", DataType::Utf8, false), |
1836 | | ], |
1837 | | ); |
1838 | | let children = vec![ |
1839 | | Arc::new(StringArray::from_iter_values(["a", "b"])) as _, |
1840 | | Arc::new(StringArray::from_iter_values(["c", "d"])) as _, |
1841 | | ]; |
1842 | | |
1843 | | let type_ids = vec![3, 3, 2].into(); |
1844 | | let err = |
1845 | | UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err(); |
1846 | | assert_eq!( |
1847 | | err.to_string(), |
1848 | | "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union" |
1849 | | ); |
1850 | | |
1851 | | let type_ids = vec![1, 2].into(); |
1852 | | let err = |
1853 | | UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err(); |
1854 | | assert_eq!( |
1855 | | err.to_string(), |
1856 | | "Invalid argument error: Type Ids values must match one of the field type ids" |
1857 | | ); |
1858 | | |
1859 | | let type_ids = vec![7, 2].into(); |
1860 | | let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err(); |
1861 | | assert_eq!( |
1862 | | err.to_string(), |
1863 | | "Invalid argument error: Type Ids values must match one of the field type ids" |
1864 | | ); |
1865 | | |
1866 | | let children = vec![ |
1867 | | Arc::new(StringArray::from_iter_values(["a", "b"])) as _, |
1868 | | Arc::new(StringArray::from_iter_values(["c"])) as _, |
1869 | | ]; |
1870 | | let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]); |
1871 | | let offsets = Some(vec![0, 1, 0].into()); |
1872 | | UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap(); |
1873 | | |
1874 | | let offsets = Some(vec![0, 1, 1].into()); |
1875 | | let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()) |
1876 | | .unwrap_err(); |
1877 | | |
1878 | | assert_eq!( |
1879 | | err.to_string(), |
1880 | | "Invalid argument error: Offsets must be positive and within the length of the Array" |
1881 | | ); |
1882 | | |
1883 | | let offsets = Some(vec![0, 1].into()); |
1884 | | let err = |
1885 | | UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err(); |
1886 | | |
1887 | | assert_eq!( |
1888 | | err.to_string(), |
1889 | | "Invalid argument error: Type Ids and Offsets lengths must match" |
1890 | | ); |
1891 | | |
1892 | | let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err(); |
1893 | | |
1894 | | assert_eq!( |
1895 | | err.to_string(), |
1896 | | "Invalid argument error: Union fields length must match child arrays length" |
1897 | | ); |
1898 | | } |
1899 | | |
1900 | | #[test] |
1901 | | fn test_logical_nulls_fast_paths() { |
1902 | | // fields.len() <= 1 |
1903 | | let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap(); |
1904 | | |
1905 | | assert_eq!(array.logical_nulls(), None); |
1906 | | |
1907 | | let fields = UnionFields::new( |
1908 | | [1, 3], |
1909 | | [ |
1910 | | Field::new("a", DataType::Int8, false), // non nullable |
1911 | | Field::new("b", DataType::Int8, false), // non nullable |
1912 | | ], |
1913 | | ); |
1914 | | let array = UnionArray::try_new( |
1915 | | fields, |
1916 | | vec![1].into(), |
1917 | | None, |
1918 | | vec![ |
1919 | | Arc::new(Int8Array::from_value(5, 1)), |
1920 | | Arc::new(Int8Array::from_value(5, 1)), |
1921 | | ], |
1922 | | ) |
1923 | | .unwrap(); |
1924 | | |
1925 | | assert_eq!(array.logical_nulls(), None); |
1926 | | |
1927 | | let nullable_fields = UnionFields::new( |
1928 | | [1, 3], |
1929 | | [ |
1930 | | Field::new("a", DataType::Int8, true), // nullable but without nulls |
1931 | | Field::new("b", DataType::Int8, true), // nullable but without nulls |
1932 | | ], |
1933 | | ); |
1934 | | let array = UnionArray::try_new( |
1935 | | nullable_fields.clone(), |
1936 | | vec![1, 1].into(), |
1937 | | None, |
1938 | | vec![ |
1939 | | Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls |
1940 | | Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls |
1941 | | ], |
1942 | | ) |
1943 | | .unwrap(); |
1944 | | |
1945 | | assert_eq!(array.logical_nulls(), None); |
1946 | | |
1947 | | let array = UnionArray::try_new( |
1948 | | nullable_fields.clone(), |
1949 | | vec![1, 1].into(), |
1950 | | None, |
1951 | | vec![ |
1952 | | // every children is completly null |
1953 | | Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent |
1954 | | Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent |
1955 | | ], |
1956 | | ) |
1957 | | .unwrap(); |
1958 | | |
1959 | | assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2))); |
1960 | | |
1961 | | let array = UnionArray::try_new( |
1962 | | nullable_fields.clone(), |
1963 | | vec![1, 1].into(), |
1964 | | Some(vec![0, 1].into()), |
1965 | | vec![ |
1966 | | // every children is completly null |
1967 | | Arc::new(Int8Array::new_null(3)), // bigger that parent |
1968 | | Arc::new(Int8Array::new_null(3)), // bigger that parent |
1969 | | ], |
1970 | | ) |
1971 | | .unwrap(); |
1972 | | |
1973 | | assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2))); |
1974 | | } |
1975 | | |
1976 | | #[test] |
1977 | | fn test_dense_union_logical_nulls_gather() { |
1978 | | // union of [{A=1}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}] |
1979 | | let int_array = Int32Array::from(vec![1, 2]); |
1980 | | let float_array = Float64Array::from(vec![Some(3.2), None]); |
1981 | | let str_array = StringArray::new_null(1); |
1982 | | let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>(); |
1983 | | let offsets = [0, 1, 0, 1, 0, 0] |
1984 | | .into_iter() |
1985 | | .collect::<ScalarBuffer<i32>>(); |
1986 | | |
1987 | | let children = vec![ |
1988 | | Arc::new(int_array) as Arc<dyn Array>, |
1989 | | Arc::new(float_array), |
1990 | | Arc::new(str_array), |
1991 | | ]; |
1992 | | |
1993 | | let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap(); |
1994 | | |
1995 | | let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]); |
1996 | | |
1997 | | assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); |
1998 | | assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls())); |
1999 | | } |
2000 | | |
2001 | | #[test] |
2002 | | fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() { |
2003 | | let fields: UnionFields = [ |
2004 | | (1, Arc::new(Field::new("A", DataType::Int32, true))), |
2005 | | (3, Arc::new(Field::new("B", DataType::Float64, true))), |
2006 | | ] |
2007 | | .into_iter() |
2008 | | .collect(); |
2009 | | |
2010 | | // union of [{A=}, {A=}, {B=3.2}, {B=}] |
2011 | | let int_array = Int32Array::new_null(4); |
2012 | | let float_array = Float64Array::from(vec![None, None, Some(3.2), None]); |
2013 | | let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>(); |
2014 | | |
2015 | | let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)]; |
2016 | | |
2017 | | let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap(); |
2018 | | |
2019 | | let expected = BooleanBuffer::from(vec![false, false, true, false]); |
2020 | | |
2021 | | assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); |
2022 | | assert_eq!( |
2023 | | expected, |
2024 | | array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls()) |
2025 | | ); |
2026 | | |
2027 | | //like above, but repeated to genereate two exact bitmasks and a non empty remainder |
2028 | | let len = 2 * 64 + 32; |
2029 | | |
2030 | | let int_array = Int32Array::new_null(len); |
2031 | | let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len)); |
2032 | | let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len)); |
2033 | | |
2034 | | let array = UnionArray::try_new( |
2035 | | fields, |
2036 | | type_ids, |
2037 | | None, |
2038 | | vec![Arc::new(int_array), Arc::new(float_array)], |
2039 | | ) |
2040 | | .unwrap(); |
2041 | | |
2042 | | let expected = |
2043 | | BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len)); |
2044 | | |
2045 | | assert_eq!(array.len(), len); |
2046 | | assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); |
2047 | | assert_eq!( |
2048 | | expected, |
2049 | | array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls()) |
2050 | | ); |
2051 | | } |
2052 | | |
2053 | | #[test] |
2054 | | fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() { |
2055 | | // union of [{A=2}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}] |
2056 | | let int_array = Int32Array::from_value(2, 6); |
2057 | | let float_array = Float64Array::from_value(4.2, 6); |
2058 | | let str_array = StringArray::new_null(6); |
2059 | | let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>(); |
2060 | | |
2061 | | let children = vec![ |
2062 | | Arc::new(int_array) as Arc<dyn Array>, |
2063 | | Arc::new(float_array), |
2064 | | Arc::new(str_array), |
2065 | | ]; |
2066 | | |
2067 | | let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); |
2068 | | |
2069 | | let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]); |
2070 | | |
2071 | | assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); |
2072 | | assert_eq!( |
2073 | | expected, |
2074 | | array.mask_sparse_skip_without_nulls(array.fields_logical_nulls()) |
2075 | | ); |
2076 | | |
2077 | | //like above, but repeated to genereate two exact bitmasks and a non empty remainder |
2078 | | let len = 2 * 64 + 32; |
2079 | | |
2080 | | let int_array = Int32Array::from_value(2, len); |
2081 | | let float_array = Float64Array::from_value(4.2, len); |
2082 | | let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len)); |
2083 | | let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len)); |
2084 | | |
2085 | | let children = vec![ |
2086 | | Arc::new(int_array) as Arc<dyn Array>, |
2087 | | Arc::new(float_array), |
2088 | | Arc::new(str_array), |
2089 | | ]; |
2090 | | |
2091 | | let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); |
2092 | | |
2093 | | let expected = BooleanBuffer::from_iter( |
2094 | | [true, true, true, true, false, true] |
2095 | | .into_iter() |
2096 | | .cycle() |
2097 | | .take(len), |
2098 | | ); |
2099 | | |
2100 | | assert_eq!(array.len(), len); |
2101 | | assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); |
2102 | | assert_eq!( |
2103 | | expected, |
2104 | | array.mask_sparse_skip_without_nulls(array.fields_logical_nulls()) |
2105 | | ); |
2106 | | } |
2107 | | |
2108 | | #[test] |
2109 | | fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() { |
2110 | | // union of [{A=}, {A=}, {B=4.2}, {B=4.2}, {C=}, {C=}] |
2111 | | let int_array = Int32Array::new_null(6); |
2112 | | let float_array = Float64Array::from_value(4.2, 6); |
2113 | | let str_array = StringArray::new_null(6); |
2114 | | let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>(); |
2115 | | |
2116 | | let children = vec![ |
2117 | | Arc::new(int_array) as Arc<dyn Array>, |
2118 | | Arc::new(float_array), |
2119 | | Arc::new(str_array), |
2120 | | ]; |
2121 | | |
2122 | | let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); |
2123 | | |
2124 | | let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]); |
2125 | | |
2126 | | assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); |
2127 | | assert_eq!( |
2128 | | expected, |
2129 | | array.mask_sparse_skip_fully_null(array.fields_logical_nulls()) |
2130 | | ); |
2131 | | |
2132 | | //like above, but repeated to genereate two exact bitmasks and a non empty remainder |
2133 | | let len = 2 * 64 + 32; |
2134 | | |
2135 | | let int_array = Int32Array::new_null(len); |
2136 | | let float_array = Float64Array::from_value(4.2, len); |
2137 | | let str_array = StringArray::new_null(len); |
2138 | | let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len)); |
2139 | | |
2140 | | let children = vec![ |
2141 | | Arc::new(int_array) as Arc<dyn Array>, |
2142 | | Arc::new(float_array), |
2143 | | Arc::new(str_array), |
2144 | | ]; |
2145 | | |
2146 | | let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); |
2147 | | |
2148 | | let expected = BooleanBuffer::from_iter( |
2149 | | [false, false, true, true, false, false] |
2150 | | .into_iter() |
2151 | | .cycle() |
2152 | | .take(len), |
2153 | | ); |
2154 | | |
2155 | | assert_eq!(array.len(), len); |
2156 | | assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); |
2157 | | assert_eq!( |
2158 | | expected, |
2159 | | array.mask_sparse_skip_fully_null(array.fields_logical_nulls()) |
2160 | | ); |
2161 | | } |
2162 | | |
2163 | | #[test] |
2164 | | fn test_sparse_union_logical_nulls_gather() { |
2165 | | let n_fields = 50; |
2166 | | |
2167 | | let non_null = Int32Array::from_value(2, 4); |
2168 | | let mixed = Int32Array::from(vec![None, None, Some(1), None]); |
2169 | | let fully_null = Int32Array::new_null(4); |
2170 | | |
2171 | | let array = UnionArray::try_new( |
2172 | | (1..) |
2173 | | .step_by(2) |
2174 | | .map(|i| { |
2175 | | ( |
2176 | | i, |
2177 | | Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)), |
2178 | | ) |
2179 | | }) |
2180 | | .take(n_fields) |
2181 | | .collect(), |
2182 | | vec![1, 3, 3, 5].into(), |
2183 | | None, |
2184 | | [ |
2185 | | Arc::new(non_null) as ArrayRef, |
2186 | | Arc::new(mixed), |
2187 | | Arc::new(fully_null), |
2188 | | ] |
2189 | | .into_iter() |
2190 | | .cycle() |
2191 | | .take(n_fields) |
2192 | | .collect(), |
2193 | | ) |
2194 | | .unwrap(); |
2195 | | |
2196 | | let expected = BooleanBuffer::from(vec![true, false, true, false]); |
2197 | | |
2198 | | assert_eq!(expected, array.logical_nulls().unwrap().into_inner()); |
2199 | | assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls())); |
2200 | | } |
2201 | | |
2202 | | fn union_fields() -> UnionFields { |
2203 | | [ |
2204 | | (1, Arc::new(Field::new("A", DataType::Int32, true))), |
2205 | | (3, Arc::new(Field::new("B", DataType::Float64, true))), |
2206 | | (4, Arc::new(Field::new("C", DataType::Utf8, true))), |
2207 | | ] |
2208 | | .into_iter() |
2209 | | .collect() |
2210 | | } |
2211 | | |
2212 | | #[test] |
2213 | | fn test_is_nullable() { |
2214 | | assert!(!create_union_array(false, false).is_nullable()); |
2215 | | assert!(create_union_array(true, false).is_nullable()); |
2216 | | assert!(create_union_array(false, true).is_nullable()); |
2217 | | assert!(create_union_array(true, true).is_nullable()); |
2218 | | } |
2219 | | |
2220 | | /// Create a union array with a float and integer field |
2221 | | /// |
2222 | | /// If the `int_nullable` is true, the integer field will have nulls |
2223 | | /// If the `float_nullable` is true, the float field will have nulls |
2224 | | /// |
2225 | | /// Note the `Field` definitions are always declared to be nullable |
2226 | | fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray { |
2227 | | let int_array = if int_nullable { |
2228 | | Int32Array::from(vec![Some(1), None, Some(3)]) |
2229 | | } else { |
2230 | | Int32Array::from(vec![1, 2, 3]) |
2231 | | }; |
2232 | | let float_array = if float_nullable { |
2233 | | Float64Array::from(vec![Some(3.2), None, Some(4.2)]) |
2234 | | } else { |
2235 | | Float64Array::from(vec![3.2, 4.2, 5.2]) |
2236 | | }; |
2237 | | let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>(); |
2238 | | let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>(); |
2239 | | let union_fields = [ |
2240 | | (0, Arc::new(Field::new("A", DataType::Int32, true))), |
2241 | | (1, Arc::new(Field::new("B", DataType::Float64, true))), |
2242 | | ] |
2243 | | .into_iter() |
2244 | | .collect::<UnionFields>(); |
2245 | | |
2246 | | let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)]; |
2247 | | |
2248 | | UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap() |
2249 | | } |
2250 | | } |