Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, ToPyArrow};
use arrow::pyarrow::{FromPyArrow, PyArrowException, PyArrowType, Table, ToPyArrow};
use arrow::record_batch::RecordBatch;

fn to_py_err(err: ArrowError) -> PyErr {
Expand Down Expand Up @@ -140,6 +140,28 @@ fn round_trip_record_batch_reader(
Ok(obj)
}

#[pyfunction]
fn round_trip_table(obj: PyArrowType<Table>) -> PyResult<PyArrowType<Table>> {
Ok(obj)
}

/// Function for testing whether a `Vec<RecordBatch>` is exportable as `pyarrow.Table`, with or
/// without explicitly providing a schema
#[pyfunction]
#[pyo3(signature = (record_batches, *, schema=None))]
pub fn build_table(
record_batches: Vec<PyArrowType<RecordBatch>>,
schema: Option<PyArrowType<Schema>>,
) -> PyResult<PyArrowType<Table>> {
Ok(PyArrowType(
Table::try_new(
record_batches.into_iter().map(|rb| rb.0).collect(),
schema.map(|s| Arc::new(s.0)),
)
.map_err(to_py_err)?,
))
}

#[pyfunction]
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
Expand Down Expand Up @@ -178,6 +200,8 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &Bound<PyModule>) -> PyResu
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
m.add_wrapped(wrap_pyfunction!(round_trip_table))?;
m.add_wrapped(wrap_pyfunction!(build_table))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
m.add_wrapped(wrap_pyfunction!(boxed_reader_roundtrip))?;
Ok(())
Expand Down
74 changes: 74 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,80 @@ def test_table_pycapsule():
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_empty():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
table = pa.Table.from_batches([], schema=schema)
new_table = rust.build_table([], schema=schema)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_roundtrip():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
new_table = rust.round_trip_table(table)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


@pytest.mark.parametrize("set_schema", (True, False))
def test_table_from_batches(set_schema: bool):
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
new_table = rust.build_table(batches, schema=schema if set_schema else None)

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_table_error_inconsistent_schema():
"""
Python -> Rust -> Python
"""
schema_1 = pa.schema([('ints', pa.list_(pa.int32()))])
schema_2 = pa.schema([('floats', pa.list_(pa.float32()))])
batches = [
pa.record_batch([[[1], [2, 42]]], schema_1),
pa.record_batch([[None, [], [5.6, 6.4]]], schema_2),
]
with pytest.raises(pa.ArrowException, match="Schema error: All record batches must have the same schema."):
rust.build_table(batches)


def test_table_error_no_schema():
"""
Python -> Rust -> Python
"""
batches = []
with pytest.raises(
pa.ArrowException,
match="Schema error: If no schema is supplied explicitly, there must be at least one RecordBatch!"
):
rust.build_table(batches)


def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
Expand Down
150 changes: 142 additions & 8 deletions arrow-pyarrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,20 @@
//! | `pyarrow.Array` | [ArrayData] |
//! | `pyarrow.RecordBatch` | [RecordBatch] |
//! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) |
//! | `pyarrow.Table` | [Table] (2) |
//!
//! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either
//! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported
//! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically
//! easier to create.)
//!
//! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't
//! have these same concepts. A chunked table is instead represented with
//! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling
//! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader)
//! and then importing the reader as a [ArrowArrayStreamReader].
//! (2) Although arrow-rs offers a [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table)
//! convenience wrapper [Table] (which internally holds `Vec<RecordBatch>`), this is more meant for
//! use cases where you already have `Vec<RecordBatch>` on the Rust side and want to export that in
//! bulk as a `pyarrow.Table`. In general, it is recommended to use streaming approaches instead of
//! dealing with bulk data.
//! For example, a `pyarrow.Table` can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>`
//! instead (since `pyarrow.Table` implements the ArrayStream PyCapsule interface).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to note here that another advantage of using ArrowArrayStreamReader is that it works with tables and stream input out of the box. It doesn't matter which type the user passes in.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well actually slight correction, assuming PyCapsule Interface input, both Table and ArrowArrayStreamReader will work with both table and stream input out of the box, the difference is just whether the Rust code materializes the data.

This is why I have this table in the pyo3-arrow docs:

Image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also reading through the docs again, I'd suggest making a reference to Box<dyn RecordBatchReader> rather than ArrowArrayStreamReader. The former is a higher level API and much easier to use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to note here that another advantage of using ArrowArrayStreamReader is that it works with tables and stream input out of the box.

I added that in the docs.

Also reading through the docs again, I'd suggest making a reference to Box rather than ArrowArrayStreamReader. The former is a higher level API and much easier to use.

I'm not exactly sure what you mean here. Box<dyn RecordBatchReader> only implements IntoPyArrow, but not FromPyArrow. So in the example I state in the new documentation, that for consuming a pyarrow.Table in Rust, also a streaming approach could be used, the Box<dyn RecordBatchReader> isn't helping sadly. One has to use ArrowArrayStreamReader, since that properly implements FromPyArrow.


use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
Expand All @@ -68,13 +71,13 @@ use arrow_array::{
make_array,
};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, Schema};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
use pyo3::{import_exception, intern};

import_exception!(pyarrow, ArrowException);
/// Represents an exception raised by PyArrow.
Expand Down Expand Up @@ -484,6 +487,137 @@ impl IntoPyArrow for ArrowArrayStreamReader {
}
}

/// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from
/// and to `pyarrow.Table`.
///
/// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly
/// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule
/// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more
/// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust
/// side.
///
/// ```ignore
/// #[pyfunction]
/// fn return_table(...) -> PyResult<PyArrowType<Table>> {
/// let batches: Vec<RecordBatch>;
/// let schema: SchemaRef;
/// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?)
/// }
/// ```
#[derive(Clone)]
pub struct Table {
record_batches: Vec<RecordBatch>,
schema: SchemaRef,
}

impl Table {
pub unsafe fn new_unchecked(record_batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
Self {
record_batches,
schema,
}
}

pub fn try_new(
record_batches: Vec<RecordBatch>,
schema: Option<SchemaRef>,
) -> Result<Self, ArrowError> {
let schema = match schema {
Some(s) => s,
None => {
record_batches
.get(0)
.ok_or_else(|| ArrowError::SchemaError(
"If no schema is supplied explicitly, there must be at least one RecordBatch!".to_owned()
))?
.schema()
.clone()
}
};
for record_batch in &record_batches {
if schema != record_batch.schema() {
return Err(ArrowError::SchemaError(
"All record batches must have the same schema.".to_owned(),
));
}
}
Ok(Self {
record_batches,
schema,
})
}

pub fn record_batches(&self) -> &[RecordBatch] {
&self.record_batches
}

pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}

pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
(self.record_batches, self.schema)
}
}

impl TryFrom<ArrowArrayStreamReader> for Table {
type Error = ArrowError;

fn try_from(value: ArrowArrayStreamReader) -> Result<Self, ArrowError> {
let schema = value.schema();
let batches = value.collect::<Result<Vec<_>, _>>()?;
// We assume all batches have the same schema here.
unsafe { Ok(Self::new_unchecked(batches, schema)) }
}
}

impl FromPyArrow for Table {
fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> {
let array_stream_reader: PyResult<ArrowArrayStreamReader> = {
// First, try whether the object implements the Arrow ArrayStreamReader protocol directly
// (which `pyarrow.Table` does) or test whether it is a RecordBatchReader.
let reader_result = if let Ok(reader) = ArrowArrayStreamReader::from_pyarrow_bound(ob) {
Some(reader)
}
// If that is not the case, test whether it has a `to_reader` method (which
// `pyarrow.Table` does) whose return value implements the Arrow ArrayStreamReader
// protocol or is a RecordBatchReader.
else if ob.hasattr(intern!(ob.py(), "to_reader"))? {
let py_reader = ob.getattr(intern!(ob.py(), "to_reader"))?.call0()?;
ArrowArrayStreamReader::from_pyarrow_bound(&py_reader).ok()
} else {
None
};

match reader_result {
Some(reader) => Ok(reader),
None => Err(PyTypeError::new_err(
"Expected Arrow Table, Arrow RecordBatchReader or other object which conforms to the Arrow ArrayStreamReader protocol.",
)),
}
};
Self::try_from(array_stream_reader?)
.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
}
}

impl IntoPyArrow for Table {
fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
let module = py.import(intern!(py, "pyarrow"))?;
let class = module.getattr(intern!(py, "Table"))?;

let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?;
let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));

let kwargs = PyDict::new(py);
kwargs.set_item("schema", py_schema)?;

let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?;

Ok(reader)
}
}

/// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`].
///
/// When wrapped around a type `T: FromPyArrow`, it
Expand Down
Loading