-
Notifications
You must be signed in to change notification settings - Fork 1k
Implement a Vec<RecordBatch> wrapper for pyarrow.Table convenience
#8790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -44,17 +44,20 @@ | |||
| //! | `pyarrow.Array` | [ArrayData] | | ||||
| //! | `pyarrow.RecordBatch` | [RecordBatch] | | ||||
| //! | `pyarrow.RecordBatchReader` | [ArrowArrayStreamReader] / `Box<dyn RecordBatchReader + Send>` (1) | | ||||
| //! | `pyarrow.Table` | [Table] (2) | | ||||
| //! | ||||
| //! (1) `pyarrow.RecordBatchReader` can be imported as [ArrowArrayStreamReader]. Either | ||||
| //! [ArrowArrayStreamReader] or `Box<dyn RecordBatchReader + Send>` can be exported | ||||
| //! as `pyarrow.RecordBatchReader`. (`Box<dyn RecordBatchReader + Send>` is typically | ||||
| //! easier to create.) | ||||
| //! | ||||
| //! PyArrow has the notion of chunked arrays and tables, but arrow-rs doesn't | ||||
| //! have these same concepts. A chunked table is instead represented with | ||||
| //! `Vec<RecordBatch>`. A `pyarrow.Table` can be imported to Rust by calling | ||||
| //! [pyarrow.Table.to_reader()](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_reader) | ||||
| //! and then importing the reader as a [ArrowArrayStreamReader]. | ||||
| //! (2) Although arrow-rs offers [Table], a convenience wrapper for [pyarrow.Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table) | ||||
| //! that internally holds `Vec<RecordBatch>`, it is meant primarily for use cases where you already | ||||
| //! have `Vec<RecordBatch>` on the Rust side and want to export that in bulk as a `pyarrow.Table`. | ||||
| //! In general, it is recommended to use streaming approaches instead of dealing with data in bulk. | ||||
| //! For example, a `pyarrow.Table` (or any other object that implements the ArrayStream PyCapsule | ||||
| //! interface) can be imported to Rust through `PyArrowType<ArrowArrayStreamReader>>` instead of | ||||
| //! forcing eager reading into `Vec<RecordBatch>`. | ||||
|
|
||||
| use std::convert::{From, TryFrom}; | ||||
| use std::ptr::{addr_of, addr_of_mut}; | ||||
|
|
@@ -68,13 +71,13 @@ use arrow_array::{ | |||
| make_array, | ||||
| }; | ||||
| use arrow_data::ArrayData; | ||||
| use arrow_schema::{ArrowError, DataType, Field, Schema}; | ||||
| use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; | ||||
| use pyo3::exceptions::{PyTypeError, PyValueError}; | ||||
| use pyo3::ffi::Py_uintptr_t; | ||||
| use pyo3::import_exception; | ||||
| use pyo3::prelude::*; | ||||
| use pyo3::pybacked::PyBackedStr; | ||||
| use pyo3::types::{PyCapsule, PyList, PyTuple}; | ||||
| use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple}; | ||||
| use pyo3::{import_exception, intern}; | ||||
|
|
||||
| import_exception!(pyarrow, ArrowException); | ||||
| /// Represents an exception raised by PyArrow. | ||||
|
|
@@ -484,6 +487,120 @@ impl IntoPyArrow for ArrowArrayStreamReader { | |||
| } | ||||
| } | ||||
|
|
||||
| /// This is a convenience wrapper around `Vec<RecordBatch>` that tries to simplify conversion from | ||||
| /// and to `pyarrow.Table`. | ||||
| /// | ||||
| /// This could be used in circumstances where you either want to consume a `pyarrow.Table` directly | ||||
| /// (although technically, since `pyarrow.Table` implements the ArrayStreamReader PyCapsule | ||||
| /// interface, one could also consume a `PyArrowType<ArrowArrayStreamReader>` instead) or, more | ||||
| /// importantly, where one wants to export a `pyarrow.Table` from a `Vec<RecordBatch>` from the Rust | ||||
| /// side. | ||||
| /// | ||||
| /// ```ignore | ||||
| /// #[pyfunction] | ||||
| /// fn return_table(...) -> PyResult<PyArrowType<Table>> { | ||||
| /// let batches: Vec<RecordBatch>; | ||||
| /// let schema: SchemaRef; | ||||
| /// PyArrowType(Table::try_new(batches, schema).map_err(|err| err.into_py_err(py))?) | ||||
| /// } | ||||
| /// ``` | ||||
| #[derive(Clone)] | ||||
| pub struct Table { | ||||
| record_batches: Vec<RecordBatch>, | ||||
| schema: SchemaRef, | ||||
| } | ||||
|
|
||||
| impl Table { | ||||
| pub fn try_new( | ||||
| record_batches: Vec<RecordBatch>, | ||||
| schema: SchemaRef, | ||||
| ) -> Result<Self, ArrowError> { | ||||
| /// This function was copied from `pyo3_arrow/utils.rs` for now. I don't understand yet why | ||||
| /// this is required instead of a "normal" `schema == record_batch.schema()` check. | ||||
| /// | ||||
| /// TODO: Either remove this check, replace it with something already existing in `arrow-rs` | ||||
| /// or move it to a central `utils` location. | ||||
| fn schema_equals(left: &SchemaRef, right: &SchemaRef) -> bool { | ||||
| left.fields | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This impl seems incorrect - the zip() operation does not check that the iterators have the same number of items. It actually checks that left is a subset of right or right is a subset of left. So, if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In principle, instead of using this schema check method at all, I'd much rather have the underlying issue solved by understanding why either the ArrowStreamReader PyCapsule interface of If that issue would be fixed, then this function can be left out again and a normal But in general your comment would be also relevant for @kylebarron as he is using this function as-is in his crate |
||||
| .iter() | ||||
| .zip(right.fields.iter()) | ||||
| .all(|(left_field, right_field)| { | ||||
| left_field.name() == right_field.name() | ||||
| && left_field | ||||
| .data_type() | ||||
| .equals_datatype(right_field.data_type()) | ||||
| }) | ||||
| } | ||||
|
|
||||
| for record_batch in &record_batches { | ||||
| if !schema_equals(&schema, &record_batch.schema()) { | ||||
| return Err(ArrowError::SchemaError( | ||||
| //"All record batches must have the same schema.".to_owned(), | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I only have the more verbose error message here right now to understand what's going on in the schema mismatch. This is currently commented out to signal that this is not intended to be merged as-is, but the schema mismatch issue shall be understood first. In general I'm opinionless about how verbose the error message shall be, I'd happily to eventually remove whatever variant you dislike. |
||||
| format!( | ||||
| "All record batches must have the same schema. \ | ||||
| Expected schema: {:?}, got schema: {:?}", | ||||
| schema, | ||||
| record_batch.schema() | ||||
| ), | ||||
| )); | ||||
| } | ||||
| } | ||||
| Ok(Self { | ||||
| record_batches, | ||||
| schema, | ||||
| }) | ||||
| } | ||||
|
|
||||
| pub fn record_batches(&self) -> &[RecordBatch] { | ||||
| &self.record_batches | ||||
| } | ||||
|
|
||||
| pub fn schema(&self) -> SchemaRef { | ||||
| self.schema.clone() | ||||
| } | ||||
|
|
||||
| pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) { | ||||
| (self.record_batches, self.schema) | ||||
| } | ||||
| } | ||||
|
|
||||
| impl TryFrom<Box<dyn RecordBatchReader>> for Table { | ||||
| type Error = ArrowError; | ||||
|
|
||||
| fn try_from(value: Box<dyn RecordBatchReader>) -> Result<Self, ArrowError> { | ||||
| let schema = value.schema(); | ||||
| let batches = value.collect::<Result<Vec<_>, _>>()?; | ||||
| Self::try_new(batches, schema) | ||||
| } | ||||
| } | ||||
|
|
||||
| /// Convert a `pyarrow.Table` (or any other ArrowArrayStream compliant object) into [`Table`] | ||||
| impl FromPyArrow for Table { | ||||
| fn from_pyarrow_bound(ob: &Bound<PyAny>) -> PyResult<Self> { | ||||
| let reader: Box<dyn RecordBatchReader> = | ||||
| Box::new(ArrowArrayStreamReader::from_pyarrow_bound(ob)?); | ||||
| Self::try_from(reader).map_err(|err| PyErr::new::<PyValueError, _>(err.to_string())) | ||||
| } | ||||
| } | ||||
|
|
||||
| /// Convert a [`Table`] into `pyarrow.Table`. | ||||
| impl IntoPyArrow for Table { | ||||
| fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> { | ||||
| let module = py.import(intern!(py, "pyarrow"))?; | ||||
| let class = module.getattr(intern!(py, "Table"))?; | ||||
|
|
||||
| let py_batches = PyList::new(py, self.record_batches.into_iter().map(PyArrowType))?; | ||||
| let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema)); | ||||
|
|
||||
| let kwargs = PyDict::new(py); | ||||
| kwargs.set_item("schema", py_schema)?; | ||||
|
|
||||
| let reader = class.call_method("from_batches", (py_batches,), Some(&kwargs))?; | ||||
|
|
||||
| Ok(reader) | ||||
| } | ||||
| } | ||||
|
|
||||
| /// A newtype wrapper for types implementing [`FromPyArrow`] or [`IntoPyArrow`]. | ||||
| /// | ||||
| /// When wrapped around a type `T: FromPyArrow`, it | ||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.