Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions parquet/src/arrow/async_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ impl<T> std::fmt::Debug for StreamState<T> {

/// An asynchronous [`Stream`](https://docs.rs/futures/latest/futures/stream/trait.Stream.html) of [`RecordBatch`]
/// for a parquet file that can be constructed using [`ParquetRecordBatchStreamBuilder`].
///
/// `ParquetRecordBatchStream` also provides [`ParquetRecordBatchStream::next_row_group`] for fetching row groups,
/// allowing users to decode record batches separately from I/O.
pub struct ParquetRecordBatchStream<T> {
metadata: Arc<ParquetMetaData>,

Expand Down Expand Up @@ -654,6 +657,66 @@ impl<T> ParquetRecordBatchStream<T> {
}
}

impl<T> ParquetRecordBatchStream<T>
where
T: AsyncFileReader + Unpin + Send + 'static,
{
/// Fetches the next row group from the stream.
///
/// Users can continue to call this function to get row groups and decode them concurrently.
///
/// ## Notes
///
/// ParquetRecordBatchStream should be used either as a `Stream` or with `next_row_group`; they should not be used simultaneously.
///
/// ## Returns
///
/// - `Ok(None)` if the stream has ended.
/// - `Err(error)` if the stream has errored. All subsequent calls will return `Ok(None)`.
/// - `Ok(Some(reader))` which holds all the data for the row group.
pub async fn next_row_group(&mut self) -> Result<Option<ParquetRecordBatchReader>> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if next_row_group is the best name, open to other options.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good, clear name as it clearly explains what it does

loop {
match &mut self.state {
StreamState::Decoding(_) | StreamState::Reading(_) => unreachable!(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should probably return an error saying not to mix polling the stream and using this API

StreamState::Init => {
let row_group_idx = match self.row_groups.pop_front() {
Some(idx) => idx,
None => return Ok(None),
};

let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize;

let selection = self.selection.as_mut().map(|s| s.split_off(row_count));

let reader_factory = self.reader.take().expect("lost reader");

let (reader_factory, maybe_reader) = reader_factory
.read_row_group(
row_group_idx,
selection,
self.projection.clone(),
self.batch_size,
)
.await
.map_err(|err| {
self.state = StreamState::Error;
err
})?;
self.reader = Some(reader_factory);

if let Some(reader) = maybe_reader {
return Ok(Some(reader));
} else {
// All rows skipped, read next row group
continue;
}
}
StreamState::Error => return Ok(None), // Ends the stream as error happens.
}
}
}
}

impl<T> Stream for ParquetRecordBatchStream<T>
where
T: AsyncFileReader + Unpin + Send + 'static,
Expand Down Expand Up @@ -1020,6 +1083,71 @@ mod tests {
);
}

#[tokio::test]
async fn test_async_reader_with_next_row_group() {
let testdata = arrow::util::test_util::parquet_test_data();
let path = format!("{testdata}/alltypes_plain.parquet");
let data = Bytes::from(std::fs::read(path).unwrap());

let metadata = ParquetMetaDataReader::new()
.parse_and_finish(&data)
.unwrap();
let metadata = Arc::new(metadata);

assert_eq!(metadata.num_row_groups(), 1);

let async_reader = TestReader {
data: data.clone(),
metadata: metadata.clone(),
requests: Default::default(),
};

let requests = async_reader.requests.clone();
let builder = ParquetRecordBatchStreamBuilder::new(async_reader)
.await
.unwrap();

let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]);
let mut stream = builder
.with_projection(mask.clone())
.with_batch_size(1024)
.build()
.unwrap();

let mut readers = vec![];
while let Some(reader) = stream.next_row_group().await.unwrap() {
readers.push(reader);
}

let async_batches: Vec<_> = readers
.into_iter()
.flat_map(|r| r.map(|v| v.unwrap()).collect::<Vec<_>>())
.collect();

let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data)
.unwrap()
.with_projection(mask)
.with_batch_size(104)
.build()
.unwrap()
.collect::<ArrowResult<Vec<_>>>()
.unwrap();

assert_eq!(async_batches, sync_batches);

let requests = requests.lock().unwrap();
let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range();
let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range();

assert_eq!(
&requests[..],
&[
offset_1 as usize..(offset_1 + length_1) as usize,
offset_2 as usize..(offset_2 + length_2) as usize
]
);
}

#[tokio::test]
async fn test_async_reader_with_index() {
let testdata = arrow::util::test_util::parquet_test_data();
Expand Down
Loading