Skip to content

Commit 7a2613a

Browse files
committed
Add option to skip validation when reading IPC streams/files
1 parent e08ad17 commit 7a2613a

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

arrow-ipc/src/reader.rs

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use arrow_data::ArrayData;
4040
use arrow_schema::*;
4141

4242
use crate::compression::CompressionCodec;
43+
use crate::reader::private::UnsafeFlag;
4344
use crate::{Block, FieldNode, Message, MetadataVersion, CONTINUATION_MARKER};
4445
use DataType::*;
4546

@@ -65,6 +66,41 @@ fn read_buffer(
6566
(false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data),
6667
}
6768
}
69+
70+
mod private {
71+
/// A boolean flag that cannot be mutated outside of unsafe code.
72+
///
73+
/// Defaults to a value of false.
74+
///
75+
/// This structure is used to enforce safety in the various readers
76+
#[derive(Debug, Clone, Copy)]
77+
pub struct UnsafeFlag(bool);
78+
79+
impl Default for UnsafeFlag {
80+
fn default() -> Self {
81+
Self::new()
82+
}
83+
}
84+
85+
impl UnsafeFlag {
86+
/// Creates a new `UnsafeFlag` with the value set to `false`
87+
#[inline]
88+
pub const fn new() -> Self {
89+
Self(false)
90+
}
91+
92+
#[inline]
93+
pub unsafe fn set(&mut self, val: bool) {
94+
self.0 = val;
95+
}
96+
97+
#[inline]
98+
pub fn get(&self) -> bool {
99+
self.0
100+
}
101+
}
102+
}
103+
68104
impl RecordBatchDecoder<'_> {
69105
/// Coordinates reading arrays based on data types.
70106
///
@@ -376,6 +412,18 @@ struct RecordBatchDecoder<'a> {
376412
/// Are buffers required to already be aligned? See
377413
/// [`RecordBatchDecoder::with_require_alignment`] for details
378414
require_alignment: bool,
415+
/// Should validation be skipped when reading data?
416+
///
417+
/// Defaults to false.
418+
///
419+
/// If true [`ArrayData::validate`] is not called after reading
420+
///
421+
/// # Safety
422+
///
423+
/// This flag can only be set to true using `unsafe` APIs. However, once true
424+
/// subsequent calls to `build()` may result in undefined behavior if the data
425+
/// is not valid.
426+
skip_validation: UnsafeFlag,
379427
}
380428

381429
impl<'a> RecordBatchDecoder<'a> {
@@ -410,6 +458,7 @@ impl<'a> RecordBatchDecoder<'a> {
410458
buffers: buffers.iter(),
411459
projection: None,
412460
require_alignment: false,
461+
skip_validation: UnsafeFlag::new(),
413462
})
414463
}
415464

@@ -432,6 +481,21 @@ impl<'a> RecordBatchDecoder<'a> {
432481
self
433482
}
434483

484+
/// Set skip_validation (default: false)
485+
///
486+
/// Note this is a pub(crate) API and can not be used outside of this crate
487+
///
488+
/// If true, validation is skipped.
489+
///
490+
/// # Safety
491+
///
492+
/// Relies on `UnsafeFlag` to enforce safety -- can only be enabled via
493+
/// unsafe APIs.
494+
pub(crate) fn with_skip_validation(mut self, skip_validation: UnsafeFlag) -> Self {
495+
self.skip_validation = skip_validation;
496+
self
497+
}
498+
435499
/// Read the record batch, consuming the reader
436500
fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
437501
let mut variadic_counts: VecDeque<i64> = self
@@ -601,7 +665,16 @@ pub fn read_dictionary(
601665
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
602666
metadata: &MetadataVersion,
603667
) -> Result<(), ArrowError> {
604-
read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata, false)
668+
let skip_validation = UnsafeFlag::new(); // do not skip valididation
669+
read_dictionary_impl(
670+
buf,
671+
batch,
672+
schema,
673+
dictionaries_by_id,
674+
metadata,
675+
false,
676+
skip_validation,
677+
)
605678
}
606679

607680
fn read_dictionary_impl(
@@ -611,6 +684,7 @@ fn read_dictionary_impl(
611684
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
612685
metadata: &MetadataVersion,
613686
require_alignment: bool,
687+
skip_validation: UnsafeFlag,
614688
) -> Result<(), ArrowError> {
615689
if batch.isDelta() {
616690
return Err(ArrowError::InvalidArgumentError(
@@ -642,6 +716,7 @@ fn read_dictionary_impl(
642716
metadata,
643717
)?
644718
.with_require_alignment(require_alignment)
719+
.with_skip_validation(skip_validation)
645720
.read_record_batch()?;
646721

647722
Some(record_batch.column(0).clone())
@@ -772,6 +847,7 @@ pub struct FileDecoder {
772847
version: MetadataVersion,
773848
projection: Option<Vec<usize>>,
774849
require_alignment: bool,
850+
skip_validation: UnsafeFlag,
775851
}
776852

777853
impl FileDecoder {
@@ -783,6 +859,7 @@ impl FileDecoder {
783859
dictionaries: Default::default(),
784860
projection: None,
785861
require_alignment: false,
862+
skip_validation: UnsafeFlag::new(),
786863
}
787864
}
788865

@@ -809,6 +886,21 @@ impl FileDecoder {
809886
self
810887
}
811888

889+
/// Specifies whether validation should be skipped when reading data (default to `false`)
890+
///
891+
/// # Safety
892+
///
893+
/// This flag must only be set to `true` when you trust and are sure the data you are
894+
/// reading is a valid Arrow IPC file, otherwise undefined behavior may
895+
/// result.
896+
///
897+
/// For example, some programs may wish to trust reading IPC files written
898+
/// by the same process that created the files.
899+
pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
900+
self.skip_validation.set(skip_validation);
901+
self
902+
}
903+
812904
fn read_message<'a>(&self, buf: &'a [u8]) -> Result<Message<'a>, ArrowError> {
813905
let message = parse_message(buf)?;
814906

@@ -834,6 +926,7 @@ impl FileDecoder {
834926
&mut self.dictionaries,
835927
&message.version(),
836928
self.require_alignment,
929+
self.skip_validation,
837930
)
838931
}
839932
t => Err(ArrowError::ParseError(format!(
@@ -1250,6 +1343,9 @@ pub struct StreamReader<R> {
12501343

12511344
/// Optional projection
12521345
projection: Option<(Vec<usize>, Schema)>,
1346+
1347+
/// Should the reader skip validation
1348+
skip_validation: UnsafeFlag,
12531349
}
12541350

12551351
impl<R> fmt::Debug for StreamReader<R> {
@@ -1329,6 +1425,7 @@ impl<R: Read> StreamReader<R> {
13291425
finished: false,
13301426
dictionaries_by_id,
13311427
projection,
1428+
skip_validation: UnsafeFlag::new(),
13321429
})
13331430
}
13341431

@@ -1437,6 +1534,7 @@ impl<R: Read> StreamReader<R> {
14371534
&mut self.dictionaries_by_id,
14381535
&message.version(),
14391536
false,
1537+
self.skip_validation,
14401538
)?;
14411539

14421540
// read the next message until we encounter a RecordBatch
@@ -1462,6 +1560,21 @@ impl<R: Read> StreamReader<R> {
14621560
pub fn get_mut(&mut self) -> &mut R {
14631561
&mut self.reader
14641562
}
1563+
1564+
/// Specifies whether validation should be skipped when reading data (default to `false`)
1565+
///
1566+
/// # Safety
1567+
///
1568+
/// This flag must only be set to `true` when you trust and are sure the data you are
1569+
/// reading is a valid Arrow IPC file, otherwise undefined behavior may
1570+
/// result.
1571+
///
1572+
/// For example, some programs may wish to trust reading IPC files written
1573+
/// by the same process that created the files.
1574+
pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
1575+
self.skip_validation.set(skip_validation);
1576+
self
1577+
}
14651578
}
14661579

14671580
impl<R: Read> Iterator for StreamReader<R> {

arrow-ipc/src/reader/stream.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use arrow_buffer::{Buffer, MutableBuffer};
2424
use arrow_schema::{ArrowError, SchemaRef};
2525

2626
use crate::convert::MessageBuffer;
27+
use crate::reader::private::UnsafeFlag;
2728
use crate::reader::{read_dictionary_impl, RecordBatchDecoder};
2829
use crate::{MessageHeader, CONTINUATION_MARKER};
2930

@@ -42,6 +43,8 @@ pub struct StreamDecoder {
4243
buf: MutableBuffer,
4344
/// Whether or not array data in input buffers are required to be aligned
4445
require_alignment: bool,
46+
/// Should we skip validation when reading arrays?
47+
skip_validation: UnsafeFlag,
4548
}
4649

4750
#[derive(Debug)]
@@ -102,6 +105,21 @@ impl StreamDecoder {
102105
self
103106
}
104107

108+
/// Specifies whether validation should be skipped when reading data (default to `false`)
109+
///
110+
/// # Safety
111+
///
112+
/// This flag must only be set to `true` when you trust and are sure the data you are
113+
/// reading is a valid Arrow IPC stream, otherwise undefined behavior may
114+
/// result.
115+
///
116+
/// For example, some programs may wish to trust reading IPC streams written
117+
/// by the same process that created the files.
118+
pub unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self {
119+
self.skip_validation.set(skip_validation);
120+
self
121+
}
122+
105123
/// Try to read the next [`RecordBatch`] from the provided [`Buffer`]
106124
///
107125
/// [`Buffer::advance`] will be called on `buffer` for any consumed bytes.
@@ -219,6 +237,7 @@ impl StreamDecoder {
219237
&version,
220238
)?
221239
.with_require_alignment(self.require_alignment)
240+
.with_skip_validation(self.skip_validation)
222241
.read_record_batch()?;
223242
self.state = DecoderState::default();
224243
return Ok(Some(batch));
@@ -235,6 +254,7 @@ impl StreamDecoder {
235254
&mut self.dictionaries,
236255
&version,
237256
self.require_alignment,
257+
self.skip_validation,
238258
)?;
239259
self.state = DecoderState::default();
240260
}

0 commit comments

Comments
 (0)