Coverage Report

Created: 2025-11-17 14:14

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-array/src/record_batch.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! A two-dimensional batch of column-oriented data with a defined
19
//! [schema](arrow_schema::Schema).
20
21
use crate::cast::AsArray;
22
use crate::{Array, ArrayRef, StructArray, new_empty_array};
23
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24
use std::ops::Index;
25
use std::sync::Arc;
26
27
/// Trait for types that can read `RecordBatch`'s.
28
///
29
/// To create from an iterator, see [RecordBatchIterator].
30
pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
31
    /// Returns the schema of this `RecordBatchReader`.
32
    ///
33
    /// Implementation of this trait should guarantee that all `RecordBatch`'s returned by this
34
    /// reader should have the same schema as returned from this method.
35
    fn schema(&self) -> SchemaRef;
36
}
37
38
impl<R: RecordBatchReader + ?Sized> RecordBatchReader for Box<R> {
39
    fn schema(&self) -> SchemaRef {
40
        self.as_ref().schema()
41
    }
42
}
43
44
/// Trait for types that can write `RecordBatch`'s.
45
pub trait RecordBatchWriter {
46
    /// Write a single batch to the writer.
47
    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
48
49
    /// Write footer or termination data, then mark the writer as done.
50
    fn close(self) -> Result<(), ArrowError>;
51
}
52
53
/// Creates an array from a literal slice of values,
54
/// suitable for rapid testing and development.
55
///
56
/// Example:
57
///
58
/// ```rust
59
///
60
/// use arrow_array::create_array;
61
///
62
/// let array = create_array!(Int32, [1, 2, 3, 4, 5]);
63
/// let array = create_array!(Utf8, [Some("a"), Some("b"), None, Some("e")]);
64
/// ```
65
/// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used.
66
/// Presently supported data types are:
67
/// - `Boolean`, `Null`
68
/// - `Decimal32`, `Decimal64`, `Decimal128`, `Decimal256`
69
/// - `Float16`, `Float32`, `Float64`
70
/// - `Int8`, `Int16`, `Int32`, `Int64`
71
/// - `UInt8`, `UInt16`, `UInt32`, `UInt64`
72
/// - `IntervalDayTime`, `IntervalYearMonth`
73
/// - `Second`, `Millisecond`, `Microsecond`, `Nanosecond`
74
/// - `Second32`, `Millisecond32`, `Microsecond64`, `Nanosecond64`
75
/// - `DurationSecond`, `DurationMillisecond`, `DurationMicrosecond`, `DurationNanosecond`
76
/// - `TimestampSecond`, `TimestampMillisecond`, `TimestampMicrosecond`, `TimestampNanosecond`
77
/// - `Utf8`, `Utf8View`, `LargeUtf8`, `Binary`, `LargeBinary`
78
#[macro_export]
79
macro_rules! create_array {
80
    // `@from` is used for those types that have a common method `<type>::from`
81
    (@from Boolean) => { $crate::BooleanArray };
82
    (@from Int8) => { $crate::Int8Array };
83
    (@from Int16) => { $crate::Int16Array };
84
    (@from Int32) => { $crate::Int32Array };
85
    (@from Int64) => { $crate::Int64Array };
86
    (@from UInt8) => { $crate::UInt8Array };
87
    (@from UInt16) => { $crate::UInt16Array };
88
    (@from UInt32) => { $crate::UInt32Array };
89
    (@from UInt64) => { $crate::UInt64Array };
90
    (@from Float16) => { $crate::Float16Array };
91
    (@from Float32) => { $crate::Float32Array };
92
    (@from Float64) => { $crate::Float64Array };
93
    (@from Utf8) => { $crate::StringArray };
94
    (@from Utf8View) => { $crate::StringViewArray };
95
    (@from LargeUtf8) => { $crate::LargeStringArray };
96
    (@from IntervalDayTime) => { $crate::IntervalDayTimeArray };
97
    (@from IntervalYearMonth) => { $crate::IntervalYearMonthArray };
98
    (@from Second) => { $crate::TimestampSecondArray };
99
    (@from Millisecond) => { $crate::TimestampMillisecondArray };
100
    (@from Microsecond) => { $crate::TimestampMicrosecondArray };
101
    (@from Nanosecond) => { $crate::TimestampNanosecondArray };
102
    (@from Second32) => { $crate::Time32SecondArray };
103
    (@from Millisecond32) => { $crate::Time32MillisecondArray };
104
    (@from Microsecond64) => { $crate::Time64MicrosecondArray };
105
    (@from Nanosecond64) => { $crate::Time64Nanosecond64Array };
106
    (@from DurationSecond) => { $crate::DurationSecondArray };
107
    (@from DurationMillisecond) => { $crate::DurationMillisecondArray };
108
    (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray };
109
    (@from DurationNanosecond) => { $crate::DurationNanosecondArray };
110
    (@from Decimal32) => { $crate::Decimal32Array };
111
    (@from Decimal64) => { $crate::Decimal64Array };
112
    (@from Decimal128) => { $crate::Decimal128Array };
113
    (@from Decimal256) => { $crate::Decimal256Array };
114
    (@from TimestampSecond) => { $crate::TimestampSecondArray };
115
    (@from TimestampMillisecond) => { $crate::TimestampMillisecondArray };
116
    (@from TimestampMicrosecond) => { $crate::TimestampMicrosecondArray };
117
    (@from TimestampNanosecond) => { $crate::TimestampNanosecondArray };
118
119
    (@from $ty: ident) => {
120
        compile_error!(concat!("Unsupported data type: ", stringify!($ty)))
121
    };
122
123
    (Null, $size: expr) => {
124
        std::sync::Arc::new($crate::NullArray::new($size))
125
    };
126
127
    (Binary, [$($values: expr),*]) => {
128
        std::sync::Arc::new($crate::BinaryArray::from_vec(vec![$($values),*]))
129
    };
130
131
    (LargeBinary, [$($values: expr),*]) => {
132
        std::sync::Arc::new($crate::LargeBinaryArray::from_vec(vec![$($values),*]))
133
    };
134
135
    ($ty: tt, [$($values: expr),*]) => {
136
        std::sync::Arc::new(<$crate::create_array!(@from $ty)>::from(vec![$($values),*]))
137
    };
138
}
139
140
/// Creates a record batch from literal slice of values, suitable for rapid
141
/// testing and development.
142
///
143
/// Example:
144
///
145
/// ```rust
146
/// use arrow_array::record_batch;
147
/// use arrow_schema;
148
///
149
/// let batch = record_batch!(
150
///     ("a", Int32, [1, 2, 3]),
151
///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
152
///     ("c", Utf8, ["alpha", "beta", "gamma"])
153
/// );
154
/// ```
155
/// Due to limitation of [`create_array!`] macro, support for limited data types is available.
156
#[macro_export]
157
macro_rules! record_batch {
158
    ($(($name: expr, $type: ident, [$($values: expr),*])),*) => {
159
        {
160
            let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
161
                $(
162
                    arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
163
                )*
164
            ]));
165
166
            let batch = $crate::RecordBatch::try_new(
167
                schema,
168
                vec![$(
169
                    $crate::create_array!($type, [$($values),*]),
170
                )*]
171
            );
172
173
            batch
174
        }
175
    }
176
}
177
178
/// A two-dimensional batch of column-oriented data with a defined
179
/// [schema](arrow_schema::Schema).
180
///
181
/// A `RecordBatch` is a two-dimensional dataset of a number of
182
/// contiguous arrays, each the same length.
183
/// A record batch has a schema which must match its arrays’
184
/// datatypes.
185
///
186
/// Record batches are a convenient unit of work for various
187
/// serialization and computation functions, possibly incremental.
188
///
189
/// Use the [`record_batch!`] macro to create a [`RecordBatch`] from
190
/// literal slice of values, useful for rapid prototyping and testing.
191
///
192
/// Example:
193
/// ```rust
194
/// use arrow_array::record_batch;
195
/// let batch = record_batch!(
196
///     ("a", Int32, [1, 2, 3]),
197
///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
198
///     ("c", Utf8, ["alpha", "beta", "gamma"])
199
/// );
200
/// ```
201
#[derive(Clone, Debug, PartialEq)]
202
pub struct RecordBatch {
203
    schema: SchemaRef,
204
    columns: Vec<Arc<dyn Array>>,
205
206
    /// The number of rows in this RecordBatch
207
    ///
208
    /// This is stored separately from the columns to handle the case of no columns
209
    row_count: usize,
210
}
211
212
impl RecordBatch {
213
    /// Creates a `RecordBatch` from a schema and columns.
214
    ///
215
    /// Expects the following:
216
    ///
217
    ///  * `!columns.is_empty()`
218
    ///  * `schema.fields.len() == columns.len()`
219
    ///  * `schema.fields[i].data_type() == columns[i].data_type()`
220
    ///  * `columns[i].len() == columns[j].len()`
221
    ///
222
    /// If the conditions are not met, an error is returned.
223
    ///
224
    /// # Example
225
    ///
226
    /// ```
227
    /// # use std::sync::Arc;
228
    /// # use arrow_array::{Int32Array, RecordBatch};
229
    /// # use arrow_schema::{DataType, Field, Schema};
230
    ///
231
    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
232
    /// let schema = Schema::new(vec![
233
    ///     Field::new("id", DataType::Int32, false)
234
    /// ]);
235
    ///
236
    /// let batch = RecordBatch::try_new(
237
    ///     Arc::new(schema),
238
    ///     vec![Arc::new(id_array)]
239
    /// ).unwrap();
240
    /// ```
241
128
    pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self, ArrowError> {
242
128
        let options = RecordBatchOptions::new();
243
128
        Self::try_new_impl(schema, columns, &options)
244
128
    }
245
246
    /// Creates a `RecordBatch` from a schema and columns, without validation.
247
    ///
248
    /// See [`Self::try_new`] for the checked version.
249
    ///
250
    /// # Safety
251
    ///
252
    /// Expects the following:
253
    ///
254
    ///  * `schema.fields.len() == columns.len()`
255
    ///  * `schema.fields[i].data_type() == columns[i].data_type()`
256
    ///  * `columns[i].len() == row_count`
257
    ///
258
    /// Note: if the schema does not match the underlying data exactly, it can lead to undefined
259
    /// behavior, for example, via conversion to a `StructArray`, which in turn could lead
260
    /// to incorrect access.
261
185
    pub unsafe fn new_unchecked(
262
185
        schema: SchemaRef,
263
185
        columns: Vec<Arc<dyn Array>>,
264
185
        row_count: usize,
265
185
    ) -> Self {
266
185
        Self {
267
185
            schema,
268
185
            columns,
269
185
            row_count,
270
185
        }
271
185
    }
272
273
    /// Creates a `RecordBatch` from a schema and columns, with additional options,
274
    /// such as whether to strictly validate field names.
275
    ///
276
    /// See [`RecordBatch::try_new`] for the expected conditions.
277
150
    pub fn try_new_with_options(
278
150
        schema: SchemaRef,
279
150
        columns: Vec<ArrayRef>,
280
150
        options: &RecordBatchOptions,
281
150
    ) -> Result<Self, ArrowError> {
282
150
        Self::try_new_impl(schema, columns, options)
283
150
    }
284
285
    /// Creates a new empty [`RecordBatch`].
286
3
    pub fn new_empty(schema: SchemaRef) -> Self {
287
3
        let columns = schema
288
3
            .fields()
289
3
            .iter()
290
3
            .map(|field| new_empty_array(field.data_type()))
291
3
            .collect();
292
293
3
        RecordBatch {
294
3
            schema,
295
3
            columns,
296
3
            row_count: 0,
297
3
        }
298
3
    }
299
300
    /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error
301
    /// if any validation check fails, otherwise returns the created [`Self`]
302
278
    fn try_new_impl(
303
278
        schema: SchemaRef,
304
278
        columns: Vec<ArrayRef>,
305
278
        options: &RecordBatchOptions,
306
278
    ) -> Result<Self, ArrowError> {
307
        // check that number of fields in schema match column length
308
278
        if schema.fields().len() != columns.len() {
309
0
            return Err(ArrowError::InvalidArgumentError(format!(
310
0
                "number of columns({}) must match number of fields({}) in schema",
311
0
                columns.len(),
312
0
                schema.fields().len(),
313
0
            )));
314
278
        }
315
316
278
        let row_count = options
317
278
            .row_count
318
278
            .or_else(|| 
columns.first()128
.
map128
(|col|
col128
.
len128
()))
319
278
            .ok_or_else(|| 
{0
320
0
                ArrowError::InvalidArgumentError(
321
0
                    "must either specify a row count or at least one column".to_string(),
322
0
                )
323
0
            })?;
324
325
572
        for (c, f) in 
columns.iter()278
.
zip278
(
&schema.fields278
) {
326
572
            if !f.is_nullable() && 
c.null_count() > 060
{
327
0
                return Err(ArrowError::InvalidArgumentError(format!(
328
0
                    "Column '{}' is declared as non-nullable but contains null values",
329
0
                    f.name()
330
0
                )));
331
572
            }
332
        }
333
334
        // check that all columns have the same row count
335
572
        if 
columns.iter()278
.
any278
(|c| c.len() != row_count) {
336
0
            let err = match options.row_count {
337
0
                Some(_) => "all columns in a record batch must have the specified row count",
338
0
                None => "all columns in a record batch must have the same length",
339
            };
340
0
            return Err(ArrowError::InvalidArgumentError(err.to_string()));
341
278
        }
342
343
        // function for comparing column type and field type
344
        // return true if 2 types are not matched
345
278
        let type_not_match = if options.match_field_names {
346
572
            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
347
        } else {
348
0
            |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
349
0
                !col_type.equals_datatype(field_type)
350
0
            }
351
        };
352
353
        // check that all columns match the schema
354
278
        let not_match = columns
355
278
            .iter()
356
278
            .zip(schema.fields().iter())
357
572
            .
map278
(|(col, field)| (col.data_type(), field.data_type()))
358
278
            .enumerate()
359
278
            .find(type_not_match);
360
361
278
        if let Some((
i0
, (
col_type0
,
field_type0
))) = not_match {
362
0
            return Err(ArrowError::InvalidArgumentError(format!(
363
0
                "column types must match schema types, expected {field_type} but found {col_type} at column index {i}"
364
0
            )));
365
278
        }
366
367
278
        Ok(RecordBatch {
368
278
            schema,
369
278
            columns,
370
278
            row_count,
371
278
        })
372
278
    }
373
374
    /// Return the schema, columns and row count of this [`RecordBatch`]
375
372
    pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
376
372
        (self.schema, self.columns, self.row_count)
377
372
    }
378
379
    /// Override the schema of this [`RecordBatch`]
380
    ///
381
    /// Returns an error if `schema` is not a superset of the current schema
382
    /// as determined by [`Schema::contains`]
383
    ///
384
    /// See also [`Self::schema_metadata_mut`].
385
0
    pub fn with_schema(self, schema: SchemaRef) -> Result<Self, ArrowError> {
386
0
        if !schema.contains(self.schema.as_ref()) {
387
0
            return Err(ArrowError::SchemaError(format!(
388
0
                "target schema is not superset of current schema target={schema} current={}",
389
0
                self.schema
390
0
            )));
391
0
        }
392
393
0
        Ok(Self {
394
0
            schema,
395
0
            columns: self.columns,
396
0
            row_count: self.row_count,
397
0
        })
398
0
    }
399
400
    /// Returns the [`Schema`] of the record batch.
401
145
    pub fn schema(&self) -> SchemaRef {
402
145
        self.schema.clone()
403
145
    }
404
405
    /// Returns a reference to the [`Schema`] of the record batch.
406
1
    pub fn schema_ref(&self) -> &SchemaRef {
407
1
        &self.schema
408
1
    }
409
410
    /// Mutable access to the metadata of the schema.
411
    ///
412
    /// This allows you to modify [`Schema::metadata`] of [`Self::schema`] in a convenient and fast way.
413
    ///
414
    /// Note this will clone the entire underlying `Schema` object if it is currently shared
415
    ///
416
    /// # Example
417
    /// ```
418
    /// # use std::sync::Arc;
419
    /// # use arrow_array::{record_batch, RecordBatch};
420
    /// let mut batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
421
    /// // Initially, the metadata is empty
422
    /// assert!(batch.schema().metadata().get("key").is_none());
423
    /// // Insert a key-value pair into the metadata
424
    /// batch.schema_metadata_mut().insert("key".into(), "value".into());
425
    /// assert_eq!(batch.schema().metadata().get("key"), Some(&String::from("value")));
426
    /// ```
427
0
    pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap<String, String> {
428
0
        let schema = Arc::make_mut(&mut self.schema);
429
0
        &mut schema.metadata
430
0
    }
431
432
    /// Projects the schema onto the specified columns
433
0
    pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
434
0
        let projected_schema = self.schema.project(indices)?;
435
0
        let batch_fields = indices
436
0
            .iter()
437
0
            .map(|f| {
438
0
                self.columns.get(*f).cloned().ok_or_else(|| {
439
0
                    ArrowError::SchemaError(format!(
440
0
                        "project index {} out of bounds, max field {}",
441
0
                        f,
442
0
                        self.columns.len()
443
0
                    ))
444
0
                })
445
0
            })
446
0
            .collect::<Result<Vec<_>, _>>()?;
447
448
        unsafe {
449
            // Since we're starting from a valid RecordBatch and project
450
            // creates a strict subset of the original, there's no need to
451
            // redo the validation checks in `try_new_with_options`.
452
0
            Ok(RecordBatch::new_unchecked(
453
0
                SchemaRef::new(projected_schema),
454
0
                batch_fields,
455
0
                self.row_count,
456
0
            ))
457
        }
458
0
    }
459
460
    /// Normalize a semi-structured [`RecordBatch`] into a flat table.
461
    ///
462
    /// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level`
463
    /// (unlimited if `None`).
464
    ///
465
    /// e.g. given a [`RecordBatch`] with schema:
466
    ///
467
    /// ```text
468
    ///     "foo": StructArray<"bar": Utf8>
469
    /// ```
470
    ///
471
    /// A separator of `"."` would generate a batch with the schema:
472
    ///
473
    /// ```text
474
    ///     "foo.bar": Utf8
475
    /// ```
476
    ///
477
    /// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`;
478
    /// it will be treated as unlimited.
479
    ///
480
    /// # Example
481
    ///
482
    /// ```
483
    /// # use std::sync::Arc;
484
    /// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch};
485
    /// # use arrow_schema::{DataType, Field, Fields, Schema};
486
    /// #
487
    /// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
488
    /// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
489
    ///
490
    /// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
491
    /// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
492
    ///
493
    /// let a = Arc::new(StructArray::from(vec![
494
    ///     (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
495
    ///     (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
496
    /// ]));
497
    ///
498
    /// let schema = Schema::new(vec![
499
    ///     Field::new(
500
    ///         "a",
501
    ///         DataType::Struct(Fields::from(vec![animals_field, n_legs_field])),
502
    ///         false,
503
    ///     )
504
    /// ]);
505
    ///
506
    /// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a])
507
    ///     .expect("valid conversion")
508
    ///     .normalize(".", None)
509
    ///     .expect("valid normalization");
510
    ///
511
    /// let expected = RecordBatch::try_from_iter_with_nullable(vec![
512
    ///     ("a.animals", animals.clone(), true),
513
    ///     ("a.n_legs", n_legs.clone(), true),
514
    /// ])
515
    /// .expect("valid conversion");
516
    ///
517
    /// assert_eq!(expected, normalized);
518
    /// ```
519
0
    pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
520
0
        let max_level = match max_level.unwrap_or(usize::MAX) {
521
0
            0 => usize::MAX,
522
0
            val => val,
523
        };
524
0
        let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
525
0
            .columns
526
0
            .iter()
527
0
            .zip(self.schema.fields())
528
0
            .rev()
529
0
            .map(|(c, f)| {
530
0
                let name_vec: Vec<&str> = vec![f.name()];
531
0
                (0, c, name_vec, f)
532
0
            })
533
0
            .collect();
534
0
        let mut columns: Vec<ArrayRef> = Vec::new();
535
0
        let mut fields: Vec<FieldRef> = Vec::new();
536
537
0
        while let Some((depth, c, name, field_ref)) = stack.pop() {
538
0
            match field_ref.data_type() {
539
0
                DataType::Struct(ff) if depth < max_level => {
540
                    // Need to zip these in reverse to maintain original order
541
0
                    for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
542
0
                        let mut name = name.clone();
543
0
                        name.push(separator);
544
0
                        name.push(fff.name());
545
0
                        stack.push((depth + 1, cff, name, fff))
546
                    }
547
                }
548
0
                _ => {
549
0
                    let updated_field = Field::new(
550
0
                        name.concat(),
551
0
                        field_ref.data_type().clone(),
552
0
                        field_ref.is_nullable(),
553
0
                    );
554
0
                    columns.push(c.clone());
555
0
                    fields.push(Arc::new(updated_field));
556
0
                }
557
            }
558
        }
559
0
        RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
560
0
    }
561
562
    /// Returns the number of columns in the record batch.
563
    ///
564
    /// # Example
565
    ///
566
    /// ```
567
    /// # use std::sync::Arc;
568
    /// # use arrow_array::{Int32Array, RecordBatch};
569
    /// # use arrow_schema::{DataType, Field, Schema};
570
    ///
571
    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
572
    /// let schema = Schema::new(vec![
573
    ///     Field::new("id", DataType::Int32, false)
574
    /// ]);
575
    ///
576
    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
577
    ///
578
    /// assert_eq!(batch.num_columns(), 1);
579
    /// ```
580
84
    pub fn num_columns(&self) -> usize {
581
84
        self.columns.len()
582
84
    }
583
584
    /// Returns the number of rows in each column.
585
    ///
586
    /// # Example
587
    ///
588
    /// ```
589
    /// # use std::sync::Arc;
590
    /// # use arrow_array::{Int32Array, RecordBatch};
591
    /// # use arrow_schema::{DataType, Field, Schema};
592
    ///
593
    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
594
    /// let schema = Schema::new(vec![
595
    ///     Field::new("id", DataType::Int32, false)
596
    /// ]);
597
    ///
598
    /// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array)]).unwrap();
599
    ///
600
    /// assert_eq!(batch.num_rows(), 5);
601
    /// ```
602
473
    pub fn num_rows(&self) -> usize {
603
473
        self.row_count
604
473
    }
605
606
    /// Get a reference to a column's array by index.
607
    ///
608
    /// # Panics
609
    ///
610
    /// Panics if `index` is outside of `0..num_columns`.
611
341
    pub fn column(&self, index: usize) -> &ArrayRef {
612
341
        &self.columns[index]
613
341
    }
614
615
    /// Get a reference to a column's array by name.
616
13
    pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
617
13
        self.schema()
618
13
            .column_with_name(name)
619
13
            .map(|(index, _)| &self.columns[index])
620
13
    }
621
622
    /// Get a reference to all columns in the record batch.
623
159
    pub fn columns(&self) -> &[ArrayRef] {
624
159
        &self.columns[..]
625
159
    }
626
627
    /// Remove column by index and return it.
628
    ///
629
    /// Return the `ArrayRef` if the column is removed.
630
    ///
631
    /// # Panics
632
    ///
633
    /// Panics if `index`` out of bounds.
634
    ///
635
    /// # Example
636
    ///
637
    /// ```
638
    /// use std::sync::Arc;
639
    /// use arrow_array::{BooleanArray, Int32Array, RecordBatch};
640
    /// use arrow_schema::{DataType, Field, Schema};
641
    /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
642
    /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
643
    /// let schema = Schema::new(vec![
644
    ///     Field::new("id", DataType::Int32, false),
645
    ///     Field::new("bool", DataType::Boolean, false),
646
    /// ]);
647
    ///
648
    /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap();
649
    ///
650
    /// let removed_column = batch.remove_column(0);
651
    /// assert_eq!(removed_column.as_any().downcast_ref::<Int32Array>().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5]));
652
    /// assert_eq!(batch.num_columns(), 1);
653
    /// ```
654
0
    pub fn remove_column(&mut self, index: usize) -> ArrayRef {
655
0
        let mut builder = SchemaBuilder::from(self.schema.as_ref());
656
0
        builder.remove(index);
657
0
        self.schema = Arc::new(builder.finish());
658
0
        self.columns.remove(index)
659
0
    }
660
661
    /// Return a new RecordBatch where each column is sliced
662
    /// according to `offset` and `length`
663
    ///
664
    /// # Panics
665
    ///
666
    /// Panics if `offset` with `length` is greater than column length.
667
77
    pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
668
77
        assert!((offset + length) <= self.num_rows());
669
670
77
        let columns = self
671
77
            .columns()
672
77
            .iter()
673
158
            .
map77
(|column| column.slice(offset, length))
674
77
            .collect();
675
676
77
        Self {
677
77
            schema: self.schema.clone(),
678
77
            columns,
679
77
            row_count: length,
680
77
        }
681
77
    }
682
683
    /// Create a `RecordBatch` from an iterable list of pairs of the
684
    /// form `(field_name, array)`, with the same requirements on
685
    /// fields and arrays as [`RecordBatch::try_new`]. This method is
686
    /// often used to create a single `RecordBatch` from arrays,
687
    /// e.g. for testing.
688
    ///
689
    /// The resulting schema is marked as nullable for each column if
690
    /// the array for that column is has any nulls. To explicitly
691
    /// specify nullibility, use [`RecordBatch::try_from_iter_with_nullable`]
692
    ///
693
    /// Example:
694
    /// ```
695
    /// # use std::sync::Arc;
696
    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
697
    ///
698
    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
699
    /// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
700
    ///
701
    /// let record_batch = RecordBatch::try_from_iter(vec![
702
    ///   ("a", a),
703
    ///   ("b", b),
704
    /// ]);
705
    /// ```
706
    /// Another way to quickly create a [`RecordBatch`] is to use the [`record_batch!`] macro,
707
    /// which is particularly helpful for rapid prototyping and testing.
708
    ///
709
    /// Example:
710
    ///
711
    /// ```rust
712
    /// use arrow_array::record_batch;
713
    /// let batch = record_batch!(
714
    ///     ("a", Int32, [1, 2, 3]),
715
    ///     ("b", Float64, [Some(4.0), None, Some(5.0)]),
716
    ///     ("c", Utf8, ["alpha", "beta", "gamma"])
717
    /// );
718
    /// ```
719
41
    pub fn try_from_iter<I, F>(value: I) -> Result<Self, ArrowError>
720
41
    where
721
41
        I: IntoIterator<Item = (F, ArrayRef)>,
722
41
        F: AsRef<str>,
723
    {
724
        // TODO: implement `TryFrom` trait, once
725
        // https://github.com/rust-lang/rust/issues/50133 is no longer an
726
        // issue
727
161
        let 
iter41
=
value41
.
into_iter41
().
map41
(|(field_name, array)| {
728
161
            let nullable = array.null_count() > 0;
729
161
            (field_name, array, nullable)
730
161
        });
731
732
41
        Self::try_from_iter_with_nullable(iter)
733
41
    }
734
735
    /// Create a `RecordBatch` from an iterable list of tuples of the
736
    /// form `(field_name, array, nullable)`, with the same requirements on
737
    /// fields and arrays as [`RecordBatch::try_new`]. This method is often
738
    /// used to create a single `RecordBatch` from arrays, e.g. for
739
    /// testing.
740
    ///
741
    /// Example:
742
    /// ```
743
    /// # use std::sync::Arc;
744
    /// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray};
745
    ///
746
    /// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
747
    /// let b: ArrayRef = Arc::new(StringArray::from(vec![Some("a"), Some("b")]));
748
    ///
749
    /// // Note neither `a` nor `b` has any actual nulls, but we mark
750
    /// // b an nullable
751
    /// let record_batch = RecordBatch::try_from_iter_with_nullable(vec![
752
    ///   ("a", a, false),
753
    ///   ("b", b, true),
754
    /// ]);
755
    /// ```
756
41
    pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self, ArrowError>
757
41
    where
758
41
        I: IntoIterator<Item = (F, ArrayRef, bool)>,
759
41
        F: AsRef<str>,
760
    {
761
41
        let iter = value.into_iter();
762
41
        let capacity = iter.size_hint().0;
763
41
        let mut schema = SchemaBuilder::with_capacity(capacity);
764
41
        let mut columns = Vec::with_capacity(capacity);
765
766
202
        for (
field_name161
,
array161
,
nullable161
) in iter {
767
161
            let field_name = field_name.as_ref();
768
161
            schema.push(Field::new(field_name, array.data_type().clone(), nullable));
769
161
            columns.push(array);
770
161
        }
771
772
41
        let schema = Arc::new(schema.finish());
773
41
        RecordBatch::try_new(schema, columns)
774
41
    }
775
776
    /// Returns the total number of bytes of memory occupied physically by this batch.
777
    ///
778
    /// Note that this does not always correspond to the exact memory usage of a
779
    /// `RecordBatch` (might overestimate), since multiple columns can share the same
780
    /// buffers or slices thereof, the memory used by the shared buffers might be
781
    /// counted multiple times.
782
0
    pub fn get_array_memory_size(&self) -> usize {
783
0
        self.columns()
784
0
            .iter()
785
0
            .map(|array| array.get_array_memory_size())
786
0
            .sum()
787
0
    }
788
}
789
790
/// Options that control the behaviour used when creating a [`RecordBatch`].
791
#[derive(Debug)]
792
#[non_exhaustive]
793
pub struct RecordBatchOptions {
794
    /// Match field names of structs and lists. If set to `true`, the names must match.
795
    pub match_field_names: bool,
796
797
    /// Optional row count, useful for specifying a row count for a RecordBatch with no columns
798
    pub row_count: Option<usize>,
799
}
800
801
impl RecordBatchOptions {
802
    /// Creates a new `RecordBatchOptions`
803
278
    pub fn new() -> Self {
804
278
        Self {
805
278
            match_field_names: true,
806
278
            row_count: None,
807
278
        }
808
278
    }
809
    /// Sets the row_count of RecordBatchOptions and returns self
810
147
    pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
811
147
        self.row_count = row_count;
812
147
        self
813
147
    }
814
    /// Sets the match_field_names of RecordBatchOptions and returns self
815
0
    pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
816
0
        self.match_field_names = match_field_names;
817
0
        self
818
0
    }
819
}
820
impl Default for RecordBatchOptions {
821
4
    fn default() -> Self {
822
4
        Self::new()
823
4
    }
824
}
825
impl From<StructArray> for RecordBatch {
826
0
    fn from(value: StructArray) -> Self {
827
0
        let row_count = value.len();
828
0
        let (fields, columns, nulls) = value.into_parts();
829
0
        assert_eq!(
830
0
            nulls.map(|n| n.null_count()).unwrap_or_default(),
831
            0,
832
0
            "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
833
        );
834
835
0
        RecordBatch {
836
0
            schema: Arc::new(Schema::new(fields)),
837
0
            row_count,
838
0
            columns,
839
0
        }
840
0
    }
841
}
842
843
impl From<&StructArray> for RecordBatch {
844
0
    fn from(struct_array: &StructArray) -> Self {
845
0
        struct_array.clone().into()
846
0
    }
847
}
848
849
impl Index<&str> for RecordBatch {
850
    type Output = ArrayRef;
851
852
    /// Get a reference to a column's array by name.
853
    ///
854
    /// # Panics
855
    ///
856
    /// Panics if the name is not in the schema.
857
0
    fn index(&self, name: &str) -> &Self::Output {
858
0
        self.column_by_name(name).unwrap()
859
0
    }
860
}
861
862
/// Generic implementation of [RecordBatchReader] that wraps an iterator.
863
///
864
/// # Example
865
///
866
/// ```
867
/// # use std::sync::Arc;
868
/// # use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, RecordBatchIterator, RecordBatchReader};
869
/// #
870
/// let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
871
/// let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"]));
872
///
873
/// let record_batch = RecordBatch::try_from_iter(vec![
874
///   ("a", a),
875
///   ("b", b),
876
/// ]).unwrap();
877
///
878
/// let batches: Vec<RecordBatch> = vec![record_batch.clone(), record_batch.clone()];
879
///
880
/// let mut reader = RecordBatchIterator::new(batches.into_iter().map(Ok), record_batch.schema());
881
///
882
/// assert_eq!(reader.schema(), record_batch.schema());
883
/// assert_eq!(reader.next().unwrap().unwrap(), record_batch);
884
/// # assert_eq!(reader.next().unwrap().unwrap(), record_batch);
885
/// # assert!(reader.next().is_none());
886
/// ```
887
pub struct RecordBatchIterator<I>
888
where
889
    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
890
{
891
    inner: I::IntoIter,
892
    inner_schema: SchemaRef,
893
}
894
895
impl<I> RecordBatchIterator<I>
896
where
897
    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
898
{
899
    /// Create a new [RecordBatchIterator].
900
    ///
901
    /// If `iter` is an infallible iterator, use `.map(Ok)`.
902
    pub fn new(iter: I, schema: SchemaRef) -> Self {
903
        Self {
904
            inner: iter.into_iter(),
905
            inner_schema: schema,
906
        }
907
    }
908
}
909
910
impl<I> Iterator for RecordBatchIterator<I>
911
where
912
    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
913
{
914
    type Item = I::Item;
915
916
    fn next(&mut self) -> Option<Self::Item> {
917
        self.inner.next()
918
    }
919
920
    fn size_hint(&self) -> (usize, Option<usize>) {
921
        self.inner.size_hint()
922
    }
923
}
924
925
impl<I> RecordBatchReader for RecordBatchIterator<I>
926
where
927
    I: IntoIterator<Item = Result<RecordBatch, ArrowError>>,
928
{
929
    fn schema(&self) -> SchemaRef {
930
        self.inner_schema.clone()
931
    }
932
}
933
934
#[cfg(test)]
935
mod tests {
936
    use super::*;
937
    use crate::{
938
        BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray,
939
    };
940
    use arrow_buffer::{Buffer, ToByteSlice};
941
    use arrow_data::{ArrayData, ArrayDataBuilder};
942
    use arrow_schema::Fields;
943
    use std::collections::HashMap;
944
945
    #[test]
946
    fn create_record_batch() {
947
        let schema = Schema::new(vec![
948
            Field::new("a", DataType::Int32, false),
949
            Field::new("b", DataType::Utf8, false),
950
        ]);
951
952
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
953
        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
954
955
        let record_batch =
956
            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
957
        check_batch(record_batch, 5)
958
    }
959
960
    #[test]
961
    fn create_string_view_record_batch() {
962
        let schema = Schema::new(vec![
963
            Field::new("a", DataType::Int32, false),
964
            Field::new("b", DataType::Utf8View, false),
965
        ]);
966
967
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
968
        let b = StringViewArray::from(vec!["a", "b", "c", "d", "e"]);
969
970
        let record_batch =
971
            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
972
973
        assert_eq!(5, record_batch.num_rows());
974
        assert_eq!(2, record_batch.num_columns());
975
        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
976
        assert_eq!(
977
            &DataType::Utf8View,
978
            record_batch.schema().field(1).data_type()
979
        );
980
        assert_eq!(5, record_batch.column(0).len());
981
        assert_eq!(5, record_batch.column(1).len());
982
    }
983
984
    #[test]
985
    fn byte_size_should_not_regress() {
986
        let schema = Schema::new(vec![
987
            Field::new("a", DataType::Int32, false),
988
            Field::new("b", DataType::Utf8, false),
989
        ]);
990
991
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
992
        let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
993
994
        let record_batch =
995
            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
996
        assert_eq!(record_batch.get_array_memory_size(), 364);
997
    }
998
999
    fn check_batch(record_batch: RecordBatch, num_rows: usize) {
1000
        assert_eq!(num_rows, record_batch.num_rows());
1001
        assert_eq!(2, record_batch.num_columns());
1002
        assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
1003
        assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
1004
        assert_eq!(num_rows, record_batch.column(0).len());
1005
        assert_eq!(num_rows, record_batch.column(1).len());
1006
    }
1007
1008
    #[test]
1009
    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1010
    fn create_record_batch_slice() {
1011
        let schema = Schema::new(vec![
1012
            Field::new("a", DataType::Int32, false),
1013
            Field::new("b", DataType::Utf8, false),
1014
        ]);
1015
        let expected_schema = schema.clone();
1016
1017
        let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
1018
        let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
1019
1020
        let record_batch =
1021
            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
1022
1023
        let offset = 2;
1024
        let length = 5;
1025
        let record_batch_slice = record_batch.slice(offset, length);
1026
1027
        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1028
        check_batch(record_batch_slice, 5);
1029
1030
        let offset = 2;
1031
        let length = 0;
1032
        let record_batch_slice = record_batch.slice(offset, length);
1033
1034
        assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
1035
        check_batch(record_batch_slice, 0);
1036
1037
        let offset = 2;
1038
        let length = 10;
1039
        let _record_batch_slice = record_batch.slice(offset, length);
1040
    }
1041
1042
    #[test]
1043
    #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
1044
    fn create_record_batch_slice_empty_batch() {
1045
        let schema = Schema::empty();
1046
1047
        let record_batch = RecordBatch::new_empty(Arc::new(schema));
1048
1049
        let offset = 0;
1050
        let length = 0;
1051
        let record_batch_slice = record_batch.slice(offset, length);
1052
        assert_eq!(0, record_batch_slice.schema().fields().len());
1053
1054
        let offset = 1;
1055
        let length = 2;
1056
        let _record_batch_slice = record_batch.slice(offset, length);
1057
    }
1058
1059
    #[test]
1060
    fn create_record_batch_try_from_iter() {
1061
        let a: ArrayRef = Arc::new(Int32Array::from(vec![
1062
            Some(1),
1063
            Some(2),
1064
            None,
1065
            Some(4),
1066
            Some(5),
1067
        ]));
1068
        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1069
1070
        let record_batch =
1071
            RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
1072
1073
        let expected_schema = Schema::new(vec![
1074
            Field::new("a", DataType::Int32, true),
1075
            Field::new("b", DataType::Utf8, false),
1076
        ]);
1077
        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1078
        check_batch(record_batch, 5);
1079
    }
1080
1081
    #[test]
1082
    fn create_record_batch_try_from_iter_with_nullable() {
1083
        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
1084
        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1085
1086
        // Note there are no nulls in a or b, but we specify that b is nullable
1087
        let record_batch =
1088
            RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
1089
                .expect("valid conversion");
1090
1091
        let expected_schema = Schema::new(vec![
1092
            Field::new("a", DataType::Int32, false),
1093
            Field::new("b", DataType::Utf8, true),
1094
        ]);
1095
        assert_eq!(record_batch.schema().as_ref(), &expected_schema);
1096
        check_batch(record_batch, 5);
1097
    }
1098
1099
    #[test]
1100
    fn create_record_batch_schema_mismatch() {
1101
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1102
1103
        let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
1104
1105
        let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err();
1106
        assert_eq!(
1107
            err.to_string(),
1108
            "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0"
1109
        );
1110
    }
1111
1112
    #[test]
1113
    fn create_record_batch_field_name_mismatch() {
1114
        let fields = vec![
1115
            Field::new("a1", DataType::Int32, false),
1116
            Field::new_list("a2", Field::new_list_field(DataType::Int8, false), false),
1117
        ];
1118
        let schema = Arc::new(Schema::new(vec![Field::new_struct("a", fields, true)]));
1119
1120
        let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1121
        let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
1122
        let a2 = ArrayDataBuilder::new(DataType::List(Arc::new(Field::new(
1123
            "array",
1124
            DataType::Int8,
1125
            false,
1126
        ))))
1127
        .add_child_data(a2_child.into_data())
1128
        .len(2)
1129
        .add_buffer(Buffer::from([0i32, 3, 4].to_byte_slice()))
1130
        .build()
1131
        .unwrap();
1132
        let a2: ArrayRef = Arc::new(ListArray::from(a2));
1133
        let a = ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![
1134
            Field::new("aa1", DataType::Int32, false),
1135
            Field::new("a2", a2.data_type().clone(), false),
1136
        ])))
1137
        .add_child_data(a1.into_data())
1138
        .add_child_data(a2.into_data())
1139
        .len(2)
1140
        .build()
1141
        .unwrap();
1142
        let a: ArrayRef = Arc::new(StructArray::from(a));
1143
1144
        // creating the batch with field name validation should fail
1145
        let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
1146
        assert!(batch.is_err());
1147
1148
        // creating the batch without field name validation should pass
1149
        let options = RecordBatchOptions {
1150
            match_field_names: false,
1151
            row_count: None,
1152
        };
1153
        let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
1154
        assert!(batch.is_ok());
1155
    }
1156
1157
    #[test]
1158
    fn create_record_batch_record_mismatch() {
1159
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1160
1161
        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1162
        let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
1163
1164
        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
1165
        assert!(batch.is_err());
1166
    }
1167
1168
    #[test]
1169
    fn create_record_batch_from_struct_array() {
1170
        let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
1171
        let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
1172
        let struct_array = StructArray::from(vec![
1173
            (
1174
                Arc::new(Field::new("b", DataType::Boolean, false)),
1175
                boolean.clone() as ArrayRef,
1176
            ),
1177
            (
1178
                Arc::new(Field::new("c", DataType::Int32, false)),
1179
                int.clone() as ArrayRef,
1180
            ),
1181
        ]);
1182
1183
        let batch = RecordBatch::from(&struct_array);
1184
        assert_eq!(2, batch.num_columns());
1185
        assert_eq!(4, batch.num_rows());
1186
        assert_eq!(
1187
            struct_array.data_type(),
1188
            &DataType::Struct(batch.schema().fields().clone())
1189
        );
1190
        assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
1191
        assert_eq!(batch.column(1).as_ref(), int.as_ref());
1192
    }
1193
1194
    #[test]
1195
    fn record_batch_equality() {
1196
        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1197
        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1198
        let schema1 = Schema::new(vec![
1199
            Field::new("id", DataType::Int32, false),
1200
            Field::new("val", DataType::Int32, false),
1201
        ]);
1202
1203
        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1204
        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1205
        let schema2 = Schema::new(vec![
1206
            Field::new("id", DataType::Int32, false),
1207
            Field::new("val", DataType::Int32, false),
1208
        ]);
1209
1210
        let batch1 = RecordBatch::try_new(
1211
            Arc::new(schema1),
1212
            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1213
        )
1214
        .unwrap();
1215
1216
        let batch2 = RecordBatch::try_new(
1217
            Arc::new(schema2),
1218
            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1219
        )
1220
        .unwrap();
1221
1222
        assert_eq!(batch1, batch2);
1223
    }
1224
1225
    /// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]`
1226
    #[test]
1227
    fn record_batch_index_access() {
1228
        let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
1229
        let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
1230
        let schema1 = Schema::new(vec![
1231
            Field::new("id", DataType::Int32, false),
1232
            Field::new("val", DataType::Int32, false),
1233
        ]);
1234
        let record_batch =
1235
            RecordBatch::try_new(Arc::new(schema1), vec![id_arr.clone(), val_arr.clone()]).unwrap();
1236
1237
        assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
1238
        assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
1239
    }
1240
1241
    #[test]
1242
    fn record_batch_vals_ne() {
1243
        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1244
        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1245
        let schema1 = Schema::new(vec![
1246
            Field::new("id", DataType::Int32, false),
1247
            Field::new("val", DataType::Int32, false),
1248
        ]);
1249
1250
        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1251
        let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1252
        let schema2 = Schema::new(vec![
1253
            Field::new("id", DataType::Int32, false),
1254
            Field::new("val", DataType::Int32, false),
1255
        ]);
1256
1257
        let batch1 = RecordBatch::try_new(
1258
            Arc::new(schema1),
1259
            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1260
        )
1261
        .unwrap();
1262
1263
        let batch2 = RecordBatch::try_new(
1264
            Arc::new(schema2),
1265
            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1266
        )
1267
        .unwrap();
1268
1269
        assert_ne!(batch1, batch2);
1270
    }
1271
1272
    #[test]
1273
    fn record_batch_column_names_ne() {
1274
        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1275
        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1276
        let schema1 = Schema::new(vec![
1277
            Field::new("id", DataType::Int32, false),
1278
            Field::new("val", DataType::Int32, false),
1279
        ]);
1280
1281
        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1282
        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1283
        let schema2 = Schema::new(vec![
1284
            Field::new("id", DataType::Int32, false),
1285
            Field::new("num", DataType::Int32, false),
1286
        ]);
1287
1288
        let batch1 = RecordBatch::try_new(
1289
            Arc::new(schema1),
1290
            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1291
        )
1292
        .unwrap();
1293
1294
        let batch2 = RecordBatch::try_new(
1295
            Arc::new(schema2),
1296
            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1297
        )
1298
        .unwrap();
1299
1300
        assert_ne!(batch1, batch2);
1301
    }
1302
1303
    #[test]
1304
    fn record_batch_column_number_ne() {
1305
        let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
1306
        let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
1307
        let schema1 = Schema::new(vec![
1308
            Field::new("id", DataType::Int32, false),
1309
            Field::new("val", DataType::Int32, false),
1310
        ]);
1311
1312
        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1313
        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1314
        let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1315
        let schema2 = Schema::new(vec![
1316
            Field::new("id", DataType::Int32, false),
1317
            Field::new("val", DataType::Int32, false),
1318
            Field::new("num", DataType::Int32, false),
1319
        ]);
1320
1321
        let batch1 = RecordBatch::try_new(
1322
            Arc::new(schema1),
1323
            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1324
        )
1325
        .unwrap();
1326
1327
        let batch2 = RecordBatch::try_new(
1328
            Arc::new(schema2),
1329
            vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
1330
        )
1331
        .unwrap();
1332
1333
        assert_ne!(batch1, batch2);
1334
    }
1335
1336
    #[test]
1337
    fn record_batch_row_count_ne() {
1338
        let id_arr1 = Int32Array::from(vec![1, 2, 3]);
1339
        let val_arr1 = Int32Array::from(vec![5, 6, 7]);
1340
        let schema1 = Schema::new(vec![
1341
            Field::new("id", DataType::Int32, false),
1342
            Field::new("val", DataType::Int32, false),
1343
        ]);
1344
1345
        let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
1346
        let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
1347
        let schema2 = Schema::new(vec![
1348
            Field::new("id", DataType::Int32, false),
1349
            Field::new("num", DataType::Int32, false),
1350
        ]);
1351
1352
        let batch1 = RecordBatch::try_new(
1353
            Arc::new(schema1),
1354
            vec![Arc::new(id_arr1), Arc::new(val_arr1)],
1355
        )
1356
        .unwrap();
1357
1358
        let batch2 = RecordBatch::try_new(
1359
            Arc::new(schema2),
1360
            vec![Arc::new(id_arr2), Arc::new(val_arr2)],
1361
        )
1362
        .unwrap();
1363
1364
        assert_ne!(batch1, batch2);
1365
    }
1366
1367
    #[test]
1368
    fn normalize_simple() {
1369
        let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
1370
        let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
1371
        let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));
1372
1373
        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1374
        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1375
        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1376
1377
        let a = Arc::new(StructArray::from(vec![
1378
            (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
1379
            (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
1380
            (year_field.clone(), Arc::new(year.clone()) as ArrayRef),
1381
        ]));
1382
1383
        let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
1384
1385
        let schema = Schema::new(vec![
1386
            Field::new(
1387
                "a",
1388
                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1389
                false,
1390
            ),
1391
            Field::new("month", DataType::Int64, true),
1392
        ]);
1393
1394
        let normalized =
1395
            RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
1396
                .expect("valid conversion")
1397
                .normalize(".", Some(0))
1398
                .expect("valid normalization");
1399
1400
        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1401
            ("a.animals", animals.clone(), true),
1402
            ("a.n_legs", n_legs.clone(), true),
1403
            ("a.year", year.clone(), true),
1404
            ("month", month.clone(), true),
1405
        ])
1406
        .expect("valid conversion");
1407
1408
        assert_eq!(expected, normalized);
1409
1410
        // check 0 and None have the same effect
1411
        let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
1412
            .expect("valid conversion")
1413
            .normalize(".", None)
1414
            .expect("valid normalization");
1415
1416
        assert_eq!(expected, normalized);
1417
    }
1418
1419
    #[test]
1420
    fn normalize_nested() {
1421
        // Initialize schema
1422
        let a = Arc::new(Field::new("a", DataType::Int64, true));
1423
        let b = Arc::new(Field::new("b", DataType::Int64, false));
1424
        let c = Arc::new(Field::new("c", DataType::Int64, true));
1425
1426
        let one = Arc::new(Field::new(
1427
            "1",
1428
            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1429
            false,
1430
        ));
1431
        let two = Arc::new(Field::new(
1432
            "2",
1433
            DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
1434
            true,
1435
        ));
1436
1437
        let exclamation = Arc::new(Field::new(
1438
            "!",
1439
            DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
1440
            false,
1441
        ));
1442
1443
        let schema = Schema::new(vec![exclamation.clone()]);
1444
1445
        // Initialize fields
1446
        let a_field = Int64Array::from(vec![Some(0), Some(1)]);
1447
        let b_field = Int64Array::from(vec![Some(2), Some(3)]);
1448
        let c_field = Int64Array::from(vec![None, Some(4)]);
1449
1450
        let one_field = StructArray::from(vec![
1451
            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1452
            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1453
            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1454
        ]);
1455
        let two_field = StructArray::from(vec![
1456
            (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1457
            (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1458
            (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1459
        ]);
1460
1461
        let exclamation_field = Arc::new(StructArray::from(vec![
1462
            (one.clone(), Arc::new(one_field) as ArrayRef),
1463
            (two.clone(), Arc::new(two_field) as ArrayRef),
1464
        ]));
1465
1466
        // Normalize top level
1467
        let normalized =
1468
            RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
1469
                .expect("valid conversion")
1470
                .normalize(".", Some(1))
1471
                .expect("valid normalization");
1472
1473
        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1474
            (
1475
                "!.1",
1476
                Arc::new(StructArray::from(vec![
1477
                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1478
                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1479
                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1480
                ])) as ArrayRef,
1481
                false,
1482
            ),
1483
            (
1484
                "!.2",
1485
                Arc::new(StructArray::from(vec![
1486
                    (a.clone(), Arc::new(a_field.clone()) as ArrayRef),
1487
                    (b.clone(), Arc::new(b_field.clone()) as ArrayRef),
1488
                    (c.clone(), Arc::new(c_field.clone()) as ArrayRef),
1489
                ])) as ArrayRef,
1490
                true,
1491
            ),
1492
        ])
1493
        .expect("valid conversion");
1494
1495
        assert_eq!(expected, normalized);
1496
1497
        // Normalize all levels
1498
        let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
1499
            .expect("valid conversion")
1500
            .normalize(".", None)
1501
            .expect("valid normalization");
1502
1503
        let expected = RecordBatch::try_from_iter_with_nullable(vec![
1504
            ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
1505
            ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
1506
            ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
1507
            ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
1508
            ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
1509
            ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
1510
        ])
1511
        .expect("valid conversion");
1512
1513
        assert_eq!(expected, normalized);
1514
    }
1515
1516
    #[test]
1517
    fn normalize_empty() {
1518
        let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
1519
        let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
1520
        let year_field = Arc::new(Field::new("year", DataType::Int64, true));
1521
1522
        let schema = Schema::new(vec![
1523
            Field::new(
1524
                "a",
1525
                DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
1526
                false,
1527
            ),
1528
            Field::new("month", DataType::Int64, true),
1529
        ]);
1530
1531
        let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
1532
            .normalize(".", Some(0))
1533
            .expect("valid normalization");
1534
1535
        let expected = RecordBatch::new_empty(Arc::new(
1536
            schema.normalize(".", Some(0)).expect("valid normalization"),
1537
        ));
1538
1539
        assert_eq!(expected, normalized);
1540
    }
1541
1542
    #[test]
1543
    fn project() {
1544
        let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
1545
        let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1546
        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1547
1548
        let record_batch =
1549
            RecordBatch::try_from_iter(vec![("a", a.clone()), ("b", b.clone()), ("c", c.clone())])
1550
                .expect("valid conversion");
1551
1552
        let expected =
1553
            RecordBatch::try_from_iter(vec![("a", a), ("c", c)]).expect("valid conversion");
1554
1555
        assert_eq!(expected, record_batch.project(&[0, 2]).unwrap());
1556
    }
1557
1558
    #[test]
1559
    fn project_empty() {
1560
        let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));
1561
1562
        let record_batch =
1563
            RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");
1564
1565
        let expected = RecordBatch::try_new_with_options(
1566
            Arc::new(Schema::empty()),
1567
            vec![],
1568
            &RecordBatchOptions {
1569
                match_field_names: true,
1570
                row_count: Some(3),
1571
            },
1572
        )
1573
        .expect("valid conversion");
1574
1575
        assert_eq!(expected, record_batch.project(&[]).unwrap());
1576
    }
1577
1578
    #[test]
1579
    fn test_no_column_record_batch() {
1580
        let schema = Arc::new(Schema::empty());
1581
1582
        let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err();
1583
        assert!(
1584
            err.to_string()
1585
                .contains("must either specify a row count or at least one column")
1586
        );
1587
1588
        let options = RecordBatchOptions::new().with_row_count(Some(10));
1589
1590
        let ok = RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap();
1591
        assert_eq!(ok.num_rows(), 10);
1592
1593
        let a = ok.slice(2, 5);
1594
        assert_eq!(a.num_rows(), 5);
1595
1596
        let b = ok.slice(5, 0);
1597
        assert_eq!(b.num_rows(), 0);
1598
1599
        assert_ne!(a, b);
1600
        assert_eq!(b, RecordBatch::new_empty(schema))
1601
    }
1602
1603
    #[test]
1604
    fn test_nulls_in_non_nullable_field() {
1605
        let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
1606
        let maybe_batch = RecordBatch::try_new(
1607
            schema,
1608
            vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
1609
        );
1610
        assert_eq!(
1611
            "Invalid argument error: Column 'a' is declared as non-nullable but contains null values",
1612
            format!("{}", maybe_batch.err().unwrap())
1613
        );
1614
    }
1615
    #[test]
1616
    fn test_record_batch_options() {
1617
        let options = RecordBatchOptions::new()
1618
            .with_match_field_names(false)
1619
            .with_row_count(Some(20));
1620
        assert!(!options.match_field_names);
1621
        assert_eq!(options.row_count.unwrap(), 20)
1622
    }
1623
1624
    #[test]
1625
    #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")]
1626
    fn test_from_struct() {
1627
        let s = StructArray::from(ArrayData::new_null(
1628
            // Note child is not nullable
1629
            &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()),
1630
            2,
1631
        ));
1632
        let _ = RecordBatch::from(s);
1633
    }
1634
1635
    #[test]
1636
    fn test_with_schema() {
1637
        let required_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1638
        let required_schema = Arc::new(required_schema);
1639
        let nullable_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1640
        let nullable_schema = Arc::new(nullable_schema);
1641
1642
        let batch = RecordBatch::try_new(
1643
            required_schema.clone(),
1644
            vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as _],
1645
        )
1646
        .unwrap();
1647
1648
        // Can add nullability
1649
        let batch = batch.with_schema(nullable_schema.clone()).unwrap();
1650
1651
        // Cannot remove nullability
1652
        batch.clone().with_schema(required_schema).unwrap_err();
1653
1654
        // Can add metadata
1655
        let metadata = vec![("foo".to_string(), "bar".to_string())]
1656
            .into_iter()
1657
            .collect();
1658
        let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata);
1659
        let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap();
1660
1661
        // Cannot remove metadata
1662
        batch.with_schema(nullable_schema).unwrap_err();
1663
    }
1664
1665
    #[test]
1666
    fn test_boxed_reader() {
1667
        // Make sure we can pass a boxed reader to a function generic over
1668
        // RecordBatchReader.
1669
        let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1670
        let schema = Arc::new(schema);
1671
1672
        let reader = RecordBatchIterator::new(std::iter::empty(), schema);
1673
        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1674
1675
        fn get_size(reader: impl RecordBatchReader) -> usize {
1676
            reader.size_hint().0
1677
        }
1678
1679
        let size = get_size(reader);
1680
        assert_eq!(size, 0);
1681
    }
1682
1683
    #[test]
1684
    fn test_remove_column_maintains_schema_metadata() {
1685
        let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]);
1686
        let bool_array = BooleanArray::from(vec![true, false, false, true, true]);
1687
1688
        let mut metadata = HashMap::new();
1689
        metadata.insert("foo".to_string(), "bar".to_string());
1690
        let schema = Schema::new(vec![
1691
            Field::new("id", DataType::Int32, false),
1692
            Field::new("bool", DataType::Boolean, false),
1693
        ])
1694
        .with_metadata(metadata);
1695
1696
        let mut batch = RecordBatch::try_new(
1697
            Arc::new(schema),
1698
            vec![Arc::new(id_array), Arc::new(bool_array)],
1699
        )
1700
        .unwrap();
1701
1702
        let _removed_column = batch.remove_column(0);
1703
        assert_eq!(batch.schema().metadata().len(), 1);
1704
        assert_eq!(
1705
            batch.schema().metadata().get("foo").unwrap().as_str(),
1706
            "bar"
1707
        );
1708
    }
1709
}