-
Notifications
You must be signed in to change notification settings - Fork 1k
Implement a Vec<RecordBatch> wrapper for pyarrow.Table convenience
#8790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,17 +44,20 @@ | |
| //! | `pyarrow.Array` | [ArrayData] | | ||
| //! | `pyarrow.RecordBatch` | [RecordBatch] | | ||
| //! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) | | ||
| //! | `pyarrow.Table` | [Table] (2) | | ||
| //! | ||
| //! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either | ||
| //! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported | ||
| //! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically | ||
| //! easier to create.) | ||
| //! | ||
| //! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't | ||
| //! have these same concepts. A chunked table is instead represented with | ||
| //! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling | ||
| //! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader) | ||
| //! and then importing the reader as a [ArrowArrayStreamReader]. | ||
| //! (2) Although arrow-rs offers a [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table) | ||
| //! convenience wrapper [Table] (which internally holds `Vec<RecordBatch>`), this is more meant for | ||
| //! use cases where you already have `Vec<RecordBatch>` on the Rust side and want to export that in | ||
| //! bulk as a `pyarrow.Table`. In general, it is recommended to use streaming approaches instead of | ||
jonded94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| //! dealing with bulk data. | ||
| //! For example, a `pyarrow.Table` can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>` | ||
| //! instead (since `pyarrow.Table` implements the ArrayStream PyCapsule interface). | ||
|
||
|
|
||
| use std::convert::{From, TryFrom}; | ||
| use std::ptr::{addr_of, addr_of_mut}; | ||
|
|
@@ -68,13 +71,13 @@ use arrow_array::{ | |
| make_array, | ||
| }; | ||
| use arrow_data::ArrayData; | ||
| use arrow_schema::{ArrowError, DataType, Field, Schema}; | ||
| use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; | ||
| use pyo3::exceptions::{PyTypeError, PyValueError}; | ||
| use pyo3::ffi::Py_uintptr_t; | ||
| use pyo3::import_exception; | ||
| use pyo3::prelude::*; | ||
| use pyo3::pybacked::PyBackedStr; | ||
| use pyo3::types::{PyCapsule, PyList, PyTuple}; | ||
| use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; | ||
| use pyo3::{import_exception, intern}; | ||
|
|
||
| import_exception!(pyarrow, ArrowException); | ||
| /// Represents an exception raised by PyArrow. | ||
|
|
@@ -484,6 +487,137 @@ impl IntoPyArrow for ArrowArrayStreamReader { | |
| } | ||
| } | ||
|
|
||
| /// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from | ||
| /// and to `pyarrow.Table`. | ||
| /// | ||
| /// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly | ||
| /// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule | ||
| /// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more | ||
| /// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust | ||
| /// side. | ||
| /// | ||
| /// ```ignore | ||
| /// #[pyfunction] | ||
| /// fn return_table(...) -> PyResult<PyArrowType<Table>> { | ||
| /// let batches: Vec<RecordBatch>; | ||
| /// let schema: SchemaRef; | ||
| /// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?) | ||
| /// } | ||
| /// ``` | ||
| #[derive(Clone)] | ||
| pub struct Table { | ||
| record_batches: Vec<RecordBatch>, | ||
| schema: SchemaRef, | ||
| } | ||
|
|
||
| impl Table { | ||
| pub unsafe fn new_unchecked(record_batches: Vec<RecordBatch>, schema: SchemaRef) -> Self { | ||
| Self { | ||
| record_batches, | ||
| schema, | ||
| } | ||
| } | ||
jonded94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| pub fn try_new( | ||
| record_batches: Vec<RecordBatch>, | ||
| schema: Option<SchemaRef>, | ||
jonded94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> Result<Self, ArrowError> { | ||
| let schema = match schema { | ||
| Some(s) => s, | ||
| None => { | ||
| record_batches | ||
| .get(0) | ||
| .ok_or_else(|| ArrowError::SchemaError( | ||
| "If no schema is supplied explicitly, there must be at least one RecordBatch!".to_owned() | ||
| ))? | ||
| .schema() | ||
| .clone() | ||
| } | ||
| }; | ||
| for record_batch in &record_batches { | ||
| if schema != record_batch.schema() { | ||
| return Err(ArrowError::SchemaError( | ||
| "All record batches must have the same schema.".to_owned(), | ||
| )); | ||
| } | ||
| } | ||
| Ok(Self { | ||
| record_batches, | ||
| schema, | ||
| }) | ||
| } | ||
|
|
||
| pub fn record_batches(&self) -> &[RecordBatch] { | ||
| &self.record_batches | ||
| } | ||
|
|
||
| pub fn schema(&self) -> SchemaRef { | ||
| self.schema.clone() | ||
| } | ||
|
|
||
| pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) { | ||
| (self.record_batches, self.schema) | ||
| } | ||
| } | ||
|
|
||
| impl TryFrom<ArrowArrayStreamReader> for Table { | ||
jonded94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| type Error = ArrowError; | ||
|
|
||
| fn try_from(value: ArrowArrayStreamReader) -> Result<Self, ArrowError> { | ||
| let schema = value.schema(); | ||
| let batches = value.collect::<Result<Vec<_>, _>>()?; | ||
| // We assume all batches have the same schema here. | ||
| unsafe { Ok(Self::new_unchecked(batches, schema)) } | ||
| } | ||
| } | ||
|
|
||
| impl FromPyArrow for Table { | ||
| fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> { | ||
| let array_stream_reader: PyResult<ArrowArrayStreamReader> = { | ||
jonded94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // First, try whether the object implements the Arrow ArrayStreamReader protocol directly | ||
| // (which `pyarrow.Table` does) or test whether it is a RecordBatchReader. | ||
| let reader_result = if let Ok(reader) = ArrowArrayStreamReader::from_pyarrow_bound(ob) { | ||
| Some(reader) | ||
| } | ||
| // If that is not the case, test whether it has a `to_reader` method (which | ||
| // `pyarrow.Table` does) whose return value implements the Arrow ArrayStreamReader | ||
| // protocol or is a RecordBatchReader. | ||
| else if ob.hasattr(intern!(ob.py(), "to_reader"))? { | ||
| let py_reader = ob.getattr(intern!(ob.py(), "to_reader"))?.call0()?; | ||
| ArrowArrayStreamReader::from_pyarrow_bound(&py_reader).ok() | ||
| } else { | ||
| None | ||
| }; | ||
|
|
||
| match reader_result { | ||
| Some(reader) => Ok(reader), | ||
| None => Err(PyTypeError::new_err( | ||
| "Expected Arrow Table, Arrow RecordBatchReader or other object which conforms to the Arrow ArrayStreamReader protocol.", | ||
| )), | ||
| } | ||
| }; | ||
| Self::try_from(array_stream_reader?) | ||
| .map_err(|err| PyErr::new::<PyValueError, _>(err.to_string())) | ||
| } | ||
| } | ||
|
|
||
| impl IntoPyArrow for Table { | ||
| fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> { | ||
| let module = py.import(intern!(py, "pyarrow"))?; | ||
| let class = module.getattr(intern!(py, "Table"))?; | ||
|
|
||
| let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?; | ||
| let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema)); | ||
|
|
||
| let kwargs = PyDict::new(py); | ||
| kwargs.set_item("schema", py_schema)?; | ||
|
|
||
| let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?; | ||
|
|
||
| Ok(reader) | ||
| } | ||
| } | ||
|
|
||
| /// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`]. | ||
| /// | ||
| /// When wrapped around a type `T: FromPyArrow`, it | ||
|
|
||

Uh oh!
There was an error while loading. Please reload this page.