Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
use arrow::record_batch::RecordBatch;

fn to_py_err(err: ArrowError) -> PyErr {
Expand Down Expand Up @@ -140,6 +140,25 @@ fn round_trip_record_batch_reader(
Ok(obj)
}

#[pyfunction]
fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
Ok(obj)
}

#[pyfunction]
pub fn build_table(
record_batches: Vec<PyArrowType<RecordBatch>>,
schema: PyArrowType<Schema>,
) -> PyResult<PyArrowType<Table>> {
Ok(PyArrowType(
Table::try_new(
record_batches.into_iter().map(|rb| rb.0).collect(),
Arc::new(schema.0),
)
.map_err(to_py_err)?,
))
}

#[pyfunction]
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
Expand Down Expand Up @@ -178,6 +197,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound<PyModule>) -> PyResu
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
m.add_wrapped(wrap_pyfunction!(build_table))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
Expand Down
65 changes: 65 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,71 @@ def test_table_pycapsule():
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_empty():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
table = pa.Table.from_batches([], schema=schema)
new_table = rust.build_table([], schema=schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_roundtrip():
"""
Python -> Rust -> Python
"""
metadata = {b'key1': b'value1'}
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata=metadata)
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches, schema=schema)
# TODO: Remove these `assert`s as soon as the metadata issue is solved in Rust
assert table.schema.metadata == metadata
assert all(batch.schema.metadata == metadata for batch in table.to_batches())
new_table = rust.round_trip_table(table)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_from_batches():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
new_table = rust.build_table(batches, schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_error_inconsistent_schema():
"""
Python -> Rust -> Python
"""
schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
batches = [
pa.record_batch([[[1], [2, 42]]], schema_1),
pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
]
with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
rust.build_table(batches, schema_1)


def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
Expand Down
133 changes: 125 additions & 8 deletions arrow-pyarrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,20 @@
//! | `pyarrow.Array` | [ArrayData] |
//! | `pyarrow.RecordBatch` | [RecordBatch] |
//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
//! | `pyarrow.Table` | [Table] (2) |
//!
//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
//! easier to create.)
//!
//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
//! have these same concepts. A chunked table is instead represented with
//! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling
//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
//! and then importing the reader as a [ArrowArrayStreamReader].
//! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
//! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already
//! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`.
//! In general, it is recommended to use streaming approaches instead of dealing with data in bulk.
//! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule
//! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>>` instead of
//! forcing eager reading into `Vec<RecordBatch>`.

use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
Expand All @@ -68,13 +71,13 @@ use arrow_array::{
make_array,
};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, Schema};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
use pyo3::{import_exception, intern};

import_exception!(pyarrow, ArrowException);
/// Represents an exception raised by PyArrow.
Expand Down Expand Up @@ -484,6 +487,120 @@ impl IntoPyArrow for ArrowArrayStreamReader {
}
}

/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
/// and to `pyarrow.Table`.
///
/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
/// side.
///
/// ```ignore
/// #[pyfunction]
/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
/// let batches: Vec<RecordBatch>;
/// let schema: SchemaRef;
/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
/// }
/// ```
#[derive(Clone)]
pub struct Table {
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
}

impl Table {
pub fn try_new(
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
) -> Result<Self, ArrowError> {
/// This function was copied from `pyo3_arrow/utils.rs` for now. I don't understand yet why
/// this is required instead of a "normal" `schema == record_batch.schema()` check.
///
/// TODO: Either remove this check, replace it with something already existing in `arrow-rs`
/// or move it to a central `utils` location.
fn schema_equals(left: &SchemaRef, right: &SchemaRef) -> bool {
left.fields
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This impl seems incorrect - the zip() operation does not check that the iterators have the same number of items. It actually checks that left is a subset of right or right is a subset of left. So, if right has one field more than left this function will still return true.
https://play.rust-lang.org/?version=stable&mode=debug&edition=2024&gist=95c113900129b392365cdfb3b4c2b4e6

Copy link
Contributor Author

@jonded94 jonded94 Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, instead of using this schema check method at all, I'd much rather have the underlying issue solved by understanding why either the ArrowStreamReader PyCapsule interface of pyarrow.Table or the Rust Box<dyn RecordBatchReader> part seems to swallow up RecordBatch metadata.

If that issue would be fixed, then this function can be left out again and a normal schema == recordbatch.schema() test could be used. This function would only have relevance if it's expected that the RecordBatch coming from a stream reader somehow doesn't have metadata anymore.

But in general your comment would be also relevant for @kylebarron as he is using this function as-is in his crate pyo3-arrow.

.iter()
.zip(right.fields.iter())
.all(|(left_field, right_field)| {
left_field.name() == right_field.name()
&& left_field
.data_type()
.equals_datatype(right_field.data_type())
})
}

for record_batch in &record_batches {
if !schema_equals(&schema, &record_batch.schema()) {
return Err(ArrowError::SchemaError(
//"All record batches must have the same schema.".to_owned(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
//"All record batches must have the same schema.".to_owned(),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only have the more verbose error message here right now to understand what's going on in the schema mismatch. This is currently commented out to signal that this is not intended to be merged as-is, but the schema mismatch issue shall be understood first.

In general I'm opinionless about how verbose the error message shall be, I'd happily to eventually remove whatever variant you dislike.

format!(
"All record batches must have the same schema. \
Expected schema: {:?}, got schema: {:?}",
schema,
record_batch.schema()
),
));
}
}
Ok(Self {
record_batches,
schema,
})
}

pub fn record_batches(&self) -> &[RecordBatch] {
&self.record_batches
}

pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}

pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
(self.record_batches, self.schema)
}
}

impl TryFrom<Box<dyn RecordBatchReader>> for Table {
type Error = ArrowError;

fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> {
let schema = value.schema();
let batches = value.collect::<Result<Vec<_>, _>>()?;
Self::try_new(batches, schema)
}
}

/// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`]
impl FromPyArrow for Table {
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
let reader: Box<dyn RecordBatchReader> =
Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?);
Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
}
}

/// Convert a [`Table`] into `pyarrow.Table`.
impl IntoPyArrow for Table {
fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
let module = py.import(intern!(py, "pyarrow"))?;
let class = module.getattr(intern!(py, "Table"))?;

let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));

let kwargs = PyDict::new(py);
kwargs.set_item("schema", py_schema)?;

let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;

Ok(reader)
}
}

/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
///
/// When wrapped around a type `T: FromPyArrow`, it
Expand Down
Loading