/Users/andrewlamb/Software/arrow-rs/arrow-array/src/record_batch.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! A two-dimensional batch of column-oriented data with a defined |
19 | | //! [schema](arrow_schema::Schema). |
20 | | |
21 | | use crate::cast::AsArray; |
22 | | use crate::{Array, ArrayRef, StructArray, new_empty_array}; |
23 | | use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef}; |
24 | | use std::ops::Index; |
25 | | use std::sync::Arc; |
26 | | |
27 | | /// Trait for types that can read `RecordBatch`'s. |
28 | | /// |
29 | | /// To create from an iterator, see [RecordBatchIterator]. |
30 | | pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> { |
31 | | /// Returns the schema of this `RecordBatchReader`. |
32 | | /// |
33 | | /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this |
34 | | /// reader should have the same schema as returned from this method. |
35 | | fn schema(&self) -> SchemaRef; |
36 | | } |
37 | | |
38 | | impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> { |
39 | | fn schema(&self) -> SchemaRef { |
40 | | self.as_ref().schema() |
41 | | } |
42 | | } |
43 | | |
44 | | /// Trait for types that can write `RecordBatch`'s. |
45 | | pub trait RecordBatchWriter { |
46 | | /// Write a single batch to the writer. |
47 | | fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>; |
48 | | |
49 | | /// Write footer or termination data, then mark the writer as done. |
50 | | fn close(self) -> Result<(), ArrowError>; |
51 | | } |
52 | | |
53 | | /// Creates an array from a literal slice of values, |
54 | | /// suitable for rapid testing and development. |
55 | | /// |
56 | | /// Example: |
57 | | /// |
58 | | /// ```rust |
59 | | /// |
60 | | /// use arrow_array::create_array; |
61 | | /// |
62 | | /// let array = create_array!(Int32, [1, 2, 3, 4, 5]); |
63 | | /// let array = create_array!(Utf8, [Some("a"), Some("b"), None, Some("e")]); |
64 | | /// ``` |
65 | | /// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used. |
66 | | /// Presently supported data types are: |
67 | | /// - `Boolean`, `Null` |
68 | | /// - `Decimal32`, `Decimal64`, `Decimal128`, `Decimal256` |
69 | | /// - `Float16`, `Float32`, `Float64` |
70 | | /// - `Int8`, `Int16`, `Int32`, `Int64` |
71 | | /// - `UInt8`, `UInt16`, `UInt32`, `UInt64` |
72 | | /// - `IntervalDayTime`, `IntervalYearMonth` |
73 | | /// - `Second`, `Millisecond`, `Microsecond`, `Nanosecond` |
74 | | /// - `Second32`, `Millisecond32`, `Microsecond64`, `Nanosecond64` |
75 | | /// - `DurationSecond`, `DurationMillisecond`, `DurationMicrosecond`, `DurationNanosecond` |
76 | | /// - `TimestampSecond`, `TimestampMillisecond`, `TimestampMicrosecond`, `TimestampNanosecond` |
77 | | /// - `Utf8`, `Utf8View`, `LargeUtf8`, `Binary`, `LargeBinary` |
78 | | #[macro_export] |
79 | | macro_rules! create_array { |
80 | | // `@from` is used for those types that have a common method `<type>::from` |
81 | | (@from Boolean) => { $crate::BooleanArray }; |
82 | | (@from Int8) => { $crate::Int8Array }; |
83 | | (@from Int16) => { $crate::Int16Array }; |
84 | | (@from Int32) => { $crate::Int32Array }; |
85 | | (@from Int64) => { $crate::Int64Array }; |
86 | | (@from UInt8) => { $crate::UInt8Array }; |
87 | | (@from UInt16) => { $crate::UInt16Array }; |
88 | | (@from UInt32) => { $crate::UInt32Array }; |
89 | | (@from UInt64) => { $crate::UInt64Array }; |
90 | | (@from Float16) => { $crate::Float16Array }; |
91 | | (@from Float32) => { $crate::Float32Array }; |
92 | | (@from Float64) => { $crate::Float64Array }; |
93 | | (@from Utf8) => { $crate::StringArray }; |
94 | | (@from Utf8View) => { $crate::StringViewArray }; |
95 | | (@from LargeUtf8) => { $crate::LargeStringArray }; |
96 | | (@from IntervalDayTime) => { $crate::IntervalDayTimeArray }; |
97 | | (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray }; |
98 | | (@from Second) => { $crate::TimestampSecondArray }; |
99 | | (@from Millisecond) => { $crate::TimestampMillisecondArray }; |
100 | | (@from Microsecond) => { $crate::TimestampMicrosecondArray }; |
101 | | (@from Nanosecond) => { $crate::TimestampNanosecondArray }; |
102 | | (@from Second32) => { $crate::Time32SecondArray }; |
103 | | (@from Millisecond32) => { $crate::Time32MillisecondArray }; |
104 | | (@from Microsecond64) => { $crate::Time64MicrosecondArray }; |
105 | | (@from Nanosecond64) => { $crate::Time64Nanosecond64Array }; |
106 | | (@from DurationSecond) => { $crate::DurationSecondArray }; |
107 | | (@from DurationMillisecond) => { $crate::DurationMillisecondArray }; |
108 | | (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray }; |
109 | | (@from DurationNanosecond) => { $crate::DurationNanosecondArray }; |
110 | | (@from Decimal32) => { $crate::Decimal32Array }; |
111 | | (@from Decimal64) => { $crate::Decimal64Array }; |
112 | | (@from Decimal128) => { $crate::Decimal128Array }; |
113 | | (@from Decimal256) => { $crate::Decimal256Array }; |
114 | | (@from TimestampSecond) => { $crate::TimestampSecondArray }; |
115 | | (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray }; |
116 | | (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray }; |
117 | | (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray }; |
118 | | |
119 | | (@from $ty: ident) => { |
120 | | compile_error!(concat!("Unsupported data type: ", stringify!($ty))) |
121 | | }; |
122 | | |
123 | | (Null, $size: expr) => { |
124 | | std::sync::Arc::new($crate::NullArray::new($size)) |
125 | | }; |
126 | | |
127 | | (Binary, [$($values: expr),*]) => { |
128 | | std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*])) |
129 | | }; |
130 | | |
131 | | (LargeBinary, [$($values: expr),*]) => { |
132 | | std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*])) |
133 | | }; |
134 | | |
135 | | ($ty: tt, [$($values: expr),*]) => { |
136 | | std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*])) |
137 | | }; |
138 | | } |
139 | | |
140 | | /// Creates a record batch from literal slice of values, suitable for rapid |
141 | | /// testing and development. |
142 | | /// |
143 | | /// Example: |
144 | | /// |
145 | | /// ```rust |
146 | | /// use arrow_array::record_batch; |
147 | | /// use arrow_schema; |
148 | | /// |
149 | | /// let batch = record_batch!( |
150 | | /// ("a", Int32, [1, 2, 3]), |
151 | | /// ("b", Float64, [Some(4.0), None, Some(5.0)]), |
152 | | /// ("c", Utf8, ["alpha", "beta", "gamma"]) |
153 | | /// ); |
154 | | /// ``` |
155 | | /// Due to limitation of [`create_array!`] macro, support for limited data types is available. |
156 | | #[macro_export] |
157 | | macro_rules! record_batch { |
158 | | ($(($name: expr, $type: ident, [$($values: expr),*])),*) => { |
159 | | { |
160 | | let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ |
161 | | $( |
162 | | arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), |
163 | | )* |
164 | | ])); |
165 | | |
166 | | let batch = $crate::RecordBatch::try_new( |
167 | | schema, |
168 | | vec![$( |
169 | | $crate::create_array!($type, [$($values),*]), |
170 | | )*] |
171 | | ); |
172 | | |
173 | | batch |
174 | | } |
175 | | } |
176 | | } |
177 | | |
178 | | /// A two-dimensional batch of column-oriented data with a defined |
179 | | /// [schema](arrow_schema::Schema). |
180 | | /// |
181 | | /// A `RecordBatch` is a two-dimensional dataset of a number of |
182 | | /// contiguous arrays, each the same length. |
183 | | /// A record batch has a schema which must match its arrays’ |
184 | | /// datatypes. |
185 | | /// |
186 | | /// Record batches are a convenient unit of work for various |
187 | | /// serialization and computation functions, possibly incremental. |
188 | | /// |
189 | | /// Use the [`record_batch!`] macro to create a [`RecordBatch`] from |
190 | | /// literal slice of values, useful for rapid prototyping and testing. |
191 | | /// |
192 | | /// Example: |
193 | | /// ```rust |
194 | | /// use arrow_array::record_batch; |
195 | | /// let batch = record_batch!( |
196 | | /// ("a", Int32, [1, 2, 3]), |
197 | | /// ("b", Float64, [Some(4.0), None, Some(5.0)]), |
198 | | /// ("c", Utf8, ["alpha", "beta", "gamma"]) |
199 | | /// ); |
200 | | /// ``` |
201 | | #[derive(Clone, Debug, PartialEq)] |
202 | | pub struct RecordBatch { |
203 | | schema: SchemaRef, |
204 | | columns: Vec<Arc<dyn Array>>, |
205 | | |
206 | | /// The number of rows in this RecordBatch |
207 | | /// |
208 | | /// This is stored separately from the columns to handle the case of no columns |
209 | | row_count: usize, |
210 | | } |
211 | | |
212 | | impl RecordBatch { |
213 | | /// Creates a `RecordBatch` from a schema and columns. |
214 | | /// |
215 | | /// Expects the following: |
216 | | /// |
217 | | /// * `!columns.is_empty()` |
218 | | /// * `schema.fields.len() == columns.len()` |
219 | | /// * `schema.fields[i].data_type() == columns[i].data_type()` |
220 | | /// * `columns[i].len() == columns[j].len()` |
221 | | /// |
222 | | /// If the conditions are not met, an error is returned. |
223 | | /// |
224 | | /// # Example |
225 | | /// |
226 | | /// ``` |
227 | | /// # use std::sync::Arc; |
228 | | /// # use arrow_array::{Int32Array, RecordBatch}; |
229 | | /// # use arrow_schema::{DataType, Field, Schema}; |
230 | | /// |
231 | | /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
232 | | /// let schema = Schema::new(vec![ |
233 | | /// Field::new("id", DataType::Int32, false) |
234 | | /// ]); |
235 | | /// |
236 | | /// let batch = RecordBatch::try_new( |
237 | | /// Arc::new(schema), |
238 | | /// vec![Arc::new(id_array)] |
239 | | /// ).unwrap(); |
240 | | /// ``` |
241 | 128 | pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> { |
242 | 128 | let options = RecordBatchOptions::new(); |
243 | 128 | Self::try_new_impl(schema, columns, &options) |
244 | 128 | } |
245 | | |
246 | | /// Creates a `RecordBatch` from a schema and columns, without validation. |
247 | | /// |
248 | | /// See [`Self::try_new`] for the checked version. |
249 | | /// |
250 | | /// # Safety |
251 | | /// |
252 | | /// Expects the following: |
253 | | /// |
254 | | /// * `schema.fields.len() == columns.len()` |
255 | | /// * `schema.fields[i].data_type() == columns[i].data_type()` |
256 | | /// * `columns[i].len() == row_count` |
257 | | /// |
258 | | /// Note: if the schema does not match the underlying data exactly, it can lead to undefined |
259 | | /// behavior, for example, via conversion to a `StructArray`, which in turn could lead |
260 | | /// to incorrect access. |
261 | 185 | pub unsafe fn new_unchecked( |
262 | 185 | schema: SchemaRef, |
263 | 185 | columns: Vec<Arc<dyn Array>>, |
264 | 185 | row_count: usize, |
265 | 185 | ) -> Self { |
266 | 185 | Self { |
267 | 185 | schema, |
268 | 185 | columns, |
269 | 185 | row_count, |
270 | 185 | } |
271 | 185 | } |
272 | | |
273 | | /// Creates a `RecordBatch` from a schema and columns, with additional options, |
274 | | /// such as whether to strictly validate field names. |
275 | | /// |
276 | | /// See [`RecordBatch::try_new`] for the expected conditions. |
277 | 150 | pub fn try_new_with_options( |
278 | 150 | schema: SchemaRef, |
279 | 150 | columns: Vec<ArrayRef>, |
280 | 150 | options: &RecordBatchOptions, |
281 | 150 | ) -> Result<Self, ArrowError> { |
282 | 150 | Self::try_new_impl(schema, columns, options) |
283 | 150 | } |
284 | | |
285 | | /// Creates a new empty [`RecordBatch`]. |
286 | 3 | pub fn new_empty(schema: SchemaRef) -> Self { |
287 | 3 | let columns = schema |
288 | 3 | .fields() |
289 | 3 | .iter() |
290 | 3 | .map(|field| new_empty_array(field.data_type())) |
291 | 3 | .collect(); |
292 | | |
293 | 3 | RecordBatch { |
294 | 3 | schema, |
295 | 3 | columns, |
296 | 3 | row_count: 0, |
297 | 3 | } |
298 | 3 | } |
299 | | |
300 | | /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error |
301 | | /// if any validation check fails, otherwise returns the created [`Self`] |
302 | 278 | fn try_new_impl( |
303 | 278 | schema: SchemaRef, |
304 | 278 | columns: Vec<ArrayRef>, |
305 | 278 | options: &RecordBatchOptions, |
306 | 278 | ) -> Result<Self, ArrowError> { |
307 | | // check that number of fields in schema match column length |
308 | 278 | if schema.fields().len() != columns.len() { |
309 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
310 | 0 | "number of columns({}) must match number of fields({}) in schema", |
311 | 0 | columns.len(), |
312 | 0 | schema.fields().len(), |
313 | 0 | ))); |
314 | 278 | } |
315 | | |
316 | 278 | let row_count = options |
317 | 278 | .row_count |
318 | 278 | .or_else(|| columns.first()128 .map128 (|col| col128 .len128 ())) |
319 | 278 | .ok_or_else(|| {0 |
320 | 0 | ArrowError::InvalidArgumentError( |
321 | 0 | "must either specify a row count or at least one column".to_string(), |
322 | 0 | ) |
323 | 0 | })?; |
324 | | |
325 | 572 | for (c, f) in columns.iter()278 .zip278 (&schema.fields278 ) { |
326 | 572 | if !f.is_nullable() && c.null_count() > 060 { |
327 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
328 | 0 | "Column '{}' is declared as non-nullable but contains null values", |
329 | 0 | f.name() |
330 | 0 | ))); |
331 | 572 | } |
332 | | } |
333 | | |
334 | | // check that all columns have the same row count |
335 | 572 | if columns.iter()278 .any278 (|c| c.len() != row_count) { |
336 | 0 | let err = match options.row_count { |
337 | 0 | Some(_) => "all columns in a record batch must have the specified row count", |
338 | 0 | None => "all columns in a record batch must have the same length", |
339 | | }; |
340 | 0 | return Err(ArrowError::InvalidArgumentError(err.to_string())); |
341 | 278 | } |
342 | | |
343 | | // function for comparing column type and field type |
344 | | // return true if 2 types are not matched |
345 | 278 | let type_not_match = if options.match_field_names { |
346 | 572 | |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type |
347 | | } else { |
348 | 0 | |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { |
349 | 0 | !col_type.equals_datatype(field_type) |
350 | 0 | } |
351 | | }; |
352 | | |
353 | | // check that all columns match the schema |
354 | 278 | let not_match = columns |
355 | 278 | .iter() |
356 | 278 | .zip(schema.fields().iter()) |
357 | 572 | .map278 (|(col, field)| (col.data_type(), field.data_type())) |
358 | 278 | .enumerate() |
359 | 278 | .find(type_not_match); |
360 | | |
361 | 278 | if let Some((i0 , (col_type0 , field_type0 ))) = not_match { |
362 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
363 | 0 | "column types must match schema types, expected {field_type} but found {col_type} at column index {i}" |
364 | 0 | ))); |
365 | 278 | } |
366 | | |
367 | 278 | Ok(RecordBatch { |
368 | 278 | schema, |
369 | 278 | columns, |
370 | 278 | row_count, |
371 | 278 | }) |
372 | 278 | } |
373 | | |
374 | | /// Return the schema, columns and row count of this [`RecordBatch`] |
375 | 372 | pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) { |
376 | 372 | (self.schema, self.columns, self.row_count) |
377 | 372 | } |
378 | | |
379 | | /// Override the schema of this [`RecordBatch`] |
380 | | /// |
381 | | /// Returns an error if `schema` is not a superset of the current schema |
382 | | /// as determined by [`Schema::contains`] |
383 | | /// |
384 | | /// See also [`Self::schema_metadata_mut`]. |
385 | 0 | pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> { |
386 | 0 | if !schema.contains(self.schema.as_ref()) { |
387 | 0 | return Err(ArrowError::SchemaError(format!( |
388 | 0 | "target schema is not superset of current schema target={schema} current={}", |
389 | 0 | self.schema |
390 | 0 | ))); |
391 | 0 | } |
392 | | |
393 | 0 | Ok(Self { |
394 | 0 | schema, |
395 | 0 | columns: self.columns, |
396 | 0 | row_count: self.row_count, |
397 | 0 | }) |
398 | 0 | } |
399 | | |
400 | | /// Returns the [`Schema`] of the record batch. |
401 | 145 | pub fn schema(&self) -> SchemaRef { |
402 | 145 | self.schema.clone() |
403 | 145 | } |
404 | | |
405 | | /// Returns a reference to the [`Schema`] of the record batch. |
406 | 1 | pub fn schema_ref(&self) -> &SchemaRef { |
407 | 1 | &self.schema |
408 | 1 | } |
409 | | |
410 | | /// Mutable access to the metadata of the schema. |
411 | | /// |
412 | | /// This allows you to modify [`Schema::metadata`] of [`Self::schema`] in a convenient and fast way. |
413 | | /// |
414 | | /// Note this will clone the entire underlying `Schema` object if it is currently shared |
415 | | /// |
416 | | /// # Example |
417 | | /// ``` |
418 | | /// # use std::sync::Arc; |
419 | | /// # use arrow_array::{record_batch, RecordBatch}; |
420 | | /// let mut batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); |
421 | | /// // Initially, the metadata is empty |
422 | | /// assert!(batch.schema().metadata().get("key").is_none()); |
423 | | /// // Insert a key-value pair into the metadata |
424 | | /// batch.schema_metadata_mut().insert("key".into(), "value".into()); |
425 | | /// assert_eq!(batch.schema().metadata().get("key"), Some(&String::from("value"))); |
426 | | /// ``` |
427 | 0 | pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> { |
428 | 0 | let schema = Arc::make_mut(&mut self.schema); |
429 | 0 | &mut schema.metadata |
430 | 0 | } |
431 | | |
432 | | /// Projects the schema onto the specified columns |
433 | 0 | pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> { |
434 | 0 | let projected_schema = self.schema.project(indices)?; |
435 | 0 | let batch_fields = indices |
436 | 0 | .iter() |
437 | 0 | .map(|f| { |
438 | 0 | self.columns.get(*f).cloned().ok_or_else(|| { |
439 | 0 | ArrowError::SchemaError(format!( |
440 | 0 | "project index {} out of bounds, max field {}", |
441 | 0 | f, |
442 | 0 | self.columns.len() |
443 | 0 | )) |
444 | 0 | }) |
445 | 0 | }) |
446 | 0 | .collect::<Result<Vec<_>, _>>()?; |
447 | | |
448 | | unsafe { |
449 | | // Since we're starting from a valid RecordBatch and project |
450 | | // creates a strict subset of the original, there's no need to |
451 | | // redo the validation checks in `try_new_with_options`. |
452 | 0 | Ok(RecordBatch::new_unchecked( |
453 | 0 | SchemaRef::new(projected_schema), |
454 | 0 | batch_fields, |
455 | 0 | self.row_count, |
456 | 0 | )) |
457 | | } |
458 | 0 | } |
459 | | |
460 | | /// Normalize a semi-structured [`RecordBatch`] into a flat table. |
461 | | /// |
462 | | /// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level` |
463 | | /// (unlimited if `None`). |
464 | | /// |
465 | | /// e.g. given a [`RecordBatch`] with schema: |
466 | | /// |
467 | | /// ```text |
468 | | /// "foo": StructArray<"bar": Utf8> |
469 | | /// ``` |
470 | | /// |
471 | | /// A separator of `"."` would generate a batch with the schema: |
472 | | /// |
473 | | /// ```text |
474 | | /// "foo.bar": Utf8 |
475 | | /// ``` |
476 | | /// |
477 | | /// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`; |
478 | | /// it will be treated as unlimited. |
479 | | /// |
480 | | /// # Example |
481 | | /// |
482 | | /// ``` |
483 | | /// # use std::sync::Arc; |
484 | | /// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch}; |
485 | | /// # use arrow_schema::{DataType, Field, Fields, Schema}; |
486 | | /// # |
487 | | /// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""])); |
488 | | /// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)])); |
489 | | /// |
490 | | /// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); |
491 | | /// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); |
492 | | /// |
493 | | /// let a = Arc::new(StructArray::from(vec![ |
494 | | /// (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef), |
495 | | /// (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef), |
496 | | /// ])); |
497 | | /// |
498 | | /// let schema = Schema::new(vec![ |
499 | | /// Field::new( |
500 | | /// "a", |
501 | | /// DataType::Struct(Fields::from(vec![animals_field, n_legs_field])), |
502 | | /// false, |
503 | | /// ) |
504 | | /// ]); |
505 | | /// |
506 | | /// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a]) |
507 | | /// .expect("valid conversion") |
508 | | /// .normalize(".", None) |
509 | | /// .expect("valid normalization"); |
510 | | /// |
511 | | /// let expected = RecordBatch::try_from_iter_with_nullable(vec![ |
512 | | /// ("a.animals", animals.clone(), true), |
513 | | /// ("a.n_legs", n_legs.clone(), true), |
514 | | /// ]) |
515 | | /// .expect("valid conversion"); |
516 | | /// |
517 | | /// assert_eq!(expected, normalized); |
518 | | /// ``` |
519 | 0 | pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> { |
520 | 0 | let max_level = match max_level.unwrap_or(usize::MAX) { |
521 | 0 | 0 => usize::MAX, |
522 | 0 | val => val, |
523 | | }; |
524 | 0 | let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self |
525 | 0 | .columns |
526 | 0 | .iter() |
527 | 0 | .zip(self.schema.fields()) |
528 | 0 | .rev() |
529 | 0 | .map(|(c, f)| { |
530 | 0 | let name_vec: Vec<&str> = vec![f.name()]; |
531 | 0 | (0, c, name_vec, f) |
532 | 0 | }) |
533 | 0 | .collect(); |
534 | 0 | let mut columns: Vec<ArrayRef> = Vec::new(); |
535 | 0 | let mut fields: Vec<FieldRef> = Vec::new(); |
536 | | |
537 | 0 | while let Some((depth, c, name, field_ref)) = stack.pop() { |
538 | 0 | match field_ref.data_type() { |
539 | 0 | DataType::Struct(ff) if depth < max_level => { |
540 | | // Need to zip these in reverse to maintain original order |
541 | 0 | for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() { |
542 | 0 | let mut name = name.clone(); |
543 | 0 | name.push(separator); |
544 | 0 | name.push(fff.name()); |
545 | 0 | stack.push((depth + 1, cff, name, fff)) |
546 | | } |
547 | | } |
548 | 0 | _ => { |
549 | 0 | let updated_field = Field::new( |
550 | 0 | name.concat(), |
551 | 0 | field_ref.data_type().clone(), |
552 | 0 | field_ref.is_nullable(), |
553 | 0 | ); |
554 | 0 | columns.push(c.clone()); |
555 | 0 | fields.push(Arc::new(updated_field)); |
556 | 0 | } |
557 | | } |
558 | | } |
559 | 0 | RecordBatch::try_new(Arc::new(Schema::new(fields)), columns) |
560 | 0 | } |
561 | | |
562 | | /// Returns the number of columns in the record batch. |
563 | | /// |
564 | | /// # Example |
565 | | /// |
566 | | /// ``` |
567 | | /// # use std::sync::Arc; |
568 | | /// # use arrow_array::{Int32Array, RecordBatch}; |
569 | | /// # use arrow_schema::{DataType, Field, Schema}; |
570 | | /// |
571 | | /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
572 | | /// let schema = Schema::new(vec![ |
573 | | /// Field::new("id", DataType::Int32, false) |
574 | | /// ]); |
575 | | /// |
576 | | /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap(); |
577 | | /// |
578 | | /// assert_eq!(batch.num_columns(), 1); |
579 | | /// ``` |
580 | 84 | pub fn num_columns(&self) -> usize { |
581 | 84 | self.columns.len() |
582 | 84 | } |
583 | | |
584 | | /// Returns the number of rows in each column. |
585 | | /// |
586 | | /// # Example |
587 | | /// |
588 | | /// ``` |
589 | | /// # use std::sync::Arc; |
590 | | /// # use arrow_array::{Int32Array, RecordBatch}; |
591 | | /// # use arrow_schema::{DataType, Field, Schema}; |
592 | | /// |
593 | | /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
594 | | /// let schema = Schema::new(vec![ |
595 | | /// Field::new("id", DataType::Int32, false) |
596 | | /// ]); |
597 | | /// |
598 | | /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap(); |
599 | | /// |
600 | | /// assert_eq!(batch.num_rows(), 5); |
601 | | /// ``` |
602 | 473 | pub fn num_rows(&self) -> usize { |
603 | 473 | self.row_count |
604 | 473 | } |
605 | | |
606 | | /// Get a reference to a column's array by index. |
607 | | /// |
608 | | /// # Panics |
609 | | /// |
610 | | /// Panics if `index` is outside of `0..num_columns`. |
611 | 341 | pub fn column(&self, index: usize) -> &ArrayRef { |
612 | 341 | &self.columns[index] |
613 | 341 | } |
614 | | |
615 | | /// Get a reference to a column's array by name. |
616 | 13 | pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> { |
617 | 13 | self.schema() |
618 | 13 | .column_with_name(name) |
619 | 13 | .map(|(index, _)| &self.columns[index]) |
620 | 13 | } |
621 | | |
622 | | /// Get a reference to all columns in the record batch. |
623 | 159 | pub fn columns(&self) -> &[ArrayRef] { |
624 | 159 | &self.columns[..] |
625 | 159 | } |
626 | | |
627 | | /// Remove column by index and return it. |
628 | | /// |
629 | | /// Return the `ArrayRef` if the column is removed. |
630 | | /// |
631 | | /// # Panics |
632 | | /// |
633 | | /// Panics if `index`` out of bounds. |
634 | | /// |
635 | | /// # Example |
636 | | /// |
637 | | /// ``` |
638 | | /// use std::sync::Arc; |
639 | | /// use arrow_array::{BooleanArray, Int32Array, RecordBatch}; |
640 | | /// use arrow_schema::{DataType, Field, Schema}; |
641 | | /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
642 | | /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]); |
643 | | /// let schema = Schema::new(vec![ |
644 | | /// Field::new("id", DataType::Int32, false), |
645 | | /// Field::new("bool", DataType::Boolean, false), |
646 | | /// ]); |
647 | | /// |
648 | | /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap(); |
649 | | /// |
650 | | /// let removed_column = batch.remove_column(0); |
651 | | /// assert_eq!(removed_column.as_any().downcast_ref::<Int32Array>().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5])); |
652 | | /// assert_eq!(batch.num_columns(), 1); |
653 | | /// ``` |
654 | 0 | pub fn remove_column(&mut self, index: usize) -> ArrayRef { |
655 | 0 | let mut builder = SchemaBuilder::from(self.schema.as_ref()); |
656 | 0 | builder.remove(index); |
657 | 0 | self.schema = Arc::new(builder.finish()); |
658 | 0 | self.columns.remove(index) |
659 | 0 | } |
660 | | |
661 | | /// Return a new RecordBatch where each column is sliced |
662 | | /// according to `offset` and `length` |
663 | | /// |
664 | | /// # Panics |
665 | | /// |
666 | | /// Panics if `offset` with `length` is greater than column length. |
667 | 77 | pub fn slice(&self, offset: usize, length: usize) -> RecordBatch { |
668 | 77 | assert!((offset + length) <= self.num_rows()); |
669 | | |
670 | 77 | let columns = self |
671 | 77 | .columns() |
672 | 77 | .iter() |
673 | 158 | .map77 (|column| column.slice(offset, length)) |
674 | 77 | .collect(); |
675 | | |
676 | 77 | Self { |
677 | 77 | schema: self.schema.clone(), |
678 | 77 | columns, |
679 | 77 | row_count: length, |
680 | 77 | } |
681 | 77 | } |
682 | | |
683 | | /// Create a `RecordBatch` from an iterable list of pairs of the |
684 | | /// form `(field_name, array)`, with the same requirements on |
685 | | /// fields and arrays as [`RecordBatch::try_new`]. This method is |
686 | | /// often used to create a single `RecordBatch` from arrays, |
687 | | /// e.g. for testing. |
688 | | /// |
689 | | /// The resulting schema is marked as nullable for each column if |
690 | | /// the array for that column is has any nulls. To explicitly |
691 | | /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`] |
692 | | /// |
693 | | /// Example: |
694 | | /// ``` |
695 | | /// # use std::sync::Arc; |
696 | | /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; |
697 | | /// |
698 | | /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
699 | | /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); |
700 | | /// |
701 | | /// let record_batch = RecordBatch::try_from_iter(vec![ |
702 | | /// ("a", a), |
703 | | /// ("b", b), |
704 | | /// ]); |
705 | | /// ``` |
706 | | /// Another way to quickly create a [`RecordBatch`] is to use the [`record_batch!`] macro, |
707 | | /// which is particularly helpful for rapid prototyping and testing. |
708 | | /// |
709 | | /// Example: |
710 | | /// |
711 | | /// ```rust |
712 | | /// use arrow_array::record_batch; |
713 | | /// let batch = record_batch!( |
714 | | /// ("a", Int32, [1, 2, 3]), |
715 | | /// ("b", Float64, [Some(4.0), None, Some(5.0)]), |
716 | | /// ("c", Utf8, ["alpha", "beta", "gamma"]) |
717 | | /// ); |
718 | | /// ``` |
719 | 41 | pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError> |
720 | 41 | where |
721 | 41 | I: IntoIterator<Item = (F, ArrayRef)>, |
722 | 41 | F: AsRef<str>, |
723 | | { |
724 | | // TODO: implement `TryFrom` trait, once |
725 | | // https://github.com/rust-lang/rust/issues/50133 is no longer an |
726 | | // issue |
727 | 161 | let iter41 = value41 .into_iter41 ().map41 (|(field_name, array)| { |
728 | 161 | let nullable = array.null_count() > 0; |
729 | 161 | (field_name, array, nullable) |
730 | 161 | }); |
731 | | |
732 | 41 | Self::try_from_iter_with_nullable(iter) |
733 | 41 | } |
734 | | |
735 | | /// Create a `RecordBatch` from an iterable list of tuples of the |
736 | | /// form `(field_name, array, nullable)`, with the same requirements on |
737 | | /// fields and arrays as [`RecordBatch::try_new`]. This method is often |
738 | | /// used to create a single `RecordBatch` from arrays, e.g. for |
739 | | /// testing. |
740 | | /// |
741 | | /// Example: |
742 | | /// ``` |
743 | | /// # use std::sync::Arc; |
744 | | /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; |
745 | | /// |
746 | | /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
747 | | /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")])); |
748 | | /// |
749 | | /// // Note neither `a` nor `b` has any actual nulls, but we mark |
750 | | /// // b an nullable |
751 | | /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![ |
752 | | /// ("a", a, false), |
753 | | /// ("b", b, true), |
754 | | /// ]); |
755 | | /// ``` |
756 | 41 | pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError> |
757 | 41 | where |
758 | 41 | I: IntoIterator<Item = (F, ArrayRef, bool)>, |
759 | 41 | F: AsRef<str>, |
760 | | { |
761 | 41 | let iter = value.into_iter(); |
762 | 41 | let capacity = iter.size_hint().0; |
763 | 41 | let mut schema = SchemaBuilder::with_capacity(capacity); |
764 | 41 | let mut columns = Vec::with_capacity(capacity); |
765 | | |
766 | 202 | for (field_name161 , array161 , nullable161 ) in iter { |
767 | 161 | let field_name = field_name.as_ref(); |
768 | 161 | schema.push(Field::new(field_name, array.data_type().clone(), nullable)); |
769 | 161 | columns.push(array); |
770 | 161 | } |
771 | | |
772 | 41 | let schema = Arc::new(schema.finish()); |
773 | 41 | RecordBatch::try_new(schema, columns) |
774 | 41 | } |
775 | | |
776 | | /// Returns the total number of bytes of memory occupied physically by this batch. |
777 | | /// |
778 | | /// Note that this does not always correspond to the exact memory usage of a |
779 | | /// `RecordBatch` (might overestimate), since multiple columns can share the same |
780 | | /// buffers or slices thereof, the memory used by the shared buffers might be |
781 | | /// counted multiple times. |
782 | 0 | pub fn get_array_memory_size(&self) -> usize { |
783 | 0 | self.columns() |
784 | 0 | .iter() |
785 | 0 | .map(|array| array.get_array_memory_size()) |
786 | 0 | .sum() |
787 | 0 | } |
788 | | } |
789 | | |
790 | | /// Options that control the behaviour used when creating a [`RecordBatch`]. |
791 | | #[derive(Debug)] |
792 | | #[non_exhaustive] |
793 | | pub struct RecordBatchOptions { |
794 | | /// Match field names of structs and lists. If set to `true`, the names must match. |
795 | | pub match_field_names: bool, |
796 | | |
797 | | /// Optional row count, useful for specifying a row count for a RecordBatch with no columns |
798 | | pub row_count: Option<usize>, |
799 | | } |
800 | | |
801 | | impl RecordBatchOptions { |
802 | | /// Creates a new `RecordBatchOptions` |
803 | 278 | pub fn new() -> Self { |
804 | 278 | Self { |
805 | 278 | match_field_names: true, |
806 | 278 | row_count: None, |
807 | 278 | } |
808 | 278 | } |
809 | | /// Sets the row_count of RecordBatchOptions and returns self |
810 | 147 | pub fn with_row_count(mut self, row_count: Option<usize>) -> Self { |
811 | 147 | self.row_count = row_count; |
812 | 147 | self |
813 | 147 | } |
814 | | /// Sets the match_field_names of RecordBatchOptions and returns self |
815 | 0 | pub fn with_match_field_names(mut self, match_field_names: bool) -> Self { |
816 | 0 | self.match_field_names = match_field_names; |
817 | 0 | self |
818 | 0 | } |
819 | | } |
820 | | impl Default for RecordBatchOptions { |
821 | 4 | fn default() -> Self { |
822 | 4 | Self::new() |
823 | 4 | } |
824 | | } |
825 | | impl From<StructArray> for RecordBatch { |
826 | 0 | fn from(value: StructArray) -> Self { |
827 | 0 | let row_count = value.len(); |
828 | 0 | let (fields, columns, nulls) = value.into_parts(); |
829 | 0 | assert_eq!( |
830 | 0 | nulls.map(|n| n.null_count()).unwrap_or_default(), |
831 | | 0, |
832 | 0 | "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" |
833 | | ); |
834 | | |
835 | 0 | RecordBatch { |
836 | 0 | schema: Arc::new(Schema::new(fields)), |
837 | 0 | row_count, |
838 | 0 | columns, |
839 | 0 | } |
840 | 0 | } |
841 | | } |
842 | | |
843 | | impl From<&StructArray> for RecordBatch { |
844 | 0 | fn from(struct_array: &StructArray) -> Self { |
845 | 0 | struct_array.clone().into() |
846 | 0 | } |
847 | | } |
848 | | |
849 | | impl Index<&str> for RecordBatch { |
850 | | type Output = ArrayRef; |
851 | | |
852 | | /// Get a reference to a column's array by name. |
853 | | /// |
854 | | /// # Panics |
855 | | /// |
856 | | /// Panics if the name is not in the schema. |
857 | 0 | fn index(&self, name: &str) -> &Self::Output { |
858 | 0 | self.column_by_name(name).unwrap() |
859 | 0 | } |
860 | | } |
861 | | |
862 | | /// Generic implementation of [RecordBatchReader] that wraps an iterator. |
863 | | /// |
864 | | /// # Example |
865 | | /// |
866 | | /// ``` |
867 | | /// # use std::sync::Arc; |
868 | | /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader}; |
869 | | /// # |
870 | | /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
871 | | /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); |
872 | | /// |
873 | | /// let record_batch = RecordBatch::try_from_iter(vec![ |
874 | | /// ("a", a), |
875 | | /// ("b", b), |
876 | | /// ]).unwrap(); |
877 | | /// |
878 | | /// let batches: Vec<RecordBatch> = vec![record_batch.clone(), record_batch.clone()]; |
879 | | /// |
880 | | /// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema()); |
881 | | /// |
882 | | /// assert_eq!(reader.schema(), record_batch.schema()); |
883 | | /// assert_eq!(reader.next().unwrap().unwrap(), record_batch); |
884 | | /// # assert_eq!(reader.next().unwrap().unwrap(), record_batch); |
885 | | /// # assert!(reader.next().is_none()); |
886 | | /// ``` |
887 | | pub struct RecordBatchIterator<I> |
888 | | where |
889 | | I: IntoIterator<Item = Result<RecordBatch, ArrowError>>, |
890 | | { |
891 | | inner: I::IntoIter, |
892 | | inner_schema: SchemaRef, |
893 | | } |
894 | | |
895 | | impl<I> RecordBatchIterator<I> |
896 | | where |
897 | | I: IntoIterator<Item = Result<RecordBatch, ArrowError>>, |
898 | | { |
899 | | /// Create a new [RecordBatchIterator]. |
900 | | /// |
901 | | /// If `iter` is an infallible iterator, use `.map(Ok)`. |
902 | | pub fn new(iter: I, schema: SchemaRef) -> Self { |
903 | | Self { |
904 | | inner: iter.into_iter(), |
905 | | inner_schema: schema, |
906 | | } |
907 | | } |
908 | | } |
909 | | |
910 | | impl<I> Iterator for RecordBatchIterator<I> |
911 | | where |
912 | | I: IntoIterator<Item = Result<RecordBatch, ArrowError>>, |
913 | | { |
914 | | type Item = I::Item; |
915 | | |
916 | | fn next(&mut self) -> Option<Self::Item> { |
917 | | self.inner.next() |
918 | | } |
919 | | |
920 | | fn size_hint(&self) -> (usize, Option<usize>) { |
921 | | self.inner.size_hint() |
922 | | } |
923 | | } |
924 | | |
925 | | impl<I> RecordBatchReader for RecordBatchIterator<I> |
926 | | where |
927 | | I: IntoIterator<Item = Result<RecordBatch, ArrowError>>, |
928 | | { |
929 | | fn schema(&self) -> SchemaRef { |
930 | | self.inner_schema.clone() |
931 | | } |
932 | | } |
933 | | |
934 | | #[cfg(test)] |
935 | | mod tests { |
936 | | use super::*; |
937 | | use crate::{ |
938 | | BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray, |
939 | | }; |
940 | | use arrow_buffer::{Buffer, ToByteSlice}; |
941 | | use arrow_data::{ArrayData, ArrayDataBuilder}; |
942 | | use arrow_schema::Fields; |
943 | | use std::collections::HashMap; |
944 | | |
945 | | #[test] |
946 | | fn create_record_batch() { |
947 | | let schema = Schema::new(vec![ |
948 | | Field::new("a", DataType::Int32, false), |
949 | | Field::new("b", DataType::Utf8, false), |
950 | | ]); |
951 | | |
952 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
953 | | let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); |
954 | | |
955 | | let record_batch = |
956 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); |
957 | | check_batch(record_batch, 5) |
958 | | } |
959 | | |
960 | | #[test] |
961 | | fn create_string_view_record_batch() { |
962 | | let schema = Schema::new(vec![ |
963 | | Field::new("a", DataType::Int32, false), |
964 | | Field::new("b", DataType::Utf8View, false), |
965 | | ]); |
966 | | |
967 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
968 | | let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]); |
969 | | |
970 | | let record_batch = |
971 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); |
972 | | |
973 | | assert_eq!(5, record_batch.num_rows()); |
974 | | assert_eq!(2, record_batch.num_columns()); |
975 | | assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); |
976 | | assert_eq!( |
977 | | &DataType::Utf8View, |
978 | | record_batch.schema().field(1).data_type() |
979 | | ); |
980 | | assert_eq!(5, record_batch.column(0).len()); |
981 | | assert_eq!(5, record_batch.column(1).len()); |
982 | | } |
983 | | |
984 | | #[test] |
985 | | fn byte_size_should_not_regress() { |
986 | | let schema = Schema::new(vec![ |
987 | | Field::new("a", DataType::Int32, false), |
988 | | Field::new("b", DataType::Utf8, false), |
989 | | ]); |
990 | | |
991 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
992 | | let b = StringArray::from(vec!["a", "b", "c", "d", "e"]); |
993 | | |
994 | | let record_batch = |
995 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); |
996 | | assert_eq!(record_batch.get_array_memory_size(), 364); |
997 | | } |
998 | | |
999 | | fn check_batch(record_batch: RecordBatch, num_rows: usize) { |
1000 | | assert_eq!(num_rows, record_batch.num_rows()); |
1001 | | assert_eq!(2, record_batch.num_columns()); |
1002 | | assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type()); |
1003 | | assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type()); |
1004 | | assert_eq!(num_rows, record_batch.column(0).len()); |
1005 | | assert_eq!(num_rows, record_batch.column(1).len()); |
1006 | | } |
1007 | | |
1008 | | #[test] |
1009 | | #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")] |
1010 | | fn create_record_batch_slice() { |
1011 | | let schema = Schema::new(vec![ |
1012 | | Field::new("a", DataType::Int32, false), |
1013 | | Field::new("b", DataType::Utf8, false), |
1014 | | ]); |
1015 | | let expected_schema = schema.clone(); |
1016 | | |
1017 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]); |
1018 | | let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]); |
1019 | | |
1020 | | let record_batch = |
1021 | | RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap(); |
1022 | | |
1023 | | let offset = 2; |
1024 | | let length = 5; |
1025 | | let record_batch_slice = record_batch.slice(offset, length); |
1026 | | |
1027 | | assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema); |
1028 | | check_batch(record_batch_slice, 5); |
1029 | | |
1030 | | let offset = 2; |
1031 | | let length = 0; |
1032 | | let record_batch_slice = record_batch.slice(offset, length); |
1033 | | |
1034 | | assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema); |
1035 | | check_batch(record_batch_slice, 0); |
1036 | | |
1037 | | let offset = 2; |
1038 | | let length = 10; |
1039 | | let _record_batch_slice = record_batch.slice(offset, length); |
1040 | | } |
1041 | | |
1042 | | #[test] |
1043 | | #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")] |
1044 | | fn create_record_batch_slice_empty_batch() { |
1045 | | let schema = Schema::empty(); |
1046 | | |
1047 | | let record_batch = RecordBatch::new_empty(Arc::new(schema)); |
1048 | | |
1049 | | let offset = 0; |
1050 | | let length = 0; |
1051 | | let record_batch_slice = record_batch.slice(offset, length); |
1052 | | assert_eq!(0, record_batch_slice.schema().fields().len()); |
1053 | | |
1054 | | let offset = 1; |
1055 | | let length = 2; |
1056 | | let _record_batch_slice = record_batch.slice(offset, length); |
1057 | | } |
1058 | | |
1059 | | #[test] |
1060 | | fn create_record_batch_try_from_iter() { |
1061 | | let a: ArrayRef = Arc::new(Int32Array::from(vec![ |
1062 | | Some(1), |
1063 | | Some(2), |
1064 | | None, |
1065 | | Some(4), |
1066 | | Some(5), |
1067 | | ])); |
1068 | | let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); |
1069 | | |
1070 | | let record_batch = |
1071 | | RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion"); |
1072 | | |
1073 | | let expected_schema = Schema::new(vec![ |
1074 | | Field::new("a", DataType::Int32, true), |
1075 | | Field::new("b", DataType::Utf8, false), |
1076 | | ]); |
1077 | | assert_eq!(record_batch.schema().as_ref(), &expected_schema); |
1078 | | check_batch(record_batch, 5); |
1079 | | } |
1080 | | |
1081 | | #[test] |
1082 | | fn create_record_batch_try_from_iter_with_nullable() { |
1083 | | let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); |
1084 | | let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); |
1085 | | |
1086 | | // Note there are no nulls in a or b, but we specify that b is nullable |
1087 | | let record_batch = |
1088 | | RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)]) |
1089 | | .expect("valid conversion"); |
1090 | | |
1091 | | let expected_schema = Schema::new(vec![ |
1092 | | Field::new("a", DataType::Int32, false), |
1093 | | Field::new("b", DataType::Utf8, true), |
1094 | | ]); |
1095 | | assert_eq!(record_batch.schema().as_ref(), &expected_schema); |
1096 | | check_batch(record_batch, 5); |
1097 | | } |
1098 | | |
1099 | | #[test] |
1100 | | fn create_record_batch_schema_mismatch() { |
1101 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); |
1102 | | |
1103 | | let a = Int64Array::from(vec![1, 2, 3, 4, 5]); |
1104 | | |
1105 | | let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err(); |
1106 | | assert_eq!( |
1107 | | err.to_string(), |
1108 | | "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0" |
1109 | | ); |
1110 | | } |
1111 | | |
1112 | | #[test] |
1113 | | fn create_record_batch_field_name_mismatch() { |
1114 | | let fields = vec![ |
1115 | | Field::new("a1", DataType::Int32, false), |
1116 | | Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false), |
1117 | | ]; |
1118 | | let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)])); |
1119 | | |
1120 | | let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); |
1121 | | let a2_child = Int8Array::from(vec![1, 2, 3, 4]); |
1122 | | let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new( |
1123 | | "array", |
1124 | | DataType::Int8, |
1125 | | false, |
1126 | | )))) |
1127 | | .add_child_data(a2_child.into_data()) |
1128 | | .len(2) |
1129 | | .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice())) |
1130 | | .build() |
1131 | | .unwrap(); |
1132 | | let a2: ArrayRef = Arc::new(ListArray::from(a2)); |
1133 | | let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![ |
1134 | | Field::new("aa1", DataType::Int32, false), |
1135 | | Field::new("a2", a2.data_type().clone(), false), |
1136 | | ]))) |
1137 | | .add_child_data(a1.into_data()) |
1138 | | .add_child_data(a2.into_data()) |
1139 | | .len(2) |
1140 | | .build() |
1141 | | .unwrap(); |
1142 | | let a: ArrayRef = Arc::new(StructArray::from(a)); |
1143 | | |
1144 | | // creating the batch with field name validation should fail |
1145 | | let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]); |
1146 | | assert!(batch.is_err()); |
1147 | | |
1148 | | // creating the batch without field name validation should pass |
1149 | | let options = RecordBatchOptions { |
1150 | | match_field_names: false, |
1151 | | row_count: None, |
1152 | | }; |
1153 | | let batch = RecordBatch::try_new_with_options(schema, vec![a], &options); |
1154 | | assert!(batch.is_ok()); |
1155 | | } |
1156 | | |
1157 | | #[test] |
1158 | | fn create_record_batch_record_mismatch() { |
1159 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); |
1160 | | |
1161 | | let a = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1162 | | let b = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1163 | | |
1164 | | let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); |
1165 | | assert!(batch.is_err()); |
1166 | | } |
1167 | | |
1168 | | #[test] |
1169 | | fn create_record_batch_from_struct_array() { |
1170 | | let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); |
1171 | | let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); |
1172 | | let struct_array = StructArray::from(vec![ |
1173 | | ( |
1174 | | Arc::new(Field::new("b", DataType::Boolean, false)), |
1175 | | boolean.clone() as ArrayRef, |
1176 | | ), |
1177 | | ( |
1178 | | Arc::new(Field::new("c", DataType::Int32, false)), |
1179 | | int.clone() as ArrayRef, |
1180 | | ), |
1181 | | ]); |
1182 | | |
1183 | | let batch = RecordBatch::from(&struct_array); |
1184 | | assert_eq!(2, batch.num_columns()); |
1185 | | assert_eq!(4, batch.num_rows()); |
1186 | | assert_eq!( |
1187 | | struct_array.data_type(), |
1188 | | &DataType::Struct(batch.schema().fields().clone()) |
1189 | | ); |
1190 | | assert_eq!(batch.column(0).as_ref(), boolean.as_ref()); |
1191 | | assert_eq!(batch.column(1).as_ref(), int.as_ref()); |
1192 | | } |
1193 | | |
1194 | | #[test] |
1195 | | fn record_batch_equality() { |
1196 | | let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); |
1197 | | let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]); |
1198 | | let schema1 = Schema::new(vec![ |
1199 | | Field::new("id", DataType::Int32, false), |
1200 | | Field::new("val", DataType::Int32, false), |
1201 | | ]); |
1202 | | |
1203 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1204 | | let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1205 | | let schema2 = Schema::new(vec![ |
1206 | | Field::new("id", DataType::Int32, false), |
1207 | | Field::new("val", DataType::Int32, false), |
1208 | | ]); |
1209 | | |
1210 | | let batch1 = RecordBatch::try_new( |
1211 | | Arc::new(schema1), |
1212 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1213 | | ) |
1214 | | .unwrap(); |
1215 | | |
1216 | | let batch2 = RecordBatch::try_new( |
1217 | | Arc::new(schema2), |
1218 | | vec![Arc::new(id_arr2), Arc::new(val_arr2)], |
1219 | | ) |
1220 | | .unwrap(); |
1221 | | |
1222 | | assert_eq!(batch1, batch2); |
1223 | | } |
1224 | | |
1225 | | /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]` |
1226 | | #[test] |
1227 | | fn record_batch_index_access() { |
1228 | | let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); |
1229 | | let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8])); |
1230 | | let schema1 = Schema::new(vec![ |
1231 | | Field::new("id", DataType::Int32, false), |
1232 | | Field::new("val", DataType::Int32, false), |
1233 | | ]); |
1234 | | let record_batch = |
1235 | | RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap(); |
1236 | | |
1237 | | assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref()); |
1238 | | assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref()); |
1239 | | } |
1240 | | |
1241 | | #[test] |
1242 | | fn record_batch_vals_ne() { |
1243 | | let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); |
1244 | | let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]); |
1245 | | let schema1 = Schema::new(vec![ |
1246 | | Field::new("id", DataType::Int32, false), |
1247 | | Field::new("val", DataType::Int32, false), |
1248 | | ]); |
1249 | | |
1250 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1251 | | let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1252 | | let schema2 = Schema::new(vec![ |
1253 | | Field::new("id", DataType::Int32, false), |
1254 | | Field::new("val", DataType::Int32, false), |
1255 | | ]); |
1256 | | |
1257 | | let batch1 = RecordBatch::try_new( |
1258 | | Arc::new(schema1), |
1259 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1260 | | ) |
1261 | | .unwrap(); |
1262 | | |
1263 | | let batch2 = RecordBatch::try_new( |
1264 | | Arc::new(schema2), |
1265 | | vec![Arc::new(id_arr2), Arc::new(val_arr2)], |
1266 | | ) |
1267 | | .unwrap(); |
1268 | | |
1269 | | assert_ne!(batch1, batch2); |
1270 | | } |
1271 | | |
1272 | | #[test] |
1273 | | fn record_batch_column_names_ne() { |
1274 | | let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); |
1275 | | let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]); |
1276 | | let schema1 = Schema::new(vec![ |
1277 | | Field::new("id", DataType::Int32, false), |
1278 | | Field::new("val", DataType::Int32, false), |
1279 | | ]); |
1280 | | |
1281 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1282 | | let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1283 | | let schema2 = Schema::new(vec![ |
1284 | | Field::new("id", DataType::Int32, false), |
1285 | | Field::new("num", DataType::Int32, false), |
1286 | | ]); |
1287 | | |
1288 | | let batch1 = RecordBatch::try_new( |
1289 | | Arc::new(schema1), |
1290 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1291 | | ) |
1292 | | .unwrap(); |
1293 | | |
1294 | | let batch2 = RecordBatch::try_new( |
1295 | | Arc::new(schema2), |
1296 | | vec![Arc::new(id_arr2), Arc::new(val_arr2)], |
1297 | | ) |
1298 | | .unwrap(); |
1299 | | |
1300 | | assert_ne!(batch1, batch2); |
1301 | | } |
1302 | | |
1303 | | #[test] |
1304 | | fn record_batch_column_number_ne() { |
1305 | | let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]); |
1306 | | let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]); |
1307 | | let schema1 = Schema::new(vec![ |
1308 | | Field::new("id", DataType::Int32, false), |
1309 | | Field::new("val", DataType::Int32, false), |
1310 | | ]); |
1311 | | |
1312 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1313 | | let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1314 | | let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1315 | | let schema2 = Schema::new(vec![ |
1316 | | Field::new("id", DataType::Int32, false), |
1317 | | Field::new("val", DataType::Int32, false), |
1318 | | Field::new("num", DataType::Int32, false), |
1319 | | ]); |
1320 | | |
1321 | | let batch1 = RecordBatch::try_new( |
1322 | | Arc::new(schema1), |
1323 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1324 | | ) |
1325 | | .unwrap(); |
1326 | | |
1327 | | let batch2 = RecordBatch::try_new( |
1328 | | Arc::new(schema2), |
1329 | | vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)], |
1330 | | ) |
1331 | | .unwrap(); |
1332 | | |
1333 | | assert_ne!(batch1, batch2); |
1334 | | } |
1335 | | |
1336 | | #[test] |
1337 | | fn record_batch_row_count_ne() { |
1338 | | let id_arr1 = Int32Array::from(vec![1, 2, 3]); |
1339 | | let val_arr1 = Int32Array::from(vec![5, 6, 7]); |
1340 | | let schema1 = Schema::new(vec![ |
1341 | | Field::new("id", DataType::Int32, false), |
1342 | | Field::new("val", DataType::Int32, false), |
1343 | | ]); |
1344 | | |
1345 | | let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]); |
1346 | | let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]); |
1347 | | let schema2 = Schema::new(vec![ |
1348 | | Field::new("id", DataType::Int32, false), |
1349 | | Field::new("num", DataType::Int32, false), |
1350 | | ]); |
1351 | | |
1352 | | let batch1 = RecordBatch::try_new( |
1353 | | Arc::new(schema1), |
1354 | | vec![Arc::new(id_arr1), Arc::new(val_arr1)], |
1355 | | ) |
1356 | | .unwrap(); |
1357 | | |
1358 | | let batch2 = RecordBatch::try_new( |
1359 | | Arc::new(schema2), |
1360 | | vec![Arc::new(id_arr2), Arc::new(val_arr2)], |
1361 | | ) |
1362 | | .unwrap(); |
1363 | | |
1364 | | assert_ne!(batch1, batch2); |
1365 | | } |
1366 | | |
1367 | | #[test] |
1368 | | fn normalize_simple() { |
1369 | | let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""])); |
1370 | | let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)])); |
1371 | | let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)])); |
1372 | | |
1373 | | let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); |
1374 | | let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); |
1375 | | let year_field = Arc::new(Field::new("year", DataType::Int64, true)); |
1376 | | |
1377 | | let a = Arc::new(StructArray::from(vec![ |
1378 | | (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef), |
1379 | | (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef), |
1380 | | (year_field.clone(), Arc::new(year.clone()) as ArrayRef), |
1381 | | ])); |
1382 | | |
1383 | | let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)])); |
1384 | | |
1385 | | let schema = Schema::new(vec![ |
1386 | | Field::new( |
1387 | | "a", |
1388 | | DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])), |
1389 | | false, |
1390 | | ), |
1391 | | Field::new("month", DataType::Int64, true), |
1392 | | ]); |
1393 | | |
1394 | | let normalized = |
1395 | | RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()]) |
1396 | | .expect("valid conversion") |
1397 | | .normalize(".", Some(0)) |
1398 | | .expect("valid normalization"); |
1399 | | |
1400 | | let expected = RecordBatch::try_from_iter_with_nullable(vec![ |
1401 | | ("a.animals", animals.clone(), true), |
1402 | | ("a.n_legs", n_legs.clone(), true), |
1403 | | ("a.year", year.clone(), true), |
1404 | | ("month", month.clone(), true), |
1405 | | ]) |
1406 | | .expect("valid conversion"); |
1407 | | |
1408 | | assert_eq!(expected, normalized); |
1409 | | |
1410 | | // check 0 and None have the same effect |
1411 | | let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()]) |
1412 | | .expect("valid conversion") |
1413 | | .normalize(".", None) |
1414 | | .expect("valid normalization"); |
1415 | | |
1416 | | assert_eq!(expected, normalized); |
1417 | | } |
1418 | | |
1419 | | #[test] |
1420 | | fn normalize_nested() { |
1421 | | // Initialize schema |
1422 | | let a = Arc::new(Field::new("a", DataType::Int64, true)); |
1423 | | let b = Arc::new(Field::new("b", DataType::Int64, false)); |
1424 | | let c = Arc::new(Field::new("c", DataType::Int64, true)); |
1425 | | |
1426 | | let one = Arc::new(Field::new( |
1427 | | "1", |
1428 | | DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])), |
1429 | | false, |
1430 | | )); |
1431 | | let two = Arc::new(Field::new( |
1432 | | "2", |
1433 | | DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])), |
1434 | | true, |
1435 | | )); |
1436 | | |
1437 | | let exclamation = Arc::new(Field::new( |
1438 | | "!", |
1439 | | DataType::Struct(Fields::from(vec![one.clone(), two.clone()])), |
1440 | | false, |
1441 | | )); |
1442 | | |
1443 | | let schema = Schema::new(vec![exclamation.clone()]); |
1444 | | |
1445 | | // Initialize fields |
1446 | | let a_field = Int64Array::from(vec![Some(0), Some(1)]); |
1447 | | let b_field = Int64Array::from(vec![Some(2), Some(3)]); |
1448 | | let c_field = Int64Array::from(vec![None, Some(4)]); |
1449 | | |
1450 | | let one_field = StructArray::from(vec![ |
1451 | | (a.clone(), Arc::new(a_field.clone()) as ArrayRef), |
1452 | | (b.clone(), Arc::new(b_field.clone()) as ArrayRef), |
1453 | | (c.clone(), Arc::new(c_field.clone()) as ArrayRef), |
1454 | | ]); |
1455 | | let two_field = StructArray::from(vec![ |
1456 | | (a.clone(), Arc::new(a_field.clone()) as ArrayRef), |
1457 | | (b.clone(), Arc::new(b_field.clone()) as ArrayRef), |
1458 | | (c.clone(), Arc::new(c_field.clone()) as ArrayRef), |
1459 | | ]); |
1460 | | |
1461 | | let exclamation_field = Arc::new(StructArray::from(vec![ |
1462 | | (one.clone(), Arc::new(one_field) as ArrayRef), |
1463 | | (two.clone(), Arc::new(two_field) as ArrayRef), |
1464 | | ])); |
1465 | | |
1466 | | // Normalize top level |
1467 | | let normalized = |
1468 | | RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()]) |
1469 | | .expect("valid conversion") |
1470 | | .normalize(".", Some(1)) |
1471 | | .expect("valid normalization"); |
1472 | | |
1473 | | let expected = RecordBatch::try_from_iter_with_nullable(vec![ |
1474 | | ( |
1475 | | "!.1", |
1476 | | Arc::new(StructArray::from(vec![ |
1477 | | (a.clone(), Arc::new(a_field.clone()) as ArrayRef), |
1478 | | (b.clone(), Arc::new(b_field.clone()) as ArrayRef), |
1479 | | (c.clone(), Arc::new(c_field.clone()) as ArrayRef), |
1480 | | ])) as ArrayRef, |
1481 | | false, |
1482 | | ), |
1483 | | ( |
1484 | | "!.2", |
1485 | | Arc::new(StructArray::from(vec![ |
1486 | | (a.clone(), Arc::new(a_field.clone()) as ArrayRef), |
1487 | | (b.clone(), Arc::new(b_field.clone()) as ArrayRef), |
1488 | | (c.clone(), Arc::new(c_field.clone()) as ArrayRef), |
1489 | | ])) as ArrayRef, |
1490 | | true, |
1491 | | ), |
1492 | | ]) |
1493 | | .expect("valid conversion"); |
1494 | | |
1495 | | assert_eq!(expected, normalized); |
1496 | | |
1497 | | // Normalize all levels |
1498 | | let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field]) |
1499 | | .expect("valid conversion") |
1500 | | .normalize(".", None) |
1501 | | .expect("valid normalization"); |
1502 | | |
1503 | | let expected = RecordBatch::try_from_iter_with_nullable(vec![ |
1504 | | ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true), |
1505 | | ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false), |
1506 | | ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true), |
1507 | | ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true), |
1508 | | ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false), |
1509 | | ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true), |
1510 | | ]) |
1511 | | .expect("valid conversion"); |
1512 | | |
1513 | | assert_eq!(expected, normalized); |
1514 | | } |
1515 | | |
1516 | | #[test] |
1517 | | fn normalize_empty() { |
1518 | | let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); |
1519 | | let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); |
1520 | | let year_field = Arc::new(Field::new("year", DataType::Int64, true)); |
1521 | | |
1522 | | let schema = Schema::new(vec![ |
1523 | | Field::new( |
1524 | | "a", |
1525 | | DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])), |
1526 | | false, |
1527 | | ), |
1528 | | Field::new("month", DataType::Int64, true), |
1529 | | ]); |
1530 | | |
1531 | | let normalized = RecordBatch::new_empty(Arc::new(schema.clone())) |
1532 | | .normalize(".", Some(0)) |
1533 | | .expect("valid normalization"); |
1534 | | |
1535 | | let expected = RecordBatch::new_empty(Arc::new( |
1536 | | schema.normalize(".", Some(0)).expect("valid normalization"), |
1537 | | )); |
1538 | | |
1539 | | assert_eq!(expected, normalized); |
1540 | | } |
1541 | | |
1542 | | #[test] |
1543 | | fn project() { |
1544 | | let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); |
1545 | | let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); |
1546 | | let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); |
1547 | | |
1548 | | let record_batch = |
1549 | | RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())]) |
1550 | | .expect("valid conversion"); |
1551 | | |
1552 | | let expected = |
1553 | | RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion"); |
1554 | | |
1555 | | assert_eq!(expected, record_batch.project(&[0, 2]).unwrap()); |
1556 | | } |
1557 | | |
1558 | | #[test] |
1559 | | fn project_empty() { |
1560 | | let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); |
1561 | | |
1562 | | let record_batch = |
1563 | | RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion"); |
1564 | | |
1565 | | let expected = RecordBatch::try_new_with_options( |
1566 | | Arc::new(Schema::empty()), |
1567 | | vec![], |
1568 | | &RecordBatchOptions { |
1569 | | match_field_names: true, |
1570 | | row_count: Some(3), |
1571 | | }, |
1572 | | ) |
1573 | | .expect("valid conversion"); |
1574 | | |
1575 | | assert_eq!(expected, record_batch.project(&[]).unwrap()); |
1576 | | } |
1577 | | |
1578 | | #[test] |
1579 | | fn test_no_column_record_batch() { |
1580 | | let schema = Arc::new(Schema::empty()); |
1581 | | |
1582 | | let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err(); |
1583 | | assert!( |
1584 | | err.to_string() |
1585 | | .contains("must either specify a row count or at least one column") |
1586 | | ); |
1587 | | |
1588 | | let options = RecordBatchOptions::new().with_row_count(Some(10)); |
1589 | | |
1590 | | let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); |
1591 | | assert_eq!(ok.num_rows(), 10); |
1592 | | |
1593 | | let a = ok.slice(2, 5); |
1594 | | assert_eq!(a.num_rows(), 5); |
1595 | | |
1596 | | let b = ok.slice(5, 0); |
1597 | | assert_eq!(b.num_rows(), 0); |
1598 | | |
1599 | | assert_ne!(a, b); |
1600 | | assert_eq!(b, RecordBatch::new_empty(schema)) |
1601 | | } |
1602 | | |
1603 | | #[test] |
1604 | | fn test_nulls_in_non_nullable_field() { |
1605 | | let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); |
1606 | | let maybe_batch = RecordBatch::try_new( |
1607 | | schema, |
1608 | | vec![Arc::new(Int32Array::from(vec![Some(1), None]))], |
1609 | | ); |
1610 | | assert_eq!( |
1611 | | "Invalid argument error: Column 'a' is declared as non-nullable but contains null values", |
1612 | | format!("{}", maybe_batch.err().unwrap()) |
1613 | | ); |
1614 | | } |
1615 | | #[test] |
1616 | | fn test_record_batch_options() { |
1617 | | let options = RecordBatchOptions::new() |
1618 | | .with_match_field_names(false) |
1619 | | .with_row_count(Some(20)); |
1620 | | assert!(!options.match_field_names); |
1621 | | assert_eq!(options.row_count.unwrap(), 20) |
1622 | | } |
1623 | | |
1624 | | #[test] |
1625 | | #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")] |
1626 | | fn test_from_struct() { |
1627 | | let s = StructArray::from(ArrayData::new_null( |
1628 | | // Note child is not nullable |
1629 | | &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()), |
1630 | | 2, |
1631 | | )); |
1632 | | let _ = RecordBatch::from(s); |
1633 | | } |
1634 | | |
1635 | | #[test] |
1636 | | fn test_with_schema() { |
1637 | | let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); |
1638 | | let required_schema = Arc::new(required_schema); |
1639 | | let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); |
1640 | | let nullable_schema = Arc::new(nullable_schema); |
1641 | | |
1642 | | let batch = RecordBatch::try_new( |
1643 | | required_schema.clone(), |
1644 | | vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _], |
1645 | | ) |
1646 | | .unwrap(); |
1647 | | |
1648 | | // Can add nullability |
1649 | | let batch = batch.with_schema(nullable_schema.clone()).unwrap(); |
1650 | | |
1651 | | // Cannot remove nullability |
1652 | | batch.clone().with_schema(required_schema).unwrap_err(); |
1653 | | |
1654 | | // Can add metadata |
1655 | | let metadata = vec![("foo".to_string(), "bar".to_string())] |
1656 | | .into_iter() |
1657 | | .collect(); |
1658 | | let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata); |
1659 | | let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap(); |
1660 | | |
1661 | | // Cannot remove metadata |
1662 | | batch.with_schema(nullable_schema).unwrap_err(); |
1663 | | } |
1664 | | |
1665 | | #[test] |
1666 | | fn test_boxed_reader() { |
1667 | | // Make sure we can pass a boxed reader to a function generic over |
1668 | | // RecordBatchReader. |
1669 | | let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); |
1670 | | let schema = Arc::new(schema); |
1671 | | |
1672 | | let reader = RecordBatchIterator::new(std::iter::empty(), schema); |
1673 | | let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader); |
1674 | | |
1675 | | fn get_size(reader: impl RecordBatchReader) -> usize { |
1676 | | reader.size_hint().0 |
1677 | | } |
1678 | | |
1679 | | let size = get_size(reader); |
1680 | | assert_eq!(size, 0); |
1681 | | } |
1682 | | |
1683 | | #[test] |
1684 | | fn test_remove_column_maintains_schema_metadata() { |
1685 | | let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); |
1686 | | let bool_array = BooleanArray::from(vec![true, false, false, true, true]); |
1687 | | |
1688 | | let mut metadata = HashMap::new(); |
1689 | | metadata.insert("foo".to_string(), "bar".to_string()); |
1690 | | let schema = Schema::new(vec![ |
1691 | | Field::new("id", DataType::Int32, false), |
1692 | | Field::new("bool", DataType::Boolean, false), |
1693 | | ]) |
1694 | | .with_metadata(metadata); |
1695 | | |
1696 | | let mut batch = RecordBatch::try_new( |
1697 | | Arc::new(schema), |
1698 | | vec![Arc::new(id_array), Arc::new(bool_array)], |
1699 | | ) |
1700 | | .unwrap(); |
1701 | | |
1702 | | let _removed_column = batch.remove_column(0); |
1703 | | assert_eq!(batch.schema().metadata().len(), 1); |
1704 | | assert_eq!( |
1705 | | batch.schema().metadata().get("foo").unwrap().as_str(), |
1706 | | "bar" |
1707 | | ); |
1708 | | } |
1709 | | } |