/Users/andrewlamb/Software/arrow-rs/arrow-array/src/record_batch.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! A two-dimensional batch of column-oriented data with a defined |
19 | | //! [schema](arrow_schema::Schema). |
20 | | |
21 | | use crate::cast::AsArray; |
22 | | use crate::{new_empty_array, Array, ArrayRef, StructArray}; |
23 | | use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef}; |
24 | | use std::ops::Index; |
25 | | use std::sync::Arc; |
26 | | |
27 | | /// Trait for types that can read `RecordBatch`'s. |
28 | | /// |
29 | | /// To create from an iterator, see [RecordBatchIterator]. |
30 | | pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> { |
31 | | /// Returns the schema of this `RecordBatchReader`. |
32 | | /// |
33 | | /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this |
34 | | /// reader should have the same schema as returned from this method. |
35 | | fn schema(&self) -> SchemaRef; |
36 | | } |
37 | | |
38 | | impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> { |
39 | | fn schema(&self) -> SchemaRef { |
40 | | self.as_ref().schema() |
41 | | } |
42 | | } |
43 | | |
44 | | /// Trait for types that can write `RecordBatch`'s. |
45 | | pub trait RecordBatchWriter { |
46 | | /// Write a single batch to the writer. |
47 | | fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>; |
48 | | |
49 | | /// Write footer or termination data, then mark the writer as done. |
50 | | fn close(self) -> Result<(), ArrowError>; |
51 | | } |
52 | | |
53 | | /// Creates an array from a literal slice of values, |
54 | | /// suitable for rapid testing and development. |
55 | | /// |
56 | | /// Example: |
57 | | /// |
58 | | /// ```rust |
59 | | /// |
60 | | /// use arrow_array::create_array; |
61 | | /// |
62 | | /// let array = create_array!(Int32, [1, 2, 3, 4, 5]); |
63 | | /// let array = create_array!(Utf8, [Some("a"), Some("b"), None, Some("e")]); |
64 | | /// ``` |
65 | | /// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used. |
66 | | /// Presently supported data types are: |
67 | | /// - `Boolean`, `Null` |
68 | | /// - `Decimal32`, `Decimal64`, `Decimal128`, `Decimal256` |
69 | | /// - `Float16`, `Float32`, `Float64` |
70 | | /// - `Int8`, `Int16`, `Int32`, `Int64` |
71 | | /// - `UInt8`, `UInt16`, `UInt32`, `UInt64` |
72 | | /// - `IntervalDayTime`, `IntervalYearMonth` |
73 | | /// - `Second`, `Millisecond`, `Microsecond`, `Nanosecond` |
74 | | /// - `Second32`, `Millisecond32`, `Microsecond64`, `Nanosecond64` |
75 | | /// - `DurationSecond`, `DurationMillisecond`, `DurationMicrosecond`, `DurationNanosecond` |
76 | | /// - `TimestampSecond`, `TimestampMillisecond`, `TimestampMicrosecond`, `TimestampNanosecond` |
77 | | /// - `Utf8`, `Utf8View`, `LargeUtf8`, `Binary`, `LargeBinary` |
78 | | #[macro_export] |
79 | | macro_rules! create_array { |
80 | | // `@from` is used for those types that have a common method `<type>::from` |
81 | | (@from Boolean) => { $crate::BooleanArray }; |
82 | | (@from Int8) => { $crate::Int8Array }; |
83 | | (@from Int16) => { $crate::Int16Array }; |
84 | | (@from Int32) => { $crate::Int32Array }; |
85 | | (@from Int64) => { $crate::Int64Array }; |
86 | | (@from UInt8) => { $crate::UInt8Array }; |
87 | | (@from UInt16) => { $crate::UInt16Array }; |
88 | | (@from UInt32) => { $crate::UInt32Array }; |
89 | | (@from UInt64) => { $crate::UInt64Array }; |
90 | | (@from Float16) => { $crate::Float16Array }; |
91 | | (@from Float32) => { $crate::Float32Array }; |
92 | | (@from Float64) => { $crate::Float64Array }; |
93 | | (@from Utf8) => { $crate::StringArray }; |
94 | | (@from Utf8View) => { $crate::StringViewArray }; |
95 | | (@from LargeUtf8) => { $crate::LargeStringArray }; |
96 | | (@from IntervalDayTime) => { $crate::IntervalDayTimeArray }; |
97 | | (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray }; |
98 | | (@from Second) => { $crate::TimestampSecondArray }; |
99 | | (@from Millisecond) => { $crate::TimestampMillisecondArray }; |
100 | | (@from Microsecond) => { $crate::TimestampMicrosecondArray }; |
101 | | (@from Nanosecond) => { $crate::TimestampNanosecondArray }; |
102 | | (@from Second32) => { $crate::Time32SecondArray }; |
103 | | (@from Millisecond32) => { $crate::Time32MillisecondArray }; |
104 | | (@from Microsecond64) => { $crate::Time64MicrosecondArray }; |
105 | | (@from Nanosecond64) => { $crate::Time64Nanosecond64Array }; |
106 | | (@from DurationSecond) => { $crate::DurationSecondArray }; |
107 | | (@from DurationMillisecond) => { $crate::DurationMillisecondArray }; |
108 | | (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray }; |
109 | | (@from DurationNanosecond) => { $crate::DurationNanosecondArray }; |
110 | | (@from Decimal32) => { $crate::Decimal32Array }; |
111 | | (@from Decimal64) => { $crate::Decimal64Array }; |
112 | | (@from Decimal128) => { $crate::Decimal128Array }; |
113 | | (@from Decimal256) => { $crate::Decimal256Array }; |
114 | | (@from TimestampSecond) => { $crate::TimestampSecondArray }; |
115 | | (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray }; |
116 | | (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray }; |
117 | | (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray }; |
118 | | |
119 | | (@from $ty: ident) => { |
120 | | compile_error!(concat!("Unsupported data type: ", stringify!($ty))) |
121 | | }; |
122 | | |
123 | | (Null, $size: expr) => { |
124 | | std::sync::Arc::new($crate::NullArray::new($size)) |
125 | | }; |
126 | | |
127 | | (Binary, [$($values: expr),*]) => { |
128 | | std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*])) |
129 | | }; |
130 | | |
131 | | (LargeBinary, [$($values: expr),*]) => { |
132 | | std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*])) |
133 | | }; |
134 | | |
135 | | ($ty: tt, [$($values: expr),*]) => { |
136 | | std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*])) |
137 | | }; |
138 | | } |
139 | | |
140 | | /// Creates a record batch from literal slice of values, suitable for rapid |
141 | | /// testing and development. |
142 | | /// |
143 | | /// Example: |
144 | | /// |
145 | | /// ```rust |
146 | | /// use arrow_array::record_batch; |
147 | | /// use arrow_schema; |
148 | | /// |
149 | | /// let batch = record_batch!( |
150 | | /// ("a", Int32, [1, 2, 3]), |
151 | | /// ("b", Float64, [Some(4.0), None, Some(5.0)]), |
152 | | /// ("c", Utf8, ["alpha", "beta", "gamma"]) |
153 | | /// ); |
154 | | /// ``` |
155 | | /// Due to limitation of [`create_array!`] macro, support for limited data types is available. |
156 | | #[macro_export] |
157 | | macro_rules! record_batch { |
158 | | ($(($name: expr, $type: ident, [$($values: expr),*])),*) => { |
159 | | { |
160 | | let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ |
161 | | $( |
162 | | arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), |
163 | | )* |
164 | | ])); |
165 | | |
166 | | let batch = $crate::RecordBatch::try_new( |
167 | | schema, |
168 | | vec![$( |
169 | | $crate::create_array!($type, [$($values),*]), |
170 | | )*] |
171 | | ); |
172 | | |
173 | | batch |
174 | | } |
175 | | } |
176 | | } |
177 | | |
178 | | /// A two-dimensional batch of column-oriented data with a defined |
179 | | /// [schema](arrow_schema::Schema). |
180 | | /// |
181 | | /// A `RecordBatch` is a two-dimensional dataset of a number of |
182 | | /// contiguous arrays, each the same length. |
183 | | /// A record batch has a schema which must match its arrays’ |
184 | | /// datatypes. |
185 | | /// |
186 | | /// Record batches are a convenient unit of work for various |
187 | | /// serialization and computation functions, possibly incremental. |
188 | | /// |
189 | | /// Use the [`record_batch!`] macro to create a [`RecordBatch`] from |
190 | | /// literal slice of values, useful for rapid prototyping and testing. |
191 | | /// |
192 | | /// Example: |
193 | | /// ```rust |
194 | | /// use arrow_array::record_batch; |
195 | | /// let batch = record_batch!( |
196 | | /// ("a", Int32, [1, 2, 3]), |
197 | | /// ("b", Float64, [Some(4.0), None, Some(5.0)]), |
198 | | /// ("c", Utf8, ["alpha", "beta", "gamma"]) |
199 | | /// ); |
200 | | /// ``` |
201 | | #[derive(Clone, Debug, PartialEq)] |
202 | | pub struct RecordBatch { |
203 | | schema: SchemaRef, |
204 | | columns: Vec<Arc<dyn Array>>, |
205 | | |
206 | | /// The number of rows in this RecordBatch |
207 | | /// |
208 | | /// This is stored separately from the columns to handle the case of no columns |
209 | | row_count: usize, |
210 | | } |
211 | | |
212 | | impl RecordBatch { |
213 | | /// Creates a `RecordBatch` from a schema and columns. |
214 | | /// |
215 | | /// Expects the following: |
216 | | /// |
217 | | /// * `!columns.is_empty()` |
218 | | /// * `schema.fields.len() == columns.len()` |
219 | | /// * `schema.fields[i].data_type() == columns[i].data_type()` |
220 | | /// * `columns[i].len() == columns[j].len()` |
221 | | /// |
222 | | /// If the conditions are not met, an error is returned. |
223 | | /// |
224 | | /// # Example |
225 | | /// |
226 | | /// ``` |
227 | | /// # use std::sync::Arc; |
228 | | /// # use arrow_array::{Int32Array, RecordBatch}; |
229 | | /// # use arrow_schema::{DataType, Field, Schema}; |
230 | | /// |
231 | | /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
232 | | /// let schema = Schema::new(vec![ |
233 | | /// Field::new("id", DataType::Int32, false) |
234 | | /// ]); |
235 | | /// |
236 | | /// let batch = RecordBatch::try_new( |
237 | | /// Arc::new(schema), |
238 | | /// vec![Arc::new(id_array)] |
239 | | /// ).unwrap(); |
240 | | /// ``` |
241 | 248 | pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> { |
242 | 248 | let options = RecordBatchOptions::new(); |
243 | 248 | Self::try_new_impl(schema, columns, &options) |
244 | 248 | } |
245 | | |
246 | | /// Creates a `RecordBatch` from a schema and columns, without validation. |
247 | | /// |
248 | | /// See [`Self::try_new`] for the checked version. |
249 | | /// |
250 | | /// # Safety |
251 | | /// |
252 | | /// Expects the following: |
253 | | /// |
254 | | /// * `schema.fields.len() == columns.len()` |
255 | | /// * `schema.fields[i].data_type() == columns[i].data_type()` |
256 | | /// * `columns[i].len() == row_count` |
257 | | /// |
258 | | /// Note: if the schema does not match the underlying data exactly, it can lead to undefined |
259 | | /// behavior, for example, via conversion to a `StructArray`, which in turn could lead |
260 | | /// to incorrect access. |
261 | 0 | pub unsafe fn new_unchecked( |
262 | 0 | schema: SchemaRef, |
263 | 0 | columns: Vec<Arc<dyn Array>>, |
264 | 0 | row_count: usize, |
265 | 0 | ) -> Self { |
266 | 0 | Self { |
267 | 0 | schema, |
268 | 0 | columns, |
269 | 0 | row_count, |
270 | 0 | } |
271 | 0 | } |
272 | | |
273 | | /// Creates a `RecordBatch` from a schema and columns, with additional options, |
274 | | /// such as whether to strictly validate field names. |
275 | | /// |
276 | | /// See [`RecordBatch::try_new`] for the expected conditions. |
277 | 0 | pub fn try_new_with_options( |
278 | 0 | schema: SchemaRef, |
279 | 0 | columns: Vec<ArrayRef>, |
280 | 0 | options: &RecordBatchOptions, |
281 | 0 | ) -> Result<Self, ArrowError> { |
282 | 0 | Self::try_new_impl(schema, columns, options) |
283 | 0 | } |
284 | | |
285 | | /// Creates a new empty [`RecordBatch`]. |
286 | 0 | pub fn new_empty(schema: SchemaRef) -> Self { |
287 | 0 | let columns = schema |
288 | 0 | .fields() |
289 | 0 | .iter() |
290 | 0 | .map(|field| new_empty_array(field.data_type())) |
291 | 0 | .collect(); |
292 | | |
293 | 0 | RecordBatch { |
294 | 0 | schema, |
295 | 0 | columns, |
296 | 0 | row_count: 0, |
297 | 0 | } |
298 | 0 | } |
299 | | |
300 | | /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error |
301 | | /// if any validation check fails, otherwise returns the created [`Self`] |
302 | 248 | fn try_new_impl( |
303 | 248 | schema: SchemaRef, |
304 | 248 | columns: Vec<ArrayRef>, |
305 | 248 | options: &RecordBatchOptions, |
306 | 248 | ) -> Result<Self, ArrowError> { |
307 | | // check that number of fields in schema match column length |
308 | 248 | if schema.fields().len() != columns.len() { |
309 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
310 | 0 | "number of columns({}) must match number of fields({}) in schema", |
311 | 0 | columns.len(), |
312 | 0 | schema.fields().len(), |
313 | 0 | ))); |
314 | 248 | } |
315 | | |
316 | 248 | let row_count = options |
317 | 248 | .row_count |
318 | 248 | .or_else(|| columns.first().map(|col| col.len())) |
319 | 248 | .ok_or_else(|| {0 |
320 | 0 | ArrowError::InvalidArgumentError( |
321 | 0 | "must either specify a row count or at least one column".to_string(), |
322 | 0 | ) |
323 | 0 | })?; |
324 | | |
325 | 1.42k | for (c, f) in columns.iter()248 .zip248 (&schema.fields248 ) { |
326 | 1.42k | if !f.is_nullable() && c.null_count() > 073 { |
327 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
328 | 0 | "Column '{}' is declared as non-nullable but contains null values", |
329 | 0 | f.name() |
330 | 0 | ))); |
331 | 1.42k | } |
332 | | } |
333 | | |
334 | | // check that all columns have the same row count |
335 | 1.42k | if columns.iter()248 .any248 (|c| c.len() != row_count) { |
336 | 0 | let err = match options.row_count { |
337 | 0 | Some(_) => "all columns in a record batch must have the specified row count", |
338 | 0 | None => "all columns in a record batch must have the same length", |
339 | | }; |
340 | 0 | return Err(ArrowError::InvalidArgumentError(err.to_string())); |
341 | 248 | } |
342 | | |
343 | | // function for comparing column type and field type |
344 | | // return true if 2 types are not matched |
345 | 248 | let type_not_match = if options.match_field_names { |
346 | 1.42k | |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type |
347 | | } else { |
348 | 0 | |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { |
349 | 0 | !col_type.equals_datatype(field_type) |
350 | 0 | } |
351 | | }; |
352 | | |
353 | | // check that all columns match the schema |
354 | 248 | let not_match = columns |
355 | 248 | .iter() |
356 | 248 | .zip(schema.fields().iter()) |
357 | 1.42k | .map248 (|(col, field)| (col.data_type(), field.data_type())) |
358 | 248 | .enumerate() |
359 | 248 | .find(type_not_match); |
360 | | |
361 | 248 | if let Some((i0 , (col_type0 , field_type0 ))) = not_match { |
362 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
363 | 0 | "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}"))); |
364 | 248 | } |
365 | | |
366 | 248 | Ok(RecordBatch { |
367 | 248 | schema, |
368 | 248 | columns, |
369 | 248 | row_count, |
370 | 248 | }) |
371 | 248 | } |
372 | | |
373 | | /// Return the schema, columns and row count of this [`RecordBatch`] |
374 | 0 | pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) { |
375 | 0 | (self.schema, self.columns, self.row_count) |
376 | 0 | } |
377 | | |
378 | | /// Override the schema of this [`RecordBatch`] |
379 | | /// |
380 | | /// Returns an error if `schema` is not a superset of the current schema |
381 | | /// as determined by [`Schema::contains`] |
382 | | /// |
383 | | /// See also [`Self::schema_metadata_mut`]. |
384 | 0 | pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> { |
385 | 0 | if !schema.contains(self.schema.as_ref()) { |
386 | 0 | return Err(ArrowError::SchemaError(format!( |
387 | 0 | "target schema is not superset of current schema target={schema} current={}", |
388 | 0 | self.schema |
389 | 0 | ))); |
390 | 0 | } |
391 | | |
392 | 0 | Ok(Self { |
393 | 0 | schema, |
394 | 0 | columns: self.columns, |
395 | 0 | row_count: self.row_count, |
396 | 0 | }) |
397 | 0 | } |
398 | | |
399 | | /// Returns the [`Schema`] of the record batch. |
400 | 27 | pub fn schema(&self) -> SchemaRef { |
401 | 27 | self.schema.clone() |
402 | 27 | } |
403 | | |
404 | | /// Returns a reference to the [`Schema`] of the record batch. |
405 | 0 | pub fn schema_ref(&self) -> &SchemaRef { |
406 | 0 | &self.schema |
407 | 0 | } |
408 | | |
409 | | /// Mutable access to the metadata of the schema. |
410 | | /// |
411 | | /// This allows you to modify [`Schema::metadata`] of [`Self::schema`] in a convenient and fast way. |
412 | | /// |
413 | | /// Note this will clone the entire underlying `Schema` object if it is currently shared |
414 | | /// |
415 | | /// # Example |
416 | | /// ``` |
417 | | /// # use std::sync::Arc; |
418 | | /// # use arrow_array::{record_batch, RecordBatch}; |
419 | | /// let mut batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); |
420 | | /// // Initially, the metadata is empty |
421 | | /// assert!(batch.schema().metadata().get("key").is_none()); |
422 | | /// // Insert a key-value pair into the metadata |
423 | | /// batch.schema_metadata_mut().insert("key".into(), "value".into()); |
424 | | /// assert_eq!(batch.schema().metadata().get("key"), Some(&String::from("value"))); |
425 | | /// ``` |
426 | 0 | pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> { |
427 | 0 | let schema = Arc::make_mut(&mut self.schema); |
428 | 0 | &mut schema.metadata |
429 | 0 | } |
430 | | |
431 | | /// Projects the schema onto the specified columns |
432 | 0 | pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> { |
433 | 0 | let projected_schema = self.schema.project(indices)?; |
434 | 0 | let batch_fields = indices |
435 | 0 | .iter() |
436 | 0 | .map(|f| { |
437 | 0 | self.columns.get(*f).cloned().ok_or_else(|| { |
438 | 0 | ArrowError::SchemaError(format!( |
439 | 0 | "project index {} out of bounds, max field {}", |
440 | 0 | f, |
441 | 0 | self.columns.len() |
442 | 0 | )) |
443 | 0 | }) |
444 | 0 | }) |
445 | 0 | .collect::<Result<Vec<_>, _>>()?; |
446 | | |
447 | 0 | RecordBatch::try_new_with_options( |
448 | 0 | SchemaRef::new(projected_schema), |
449 | 0 | batch_fields, |
450 | 0 | &RecordBatchOptions { |
451 | 0 | match_field_names: true, |
452 | 0 | row_count: Some(self.row_count), |
453 | 0 | }, |
454 | | ) |
455 | 0 | } |
456 | | |
457 | | /// Normalize a semi-structured [`RecordBatch`] into a flat table. |
458 | | /// |
459 | | /// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level` |
460 | | /// (unlimited if `None`). |
461 | | /// |
462 | | /// e.g. given a [`RecordBatch`] with schema: |
463 | | /// |
464 | | /// ```text |
465 | | /// "foo": StructArray<"bar": Utf8> |
466 | | /// ``` |
467 | | /// |
468 | | /// A separator of `"."` would generate a batch with the schema: |
469 | | /// |
470 | | /// ```text |
471 | | /// "foo.bar": Utf8 |
472 | | /// ``` |
473 | | /// |
474 | | /// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`; |
475 | | /// it will be treated as unlimited. |
476 | | /// |
477 | | /// # Example |
478 | | /// |
479 | | /// ``` |
480 | | /// # use std::sync::Arc; |
481 | | /// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch}; |
482 | | /// # use arrow_schema::{DataType, Field, Fields, Schema}; |
483 | | /// # |
484 | | /// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""])); |
485 | | /// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)])); |
486 | | /// |
487 | | /// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); |
488 | | /// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); |
489 | | /// |
490 | | /// let a = Arc::new(StructArray::from(vec![ |
491 | | /// (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef), |
492 | | /// (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef), |
493 | | /// ])); |
494 | | /// |
495 | | /// let schema = Schema::new(vec![ |
496 | | /// Field::new( |
497 | | /// "a", |
498 | | /// DataType::Struct(Fields::from(vec![animals_field, n_legs_field])), |
499 | | /// false, |
500 | | /// ) |
501 | | /// ]); |
502 | | /// |
503 | | /// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a]) |
504 | | /// .expect("valid conversion") |
505 | | /// .normalize(".", None) |
506 | | /// .expect("valid normalization"); |
507 | | /// |
508 | | /// let expected = RecordBatch::try_from_iter_with_nullable(vec![ |
509 | | /// ("a.animals", animals.clone(), true), |
510 | | /// ("a.n_legs", n_legs.clone(), true), |
511 | | /// ]) |
512 | | /// .expect("valid conversion"); |
513 | | /// |
514 | | /// assert_eq!(expected, normalized); |
515 | | /// ``` |
516 | 0 | pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> { |
517 | 0 | let max_level = match max_level.unwrap_or(usize::MAX) { |
518 | 0 | 0 => usize::MAX, |
519 | 0 | val => val, |
520 | | }; |
521 | 0 | let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self |
522 | 0 | .columns |
523 | 0 | .iter() |
524 | 0 | .zip(self.schema.fields()) |
525 | 0 | .rev() |
526 | 0 | .map(|(c, f)| { |
527 | 0 | let name_vec: Vec<&str> = vec![f.name()]; |
528 | 0 | (0, c, name_vec, f) |
529 | 0 | }) |
530 | 0 | .collect(); |
531 | 0 | let mut columns: Vec<ArrayRef> = Vec::new(); |
532 | 0 | let mut fields: Vec<FieldRef> = Vec::new(); |
533 | | |
534 | 0 | while let Some((depth, c, name, field_ref)) = stack.pop() { |
535 | 0 | match field_ref.data_type() { |
536 | 0 | DataType::Struct(ff) if depth < max_level => { |
537 | | // Need to zip these in reverse to maintain original order |
538 | 0 | for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() { |
539 | 0 | let mut name = name.clone(); |
540 | 0 | name.push(separator); |
541 | 0 | name.push(fff.name()); |
542 | 0 | stack.push((depth + 1, cff, name, fff)) |
543 | | } |
544 | | } |
545 | 0 | _ => { |
546 | 0 | let updated_field = Field::new( |
547 | 0 | name.concat(), |
548 | 0 | field_ref.data_type().clone(), |
549 | 0 | field_ref.is_nullable(), |
550 | 0 | ); |
551 | 0 | columns.push(c.clone()); |
552 | 0 | fields.push(Arc::new(updated_field)); |
553 | 0 | } |
554 | | } |
555 | | } |
556 | 0 | RecordBatch::try_new(Arc::new(Schema::new(fields)), columns) |
557 | 0 | } |
558 | | |
559 | | /// Returns the number of columns in the record batch. |
560 | | /// |
561 | | /// # Example |
562 | | /// |
563 | | /// ``` |
564 | | /// # use std::sync::Arc; |
565 | | /// # use arrow_array::{Int32Array, RecordBatch}; |
566 | | /// # use arrow_schema::{DataType, Field, Schema}; |
567 | | /// |
568 | | /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
569 | | /// let schema = Schema::new(vec![ |
570 | | /// Field::new("id", DataType::Int32, false) |
571 | | /// ]); |
572 | | /// |
573 | | /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap(); |
574 | | /// |
575 | | /// assert_eq!(batch.num_columns(), 1); |
576 | | /// ``` |
577 | 4 | pub fn num_columns(&self) -> usize { |
578 | 4 | self.columns.len() |
579 | 4 | } |
580 | | |
581 | | /// Returns the number of rows in each column. |
582 | | /// |
583 | | /// # Example |
584 | | /// |
585 | | /// ``` |
586 | | /// # use std::sync::Arc; |
587 | | /// # use arrow_array::{Int32Array, RecordBatch}; |
588 | | /// # use arrow_schema::{DataType, Field, Schema}; |
589 | | /// |
590 | | /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
591 | | /// let schema = Schema::new(vec![ |
592 | | /// Field::new("id", DataType::Int32, false) |
593 | | /// ]); |
594 | | /// |
595 | | /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap(); |
596 | | /// |
597 | | /// assert_eq!(batch.num_rows(), 5); |
598 | | /// ``` |
599 | 28 | pub fn num_rows(&self) -> usize { |
600 | 28 | self.row_count |
601 | 28 | } |
602 | | |
603 | | /// Get a reference to a column's array by index. |
604 | | /// |
605 | | /// # Panics |
606 | | /// |
607 | | /// Panics if `index` is outside of `0..num_columns`. |
608 | 692 | pub fn column(&self, index: usize) -> &ArrayRef { |
609 | 692 | &self.columns[index] |
610 | 692 | } |
611 | | |
612 | | /// Get a reference to a column's array by name. |
613 | 0 | pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> { |
614 | 0 | self.schema() |
615 | 0 | .column_with_name(name) |
616 | 0 | .map(|(index, _)| &self.columns[index]) |
617 | 0 | } |
618 | | |
619 | | /// Get a reference to all columns in the record batch. |
620 | 8 | pub fn columns(&self) -> &[ArrayRef] { |
621 | 8 | &self.columns[..] |
622 | 8 | } |
623 | | |
624 | | /// Remove column by index and return it. |
625 | | /// |
626 | | /// Return the `ArrayRef` if the column is removed. |
627 | | /// |
628 | | /// # Panics |
629 | | /// |
630 | | /// Panics if `index`` out of bounds. |
631 | | /// |
632 | | /// # Example |
633 | | /// |
634 | | /// ``` |
635 | | /// use std::sync::Arc; |
636 | | /// use arrow_array::{BooleanArray, Int32Array, RecordBatch}; |
637 | | /// use arrow_schema::{DataType, Field, Schema}; |
638 | | /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
639 | | /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]); |
640 | | /// let schema = Schema::new(vec![ |
641 | | /// Field::new("id", DataType::Int32, false), |
642 | | /// Field::new("bool", DataType::Boolean, false), |
643 | | /// ]); |
644 | | /// |
645 | | /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap(); |
646 | | /// |
647 | | /// let removed_column = batch.remove_column(0); |
648 | | /// assert_eq!(removed_column.as_any().downcast_ref::<Int32Array>().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5])); |
649 | | /// assert_eq!(batch.num_columns(), 1); |
650 | | /// ``` |
651 | 0 | pub fn remove_column(&mut self, index: usize) -> ArrayRef { |
652 | 0 | let mut builder = SchemaBuilder::from(self.schema.as_ref()); |
653 | 0 | builder.remove(index); |
654 | 0 | self.schema = Arc::new(builder.finish()); |
655 | 0 | self.columns.remove(index) |
656 | 0 | } |
657 | | |
658 | | /// Return a new RecordBatch where each column is sliced |
659 | | /// according to `offset` and `length` |
660 | | /// |
661 | | /// # Panics |
662 | | /// |
663 | | /// Panics if `offset` with `length` is greater than column length. |
664 | 0 | pub fn slice(&self, offset: usize, length: usize) -> RecordBatch { |
665 | 0 | assert!((offset + length) <= self.num_rows()); |
666 | | |
667 | 0 | let columns = self |
668 | 0 | .columns() |
669 | 0 | .iter() |
670 | 0 | .map(|column| column.slice(offset, length)) |
671 | 0 | .collect(); |
672 | | |
673 | 0 | Self { |
674 | 0 | schema: self.schema.clone(), |
675 | 0 | columns, |
676 | 0 | row_count: length, |
677 | 0 | } |
678 | 0 | } |
679 | | |
680 | | /// Create a `RecordBatch` from an iterable list of pairs of the |
681 | | /// form `(field_name, array)`, with the same requirements on |
682 | | /// fields and arrays as [`RecordBatch::try_new`]. This method is |
683 | | /// often used to create a single `RecordBatch` from arrays, |
684 | | /// e.g. for testing. |
685 | | /// |
686 | | /// The resulting schema is marked as nullable for each column if |
687 | | /// the array for that column is has any nulls. To explicitly |
688 | | /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`] |
689 | | /// |
690 | | /// Example: |
691 | | /// ``` |
692 | | /// # use std::sync::Arc; |
693 | | /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; |
694 | | /// |
695 | | /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
696 | | /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); |
697 | | /// |
698 | | /// let record_batch = RecordBatch::try_from_iter(vec![ |
699 | | /// ("a", a), |
700 | | /// ("b", b), |
701 | | /// ]); |
702 | | /// ``` |
703 | | /// Another way to quickly create a [`RecordBatch`] is to use the [`record_batch!`] macro, |
704 | | /// which is particularly helpful for rapid prototyping and testing. |
705 | | /// |
706 | | /// Example: |
707 | | /// |
708 | | /// ```rust |
709 | | /// use arrow_array::record_batch; |
710 | | /// let batch = record_batch!( |
711 | | /// ("a", Int32, [1, 2, 3]), |
712 | | /// ("b", Float64, [Some(4.0), None, Some(5.0)]), |
713 | | /// ("c", Utf8, ["alpha", "beta", "gamma"]) |
714 | | /// ); |
715 | | /// ``` |
716 | 1 | pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError> |
717 | 1 | where |
718 | 1 | I: IntoIterator<Item = (F, ArrayRef)>, |
719 | 1 | F: AsRef<str>, |
720 | | { |
721 | | // TODO: implement `TryFrom` trait, once |
722 | | // https://github.com/rust-lang/rust/issues/50133 is no longer an |
723 | | // issue |
724 | 1 | let iter = value.into_iter().map(|(field_name, array)| { |
725 | 1 | let nullable = array.null_count() > 0; |
726 | 1 | (field_name, array, nullable) |
727 | 1 | }); |
728 | | |
729 | 1 | Self::try_from_iter_with_nullable(iter) |
730 | 1 | } |
731 | | |
732 | | /// Create a `RecordBatch` from an iterable list of tuples of the |
733 | | /// form `(field_name, array, nullable)`, with the same requirements on |
734 | | /// fields and arrays as [`RecordBatch::try_new`]. This method is often |
735 | | /// used to create a single `RecordBatch` from arrays, e.g. for |
736 | | /// testing. |
737 | | /// |
738 | | /// Example: |
739 | | /// ``` |
740 | | /// # use std::sync::Arc; |
741 | | /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; |
742 | | /// |
743 | | /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
744 | | /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")])); |
745 | | /// |
746 | | /// // Note neither `a` nor `b` has any actual nulls, but we mark |
747 | | /// // b an nullable |
748 | | /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![ |
749 | | /// ("a", a, false), |
750 | | /// ("b", b, true), |
751 | | /// ]); |
752 | | /// ``` |
753 | 29 | pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError> |
754 | 29 | where |
755 | 29 | I: IntoIterator<Item = (F, ArrayRef, bool)>, |
756 | 29 | F: AsRef<str>, |
757 | | { |
758 | 29 | let iter = value.into_iter(); |
759 | 29 | let capacity = iter.size_hint().0; |
760 | 29 | let mut schema = SchemaBuilder::with_capacity(capacity); |
761 | 29 | let mut columns = Vec::with_capacity(capacity); |
762 | | |
763 | 250 | for (field_name221 , array221 , nullable221 ) in iter { |
764 | 221 | let field_name = field_name.as_ref(); |
765 | 221 | schema.push(Field::new(field_name, array.data_type().clone(), nullable)); |
766 | 221 | columns.push(array); |
767 | 221 | } |
768 | | |
769 | 29 | let schema = Arc::new(schema.finish()); |
770 | 29 | RecordBatch::try_new(schema, columns) |
771 | 29 | } |
772 | | |
773 | | /// Returns the total number of bytes of memory occupied physically by this batch. |
774 | | /// |
775 | | /// Note that this does not always correspond to the exact memory usage of a |
776 | | /// `RecordBatch` (might overestimate), since multiple columns can share the same |
777 | | /// buffers or slices thereof, the memory used by the shared buffers might be |
778 | | /// counted multiple times. |
779 | 0 | pub fn get_array_memory_size(&self) -> usize { |
780 | 0 | self.columns() |
781 | 0 | .iter() |
782 | 0 | .map(|array| array.get_array_memory_size()) |
783 | 0 | .sum() |
784 | 0 | } |
785 | | } |
786 | | |
787 | | /// Options that control the behaviour used when creating a [`RecordBatch`]. |
788 | | #[derive(Debug)] |
789 | | #[non_exhaustive] |
790 | | pub struct RecordBatchOptions { |
791 | | /// Match field names of structs and lists. If set to `true`, the names must match. |
792 | | pub match_field_names: bool, |
793 | | |
794 | | /// Optional row count, useful for specifying a row count for a RecordBatch with no columns |
795 | | pub row_count: Option<usize>, |
796 | | } |
797 | | |
798 | | impl RecordBatchOptions { |
799 | | /// Creates a new `RecordBatchOptions` |
800 | 248 | pub fn new() -> Self { |
801 | 248 | Self { |
802 | 248 | match_field_names: true, |
803 | 248 | row_count: None, |
804 | 248 | } |
805 | 248 | } |
806 | | /// Sets the row_count of RecordBatchOptions and returns self |
807 | 0 | pub fn with_row_count(mut self, row_count: Option<usize>) -> Self { |
808 | 0 | self.row_count = row_count; |
809 | 0 | self |
810 | 0 | } |
811 | | /// Sets the match_field_names of RecordBatchOptions and returns self |
812 | 0 | pub fn with_match_field_names(mut self, match_field_names: bool) -> Self { |
813 | 0 | self.match_field_names = match_field_names; |
814 | 0 | self |
815 | 0 | } |
816 | | } |
817 | | impl Default for RecordBatchOptions { |
818 | 0 | fn default() -> Self { |
819 | 0 | Self::new() |
820 | 0 | } |
821 | | } |
822 | | impl From<StructArray> for RecordBatch { |
823 | 0 | fn from(value: StructArray) -> Self { |
824 | 0 | let row_count = value.len(); |
825 | 0 | let (fields, columns, nulls) = value.into_parts(); |
826 | 0 | assert_eq!( |
827 | 0 | nulls.map(|n| n.null_count()).unwrap_or_default(), |
828 | | 0, |
829 | 0 | "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" |
830 | | ); |
831 | | |
832 | 0 | RecordBatch { |
833 | 0 | schema: Arc::new(Schema::new(fields)), |
834 | 0 | row_count, |
835 | 0 | columns, |
836 | 0 | } |
837 | 0 | } |
838 | | } |
839 | | |
840 | | impl From<&StructArray> for RecordBatch { |
841 | 0 | fn from(struct_array: &StructArray) -> Self { |
842 | 0 | struct_array.clone().into() |
843 | 0 | } |
844 | | } |
845 | | |
846 | | impl Index<&str> for RecordBatch { |
847 | | type Output = ArrayRef; |
848 | | |
849 | | /// Get a reference to a column's array by name. |
850 | | /// |
851 | | /// # Panics |
852 | | /// |
853 | | /// Panics if the name is not in the schema. |
854 | 0 | fn index(&self, name: &str) -> &Self::Output { |
855 | 0 | self.column_by_name(name).unwrap() |
856 | 0 | } |
857 | | } |
858 | | |
859 | | /// Generic implementation of [RecordBatchReader] that wraps an iterator. |
860 | | /// |
861 | | /// # Example |
862 | | /// |
863 | | /// ``` |
864 | | /// # use std::sync::Arc; |
865 | | /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader}; |
866 | | /// # |
867 | | /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
868 | | /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); |
869 | | /// |
870 | | /// let record_batch = RecordBatch::try_from_iter(vec![ |
871 | | /// ("a", a), |
872 | | /// ("b", b), |
873 | | /// ]).unwrap(); |
874 | | /// |
875 | | /// let batches: Vec<RecordBatch> = vec![record_batch.clone(), record_batch.clone()]; |
876 | | /// |
877 | | /// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema()); |
878 | | /// |
879 | | /// assert_eq!(reader.schema(), record_batch.schema()); |
880 | | /// assert_eq!(reader.next().unwrap().unwrap(), record_batch); |
881 | | /// # assert_eq!(reader.next().unwrap().unwrap(), record_batch); |
882 | | /// # assert!(reader.next().is_none()); |
883 | | /// ``` |
884 | | pub struct RecordBatchIterator<I> |
885 | | where |
886 | | I: IntoIterator<Item = Result<RecordBatch, ArrowError>>, |
887 | | { |
888 | | inner: I::IntoIter, |
889 | | inner_schema: SchemaRef, |
890 | | } |
891 | | |
892 | | impl<I> RecordBatchIterator<I> |
893 | | where |
894 | | I: IntoIterator<Item = Result<RecordBatch, ArrowError>>, |
895 | | { |
896 | | /// Create a new [RecordBatchIterator]. |
897 | | /// |
898 | | /// If `iter` is an infallible iterator, use `.map(Ok)`. |
899 | | pub fn new(iter: I, schema: SchemaRef) -> Self { |
900 | | Self { |
901 | | inner: iter.into_iter(), |
902 | | inner_schema: schema, |
903 | | } |
904 | | } |
905 | | } |
906 | | |
907 | | impl<I> Iterator for RecordBatchIterator<I> |
908 | | where |
909 | | I: IntoIterator<Item = Result<RecordBatch, ArrowError>>, |
910 | | { |
911 | | type Item = I::Item; |
912 | | |
913 | | fn next(&mut self) -> Option<Self::Item> { |
914 | | self.inner.next() |
915 | | } |
916 | | |
917 | | fn size_hint(&self) -> (usize, Option<usize>) { |
918 | | self.inner.size_hint() |
919 | | } |
920 | | } |
921 | | |
922 | | impl<I> RecordBatchReader for RecordBatchIterator<I> |
923 | | where |
924 | | I: IntoIterator<Item = Result<RecordBatch, ArrowError>>, |
925 | | { |
926 | | fn schema(&self) -> SchemaRef { |
927 | | self.inner_schema.clone() |
928 | | } |
929 | | } |
930 | | |
931 | | #[cfg(test)] |
932 | | mod tests { |
933 | | use super::*; |
934 | | use crate::{ |
935 | | BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray, |
936 | | }; |
937 | | use arrow_buffer::{Buffer, ToByteSlice}; |
938 | | use arrow_data::{ArrayData, ArrayDataBuilder}; |
939 | | use arrow_schema::Fields; |
940 | | use std::collections::HashMap; |
941 | | |
942 | | #[test] |
943 | | fn create_record_batch() { |
944 | | let schema = Schema::new(vec![ |
945 | | Field::new("a", DataType::Int32, false), |
946 | | Field::new("b", DataType::Utf8, false), |
947 | | ]); |
948 | | |
949 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
950 | | let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); |
951 | | |
952 | | let record_batch = |
953 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); |
954 | | check_batch(record_batch, 5) |
955 | | } |
956 | | |
957 | | #[test] |
958 | | fn create_string_view_record_batch() { |
959 | | let schema = Schema::new(vec![ |
960 | | Field::new("a", DataType::Int32, false), |
961 | | Field::new("b", DataType::Utf8View, false), |
962 | | ]); |
963 | | |
964 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
965 | | let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]); |
966 | | |
967 | | let record_batch = |
968 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); |
969 | | |
970 | | assert_eq!(5, record_batch.num_rows()); |
971 | | assert_eq!(2, record_batch.num_columns()); |
972 | | assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); |
973 | | assert_eq!( |
974 | | &DataType::Utf8View, |
975 | | record_batch.schema().field(1).data_type() |
976 | | ); |
977 | | assert_eq!(5, record_batch.column(0).len()); |
978 | | assert_eq!(5, record_batch.column(1).len()); |
979 | | } |
980 | | |
981 | | #[test] |
982 | | fn byte_size_should_not_regress() { |
983 | | let schema = Schema::new(vec![ |
984 | | Field::new("a", DataType::Int32, false), |
985 | | Field::new("b", DataType::Utf8, false), |
986 | | ]); |
987 | | |
988 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
989 | | let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); |
990 | | |
991 | | let record_batch = |
992 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); |
993 | | assert_eq!(record_batch.get_array_memory_size(), 364); |
994 | | } |
995 | | |
996 | | fn check_batch(record_batch: RecordBatch, num_rows: usize) { |
997 | | assert_eq!(num_rows, record_batch.num_rows()); |
998 | | assert_eq!(2, record_batch.num_columns()); |
999 | | assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); |
1000 | | assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type()); |
1001 | | assert_eq!(num_rows, record_batch.column(0).len()); |
1002 | | assert_eq!(num_rows, record_batch.column(1).len()); |
1003 | | } |
1004 | | |
1005 | | #[test] |
1006 | | #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")] |
1007 | | fn create_record_batch_slice() { |
1008 | | let schema = Schema::new(vec![ |
1009 | | Field::new("a", DataType::Int32, false), |
1010 | | Field::new("b", DataType::Utf8, false), |
1011 | | ]); |
1012 | | let expected_schema = schema.clone(); |
1013 | | |
1014 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]); |
1015 | | let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]); |
1016 | | |
1017 | | let record_batch = |
1018 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); |
1019 | | |
1020 | | let offset = 2; |
1021 | | let length = 5; |
1022 | | let record_batch_slice = record_batch.slice(offset, length); |
1023 | | |
1024 | | assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema); |
1025 | | check_batch(record_batch_slice, 5); |
1026 | | |
1027 | | let offset = 2; |
1028 | | let length = 0; |
1029 | | let record_batch_slice = record_batch.slice(offset, length); |
1030 | | |
1031 | | assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema); |
1032 | | check_batch(record_batch_slice, 0); |
1033 | | |
1034 | | let offset = 2; |
1035 | | let length = 10; |
1036 | | let _record_batch_slice = record_batch.slice(offset, length); |
1037 | | } |
1038 | | |
1039 | | #[test] |
1040 | | #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")] |
1041 | | fn create_record_batch_slice_empty_batch() { |
1042 | | let schema = Schema::empty(); |
1043 | | |
1044 | | let record_batch = RecordBatch::new_empty(Arc::new(schema)); |
1045 | | |
1046 | | let offset = 0; |
1047 | | let length = 0; |
1048 | | let record_batch_slice = record_batch.slice(offset, length); |
1049 | | assert_eq!(0, record_batch_slice.schema().fields().len()); |
1050 | | |
1051 | | let offset = 1; |
1052 | | let length = 2; |
1053 | | let _record_batch_slice = record_batch.slice(offset, length); |
1054 | | } |
1055 | | |
1056 | | #[test] |
1057 | | fn create_record_batch_try_from_iter() { |
1058 | | let a: ArrayRef = Arc::new(Int32Array::from(vec![ |
1059 | | Some(1), |
1060 | | Some(2), |
1061 | | None, |
1062 | | Some(4), |
1063 | | Some(5), |
1064 | | ])); |
1065 | | let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); |
1066 | | |
1067 | | let record_batch = |
1068 | | RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion"); |
1069 | | |
1070 | | let expected_schema = Schema::new(vec![ |
1071 | | Field::new("a", DataType::Int32, true), |
1072 | | Field::new("b", DataType::Utf8, false), |
1073 | | ]); |
1074 | | assert_eq!(record_batch.schema().as_ref(), &expected_schema); |
1075 | | check_batch(record_batch, 5); |
1076 | | } |
1077 | | |
1078 | | #[test] |
1079 | | fn create_record_batch_try_from_iter_with_nullable() { |
1080 | | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); |
1081 | | let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); |
1082 | | |
1083 | | // Note there are no nulls in a or b, but we specify that b is nullable |
1084 | | let record_batch = |
1085 | | RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)]) |
1086 | | .expect("valid conversion"); |
1087 | | |
1088 | | let expected_schema = Schema::new(vec![ |
1089 | | Field::new("a", DataType::Int32, false), |
1090 | | Field::new("b", DataType::Utf8, true), |
1091 | | ]); |
1092 | | assert_eq!(record_batch.schema().as_ref(), &expected_schema); |
1093 | | check_batch(record_batch, 5); |
1094 | | } |
1095 | | |
1096 | | #[test] |
1097 | | fn create_record_batch_schema_mismatch() { |
1098 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); |
1099 | | |
1100 | | let a = Int64Array::from(vec![1, 2, 3, 4, 5]); |
1101 | | |
1102 | | let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err(); |
1103 | | assert_eq!(err.to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0"); |
1104 | | } |
1105 | | |
1106 | | #[test] |
1107 | | fn create_record_batch_field_name_mismatch() { |
1108 | | let fields = vec![ |
1109 | | Field::new("a1", DataType::Int32, false), |
1110 | | Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false), |
1111 | | ]; |
1112 | | let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)])); |
1113 | | |
1114 | | let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
1115 | | let a2_child = Int8Array::from(vec![1, 2, 3, 4]); |
1116 | | let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( |
1117 | | "array", |
1118 | | DataType::Int8, |
1119 | | false, |
1120 | | )))) |
1121 | | .add_child_data(a2_child.into_data()) |
1122 | | .len(2) |
1123 | | .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice())) |
1124 | | .build() |
1125 | | .unwrap(); |
1126 | | let a2: ArrayRef = Arc::new(ListArray::from(a2)); |
1127 | | let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![ |
1128 | | Field::new("aa1", DataType::Int32, false), |
1129 | | Field::new("a2", a2.data_type().clone(), false), |
1130 | | ]))) |
1131 | | .add_child_data(a1.into_data()) |
1132 | | .add_child_data(a2.into_data()) |
1133 | | .len(2) |
1134 | | .build() |
1135 | | .unwrap(); |
1136 | | let a: ArrayRef = Arc::new(StructArray::from(a)); |
1137 | | |
1138 | | // creating the batch with field name validation should fail |
1139 | | let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]); |
1140 | | assert!(batch.is_err()); |
1141 | | |
1142 | | // creating the batch without field name validation should pass |
1143 | | let options = RecordBatchOptions { |
1144 | | match_field_names: false, |
1145 | | row_count: None, |
1146 | | }; |
1147 | | let batch = RecordBatch::try_new_with_options(schema, vec![a], &options); |
1148 | | assert!(batch.is_ok()); |
1149 | | } |
1150 | | |
1151 | | #[test] |
1152 | | fn create_record_batch_record_mismatch() { |
1153 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); |
1154 | | |
1155 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1156 | | let b = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1157 | | |
1158 | | let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); |
1159 | | assert!(batch.is_err()); |
1160 | | } |
1161 | | |
1162 | | #[test] |
1163 | | fn create_record_batch_from_struct_array() { |
1164 | | let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); |
1165 | | let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); |
1166 | | let struct_array = StructArray::from(vec![ |
1167 | | ( |
1168 | | Arc::new(Field::new("b", DataType::Boolean, false)), |
1169 | | boolean.clone() as ArrayRef, |
1170 | | ), |
1171 | | ( |
1172 | | Arc::new(Field::new("c", DataType::Int32, false)), |
1173 | | int.clone() as ArrayRef, |
1174 | | ), |
1175 | | ]); |
1176 | | |
1177 | | let batch = RecordBatch::from(&struct_array); |
1178 | | assert_eq!(2, batch.num_columns()); |
1179 | | assert_eq!(4, batch.num_rows()); |
1180 | | assert_eq!( |
1181 | | struct_array.data_type(), |
1182 | | &DataType::Struct(batch.schema().fields().clone()) |
1183 | | ); |
1184 | | assert_eq!(batch.column(0).as_ref(), boolean.as_ref()); |
1185 | | assert_eq!(batch.column(1).as_ref(), int.as_ref()); |
1186 | | } |
1187 | | |
1188 | | #[test] |
1189 | | fn record_batch_equality() { |
1190 | | let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); |
1191 | | let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]); |
1192 | | let schema1 = Schema::new(vec![ |
1193 | | Field::new("id", DataType::Int32, false), |
1194 | | Field::new("val", DataType::Int32, false), |
1195 | | ]); |
1196 | | |
1197 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1198 | | let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1199 | | let schema2 = Schema::new(vec![ |
1200 | | Field::new("id", DataType::Int32, false), |
1201 | | Field::new("val", DataType::Int32, false), |
1202 | | ]); |
1203 | | |
1204 | | let batch1 = RecordBatch::try_new( |
1205 | | Arc::new(schema1), |
1206 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1207 | | ) |
1208 | | .unwrap(); |
1209 | | |
1210 | | let batch2 = RecordBatch::try_new( |
1211 | | Arc::new(schema2), |
1212 | | vec![Arc::new(id_arr2), Arc::new(val_arr2)], |
1213 | | ) |
1214 | | .unwrap(); |
1215 | | |
1216 | | assert_eq!(batch1, batch2); |
1217 | | } |
1218 | | |
1219 | | /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]` |
1220 | | #[test] |
1221 | | fn record_batch_index_access() { |
1222 | | let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); |
1223 | | let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); |
1224 | | let schema1 = Schema::new(vec![ |
1225 | | Field::new("id", DataType::Int32, false), |
1226 | | Field::new("val", DataType::Int32, false), |
1227 | | ]); |
1228 | | let record_batch = |
1229 | | RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap(); |
1230 | | |
1231 | | assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref()); |
1232 | | assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref()); |
1233 | | } |
1234 | | |
1235 | | #[test] |
1236 | | fn record_batch_vals_ne() { |
1237 | | let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); |
1238 | | let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]); |
1239 | | let schema1 = Schema::new(vec![ |
1240 | | Field::new("id", DataType::Int32, false), |
1241 | | Field::new("val", DataType::Int32, false), |
1242 | | ]); |
1243 | | |
1244 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1245 | | let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1246 | | let schema2 = Schema::new(vec![ |
1247 | | Field::new("id", DataType::Int32, false), |
1248 | | Field::new("val", DataType::Int32, false), |
1249 | | ]); |
1250 | | |
1251 | | let batch1 = RecordBatch::try_new( |
1252 | | Arc::new(schema1), |
1253 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1254 | | ) |
1255 | | .unwrap(); |
1256 | | |
1257 | | let batch2 = RecordBatch::try_new( |
1258 | | Arc::new(schema2), |
1259 | | vec![Arc::new(id_arr2), Arc::new(val_arr2)], |
1260 | | ) |
1261 | | .unwrap(); |
1262 | | |
1263 | | assert_ne!(batch1, batch2); |
1264 | | } |
1265 | | |
1266 | | #[test] |
1267 | | fn record_batch_column_names_ne() { |
1268 | | let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); |
1269 | | let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]); |
1270 | | let schema1 = Schema::new(vec![ |
1271 | | Field::new("id", DataType::Int32, false), |
1272 | | Field::new("val", DataType::Int32, false), |
1273 | | ]); |
1274 | | |
1275 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1276 | | let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1277 | | let schema2 = Schema::new(vec![ |
1278 | | Field::new("id", DataType::Int32, false), |
1279 | | Field::new("num", DataType::Int32, false), |
1280 | | ]); |
1281 | | |
1282 | | let batch1 = RecordBatch::try_new( |
1283 | | Arc::new(schema1), |
1284 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1285 | | ) |
1286 | | .unwrap(); |
1287 | | |
1288 | | let batch2 = RecordBatch::try_new( |
1289 | | Arc::new(schema2), |
1290 | | vec![Arc::new(id_arr2), Arc::new(val_arr2)], |
1291 | | ) |
1292 | | .unwrap(); |
1293 | | |
1294 | | assert_ne!(batch1, batch2); |
1295 | | } |
1296 | | |
1297 | | #[test] |
1298 | | fn record_batch_column_number_ne() { |
1299 | | let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); |
1300 | | let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]); |
1301 | | let schema1 = Schema::new(vec![ |
1302 | | Field::new("id", DataType::Int32, false), |
1303 | | Field::new("val", DataType::Int32, false), |
1304 | | ]); |
1305 | | |
1306 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1307 | | let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1308 | | let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1309 | | let schema2 = Schema::new(vec![ |
1310 | | Field::new("id", DataType::Int32, false), |
1311 | | Field::new("val", DataType::Int32, false), |
1312 | | Field::new("num", DataType::Int32, false), |
1313 | | ]); |
1314 | | |
1315 | | let batch1 = RecordBatch::try_new( |
1316 | | Arc::new(schema1), |
1317 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1318 | | ) |
1319 | | .unwrap(); |
1320 | | |
1321 | | let batch2 = RecordBatch::try_new( |
1322 | | Arc::new(schema2), |
1323 | | vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)], |
1324 | | ) |
1325 | | .unwrap(); |
1326 | | |
1327 | | assert_ne!(batch1, batch2); |
1328 | | } |
1329 | | |
1330 | | #[test] |
1331 | | fn record_batch_row_count_ne() { |
1332 | | let id_arr1 = Int32Array::from(vec![1, 2, 3]); |
1333 | | let val_arr1 = Int32Array::from(vec![5, 6, 7]); |
1334 | | let schema1 = Schema::new(vec![ |
1335 | | Field::new("id", DataType::Int32, false), |
1336 | | Field::new("val", DataType::Int32, false), |
1337 | | ]); |
1338 | | |
1339 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1340 | | let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1341 | | let schema2 = Schema::new(vec![ |
1342 | | Field::new("id", DataType::Int32, false), |
1343 | | Field::new("num", DataType::Int32, false), |
1344 | | ]); |
1345 | | |
1346 | | let batch1 = RecordBatch::try_new( |
1347 | | Arc::new(schema1), |
1348 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1349 | | ) |
1350 | | .unwrap(); |
1351 | | |
1352 | | let batch2 = RecordBatch::try_new( |
1353 | | Arc::new(schema2), |
1354 | | vec![Arc::new(id_arr2), Arc::new(val_arr2)], |
1355 | | ) |
1356 | | .unwrap(); |
1357 | | |
1358 | | assert_ne!(batch1, batch2); |
1359 | | } |
1360 | | |
1361 | | #[test] |
1362 | | fn normalize_simple() { |
1363 | | let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""])); |
1364 | | let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)])); |
1365 | | let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)])); |
1366 | | |
1367 | | let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); |
1368 | | let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); |
1369 | | let year_field = Arc::new(Field::new("year", DataType::Int64, true)); |
1370 | | |
1371 | | let a = Arc::new(StructArray::from(vec![ |
1372 | | (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef), |
1373 | | (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef), |
1374 | | (year_field.clone(), Arc::new(year.clone()) as ArrayRef), |
1375 | | ])); |
1376 | | |
1377 | | let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)])); |
1378 | | |
1379 | | let schema = Schema::new(vec![ |
1380 | | Field::new( |
1381 | | "a", |
1382 | | DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])), |
1383 | | false, |
1384 | | ), |
1385 | | Field::new("month", DataType::Int64, true), |
1386 | | ]); |
1387 | | |
1388 | | let normalized = |
1389 | | RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()]) |
1390 | | .expect("valid conversion") |
1391 | | .normalize(".", Some(0)) |
1392 | | .expect("valid normalization"); |
1393 | | |
1394 | | let expected = RecordBatch::try_from_iter_with_nullable(vec![ |
1395 | | ("a.animals", animals.clone(), true), |
1396 | | ("a.n_legs", n_legs.clone(), true), |
1397 | | ("a.year", year.clone(), true), |
1398 | | ("month", month.clone(), true), |
1399 | | ]) |
1400 | | .expect("valid conversion"); |
1401 | | |
1402 | | assert_eq!(expected, normalized); |
1403 | | |
1404 | | // check 0 and None have the same effect |
1405 | | let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()]) |
1406 | | .expect("valid conversion") |
1407 | | .normalize(".", None) |
1408 | | .expect("valid normalization"); |
1409 | | |
1410 | | assert_eq!(expected, normalized); |
1411 | | } |
1412 | | |
1413 | | #[test] |
1414 | | fn normalize_nested() { |
1415 | | // Initialize schema |
1416 | | let a = Arc::new(Field::new("a", DataType::Int64, true)); |
1417 | | let b = Arc::new(Field::new("b", DataType::Int64, false)); |
1418 | | let c = Arc::new(Field::new("c", DataType::Int64, true)); |
1419 | | |
1420 | | let one = Arc::new(Field::new( |
1421 | | "1", |
1422 | | DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])), |
1423 | | false, |
1424 | | )); |
1425 | | let two = Arc::new(Field::new( |
1426 | | "2", |
1427 | | DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])), |
1428 | | true, |
1429 | | )); |
1430 | | |
1431 | | let exclamation = Arc::new(Field::new( |
1432 | | "!", |
1433 | | DataType::Struct(Fields::from(vec![one.clone(), two.clone()])), |
1434 | | false, |
1435 | | )); |
1436 | | |
1437 | | let schema = Schema::new(vec![exclamation.clone()]); |
1438 | | |
1439 | | // Initialize fields |
1440 | | let a_field = Int64Array::from(vec![Some(0), Some(1)]); |
1441 | | let b_field = Int64Array::from(vec![Some(2), Some(3)]); |
1442 | | let c_field = Int64Array::from(vec![None, Some(4)]); |
1443 | | |
1444 | | let one_field = StructArray::from(vec![ |
1445 | | (a.clone(), Arc::new(a_field.clone()) as ArrayRef), |
1446 | | (b.clone(), Arc::new(b_field.clone()) as ArrayRef), |
1447 | | (c.clone(), Arc::new(c_field.clone()) as ArrayRef), |
1448 | | ]); |
1449 | | let two_field = StructArray::from(vec![ |
1450 | | (a.clone(), Arc::new(a_field.clone()) as ArrayRef), |
1451 | | (b.clone(), Arc::new(b_field.clone()) as ArrayRef), |
1452 | | (c.clone(), Arc::new(c_field.clone()) as ArrayRef), |
1453 | | ]); |
1454 | | |
1455 | | let exclamation_field = Arc::new(StructArray::from(vec![ |
1456 | | (one.clone(), Arc::new(one_field) as ArrayRef), |
1457 | | (two.clone(), Arc::new(two_field) as ArrayRef), |
1458 | | ])); |
1459 | | |
1460 | | // Normalize top level |
1461 | | let normalized = |
1462 | | RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()]) |
1463 | | .expect("valid conversion") |
1464 | | .normalize(".", Some(1)) |
1465 | | .expect("valid normalization"); |
1466 | | |
1467 | | let expected = RecordBatch::try_from_iter_with_nullable(vec![ |
1468 | | ( |
1469 | | "!.1", |
1470 | | Arc::new(StructArray::from(vec![ |
1471 | | (a.clone(), Arc::new(a_field.clone()) as ArrayRef), |
1472 | | (b.clone(), Arc::new(b_field.clone()) as ArrayRef), |
1473 | | (c.clone(), Arc::new(c_field.clone()) as ArrayRef), |
1474 | | ])) as ArrayRef, |
1475 | | false, |
1476 | | ), |
1477 | | ( |
1478 | | "!.2", |
1479 | | Arc::new(StructArray::from(vec![ |
1480 | | (a.clone(), Arc::new(a_field.clone()) as ArrayRef), |
1481 | | (b.clone(), Arc::new(b_field.clone()) as ArrayRef), |
1482 | | (c.clone(), Arc::new(c_field.clone()) as ArrayRef), |
1483 | | ])) as ArrayRef, |
1484 | | true, |
1485 | | ), |
1486 | | ]) |
1487 | | .expect("valid conversion"); |
1488 | | |
1489 | | assert_eq!(expected, normalized); |
1490 | | |
1491 | | // Normalize all levels |
1492 | | let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field]) |
1493 | | .expect("valid conversion") |
1494 | | .normalize(".", None) |
1495 | | .expect("valid normalization"); |
1496 | | |
1497 | | let expected = RecordBatch::try_from_iter_with_nullable(vec![ |
1498 | | ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true), |
1499 | | ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false), |
1500 | | ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true), |
1501 | | ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true), |
1502 | | ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false), |
1503 | | ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true), |
1504 | | ]) |
1505 | | .expect("valid conversion"); |
1506 | | |
1507 | | assert_eq!(expected, normalized); |
1508 | | } |
1509 | | |
1510 | | #[test] |
1511 | | fn normalize_empty() { |
1512 | | let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); |
1513 | | let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); |
1514 | | let year_field = Arc::new(Field::new("year", DataType::Int64, true)); |
1515 | | |
1516 | | let schema = Schema::new(vec![ |
1517 | | Field::new( |
1518 | | "a", |
1519 | | DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])), |
1520 | | false, |
1521 | | ), |
1522 | | Field::new("month", DataType::Int64, true), |
1523 | | ]); |
1524 | | |
1525 | | let normalized = RecordBatch::new_empty(Arc::new(schema.clone())) |
1526 | | .normalize(".", Some(0)) |
1527 | | .expect("valid normalization"); |
1528 | | |
1529 | | let expected = RecordBatch::new_empty(Arc::new( |
1530 | | schema.normalize(".", Some(0)).expect("valid normalization"), |
1531 | | )); |
1532 | | |
1533 | | assert_eq!(expected, normalized); |
1534 | | } |
1535 | | |
1536 | | #[test] |
1537 | | fn project() { |
1538 | | let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); |
1539 | | let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); |
1540 | | let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); |
1541 | | |
1542 | | let record_batch = |
1543 | | RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())]) |
1544 | | .expect("valid conversion"); |
1545 | | |
1546 | | let expected = |
1547 | | RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion"); |
1548 | | |
1549 | | assert_eq!(expected, record_batch.project(&[0, 2]).unwrap()); |
1550 | | } |
1551 | | |
1552 | | #[test] |
1553 | | fn project_empty() { |
1554 | | let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); |
1555 | | |
1556 | | let record_batch = |
1557 | | RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion"); |
1558 | | |
1559 | | let expected = RecordBatch::try_new_with_options( |
1560 | | Arc::new(Schema::empty()), |
1561 | | vec![], |
1562 | | &RecordBatchOptions { |
1563 | | match_field_names: true, |
1564 | | row_count: Some(3), |
1565 | | }, |
1566 | | ) |
1567 | | .expect("valid conversion"); |
1568 | | |
1569 | | assert_eq!(expected, record_batch.project(&[]).unwrap()); |
1570 | | } |
1571 | | |
1572 | | #[test] |
1573 | | fn test_no_column_record_batch() { |
1574 | | let schema = Arc::new(Schema::empty()); |
1575 | | |
1576 | | let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err(); |
1577 | | assert!(err |
1578 | | .to_string() |
1579 | | .contains("must either specify a row count or at least one column")); |
1580 | | |
1581 | | let options = RecordBatchOptions::new().with_row_count(Some(10)); |
1582 | | |
1583 | | let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); |
1584 | | assert_eq!(ok.num_rows(), 10); |
1585 | | |
1586 | | let a = ok.slice(2, 5); |
1587 | | assert_eq!(a.num_rows(), 5); |
1588 | | |
1589 | | let b = ok.slice(5, 0); |
1590 | | assert_eq!(b.num_rows(), 0); |
1591 | | |
1592 | | assert_ne!(a, b); |
1593 | | assert_eq!(b, RecordBatch::new_empty(schema)) |
1594 | | } |
1595 | | |
1596 | | #[test] |
1597 | | fn test_nulls_in_non_nullable_field() { |
1598 | | let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); |
1599 | | let maybe_batch = RecordBatch::try_new( |
1600 | | schema, |
1601 | | vec![Arc::new(Int32Array::from(vec![Some(1), None]))], |
1602 | | ); |
1603 | | assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap())); |
1604 | | } |
1605 | | #[test] |
1606 | | fn test_record_batch_options() { |
1607 | | let options = RecordBatchOptions::new() |
1608 | | .with_match_field_names(false) |
1609 | | .with_row_count(Some(20)); |
1610 | | assert!(!options.match_field_names); |
1611 | | assert_eq!(options.row_count.unwrap(), 20) |
1612 | | } |
1613 | | |
1614 | | #[test] |
1615 | | #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")] |
1616 | | fn test_from_struct() { |
1617 | | let s = StructArray::from(ArrayData::new_null( |
1618 | | // Note child is not nullable |
1619 | | &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()), |
1620 | | 2, |
1621 | | )); |
1622 | | let _ = RecordBatch::from(s); |
1623 | | } |
1624 | | |
1625 | | #[test] |
1626 | | fn test_with_schema() { |
1627 | | let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); |
1628 | | let required_schema = Arc::new(required_schema); |
1629 | | let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); |
1630 | | let nullable_schema = Arc::new(nullable_schema); |
1631 | | |
1632 | | let batch = RecordBatch::try_new( |
1633 | | required_schema.clone(), |
1634 | | vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _], |
1635 | | ) |
1636 | | .unwrap(); |
1637 | | |
1638 | | // Can add nullability |
1639 | | let batch = batch.with_schema(nullable_schema.clone()).unwrap(); |
1640 | | |
1641 | | // Cannot remove nullability |
1642 | | batch.clone().with_schema(required_schema).unwrap_err(); |
1643 | | |
1644 | | // Can add metadata |
1645 | | let metadata = vec![("foo".to_string(), "bar".to_string())] |
1646 | | .into_iter() |
1647 | | .collect(); |
1648 | | let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata); |
1649 | | let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap(); |
1650 | | |
1651 | | // Cannot remove metadata |
1652 | | batch.with_schema(nullable_schema).unwrap_err(); |
1653 | | } |
1654 | | |
1655 | | #[test] |
1656 | | fn test_boxed_reader() { |
1657 | | // Make sure we can pass a boxed reader to a function generic over |
1658 | | // RecordBatchReader. |
1659 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); |
1660 | | let schema = Arc::new(schema); |
1661 | | |
1662 | | let reader = RecordBatchIterator::new(std::iter::empty(), schema); |
1663 | | let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader); |
1664 | | |
1665 | | fn get_size(reader: impl RecordBatchReader) -> usize { |
1666 | | reader.size_hint().0 |
1667 | | } |
1668 | | |
1669 | | let size = get_size(reader); |
1670 | | assert_eq!(size, 0); |
1671 | | } |
1672 | | |
1673 | | #[test] |
1674 | | fn test_remove_column_maintains_schema_metadata() { |
1675 | | let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1676 | | let bool_array = BooleanArray::from(vec![true, false, false, true, true]); |
1677 | | |
1678 | | let mut metadata = HashMap::new(); |
1679 | | metadata.insert("foo".to_string(), "bar".to_string()); |
1680 | | let schema = Schema::new(vec![ |
1681 | | Field::new("id", DataType::Int32, false), |
1682 | | Field::new("bool", DataType::Boolean, false), |
1683 | | ]) |
1684 | | .with_metadata(metadata); |
1685 | | |
1686 | | let mut batch = RecordBatch::try_new( |
1687 | | Arc::new(schema), |
1688 | | vec![Arc::new(id_array), Arc::new(bool_array)], |
1689 | | ) |
1690 | | .unwrap(); |
1691 | | |
1692 | | let _removed_column = batch.remove_column(0); |
1693 | | assert_eq!(batch.schema().metadata().len(), 1); |
1694 | | assert_eq!( |
1695 | | batch.schema().metadata().get("foo").unwrap().as_str(), |
1696 | | "bar" |
1697 | | ); |
1698 | | } |
1699 | | } |