/Users/andrewlamb/Software/arrow-rs/arrow-avro/src/reader/header.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Decoder for [`Header`] |
19 | | |
20 | | use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; |
21 | | use crate::reader::vlq::VLQDecoder; |
22 | | use crate::schema::{Schema, SCHEMA_METADATA_KEY}; |
23 | | use arrow_schema::ArrowError; |
24 | | |
25 | | #[derive(Debug)] |
26 | | enum HeaderDecoderState { |
27 | | /// Decoding the [`MAGIC`] prefix |
28 | | Magic, |
29 | | /// Decoding a block count |
30 | | BlockCount, |
31 | | /// Decoding a block byte length |
32 | | BlockLen, |
33 | | /// Decoding a key length |
34 | | KeyLen, |
35 | | /// Decoding a key string |
36 | | Key, |
37 | | /// Decoding a value length |
38 | | ValueLen, |
39 | | /// Decoding a value payload |
40 | | Value, |
41 | | /// Decoding sync marker |
42 | | Sync, |
43 | | /// Finished decoding |
44 | | Finished, |
45 | | } |
46 | | |
47 | | /// A decoded header for an [Object Container File](https://avro.apache.org/docs/1.11.1/specification/#object-container-files) |
48 | | #[derive(Debug, Clone)] |
49 | | pub struct Header { |
50 | | meta_offsets: Vec<usize>, |
51 | | meta_buf: Vec<u8>, |
52 | | sync: [u8; 16], |
53 | | } |
54 | | |
55 | | impl Header { |
56 | | /// Returns an iterator over the meta keys in this header |
57 | 170 | pub fn metadata(&self) -> impl Iterator<Item = (&[u8], &[u8])> { |
58 | 170 | let mut last = 0; |
59 | 289 | self.meta_offsets.chunks_exact(2)170 .map170 (move |w| { |
60 | 289 | let start = last; |
61 | 289 | last = w[1]; |
62 | 289 | (&self.meta_buf[start..w[0]], &self.meta_buf[w[0]..w[1]]) |
63 | 289 | }) |
64 | 170 | } |
65 | | |
66 | | /// Returns the value for a given metadata key if present |
67 | 169 | pub fn get(&self, key: impl AsRef<[u8]>) -> Option<&[u8]> { |
68 | 169 | self.metadata() |
69 | 286 | .find_map169 (|(k, v)| (k == key.as_ref()).then_some(v)) |
70 | 169 | } |
71 | | |
72 | | /// Returns the sync token for this file |
73 | 2 | pub fn sync(&self) -> [u8; 16] { |
74 | 2 | self.sync |
75 | 2 | } |
76 | | |
77 | | /// Returns the [`CompressionCodec`] if any |
78 | 73 | pub fn compression(&self) -> Result<Option<CompressionCodec>, ArrowError> { |
79 | 73 | let v = self.get(CODEC_METADATA_KEY); |
80 | 73 | match v { |
81 | 73 | None | Some(b"null") => Ok(None)13 , |
82 | 60 | Some(b"deflate") => Ok(Some(CompressionCodec::Deflate))0 , |
83 | 60 | Some(b"snappy") => Ok(Some(CompressionCodec::Snappy))39 , |
84 | 21 | Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard))7 , |
85 | 14 | Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2))7 , |
86 | 7 | Some(b"xz") => Ok(Some(CompressionCodec::Xz)), |
87 | 0 | Some(v) => Err(ArrowError::ParseError(format!( |
88 | 0 | "Unrecognized compression codec \'{}\'", |
89 | 0 | String::from_utf8_lossy(v) |
90 | 0 | ))), |
91 | | } |
92 | 73 | } |
93 | | |
94 | | /// Returns the [`Schema`] if any |
95 | 94 | pub(crate) fn schema(&self) -> Result<Option<Schema<'_>>, ArrowError> { |
96 | 94 | self.get(SCHEMA_METADATA_KEY) |
97 | 94 | .map(|x| { |
98 | 94 | serde_json::from_slice(x).map_err(|e| {0 |
99 | 0 | ArrowError::ParseError(format!("Failed to parse Avro schema JSON: {e}")) |
100 | 0 | }) |
101 | 94 | }) |
102 | 94 | .transpose() |
103 | 94 | } |
104 | | } |
105 | | |
106 | | /// A decoder for [`Header`] |
107 | | /// |
108 | | /// The avro file format does not encode the length of the header, and so it |
109 | | /// is necessary to provide a push-based decoder that can be used with streams |
110 | | #[derive(Debug)] |
111 | | pub struct HeaderDecoder { |
112 | | state: HeaderDecoderState, |
113 | | vlq_decoder: VLQDecoder, |
114 | | |
115 | | /// The end offsets of strings in `meta_buf` |
116 | | meta_offsets: Vec<usize>, |
117 | | /// The raw binary data of the metadata map |
118 | | meta_buf: Vec<u8>, |
119 | | |
120 | | /// The decoded sync marker |
121 | | sync_marker: [u8; 16], |
122 | | |
123 | | /// The number of remaining tuples in the current block |
124 | | tuples_remaining: usize, |
125 | | /// The number of bytes remaining in the current string/bytes payload |
126 | | bytes_remaining: usize, |
127 | | } |
128 | | |
129 | | impl Default for HeaderDecoder { |
130 | 99 | fn default() -> Self { |
131 | 99 | Self { |
132 | 99 | state: HeaderDecoderState::Magic, |
133 | 99 | meta_offsets: vec![], |
134 | 99 | meta_buf: vec![], |
135 | 99 | sync_marker: [0; 16], |
136 | 99 | vlq_decoder: Default::default(), |
137 | 99 | tuples_remaining: 0, |
138 | 99 | bytes_remaining: MAGIC.len(), |
139 | 99 | } |
140 | 99 | } |
141 | | } |
142 | | |
143 | | const MAGIC: &[u8; 4] = b"Obj\x01"; |
144 | | |
145 | | impl HeaderDecoder { |
146 | | /// Parse [`Header`] from `buf`, returning the number of bytes read |
147 | | /// |
148 | | /// This method can be called multiple times with consecutive chunks of data, allowing |
149 | | /// integration with chunked IO systems like [`BufRead::fill_buf`] |
150 | | /// |
151 | | /// All errors should be considered fatal, and decoding aborted |
152 | | /// |
153 | | /// Once the entire [`Header`] has been decoded this method will not read any further |
154 | | /// input bytes, and the header can be obtained with [`Self::flush`] |
155 | | /// |
156 | | /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf |
157 | 111 | pub fn decode(&mut self, mut buf: &[u8]) -> Result<usize, ArrowError> { |
158 | 111 | let max_read = buf.len(); |
159 | 1.44k | while !buf.is_empty() { |
160 | 1.43k | match self.state { |
161 | | HeaderDecoderState::Magic => { |
162 | 103 | let remaining = &MAGIC[MAGIC.len() - self.bytes_remaining..]; |
163 | 103 | let to_decode = buf.len().min(remaining.len()); |
164 | 103 | if !buf.starts_with(&remaining[..to_decode]) { |
165 | 1 | return Err(ArrowError::ParseError("Incorrect avro magic".to_string())); |
166 | 102 | } |
167 | 102 | self.bytes_remaining -= to_decode; |
168 | 102 | buf = &buf[to_decode..]; |
169 | 102 | if self.bytes_remaining == 0 { |
170 | 98 | self.state = HeaderDecoderState::BlockCount; |
171 | 98 | }4 |
172 | | } |
173 | | HeaderDecoderState::BlockCount => { |
174 | 192 | if let Some(block_count) = self.vlq_decoder.long(&mut buf) { |
175 | 192 | match block_count.try_into() { |
176 | 96 | Ok(0) => { |
177 | 96 | self.state = HeaderDecoderState::Sync; |
178 | 96 | self.bytes_remaining = 16; |
179 | 96 | } |
180 | 96 | Ok(remaining) => { |
181 | 96 | self.tuples_remaining = remaining; |
182 | 96 | self.state = HeaderDecoderState::KeyLen; |
183 | 96 | } |
184 | 0 | Err(_) => { |
185 | 0 | self.tuples_remaining = block_count.unsigned_abs() as _; |
186 | 0 | self.state = HeaderDecoderState::BlockLen; |
187 | 0 | } |
188 | | } |
189 | 0 | } |
190 | | } |
191 | | HeaderDecoderState::BlockLen => { |
192 | 0 | if self.vlq_decoder.long(&mut buf).is_some() { |
193 | 0 | self.state = HeaderDecoderState::KeyLen |
194 | 0 | } |
195 | | } |
196 | | HeaderDecoderState::Key => { |
197 | 235 | let to_read = self.bytes_remaining.min(buf.len()); |
198 | 235 | self.meta_buf.extend_from_slice(&buf[..to_read]); |
199 | 235 | self.bytes_remaining -= to_read; |
200 | 235 | buf = &buf[to_read..]; |
201 | 235 | if self.bytes_remaining == 0 { |
202 | 235 | self.meta_offsets.push(self.meta_buf.len()); |
203 | 235 | self.state = HeaderDecoderState::ValueLen; |
204 | 235 | }0 |
205 | | } |
206 | | HeaderDecoderState::Value => { |
207 | 243 | let to_read = self.bytes_remaining.min(buf.len()); |
208 | 243 | self.meta_buf.extend_from_slice(&buf[..to_read]); |
209 | 243 | self.bytes_remaining -= to_read; |
210 | 243 | buf = &buf[to_read..]; |
211 | 243 | if self.bytes_remaining == 0 { |
212 | 235 | self.meta_offsets.push(self.meta_buf.len()); |
213 | | |
214 | 235 | self.tuples_remaining -= 1; |
215 | 235 | match self.tuples_remaining { |
216 | 96 | 0 => self.state = HeaderDecoderState::BlockCount, |
217 | 139 | _ => self.state = HeaderDecoderState::KeyLen, |
218 | | } |
219 | 8 | } |
220 | | } |
221 | | HeaderDecoderState::KeyLen => { |
222 | 235 | if let Some(len) = self.vlq_decoder.long(&mut buf) { |
223 | 235 | self.bytes_remaining = len as _; |
224 | 235 | self.state = HeaderDecoderState::Key; |
225 | 235 | }0 |
226 | | } |
227 | | HeaderDecoderState::ValueLen => { |
228 | 235 | if let Some(len) = self.vlq_decoder.long(&mut buf) { |
229 | 235 | self.bytes_remaining = len as _; |
230 | 235 | self.state = HeaderDecoderState::Value; |
231 | 235 | }0 |
232 | | } |
233 | | HeaderDecoderState::Sync => { |
234 | 96 | let to_decode = buf.len().min(self.bytes_remaining); |
235 | 96 | let write = &mut self.sync_marker[16 - to_decode..]; |
236 | 96 | write[..to_decode].copy_from_slice(&buf[..to_decode]); |
237 | 96 | self.bytes_remaining -= to_decode; |
238 | 96 | buf = &buf[to_decode..]; |
239 | 96 | if self.bytes_remaining == 0 { |
240 | 96 | self.state = HeaderDecoderState::Finished; |
241 | 96 | }0 |
242 | | } |
243 | 96 | HeaderDecoderState::Finished => return Ok(max_read - buf.len()), |
244 | | } |
245 | | } |
246 | 14 | Ok(max_read) |
247 | 111 | } |
248 | | |
249 | | /// Flush this decoder returning the parsed [`Header`] if any |
250 | 96 | pub fn flush(&mut self) -> Option<Header> { |
251 | 96 | match self.state { |
252 | | HeaderDecoderState::Finished => { |
253 | 96 | self.state = HeaderDecoderState::Magic; |
254 | 96 | Some(Header { |
255 | 96 | meta_offsets: std::mem::take(&mut self.meta_offsets), |
256 | 96 | meta_buf: std::mem::take(&mut self.meta_buf), |
257 | 96 | sync: self.sync_marker, |
258 | 96 | }) |
259 | | } |
260 | 0 | _ => None, |
261 | | } |
262 | 96 | } |
263 | | } |
264 | | |
265 | | #[cfg(test)] |
266 | | mod test { |
267 | | use super::*; |
268 | | use crate::codec::{AvroDataType, AvroField}; |
269 | | use crate::reader::read_header; |
270 | | use crate::schema::SCHEMA_METADATA_KEY; |
271 | | use crate::test_util::arrow_test_data; |
272 | | use arrow_schema::{DataType, Field, Fields, TimeUnit}; |
273 | | use std::fs::File; |
274 | | use std::io::{BufRead, BufReader}; |
275 | | |
276 | | #[test] |
277 | 1 | fn test_header_decode() { |
278 | 1 | let mut decoder = HeaderDecoder::default(); |
279 | 5 | for m4 in MAGIC { |
280 | 4 | decoder.decode(std::slice::from_ref(m)).unwrap(); |
281 | 4 | } |
282 | | |
283 | 1 | let mut decoder = HeaderDecoder::default(); |
284 | 1 | assert_eq!(decoder.decode(MAGIC).unwrap(), 4); |
285 | | |
286 | 1 | let mut decoder = HeaderDecoder::default(); |
287 | 1 | decoder.decode(b"Ob").unwrap(); |
288 | 1 | let err = decoder.decode(b"s").unwrap_err().to_string(); |
289 | 1 | assert_eq!(err, "Parser error: Incorrect avro magic"); |
290 | 1 | } |
291 | | |
292 | 2 | fn decode_file(file: &str) -> Header { |
293 | 2 | let file = File::open(file).unwrap(); |
294 | 2 | read_header(BufReader::with_capacity(100, file)).unwrap() |
295 | 2 | } |
296 | | |
297 | | #[test] |
298 | 1 | fn test_header() { |
299 | 1 | let header = decode_file(&arrow_test_data("avro/alltypes_plain.avro")); |
300 | 1 | let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); |
301 | 1 | let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#; |
302 | 1 | assert_eq!(schema_json, expected); |
303 | 1 | let schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); |
304 | 1 | let field = AvroField::try_from(&schema).unwrap(); |
305 | | |
306 | 1 | assert_eq!( |
307 | 1 | field.field(), |
308 | 1 | Field::new( |
309 | | "topLevelRecord", |
310 | 1 | DataType::Struct(Fields::from(vec![ |
311 | 1 | Field::new("id", DataType::Int32, true), |
312 | 1 | Field::new("bool_col", DataType::Boolean, true), |
313 | 1 | Field::new("tinyint_col", DataType::Int32, true), |
314 | 1 | Field::new("smallint_col", DataType::Int32, true), |
315 | 1 | Field::new("int_col", DataType::Int32, true), |
316 | 1 | Field::new("bigint_col", DataType::Int64, true), |
317 | 1 | Field::new("float_col", DataType::Float32, true), |
318 | 1 | Field::new("double_col", DataType::Float64, true), |
319 | 1 | Field::new("date_string_col", DataType::Binary, true), |
320 | 1 | Field::new("string_col", DataType::Binary, true), |
321 | 1 | Field::new( |
322 | 1 | "timestamp_col", |
323 | 1 | DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), |
324 | 1 | true |
325 | 1 | ), |
326 | 1 | ])), |
327 | | false |
328 | | ) |
329 | | ); |
330 | | |
331 | 1 | assert_eq!( |
332 | 1 | u128::from_le_bytes(header.sync()), |
333 | | 226966037233754408753420635932530907102 |
334 | | ); |
335 | | |
336 | 1 | let header = decode_file(&arrow_test_data("avro/fixed_length_decimal.avro")); |
337 | | |
338 | 1 | let meta: Vec<_> = header |
339 | 1 | .metadata() |
340 | 3 | .map1 (|(k, _)| std::str::from_utf8(k).unwrap()) |
341 | 1 | .collect(); |
342 | | |
343 | 1 | assert_eq!( |
344 | | meta, |
345 | | &["avro.schema", "org.apache.spark.version", "avro.codec"] |
346 | | ); |
347 | | |
348 | 1 | let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); |
349 | 1 | let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"value","type":[{"type":"fixed","name":"fixed","namespace":"topLevelRecord.value","size":11,"logicalType":"decimal","precision":25,"scale":2},"null"]}]}"#; |
350 | 1 | assert_eq!(schema_json, expected); |
351 | 1 | let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); |
352 | 1 | assert_eq!( |
353 | 1 | u128::from_le_bytes(header.sync()), |
354 | | 325166208089902833952788552656412487328 |
355 | | ); |
356 | 1 | } |
357 | | } |