/Users/andrewlamb/Software/arrow-rs/arrow-array/src/builder/map_builder.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use crate::builder::ArrayBuilder; |
19 | | use crate::{Array, ArrayRef, MapArray, StructArray}; |
20 | | use arrow_buffer::Buffer; |
21 | | use arrow_buffer::{NullBuffer, NullBufferBuilder}; |
22 | | use arrow_data::ArrayData; |
23 | | use arrow_schema::{ArrowError, DataType, Field, FieldRef}; |
24 | | use std::any::Any; |
25 | | use std::sync::Arc; |
26 | | |
27 | | /// Builder for [`MapArray`] |
28 | | /// |
29 | | /// ``` |
30 | | /// # use arrow_array::builder::{Int32Builder, MapBuilder, StringBuilder}; |
31 | | /// # use arrow_array::{Int32Array, StringArray}; |
32 | | /// |
33 | | /// let string_builder = StringBuilder::new(); |
34 | | /// let int_builder = Int32Builder::with_capacity(4); |
35 | | /// |
36 | | /// // Construct `[{"joe": 1}, {"blogs": 2, "foo": 4}, {}, null]` |
37 | | /// let mut builder = MapBuilder::new(None, string_builder, int_builder); |
38 | | /// |
39 | | /// builder.keys().append_value("joe"); |
40 | | /// builder.values().append_value(1); |
41 | | /// builder.append(true).unwrap(); |
42 | | /// |
43 | | /// builder.keys().append_value("blogs"); |
44 | | /// builder.values().append_value(2); |
45 | | /// builder.keys().append_value("foo"); |
46 | | /// builder.values().append_value(4); |
47 | | /// builder.append(true).unwrap(); |
48 | | /// builder.append(true).unwrap(); |
49 | | /// builder.append(false).unwrap(); |
50 | | /// |
51 | | /// let array = builder.finish(); |
52 | | /// assert_eq!(array.value_offsets(), &[0, 1, 3, 3, 3]); |
53 | | /// assert_eq!(array.values().as_ref(), &Int32Array::from(vec![1, 2, 4])); |
54 | | /// assert_eq!(array.keys().as_ref(), &StringArray::from(vec!["joe", "blogs", "foo"])); |
55 | | /// |
56 | | /// ``` |
57 | | #[derive(Debug)] |
58 | | pub struct MapBuilder<K: ArrayBuilder, V: ArrayBuilder> { |
59 | | offsets_builder: Vec<i32>, |
60 | | null_buffer_builder: NullBufferBuilder, |
61 | | field_names: MapFieldNames, |
62 | | key_builder: K, |
63 | | value_builder: V, |
64 | | key_field: Option<FieldRef>, |
65 | | value_field: Option<FieldRef>, |
66 | | } |
67 | | |
68 | | /// The [`Field`] names for a [`MapArray`] |
69 | | #[derive(Debug, Clone)] |
70 | | pub struct MapFieldNames { |
71 | | /// [`Field`] name for map entries |
72 | | pub entry: String, |
73 | | /// [`Field`] name for map key |
74 | | pub key: String, |
75 | | /// [`Field`] name for map value |
76 | | pub value: String, |
77 | | } |
78 | | |
79 | | impl Default for MapFieldNames { |
80 | 0 | fn default() -> Self { |
81 | 0 | Self { |
82 | 0 | entry: "entries".to_string(), |
83 | 0 | key: "keys".to_string(), |
84 | 0 | value: "values".to_string(), |
85 | 0 | } |
86 | 0 | } |
87 | | } |
88 | | |
89 | | impl<K: ArrayBuilder, V: ArrayBuilder> MapBuilder<K, V> { |
90 | | /// Creates a new `MapBuilder` |
91 | 3 | pub fn new(field_names: Option<MapFieldNames>, key_builder: K, value_builder: V) -> Self { |
92 | 3 | let capacity = key_builder.len(); |
93 | 3 | Self::with_capacity(field_names, key_builder, value_builder, capacity) |
94 | 3 | } |
95 | | |
96 | | /// Creates a new `MapBuilder` with capacity |
97 | 3 | pub fn with_capacity( |
98 | 3 | field_names: Option<MapFieldNames>, |
99 | 3 | key_builder: K, |
100 | 3 | value_builder: V, |
101 | 3 | capacity: usize, |
102 | 3 | ) -> Self { |
103 | 3 | let mut offsets_builder = Vec::with_capacity(capacity + 1); |
104 | 3 | offsets_builder.push(0); |
105 | 3 | Self { |
106 | 3 | offsets_builder, |
107 | 3 | null_buffer_builder: NullBufferBuilder::new(capacity), |
108 | 3 | field_names: field_names.unwrap_or_default(), |
109 | 3 | key_builder, |
110 | 3 | value_builder, |
111 | 3 | key_field: None, |
112 | 3 | value_field: None, |
113 | 3 | } |
114 | 3 | } |
115 | | |
116 | | /// Override the field passed to [`MapBuilder::new`] |
117 | | /// |
118 | | /// By default, a non-nullable field is created with the name `keys` |
119 | | /// |
120 | | /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the |
121 | | /// field's data type does not match that of `K` or the field is nullable |
122 | 0 | pub fn with_keys_field(self, field: impl Into<FieldRef>) -> Self { |
123 | 0 | Self { |
124 | 0 | key_field: Some(field.into()), |
125 | 0 | ..self |
126 | 0 | } |
127 | 0 | } |
128 | | |
129 | | /// Override the field passed to [`MapBuilder::new`] |
130 | | /// |
131 | | /// By default, a nullable field is created with the name `values` |
132 | | /// |
133 | | /// Note: [`Self::finish`] and [`Self::finish_cloned`] will panic if the |
134 | | /// field's data type does not match that of `V` |
135 | 0 | pub fn with_values_field(self, field: impl Into<FieldRef>) -> Self { |
136 | 0 | Self { |
137 | 0 | value_field: Some(field.into()), |
138 | 0 | ..self |
139 | 0 | } |
140 | 0 | } |
141 | | |
142 | | /// Returns the key array builder of the map |
143 | | pub fn keys(&mut self) -> &mut K { |
144 | | &mut self.key_builder |
145 | | } |
146 | | |
147 | | /// Returns the value array builder of the map |
148 | | pub fn values(&mut self) -> &mut V { |
149 | | &mut self.value_builder |
150 | | } |
151 | | |
152 | | /// Returns both the key and value array builders of the map |
153 | 2 | pub fn entries(&mut self) -> (&mut K, &mut V) { |
154 | 2 | (&mut self.key_builder, &mut self.value_builder) |
155 | 2 | } |
156 | | |
157 | | /// Finish the current map array slot |
158 | | /// |
159 | | /// Returns an error if the key and values builders are in an inconsistent state. |
160 | | #[inline] |
161 | 6 | pub fn append(&mut self, is_valid: bool) -> Result<(), ArrowError> { |
162 | 6 | if self.key_builder.len() != self.value_builder.len() { |
163 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
164 | 0 | "Cannot append to a map builder when its keys and values have unequal lengths of {} and {}", |
165 | 0 | self.key_builder.len(), |
166 | 0 | self.value_builder.len() |
167 | 0 | ))); |
168 | 6 | } |
169 | 6 | self.offsets_builder.push(self.key_builder.len() as i32); |
170 | 6 | self.null_buffer_builder.append(is_valid); |
171 | 6 | Ok(()) |
172 | 6 | } |
173 | | |
174 | | /// Builds the [`MapArray`] |
175 | 3 | pub fn finish(&mut self) -> MapArray { |
176 | 3 | let len = self.len(); |
177 | | // Build the keys |
178 | 3 | let keys_arr = self.key_builder.finish(); |
179 | 3 | let values_arr = self.value_builder.finish(); |
180 | 3 | let offset_buffer = Buffer::from_vec(std::mem::take(&mut self.offsets_builder)); |
181 | 3 | self.offsets_builder.push(0); |
182 | 3 | let null_bit_buffer = self.null_buffer_builder.finish(); |
183 | | |
184 | 3 | self.finish_helper(keys_arr, values_arr, offset_buffer, null_bit_buffer, len) |
185 | 3 | } |
186 | | |
187 | | /// Builds the [`MapArray`] without resetting the builder. |
188 | 0 | pub fn finish_cloned(&self) -> MapArray { |
189 | 0 | let len = self.len(); |
190 | | // Build the keys |
191 | 0 | let keys_arr = self.key_builder.finish_cloned(); |
192 | 0 | let values_arr = self.value_builder.finish_cloned(); |
193 | 0 | let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice()); |
194 | 0 | let nulls = self.null_buffer_builder.finish_cloned(); |
195 | 0 | self.finish_helper(keys_arr, values_arr, offset_buffer, nulls, len) |
196 | 0 | } |
197 | | |
198 | 3 | fn finish_helper( |
199 | 3 | &self, |
200 | 3 | keys_arr: Arc<dyn Array>, |
201 | 3 | values_arr: Arc<dyn Array>, |
202 | 3 | offset_buffer: Buffer, |
203 | 3 | nulls: Option<NullBuffer>, |
204 | 3 | len: usize, |
205 | 3 | ) -> MapArray { |
206 | 3 | assert!( |
207 | 3 | keys_arr.null_count() == 0, |
208 | 0 | "Keys array must have no null values, found {} null value(s)", |
209 | 0 | keys_arr.null_count() |
210 | | ); |
211 | | |
212 | 3 | let keys_field = match &self.key_field { |
213 | 0 | Some(f) => { |
214 | 0 | assert!(!f.is_nullable(), "Keys field must not be nullable"); |
215 | 0 | f.clone() |
216 | | } |
217 | 3 | None => Arc::new(Field::new( |
218 | 3 | self.field_names.key.as_str(), |
219 | 3 | keys_arr.data_type().clone(), |
220 | | false, // always non-nullable |
221 | | )), |
222 | | }; |
223 | 3 | let values_field = match &self.value_field { |
224 | 0 | Some(f) => f.clone(), |
225 | 3 | None => Arc::new(Field::new( |
226 | 3 | self.field_names.value.as_str(), |
227 | 3 | values_arr.data_type().clone(), |
228 | | true, |
229 | | )), |
230 | | }; |
231 | | |
232 | 3 | let struct_array = |
233 | 3 | StructArray::from(vec![(keys_field, keys_arr), (values_field, values_arr)]); |
234 | | |
235 | 3 | let map_field = Arc::new(Field::new( |
236 | 3 | self.field_names.entry.as_str(), |
237 | 3 | struct_array.data_type().clone(), |
238 | | false, // always non-nullable |
239 | | )); |
240 | 3 | let array_data = ArrayData::builder(DataType::Map(map_field, false)) // TODO: support sorted keys |
241 | 3 | .len(len) |
242 | 3 | .add_buffer(offset_buffer) |
243 | 3 | .add_child_data(struct_array.into_data()) |
244 | 3 | .nulls(nulls); |
245 | | |
246 | 3 | let array_data = unsafe { array_data.build_unchecked() }; |
247 | | |
248 | 3 | MapArray::from(array_data) |
249 | 3 | } |
250 | | |
251 | | /// Returns the current null buffer as a slice |
252 | | pub fn validity_slice(&self) -> Option<&[u8]> { |
253 | | self.null_buffer_builder.as_slice() |
254 | | } |
255 | | } |
256 | | |
257 | | impl<K: ArrayBuilder, V: ArrayBuilder> ArrayBuilder for MapBuilder<K, V> { |
258 | 6 | fn len(&self) -> usize { |
259 | 6 | self.null_buffer_builder.len() |
260 | 6 | } |
261 | | |
262 | 2 | fn finish(&mut self) -> ArrayRef { |
263 | 2 | Arc::new(self.finish()) |
264 | 2 | } |
265 | | |
266 | | /// Builds the array without resetting the builder. |
267 | 0 | fn finish_cloned(&self) -> ArrayRef { |
268 | 0 | Arc::new(self.finish_cloned()) |
269 | 0 | } |
270 | | |
271 | 0 | fn as_any(&self) -> &dyn Any { |
272 | 0 | self |
273 | 0 | } |
274 | | |
275 | 1 | fn as_any_mut(&mut self) -> &mut dyn Any { |
276 | 1 | self |
277 | 1 | } |
278 | | |
279 | 0 | fn into_box_any(self: Box<Self>) -> Box<dyn Any> { |
280 | 0 | self |
281 | 0 | } |
282 | | } |
283 | | |
284 | | #[cfg(test)] |
285 | | mod tests { |
286 | | use super::*; |
287 | | use crate::builder::{make_builder, Int32Builder, StringBuilder}; |
288 | | use crate::{Int32Array, StringArray}; |
289 | | use std::collections::HashMap; |
290 | | |
291 | | #[test] |
292 | | #[should_panic(expected = "Keys array must have no null values, found 1 null value(s)")] |
293 | | fn test_map_builder_with_null_keys_panics() { |
294 | | let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); |
295 | | builder.keys().append_null(); |
296 | | builder.values().append_value(42); |
297 | | builder.append(true).unwrap(); |
298 | | |
299 | | builder.finish(); |
300 | | } |
301 | | |
302 | | #[test] |
303 | | fn test_boxed_map_builder() { |
304 | | let keys_builder = make_builder(&DataType::Utf8, 5); |
305 | | let values_builder = make_builder(&DataType::Int32, 5); |
306 | | |
307 | | let mut builder = MapBuilder::new(None, keys_builder, values_builder); |
308 | | builder |
309 | | .keys() |
310 | | .as_any_mut() |
311 | | .downcast_mut::<StringBuilder>() |
312 | | .expect("should be an StringBuilder") |
313 | | .append_value("1"); |
314 | | builder |
315 | | .values() |
316 | | .as_any_mut() |
317 | | .downcast_mut::<Int32Builder>() |
318 | | .expect("should be an Int32Builder") |
319 | | .append_value(42); |
320 | | builder.append(true).unwrap(); |
321 | | |
322 | | let map_array = builder.finish(); |
323 | | |
324 | | assert_eq!( |
325 | | map_array |
326 | | .keys() |
327 | | .as_any() |
328 | | .downcast_ref::<StringArray>() |
329 | | .expect("should be an StringArray") |
330 | | .value(0), |
331 | | "1" |
332 | | ); |
333 | | assert_eq!( |
334 | | map_array |
335 | | .values() |
336 | | .as_any() |
337 | | .downcast_ref::<Int32Array>() |
338 | | .expect("should be an Int32Array") |
339 | | .value(0), |
340 | | 42 |
341 | | ); |
342 | | } |
343 | | |
344 | | #[test] |
345 | | fn test_with_values_field() { |
346 | | let value_field = Arc::new(Field::new("bars", DataType::Int32, false)); |
347 | | let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new()) |
348 | | .with_values_field(value_field.clone()); |
349 | | builder.keys().append_value(1); |
350 | | builder.values().append_value(2); |
351 | | builder.append(true).unwrap(); |
352 | | builder.append(false).unwrap(); // This is fine as nullability refers to nullability of values |
353 | | builder.keys().append_value(3); |
354 | | builder.values().append_value(4); |
355 | | builder.append(true).unwrap(); |
356 | | let map = builder.finish(); |
357 | | |
358 | | assert_eq!(map.len(), 3); |
359 | | assert_eq!( |
360 | | map.data_type(), |
361 | | &DataType::Map( |
362 | | Arc::new(Field::new( |
363 | | "entries", |
364 | | DataType::Struct( |
365 | | vec![ |
366 | | Arc::new(Field::new("keys", DataType::Int32, false)), |
367 | | value_field.clone() |
368 | | ] |
369 | | .into() |
370 | | ), |
371 | | false, |
372 | | )), |
373 | | false |
374 | | ) |
375 | | ); |
376 | | |
377 | | builder.keys().append_value(5); |
378 | | builder.values().append_value(6); |
379 | | builder.append(true).unwrap(); |
380 | | let map = builder.finish(); |
381 | | |
382 | | assert_eq!(map.len(), 1); |
383 | | assert_eq!( |
384 | | map.data_type(), |
385 | | &DataType::Map( |
386 | | Arc::new(Field::new( |
387 | | "entries", |
388 | | DataType::Struct( |
389 | | vec![ |
390 | | Arc::new(Field::new("keys", DataType::Int32, false)), |
391 | | value_field |
392 | | ] |
393 | | .into() |
394 | | ), |
395 | | false, |
396 | | )), |
397 | | false |
398 | | ) |
399 | | ); |
400 | | } |
401 | | |
402 | | #[test] |
403 | | fn test_with_keys_field() { |
404 | | let mut key_metadata = HashMap::new(); |
405 | | key_metadata.insert("foo".to_string(), "bar".to_string()); |
406 | | let key_field = Arc::new( |
407 | | Field::new("keys", DataType::Int32, false).with_metadata(key_metadata.clone()), |
408 | | ); |
409 | | let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new()) |
410 | | .with_keys_field(key_field.clone()); |
411 | | builder.keys().append_value(1); |
412 | | builder.values().append_value(2); |
413 | | builder.append(true).unwrap(); |
414 | | let map = builder.finish(); |
415 | | |
416 | | assert_eq!(map.len(), 1); |
417 | | assert_eq!( |
418 | | map.data_type(), |
419 | | &DataType::Map( |
420 | | Arc::new(Field::new( |
421 | | "entries", |
422 | | DataType::Struct( |
423 | | vec![ |
424 | | Arc::new( |
425 | | Field::new("keys", DataType::Int32, false) |
426 | | .with_metadata(key_metadata) |
427 | | ), |
428 | | Arc::new(Field::new("values", DataType::Int32, true)) |
429 | | ] |
430 | | .into() |
431 | | ), |
432 | | false, |
433 | | )), |
434 | | false |
435 | | ) |
436 | | ); |
437 | | } |
438 | | |
439 | | #[test] |
440 | | #[should_panic(expected = "Keys field must not be nullable")] |
441 | | fn test_with_nullable_keys_field() { |
442 | | let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new()) |
443 | | .with_keys_field(Arc::new(Field::new("keys", DataType::Int32, true))); |
444 | | |
445 | | builder.keys().append_value(1); |
446 | | builder.values().append_value(2); |
447 | | builder.append(true).unwrap(); |
448 | | |
449 | | builder.finish(); |
450 | | } |
451 | | |
452 | | #[test] |
453 | | #[should_panic(expected = "Incorrect datatype")] |
454 | | fn test_keys_field_type_mismatch() { |
455 | | let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new()) |
456 | | .with_keys_field(Arc::new(Field::new("keys", DataType::Utf8, false))); |
457 | | |
458 | | builder.keys().append_value(1); |
459 | | builder.values().append_value(2); |
460 | | builder.append(true).unwrap(); |
461 | | |
462 | | builder.finish(); |
463 | | } |
464 | | } |