/Users/andrewlamb/Software/arrow-rs/arrow-array/src/builder/primitive_dictionary_builder.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use crate::builder::{ArrayBuilder, PrimitiveBuilder}; |
19 | | use crate::types::ArrowDictionaryKeyType; |
20 | | use crate::{ |
21 | | Array, ArrayRef, ArrowPrimitiveType, DictionaryArray, PrimitiveArray, TypedDictionaryArray, |
22 | | }; |
23 | | use arrow_buffer::{ArrowNativeType, ToByteSlice}; |
24 | | use arrow_schema::{ArrowError, DataType}; |
25 | | use num::NumCast; |
26 | | use std::any::Any; |
27 | | use std::collections::HashMap; |
28 | | use std::sync::Arc; |
29 | | |
30 | | /// Wraps a type implementing `ToByteSlice` implementing `Hash` and `Eq` for it |
31 | | /// |
32 | | /// This is necessary to handle types such as f32, which don't natively implement these |
33 | | #[derive(Debug)] |
34 | | struct Value<T>(T); |
35 | | |
36 | | impl<T: ToByteSlice> std::hash::Hash for Value<T> { |
37 | 0 | fn hash<H: std::hash::Hasher>(&self, state: &mut H) { |
38 | 0 | self.0.to_byte_slice().hash(state) |
39 | 0 | } |
40 | | } |
41 | | |
42 | | impl<T: ToByteSlice> PartialEq for Value<T> { |
43 | 0 | fn eq(&self, other: &Self) -> bool { |
44 | 0 | self.0.to_byte_slice().eq(other.0.to_byte_slice()) |
45 | 0 | } |
46 | | } |
47 | | |
48 | | impl<T: ToByteSlice> Eq for Value<T> {} |
49 | | |
50 | | /// Builder for [`DictionaryArray`] of [`PrimitiveArray`] |
51 | | /// |
52 | | /// # Example: |
53 | | /// |
54 | | /// ``` |
55 | | /// |
56 | | /// # use arrow_array::builder::PrimitiveDictionaryBuilder; |
57 | | /// # use arrow_array::types::{UInt32Type, UInt8Type}; |
58 | | /// # use arrow_array::{Array, UInt32Array, UInt8Array}; |
59 | | /// |
60 | | /// let mut builder = PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::new(); |
61 | | /// builder.append(12345678).unwrap(); |
62 | | /// builder.append_null(); |
63 | | /// builder.append(22345678).unwrap(); |
64 | | /// let array = builder.finish(); |
65 | | /// |
66 | | /// assert_eq!( |
67 | | /// array.keys(), |
68 | | /// &UInt8Array::from(vec![Some(0), None, Some(1)]) |
69 | | /// ); |
70 | | /// |
71 | | /// // Values are polymorphic and so require a downcast. |
72 | | /// let av = array.values(); |
73 | | /// let ava: &UInt32Array = av.as_any().downcast_ref::<UInt32Array>().unwrap(); |
74 | | /// let avs: &[u32] = ava.values(); |
75 | | /// |
76 | | /// assert!(!array.is_null(0)); |
77 | | /// assert!(array.is_null(1)); |
78 | | /// assert!(!array.is_null(2)); |
79 | | /// |
80 | | /// assert_eq!(avs, &[12345678, 22345678]); |
81 | | /// ``` |
82 | | #[derive(Debug)] |
83 | | pub struct PrimitiveDictionaryBuilder<K, V> |
84 | | where |
85 | | K: ArrowPrimitiveType, |
86 | | V: ArrowPrimitiveType, |
87 | | { |
88 | | keys_builder: PrimitiveBuilder<K>, |
89 | | values_builder: PrimitiveBuilder<V>, |
90 | | map: HashMap<Value<V::Native>, usize>, |
91 | | } |
92 | | |
93 | | impl<K, V> Default for PrimitiveDictionaryBuilder<K, V> |
94 | | where |
95 | | K: ArrowPrimitiveType, |
96 | | V: ArrowPrimitiveType, |
97 | | { |
98 | | fn default() -> Self { |
99 | | Self::new() |
100 | | } |
101 | | } |
102 | | |
103 | | impl<K, V> PrimitiveDictionaryBuilder<K, V> |
104 | | where |
105 | | K: ArrowPrimitiveType, |
106 | | V: ArrowPrimitiveType, |
107 | | { |
108 | | /// Creates a new `PrimitiveDictionaryBuilder`. |
109 | | pub fn new() -> Self { |
110 | | Self { |
111 | | keys_builder: PrimitiveBuilder::new(), |
112 | | values_builder: PrimitiveBuilder::new(), |
113 | | map: HashMap::new(), |
114 | | } |
115 | | } |
116 | | |
117 | | /// Creates a new `PrimitiveDictionaryBuilder` from the provided keys and values builders. |
118 | | /// |
119 | | /// # Panics |
120 | | /// |
121 | | /// This method panics if `keys_builder` or `values_builder` is not empty. |
122 | | pub fn new_from_empty_builders( |
123 | | keys_builder: PrimitiveBuilder<K>, |
124 | | values_builder: PrimitiveBuilder<V>, |
125 | | ) -> Self { |
126 | | assert!( |
127 | | keys_builder.is_empty() && values_builder.is_empty(), |
128 | | "keys and values builders must be empty" |
129 | | ); |
130 | | let values_capacity = values_builder.capacity(); |
131 | | Self { |
132 | | keys_builder, |
133 | | values_builder, |
134 | | map: HashMap::with_capacity(values_capacity), |
135 | | } |
136 | | } |
137 | | |
138 | | /// Creates a new `PrimitiveDictionaryBuilder` from existing `PrimitiveBuilder`s of keys and values. |
139 | | /// |
140 | | /// # Safety |
141 | | /// |
142 | | /// caller must ensure that the passed in builders are valid for DictionaryArray. |
143 | | pub unsafe fn new_from_builders( |
144 | | keys_builder: PrimitiveBuilder<K>, |
145 | | values_builder: PrimitiveBuilder<V>, |
146 | | ) -> Self { |
147 | | let keys = keys_builder.values_slice(); |
148 | | let values = values_builder.values_slice(); |
149 | | let mut map = HashMap::with_capacity(values.len()); |
150 | | |
151 | | keys.iter().zip(values.iter()).for_each(|(key, value)| { |
152 | | map.insert(Value(*value), K::Native::to_usize(*key).unwrap()); |
153 | | }); |
154 | | |
155 | | Self { |
156 | | keys_builder, |
157 | | values_builder, |
158 | | map, |
159 | | } |
160 | | } |
161 | | |
162 | | /// Creates a new `PrimitiveDictionaryBuilder` with the provided capacities |
163 | | /// |
164 | | /// `keys_capacity`: the number of keys, i.e. length of array to build |
165 | | /// `values_capacity`: the number of distinct dictionary values, i.e. size of dictionary |
166 | 0 | pub fn with_capacity(keys_capacity: usize, values_capacity: usize) -> Self { |
167 | 0 | Self { |
168 | 0 | keys_builder: PrimitiveBuilder::with_capacity(keys_capacity), |
169 | 0 | values_builder: PrimitiveBuilder::with_capacity(values_capacity), |
170 | 0 | map: HashMap::with_capacity(values_capacity), |
171 | 0 | } |
172 | 0 | } |
173 | | |
174 | | /// Creates a new `PrimitiveDictionaryBuilder` from the existing builder with the same |
175 | | /// keys and values, but with a new data type for the keys. |
176 | | /// |
177 | | /// # Example |
178 | | /// ``` |
179 | | /// # |
180 | | /// # use arrow_array::builder::PrimitiveDictionaryBuilder; |
181 | | /// # use arrow_array::types::{UInt8Type, UInt16Type, UInt64Type}; |
182 | | /// # use arrow_array::UInt16Array; |
183 | | /// # use arrow_schema::ArrowError; |
184 | | /// |
185 | | /// let mut u8_keyed_builder = PrimitiveDictionaryBuilder::<UInt8Type, UInt64Type>::new(); |
186 | | /// |
187 | | /// // appending too many values causes the dictionary to overflow |
188 | | /// for i in 0..256 { |
189 | | /// u8_keyed_builder.append_value(i); |
190 | | /// } |
191 | | /// let result = u8_keyed_builder.append(256); |
192 | | /// assert!(matches!(result, Err(ArrowError::DictionaryKeyOverflowError{}))); |
193 | | /// |
194 | | /// // we need to upgrade to a larger key type |
195 | | /// let mut u16_keyed_builder = PrimitiveDictionaryBuilder::<UInt16Type, UInt64Type>::try_new_from_builder(u8_keyed_builder).unwrap(); |
196 | | /// let dictionary_array = u16_keyed_builder.finish(); |
197 | | /// let keys = dictionary_array.keys(); |
198 | | /// |
199 | | /// assert_eq!(keys, &UInt16Array::from_iter(0..256)); |
200 | | pub fn try_new_from_builder<K2>( |
201 | | mut source: PrimitiveDictionaryBuilder<K2, V>, |
202 | | ) -> Result<Self, ArrowError> |
203 | | where |
204 | | K::Native: NumCast, |
205 | | K2: ArrowDictionaryKeyType, |
206 | | K2::Native: NumCast, |
207 | | { |
208 | | let map = source.map; |
209 | | let values_builder = source.values_builder; |
210 | | |
211 | | let source_keys = source.keys_builder.finish(); |
212 | | let new_keys: PrimitiveArray<K> = source_keys.try_unary(|value| { |
213 | | num::cast::cast::<K2::Native, K::Native>(value).ok_or_else(|| { |
214 | | ArrowError::CastError(format!( |
215 | | "Can't cast dictionary keys from source type {:?} to type {:?}", |
216 | | K2::DATA_TYPE, |
217 | | K::DATA_TYPE |
218 | | )) |
219 | | }) |
220 | | })?; |
221 | | |
222 | | // drop source key here because currently source_keys and new_keys are holding reference to |
223 | | // the same underlying null_buffer. Below we want to call new_keys.into_builder() it must |
224 | | // be the only reference holder. |
225 | | drop(source_keys); |
226 | | |
227 | | Ok(Self { |
228 | | map, |
229 | | keys_builder: new_keys |
230 | | .into_builder() |
231 | | .expect("underlying buffer has no references"), |
232 | | values_builder, |
233 | | }) |
234 | | } |
235 | | } |
236 | | |
237 | | impl<K, V> ArrayBuilder for PrimitiveDictionaryBuilder<K, V> |
238 | | where |
239 | | K: ArrowDictionaryKeyType, |
240 | | V: ArrowPrimitiveType, |
241 | | { |
242 | | /// Returns the builder as an non-mutable `Any` reference. |
243 | | fn as_any(&self) -> &dyn Any { |
244 | | self |
245 | | } |
246 | | |
247 | | /// Returns the builder as an mutable `Any` reference. |
248 | | fn as_any_mut(&mut self) -> &mut dyn Any { |
249 | | self |
250 | | } |
251 | | |
252 | | /// Returns the boxed builder as a box of `Any`. |
253 | | fn into_box_any(self: Box<Self>) -> Box<dyn Any> { |
254 | | self |
255 | | } |
256 | | |
257 | | /// Returns the number of array slots in the builder |
258 | | fn len(&self) -> usize { |
259 | | self.keys_builder.len() |
260 | | } |
261 | | |
262 | | /// Builds the array and reset this builder. |
263 | | fn finish(&mut self) -> ArrayRef { |
264 | | Arc::new(self.finish()) |
265 | | } |
266 | | |
267 | | /// Builds the array without resetting the builder. |
268 | | fn finish_cloned(&self) -> ArrayRef { |
269 | | Arc::new(self.finish_cloned()) |
270 | | } |
271 | | } |
272 | | |
273 | | impl<K, V> PrimitiveDictionaryBuilder<K, V> |
274 | | where |
275 | | K: ArrowDictionaryKeyType, |
276 | | V: ArrowPrimitiveType, |
277 | | { |
278 | | #[inline] |
279 | 0 | fn get_or_insert_key(&mut self, value: V::Native) -> Result<K::Native, ArrowError> { |
280 | 0 | match self.map.get(&Value(value)) { |
281 | 0 | Some(&key) => { |
282 | 0 | Ok(K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)?) |
283 | | } |
284 | | None => { |
285 | 0 | let key = self.values_builder.len(); |
286 | 0 | self.values_builder.append_value(value); |
287 | 0 | self.map.insert(Value(value), key); |
288 | 0 | Ok(K::Native::from_usize(key).ok_or(ArrowError::DictionaryKeyOverflowError)?) |
289 | | } |
290 | | } |
291 | 0 | } |
292 | | |
293 | | /// Append a primitive value to the array. Return an existing index |
294 | | /// if already present in the values array or a new index if the |
295 | | /// value is appended to the values array. |
296 | | #[inline] |
297 | 0 | pub fn append(&mut self, value: V::Native) -> Result<K::Native, ArrowError> { |
298 | 0 | let key = self.get_or_insert_key(value)?; |
299 | 0 | self.keys_builder.append_value(key); |
300 | 0 | Ok(key) |
301 | 0 | } |
302 | | |
303 | | /// Append a value multiple times to the array. |
304 | | /// This is the same as `append` but allows to append the same value multiple times without doing multiple lookups. |
305 | | /// |
306 | | /// Returns an error if the new index would overflow the key type. |
307 | | pub fn append_n(&mut self, value: V::Native, count: usize) -> Result<K::Native, ArrowError> { |
308 | | let key = self.get_or_insert_key(value)?; |
309 | | self.keys_builder.append_value_n(key, count); |
310 | | Ok(key) |
311 | | } |
312 | | |
313 | | /// Infallibly append a value to this builder |
314 | | /// |
315 | | /// # Panics |
316 | | /// |
317 | | /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` |
318 | | #[inline] |
319 | | pub fn append_value(&mut self, value: V::Native) { |
320 | | self.append(value).expect("dictionary key overflow"); |
321 | | } |
322 | | |
323 | | /// Infallibly append a value to this builder repeatedly `count` times. |
324 | | /// This is the same as `append_value` but allows to append the same value multiple times without doing multiple lookups. |
325 | | /// |
326 | | /// # Panics |
327 | | /// |
328 | | /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` |
329 | | pub fn append_values(&mut self, value: V::Native, count: usize) { |
330 | | self.append_n(value, count) |
331 | | .expect("dictionary key overflow"); |
332 | | } |
333 | | |
334 | | /// Appends a null slot into the builder |
335 | | #[inline] |
336 | 0 | pub fn append_null(&mut self) { |
337 | 0 | self.keys_builder.append_null() |
338 | 0 | } |
339 | | |
340 | | /// Append `n` null slots into the builder |
341 | | #[inline] |
342 | | pub fn append_nulls(&mut self, n: usize) { |
343 | | self.keys_builder.append_nulls(n) |
344 | | } |
345 | | |
346 | | /// Append an `Option` value into the builder |
347 | | /// |
348 | | /// # Panics |
349 | | /// |
350 | | /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` |
351 | | #[inline] |
352 | | pub fn append_option(&mut self, value: Option<V::Native>) { |
353 | | match value { |
354 | | None => self.append_null(), |
355 | | Some(v) => self.append_value(v), |
356 | | }; |
357 | | } |
358 | | |
359 | | /// Append an `Option` value into the builder repeatedly `count` times. |
360 | | /// This is the same as `append_option` but allows to append the same value multiple times without doing multiple lookups. |
361 | | /// |
362 | | /// # Panics |
363 | | /// |
364 | | /// Panics if the resulting length of the dictionary values array would exceed `T::Native::MAX` |
365 | | pub fn append_options(&mut self, value: Option<V::Native>, count: usize) { |
366 | | match value { |
367 | | None => self.keys_builder.append_nulls(count), |
368 | | Some(v) => self.append_values(v, count), |
369 | | }; |
370 | | } |
371 | | |
372 | | /// Extends builder with dictionary |
373 | | /// |
374 | | /// This is the same as [`Self::extend`] but is faster as it translates |
375 | | /// the dictionary values once rather than doing a lookup for each item in the iterator |
376 | | /// |
377 | | /// when dictionary values are null (the actual mapped values) the keys are null |
378 | | /// |
379 | | pub fn extend_dictionary( |
380 | | &mut self, |
381 | | dictionary: &TypedDictionaryArray<K, PrimitiveArray<V>>, |
382 | | ) -> Result<(), ArrowError> { |
383 | | let values = dictionary.values(); |
384 | | |
385 | | let v_len = values.len(); |
386 | | let k_len = dictionary.keys().len(); |
387 | | if v_len == 0 && k_len == 0 { |
388 | | return Ok(()); |
389 | | } |
390 | | |
391 | | // All nulls |
392 | | if v_len == 0 { |
393 | | self.append_nulls(k_len); |
394 | | return Ok(()); |
395 | | } |
396 | | |
397 | | if k_len == 0 { |
398 | | return Err(ArrowError::InvalidArgumentError( |
399 | | "Dictionary keys should not be empty when values are not empty".to_string(), |
400 | | )); |
401 | | } |
402 | | |
403 | | // Orphan values will be carried over to the new dictionary |
404 | | let mapped_values = values |
405 | | .iter() |
406 | | // Dictionary values can technically be null, so we need to handle that |
407 | | .map(|dict_value| { |
408 | | dict_value |
409 | | .map(|dict_value| self.get_or_insert_key(dict_value)) |
410 | | .transpose() |
411 | | }) |
412 | | .collect::<Result<Vec<_>, _>>()?; |
413 | | |
414 | | // Just insert the keys without additional lookups |
415 | | dictionary.keys().iter().for_each(|key| match key { |
416 | | None => self.append_null(), |
417 | | Some(original_dict_index) => { |
418 | | let index = original_dict_index.as_usize().min(v_len - 1); |
419 | | match mapped_values[index] { |
420 | | None => self.append_null(), |
421 | | Some(mapped_value) => self.keys_builder.append_value(mapped_value), |
422 | | } |
423 | | } |
424 | | }); |
425 | | |
426 | | Ok(()) |
427 | | } |
428 | | |
429 | | /// Builds the `DictionaryArray` and reset this builder. |
430 | 0 | pub fn finish(&mut self) -> DictionaryArray<K> { |
431 | 0 | self.map.clear(); |
432 | 0 | let values = self.values_builder.finish(); |
433 | 0 | let keys = self.keys_builder.finish(); |
434 | | |
435 | 0 | let data_type = |
436 | 0 | DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(values.data_type().clone())); |
437 | | |
438 | 0 | let builder = keys |
439 | 0 | .into_data() |
440 | 0 | .into_builder() |
441 | 0 | .data_type(data_type) |
442 | 0 | .child_data(vec![values.into_data()]); |
443 | | |
444 | 0 | DictionaryArray::from(unsafe { builder.build_unchecked() }) |
445 | 0 | } |
446 | | |
447 | | /// Builds the `DictionaryArray` without resetting the builder. |
448 | | pub fn finish_cloned(&self) -> DictionaryArray<K> { |
449 | | let values = self.values_builder.finish_cloned(); |
450 | | let keys = self.keys_builder.finish_cloned(); |
451 | | |
452 | | let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE)); |
453 | | |
454 | | let builder = keys |
455 | | .into_data() |
456 | | .into_builder() |
457 | | .data_type(data_type) |
458 | | .child_data(vec![values.into_data()]); |
459 | | |
460 | | DictionaryArray::from(unsafe { builder.build_unchecked() }) |
461 | | } |
462 | | |
463 | | /// Builds the `DictionaryArray` without resetting the values builder or |
464 | | /// the internal de-duplication map. |
465 | | /// |
466 | | /// The advantage of doing this is that the values will represent the entire |
467 | | /// set of what has been built so-far by this builder and ensures |
468 | | /// consistency in the assignment of keys to values across multiple calls |
469 | | /// to `finish_preserve_values`. This enables ipc writers to efficiently |
470 | | /// emit delta dictionaries. |
471 | | /// |
472 | | /// The downside to this is that building the record requires creating a |
473 | | /// copy of the values, which can become slowly more expensive if the |
474 | | /// dictionary grows. |
475 | | /// |
476 | | /// Additionally, if record batches from multiple different dictionary |
477 | | /// builders for the same column are fed into a single ipc writer, beware |
478 | | /// that entire dictionaries are likely to be re-sent frequently even when |
479 | | /// the majority of the values are not used by the current record batch. |
480 | | pub fn finish_preserve_values(&mut self) -> DictionaryArray<K> { |
481 | | let values = self.values_builder.finish_cloned(); |
482 | | let keys = self.keys_builder.finish(); |
483 | | |
484 | | let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE)); |
485 | | |
486 | | let builder = keys |
487 | | .into_data() |
488 | | .into_builder() |
489 | | .data_type(data_type) |
490 | | .child_data(vec![values.into_data()]); |
491 | | |
492 | | DictionaryArray::from(unsafe { builder.build_unchecked() }) |
493 | | } |
494 | | |
495 | | /// Returns the current dictionary values buffer as a slice |
496 | | pub fn values_slice(&self) -> &[V::Native] { |
497 | | self.values_builder.values_slice() |
498 | | } |
499 | | |
500 | | /// Returns the current dictionary values buffer as a mutable slice |
501 | | pub fn values_slice_mut(&mut self) -> &mut [V::Native] { |
502 | | self.values_builder.values_slice_mut() |
503 | | } |
504 | | |
505 | | /// Returns the current null buffer as a slice |
506 | | pub fn validity_slice(&self) -> Option<&[u8]> { |
507 | | self.keys_builder.validity_slice() |
508 | | } |
509 | | } |
510 | | |
511 | | impl<K: ArrowDictionaryKeyType, P: ArrowPrimitiveType> Extend<Option<P::Native>> |
512 | | for PrimitiveDictionaryBuilder<K, P> |
513 | | { |
514 | | #[inline] |
515 | | fn extend<T: IntoIterator<Item = Option<P::Native>>>(&mut self, iter: T) { |
516 | | for v in iter { |
517 | | self.append_option(v) |
518 | | } |
519 | | } |
520 | | } |
521 | | |
522 | | #[cfg(test)] |
523 | | mod tests { |
524 | | use super::*; |
525 | | |
526 | | use crate::array::{Int32Array, UInt32Array, UInt8Array}; |
527 | | use crate::builder::Decimal128Builder; |
528 | | use crate::cast::AsArray; |
529 | | use crate::types::{ |
530 | | Date32Type, Decimal128Type, DurationNanosecondType, Float32Type, Float64Type, Int16Type, |
531 | | Int32Type, Int64Type, Int8Type, TimestampNanosecondType, UInt16Type, UInt32Type, |
532 | | UInt64Type, UInt8Type, |
533 | | }; |
534 | | |
535 | | #[test] |
536 | | fn test_primitive_dictionary_builder() { |
537 | | let mut builder = PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::with_capacity(3, 2); |
538 | | builder.append(12345678).unwrap(); |
539 | | builder.append_null(); |
540 | | builder.append(22345678).unwrap(); |
541 | | let array = builder.finish(); |
542 | | |
543 | | assert_eq!( |
544 | | array.keys(), |
545 | | &UInt8Array::from(vec![Some(0), None, Some(1)]) |
546 | | ); |
547 | | |
548 | | // Values are polymorphic and so require a downcast. |
549 | | let av = array.values(); |
550 | | let ava: &UInt32Array = av.as_any().downcast_ref::<UInt32Array>().unwrap(); |
551 | | let avs: &[u32] = ava.values(); |
552 | | |
553 | | assert!(!array.is_null(0)); |
554 | | assert!(array.is_null(1)); |
555 | | assert!(!array.is_null(2)); |
556 | | |
557 | | assert_eq!(avs, &[12345678, 22345678]); |
558 | | } |
559 | | |
560 | | #[test] |
561 | | fn test_extend() { |
562 | | let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new(); |
563 | | builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some)); |
564 | | builder.extend([4, 5, 1, 3, 1].into_iter().map(Some)); |
565 | | let dict = builder.finish(); |
566 | | assert_eq!( |
567 | | dict.keys().values(), |
568 | | &[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 0, 2, 0] |
569 | | ); |
570 | | assert_eq!(dict.values().len(), 5); |
571 | | } |
572 | | |
573 | | #[test] |
574 | | #[should_panic(expected = "DictionaryKeyOverflowError")] |
575 | | fn test_primitive_dictionary_overflow() { |
576 | | let mut builder = |
577 | | PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::with_capacity(257, 257); |
578 | | // 256 unique keys. |
579 | | for i in 0..256 { |
580 | | builder.append(i + 1000).unwrap(); |
581 | | } |
582 | | // Special error if the key overflows (256th entry) |
583 | | builder.append(1257).unwrap(); |
584 | | } |
585 | | |
586 | | #[test] |
587 | | fn test_primitive_dictionary_with_builders() { |
588 | | let keys_builder = PrimitiveBuilder::<Int32Type>::new(); |
589 | | let values_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(1, 2)); |
590 | | let mut builder = |
591 | | PrimitiveDictionaryBuilder::<Int32Type, Decimal128Type>::new_from_empty_builders( |
592 | | keys_builder, |
593 | | values_builder, |
594 | | ); |
595 | | let dict_array = builder.finish(); |
596 | | assert_eq!(dict_array.value_type(), DataType::Decimal128(1, 2)); |
597 | | assert_eq!( |
598 | | dict_array.data_type(), |
599 | | &DataType::Dictionary( |
600 | | Box::new(DataType::Int32), |
601 | | Box::new(DataType::Decimal128(1, 2)), |
602 | | ) |
603 | | ); |
604 | | } |
605 | | |
606 | | #[test] |
607 | | fn test_extend_dictionary() { |
608 | | let some_dict = { |
609 | | let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new(); |
610 | | builder.extend([1, 2, 3, 1, 2, 3, 1, 2, 3].into_iter().map(Some)); |
611 | | builder.extend([None::<i32>]); |
612 | | builder.extend([4, 5, 1, 3, 1].into_iter().map(Some)); |
613 | | builder.append_null(); |
614 | | builder.finish() |
615 | | }; |
616 | | |
617 | | let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new(); |
618 | | builder.extend([6, 6, 7, 6, 5].into_iter().map(Some)); |
619 | | builder |
620 | | .extend_dictionary(&some_dict.downcast_dict().unwrap()) |
621 | | .unwrap(); |
622 | | let dict = builder.finish(); |
623 | | |
624 | | assert_eq!(dict.values().len(), 7); |
625 | | |
626 | | let values = dict |
627 | | .downcast_dict::<Int32Array>() |
628 | | .unwrap() |
629 | | .into_iter() |
630 | | .collect::<Vec<_>>(); |
631 | | |
632 | | assert_eq!( |
633 | | values, |
634 | | [ |
635 | | Some(6), |
636 | | Some(6), |
637 | | Some(7), |
638 | | Some(6), |
639 | | Some(5), |
640 | | Some(1), |
641 | | Some(2), |
642 | | Some(3), |
643 | | Some(1), |
644 | | Some(2), |
645 | | Some(3), |
646 | | Some(1), |
647 | | Some(2), |
648 | | Some(3), |
649 | | None, |
650 | | Some(4), |
651 | | Some(5), |
652 | | Some(1), |
653 | | Some(3), |
654 | | Some(1), |
655 | | None |
656 | | ] |
657 | | ); |
658 | | } |
659 | | |
660 | | #[test] |
661 | | fn test_extend_dictionary_with_null_in_mapped_value() { |
662 | | let some_dict = { |
663 | | let mut values_builder = PrimitiveBuilder::<Int32Type>::new(); |
664 | | let mut keys_builder = PrimitiveBuilder::<Int32Type>::new(); |
665 | | |
666 | | // Manually build a dictionary values that the mapped values have null |
667 | | values_builder.append_null(); |
668 | | keys_builder.append_value(0); |
669 | | values_builder.append_value(42); |
670 | | keys_builder.append_value(1); |
671 | | |
672 | | let values = values_builder.finish(); |
673 | | let keys = keys_builder.finish(); |
674 | | |
675 | | let data_type = DataType::Dictionary( |
676 | | Box::new(Int32Type::DATA_TYPE), |
677 | | Box::new(values.data_type().clone()), |
678 | | ); |
679 | | |
680 | | let builder = keys |
681 | | .into_data() |
682 | | .into_builder() |
683 | | .data_type(data_type) |
684 | | .child_data(vec![values.into_data()]); |
685 | | |
686 | | DictionaryArray::from(unsafe { builder.build_unchecked() }) |
687 | | }; |
688 | | |
689 | | let some_dict_values = some_dict.values().as_primitive::<Int32Type>(); |
690 | | assert_eq!( |
691 | | some_dict_values.into_iter().collect::<Vec<_>>(), |
692 | | &[None, Some(42)] |
693 | | ); |
694 | | |
695 | | let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new(); |
696 | | builder |
697 | | .extend_dictionary(&some_dict.downcast_dict().unwrap()) |
698 | | .unwrap(); |
699 | | let dict = builder.finish(); |
700 | | |
701 | | assert_eq!(dict.values().len(), 1); |
702 | | |
703 | | let values = dict |
704 | | .downcast_dict::<Int32Array>() |
705 | | .unwrap() |
706 | | .into_iter() |
707 | | .collect::<Vec<_>>(); |
708 | | |
709 | | assert_eq!(values, [None, Some(42)]); |
710 | | } |
711 | | |
712 | | #[test] |
713 | | fn test_extend_all_null_dictionary() { |
714 | | let some_dict = { |
715 | | let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new(); |
716 | | builder.append_nulls(2); |
717 | | builder.finish() |
718 | | }; |
719 | | |
720 | | let mut builder = PrimitiveDictionaryBuilder::<Int32Type, Int32Type>::new(); |
721 | | builder |
722 | | .extend_dictionary(&some_dict.downcast_dict().unwrap()) |
723 | | .unwrap(); |
724 | | let dict = builder.finish(); |
725 | | |
726 | | assert_eq!(dict.values().len(), 0); |
727 | | |
728 | | let values = dict |
729 | | .downcast_dict::<Int32Array>() |
730 | | .unwrap() |
731 | | .into_iter() |
732 | | .collect::<Vec<_>>(); |
733 | | |
734 | | assert_eq!(values, [None, None]); |
735 | | } |
736 | | |
737 | | #[test] |
738 | | fn creating_dictionary_from_builders_should_use_values_capacity_for_the_map() { |
739 | | let builder = PrimitiveDictionaryBuilder::<Int32Type, crate::types::TimestampMicrosecondType>::new_from_empty_builders( |
740 | | PrimitiveBuilder::with_capacity(1).with_data_type(DataType::Int32), |
741 | | PrimitiveBuilder::with_capacity(2).with_data_type(DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, Some("+08:00".into()))), |
742 | | ); |
743 | | |
744 | | assert!( |
745 | | builder.map.capacity() >= builder.values_builder.capacity(), |
746 | | "map capacity {} should be at least the values capacity {}", |
747 | | builder.map.capacity(), |
748 | | builder.values_builder.capacity() |
749 | | ) |
750 | | } |
751 | | |
752 | | fn _test_try_new_from_builder_generic_for_key_types<K1, K2, V>(values: Vec<V::Native>) |
753 | | where |
754 | | K1: ArrowDictionaryKeyType, |
755 | | K1::Native: NumCast, |
756 | | K2: ArrowDictionaryKeyType, |
757 | | K2::Native: NumCast + From<u8>, |
758 | | V: ArrowPrimitiveType, |
759 | | { |
760 | | let mut source = PrimitiveDictionaryBuilder::<K1, V>::new(); |
761 | | source.append(values[0]).unwrap(); |
762 | | source.append_null(); |
763 | | source.append(values[1]).unwrap(); |
764 | | source.append(values[2]).unwrap(); |
765 | | |
766 | | let mut result = PrimitiveDictionaryBuilder::<K2, V>::try_new_from_builder(source).unwrap(); |
767 | | let array = result.finish(); |
768 | | |
769 | | let mut expected_keys_builder = PrimitiveBuilder::<K2>::new(); |
770 | | expected_keys_builder |
771 | | .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(0u8)); |
772 | | expected_keys_builder.append_null(); |
773 | | expected_keys_builder |
774 | | .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(1u8)); |
775 | | expected_keys_builder |
776 | | .append_value(<<K2 as ArrowPrimitiveType>::Native as From<u8>>::from(2u8)); |
777 | | let expected_keys = expected_keys_builder.finish(); |
778 | | assert_eq!(array.keys(), &expected_keys); |
779 | | |
780 | | let av = array.values(); |
781 | | let ava = av.as_any().downcast_ref::<PrimitiveArray<V>>().unwrap(); |
782 | | assert_eq!(ava.value(0), values[0]); |
783 | | assert_eq!(ava.value(1), values[1]); |
784 | | assert_eq!(ava.value(2), values[2]); |
785 | | } |
786 | | |
787 | | fn _test_try_new_from_builder_generic_for_value<T>(values: Vec<T::Native>) |
788 | | where |
789 | | T: ArrowPrimitiveType, |
790 | | { |
791 | | // test cast to bigger size unsigned |
792 | | _test_try_new_from_builder_generic_for_key_types::<UInt8Type, UInt16Type, T>( |
793 | | values.clone(), |
794 | | ); |
795 | | // test cast going to smaller size unsigned |
796 | | _test_try_new_from_builder_generic_for_key_types::<UInt16Type, UInt8Type, T>( |
797 | | values.clone(), |
798 | | ); |
799 | | // test cast going to bigger size signed |
800 | | _test_try_new_from_builder_generic_for_key_types::<Int8Type, Int16Type, T>(values.clone()); |
801 | | // test cast going to smaller size signed |
802 | | _test_try_new_from_builder_generic_for_key_types::<Int32Type, Int16Type, T>(values.clone()); |
803 | | // test going from signed to signed for different size changes |
804 | | _test_try_new_from_builder_generic_for_key_types::<UInt8Type, Int16Type, T>(values.clone()); |
805 | | _test_try_new_from_builder_generic_for_key_types::<Int8Type, UInt8Type, T>(values.clone()); |
806 | | _test_try_new_from_builder_generic_for_key_types::<Int8Type, UInt16Type, T>(values.clone()); |
807 | | _test_try_new_from_builder_generic_for_key_types::<Int32Type, Int16Type, T>(values.clone()); |
808 | | } |
809 | | |
810 | | #[test] |
811 | | fn test_try_new_from_builder() { |
812 | | // test unsigned types |
813 | | _test_try_new_from_builder_generic_for_value::<UInt8Type>(vec![1, 2, 3]); |
814 | | _test_try_new_from_builder_generic_for_value::<UInt16Type>(vec![1, 2, 3]); |
815 | | _test_try_new_from_builder_generic_for_value::<UInt32Type>(vec![1, 2, 3]); |
816 | | _test_try_new_from_builder_generic_for_value::<UInt64Type>(vec![1, 2, 3]); |
817 | | // test signed types |
818 | | _test_try_new_from_builder_generic_for_value::<Int8Type>(vec![-1, 0, 1]); |
819 | | _test_try_new_from_builder_generic_for_value::<Int16Type>(vec![-1, 0, 1]); |
820 | | _test_try_new_from_builder_generic_for_value::<Int32Type>(vec![-1, 0, 1]); |
821 | | _test_try_new_from_builder_generic_for_value::<Int64Type>(vec![-1, 0, 1]); |
822 | | // test some date types |
823 | | _test_try_new_from_builder_generic_for_value::<Date32Type>(vec![5, 6, 7]); |
824 | | _test_try_new_from_builder_generic_for_value::<DurationNanosecondType>(vec![1, 2, 3]); |
825 | | _test_try_new_from_builder_generic_for_value::<TimestampNanosecondType>(vec![1, 2, 3]); |
826 | | // test some floating point types |
827 | | _test_try_new_from_builder_generic_for_value::<Float32Type>(vec![0.1, 0.2, 0.3]); |
828 | | _test_try_new_from_builder_generic_for_value::<Float64Type>(vec![-0.1, 0.2, 0.3]); |
829 | | } |
830 | | |
831 | | #[test] |
832 | | fn test_try_new_from_builder_cast_fails() { |
833 | | let mut source_builder = PrimitiveDictionaryBuilder::<UInt16Type, UInt64Type>::new(); |
834 | | for i in 0..257 { |
835 | | source_builder.append_value(i); |
836 | | } |
837 | | |
838 | | // there should be too many values that we can't downcast to the underlying type |
839 | | // we have keys that wouldn't fit into UInt8Type |
840 | | let result = PrimitiveDictionaryBuilder::<UInt8Type, UInt64Type>::try_new_from_builder( |
841 | | source_builder, |
842 | | ); |
843 | | assert!(result.is_err()); |
844 | | if let Err(e) = result { |
845 | | assert!(matches!(e, ArrowError::CastError(_))); |
846 | | assert_eq!( |
847 | | e.to_string(), |
848 | | "Cast error: Can't cast dictionary keys from source type UInt16 to type UInt8" |
849 | | ); |
850 | | } |
851 | | } |
852 | | |
853 | | #[test] |
854 | | fn test_finish_preserve_values() { |
855 | | // Create the first dictionary |
856 | | let mut builder = PrimitiveDictionaryBuilder::<UInt8Type, UInt32Type>::new(); |
857 | | builder.append(10).unwrap(); |
858 | | builder.append(20).unwrap(); |
859 | | let array = builder.finish_preserve_values(); |
860 | | assert_eq!(array.keys(), &UInt8Array::from(vec![Some(0), Some(1)])); |
861 | | let values: &[u32] = array |
862 | | .values() |
863 | | .as_any() |
864 | | .downcast_ref::<UInt32Array>() |
865 | | .unwrap() |
866 | | .values(); |
867 | | assert_eq!(values, &[10, 20]); |
868 | | |
869 | | // Create a new dictionary |
870 | | builder.append(30).unwrap(); |
871 | | builder.append(40).unwrap(); |
872 | | let array2 = builder.finish_preserve_values(); |
873 | | |
874 | | // Make sure the keys are assigned after the old ones |
875 | | // and that we have the right values |
876 | | assert_eq!(array2.keys(), &UInt8Array::from(vec![Some(2), Some(3)])); |
877 | | let values = array2 |
878 | | .downcast_dict::<UInt32Array>() |
879 | | .unwrap() |
880 | | .into_iter() |
881 | | .collect::<Vec<_>>(); |
882 | | assert_eq!(values, vec![Some(30), Some(40)]); |
883 | | |
884 | | // Check that we have all of the expected values |
885 | | let all_values: &[u32] = array2 |
886 | | .values() |
887 | | .as_any() |
888 | | .downcast_ref::<UInt32Array>() |
889 | | .unwrap() |
890 | | .values(); |
891 | | assert_eq!(all_values, &[10, 20, 30, 40]); |
892 | | } |
893 | | } |