/Users/andrewlamb/Software/arrow-rs/arrow-avro/src/writer/mod.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Avro writer implementation for the `arrow-avro` crate. |
19 | | //! |
20 | | //! # Overview |
21 | | //! |
22 | | //! * Use **`AvroWriter`** (Object Container File) when you want a |
23 | | //! self‑contained Avro file with header, schema JSON, optional compression, |
24 | | //! blocks, and sync markers. |
25 | | //! * Use **`AvroStreamWriter`** (raw binary stream) when you already know the |
26 | | //! schema out‑of‑band (i.e., via a schema registry) and need a stream |
27 | | //! of Avro‑encoded records with minimal framing. |
28 | | //! |
29 | | |
30 | | /// Encodes `RecordBatch` into the Avro binary format. |
31 | | pub mod encoder; |
32 | | /// Logic for different Avro container file formats. |
33 | | pub mod format; |
34 | | |
35 | | use crate::compression::CompressionCodec; |
36 | | use crate::schema::AvroSchema; |
37 | | use crate::writer::encoder::{encode_record_batch, write_long}; |
38 | | use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat}; |
39 | | use arrow_array::RecordBatch; |
40 | | use arrow_schema::{ArrowError, Schema}; |
41 | | use std::io::{self, Write}; |
42 | | use std::sync::Arc; |
43 | | |
44 | | /// Builder to configure and create a `Writer`. |
45 | | #[derive(Debug, Clone)] |
46 | | pub struct WriterBuilder { |
47 | | schema: Schema, |
48 | | codec: Option<CompressionCodec>, |
49 | | } |
50 | | |
51 | | impl WriterBuilder { |
52 | | /// Create a new builder with default settings. |
53 | 10 | pub fn new(schema: Schema) -> Self { |
54 | 10 | Self { |
55 | 10 | schema, |
56 | 10 | codec: None, |
57 | 10 | } |
58 | 10 | } |
59 | | |
60 | | /// Change the compression codec. |
61 | 0 | pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self { |
62 | 0 | self.codec = codec; |
63 | 0 | self |
64 | 0 | } |
65 | | |
66 | | /// Create a new `Writer` with specified `AvroFormat` and builder options. |
67 | 10 | pub fn build<W, F>(self, writer: W) -> Writer<W, F> |
68 | 10 | where |
69 | 10 | W: Write, |
70 | 10 | F: AvroFormat, |
71 | | { |
72 | 10 | Writer { |
73 | 10 | writer, |
74 | 10 | schema: Arc::from(self.schema), |
75 | 10 | format: F::default(), |
76 | 10 | compression: self.codec, |
77 | 10 | started: false, |
78 | 10 | } |
79 | 10 | } |
80 | | } |
81 | | |
82 | | /// Generic Avro writer. |
83 | | #[derive(Debug)] |
84 | | pub struct Writer<W: Write, F: AvroFormat> { |
85 | | writer: W, |
86 | | schema: Arc<Schema>, |
87 | | format: F, |
88 | | compression: Option<CompressionCodec>, |
89 | | started: bool, |
90 | | } |
91 | | |
92 | | /// Alias for an Avro **Object Container File** writer. |
93 | | pub type AvroWriter<W> = Writer<W, AvroOcfFormat>; |
94 | | /// Alias for a raw Avro **binary stream** writer. |
95 | | pub type AvroStreamWriter<W> = Writer<W, AvroBinaryFormat>; |
96 | | |
97 | | impl<W: Write> Writer<W, AvroOcfFormat> { |
98 | | /// Convenience constructor – same as |
99 | 10 | pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> { |
100 | 10 | Ok(WriterBuilder::new(schema).build::<W, AvroOcfFormat>(writer)) |
101 | 10 | } |
102 | | |
103 | | /// Change the compression codec after construction. |
104 | 4 | pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self { |
105 | 4 | self.compression = codec; |
106 | 4 | self |
107 | 4 | } |
108 | | |
109 | | /// Return a reference to the 16‑byte sync marker generated for this file. |
110 | 1 | pub fn sync_marker(&self) -> Option<&[u8; 16]> { |
111 | 1 | self.format.sync_marker() |
112 | 1 | } |
113 | | } |
114 | | |
115 | | impl<W: Write> Writer<W, AvroBinaryFormat> { |
116 | | /// Convenience constructor to create a new [`AvroStreamWriter`]. |
117 | 0 | pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> { |
118 | 0 | Ok(WriterBuilder::new(schema).build::<W, AvroBinaryFormat>(writer)) |
119 | 0 | } |
120 | | } |
121 | | |
122 | | impl<W: Write, F: AvroFormat> Writer<W, F> { |
123 | | /// Serialize one [`RecordBatch`] to the output. |
124 | 9 | pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { |
125 | 9 | if !self.started { |
126 | 8 | self.format |
127 | 8 | .start_stream(&mut self.writer, &self.schema, self.compression)?0 ; |
128 | 8 | self.started = true; |
129 | 1 | } |
130 | 9 | if batch.schema() != self.schema { |
131 | 1 | return Err(ArrowError::SchemaError( |
132 | 1 | "Schema of RecordBatch differs from Writer schema".to_string(), |
133 | 1 | )); |
134 | 8 | } |
135 | 8 | match self.format.sync_marker() { |
136 | 8 | Some(&sync) => self.write_ocf_block(batch, &sync), |
137 | 0 | None => self.write_stream(batch), |
138 | | } |
139 | 9 | } |
140 | | |
141 | | /// A convenience method to write a slice of [`RecordBatch`]. |
142 | | /// |
143 | | /// This is equivalent to calling `write` for each batch in the slice. |
144 | 1 | pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> { |
145 | 3 | for b2 in batches { |
146 | 2 | self.write(b)?0 ; |
147 | | } |
148 | 1 | Ok(()) |
149 | 1 | } |
150 | | |
151 | | /// Flush remaining buffered data and (for OCF) ensure the header is present. |
152 | 8 | pub fn finish(&mut self) -> Result<(), ArrowError> { |
153 | 8 | if !self.started { |
154 | 1 | self.format |
155 | 1 | .start_stream(&mut self.writer, &self.schema, self.compression)?0 ; |
156 | 1 | self.started = true; |
157 | 7 | } |
158 | 8 | self.writer |
159 | 8 | .flush() |
160 | 8 | .map_err(|e| ArrowError::IoError(format!0 ("Error flushing writer: {e}"0 ), e0 )) |
161 | 8 | } |
162 | | |
163 | | /// Consume the writer, returning the underlying output object. |
164 | 3 | pub fn into_inner(self) -> W { |
165 | 3 | self.writer |
166 | 3 | } |
167 | | |
168 | 8 | fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> { |
169 | 8 | let mut buf = Vec::<u8>::with_capacity(1024); |
170 | 8 | encode_record_batch(batch, &mut buf)?0 ; |
171 | 8 | let encoded = match self.compression { |
172 | 4 | Some(codec) => codec.compress(&buf)?0 , |
173 | 4 | None => buf, |
174 | | }; |
175 | 8 | write_long(&mut self.writer, batch.num_rows() as i64)?0 ; |
176 | 8 | write_long(&mut self.writer, encoded.len() as i64)?0 ; |
177 | 8 | self.writer |
178 | 8 | .write_all(&encoded) |
179 | 8 | .map_err(|e| ArrowError::IoError(format!0 ("Error writing Avro block: {e}"0 ), e0 ))?0 ; |
180 | 8 | self.writer |
181 | 8 | .write_all(sync) |
182 | 8 | .map_err(|e| ArrowError::IoError(format!0 ("Error writing Avro sync: {e}"0 ), e0 ))?0 ; |
183 | 8 | Ok(()) |
184 | 8 | } |
185 | | |
186 | 0 | fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { |
187 | 0 | encode_record_batch(batch, &mut self.writer) |
188 | 0 | } |
189 | | } |
190 | | |
191 | | #[cfg(test)] |
192 | | mod tests { |
193 | | use super::*; |
194 | | use crate::reader::ReaderBuilder; |
195 | | use crate::test_util::arrow_test_data; |
196 | | use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch, StringArray}; |
197 | | use arrow_schema::{DataType, Field, Schema}; |
198 | | use std::fs::File; |
199 | | use std::io::BufReader; |
200 | | use std::sync::Arc; |
201 | | use tempfile::NamedTempFile; |
202 | | |
203 | 8 | fn make_schema() -> Schema { |
204 | 8 | Schema::new(vec![ |
205 | 8 | Field::new("id", DataType::Int32, false), |
206 | 8 | Field::new("name", DataType::Binary, false), |
207 | | ]) |
208 | 8 | } |
209 | | |
210 | 4 | fn make_batch() -> RecordBatch { |
211 | 4 | let ids = Int32Array::from(vec![1, 2, 3]); |
212 | 4 | let names = BinaryArray::from_vec(vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()]); |
213 | 4 | RecordBatch::try_new( |
214 | 4 | Arc::new(make_schema()), |
215 | 4 | vec![Arc::new(ids) as ArrayRef, Arc::new(names) as ArrayRef], |
216 | | ) |
217 | 4 | .expect("failed to build test RecordBatch") |
218 | 4 | } |
219 | | |
220 | 0 | fn contains_ascii(haystack: &[u8], needle: &[u8]) -> bool { |
221 | 0 | haystack.windows(needle.len()).any(|w| w == needle) |
222 | 0 | } |
223 | | |
224 | | #[test] |
225 | 1 | fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> { |
226 | 1 | let batch = make_batch(); |
227 | 1 | let buffer: Vec<u8> = Vec::new(); |
228 | 1 | let mut writer = AvroWriter::new(buffer, make_schema())?0 ; |
229 | 1 | writer.write(&batch)?0 ; |
230 | 1 | writer.finish()?0 ; |
231 | 1 | let out = writer.into_inner(); |
232 | 1 | assert_eq!(&out[..4], b"Obj\x01", "OCF magic bytes missing/incorrect"0 ); |
233 | 1 | let sync = AvroWriter::new(Vec::new(), make_schema())?0 |
234 | 1 | .sync_marker() |
235 | 1 | .cloned(); |
236 | 1 | let trailer = &out[out.len() - 16..]; |
237 | 1 | assert_eq!(trailer.len(), 16, "expected 16‑byte sync marker"0 ); |
238 | 1 | let _ = sync; |
239 | 1 | Ok(()) |
240 | 1 | } |
241 | | |
242 | | #[test] |
243 | 1 | fn test_schema_mismatch_yields_error() { |
244 | 1 | let batch = make_batch(); |
245 | 1 | let alt_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]); |
246 | 1 | let buffer = Vec::<u8>::new(); |
247 | 1 | let mut writer = AvroWriter::new(buffer, alt_schema).unwrap(); |
248 | 1 | let err = writer.write(&batch).unwrap_err(); |
249 | 1 | assert!(matches!0 (err, ArrowError::SchemaError(_))); |
250 | 1 | } |
251 | | |
252 | | #[test] |
253 | 1 | fn test_write_batches_accumulates_multiple() -> Result<(), ArrowError> { |
254 | 1 | let batch1 = make_batch(); |
255 | 1 | let batch2 = make_batch(); |
256 | 1 | let buffer = Vec::<u8>::new(); |
257 | 1 | let mut writer = AvroWriter::new(buffer, make_schema())?0 ; |
258 | 1 | writer.write_batches(&[&batch1, &batch2])?0 ; |
259 | 1 | writer.finish()?0 ; |
260 | 1 | let out = writer.into_inner(); |
261 | 1 | assert!(out.len() > 4, "combined batches produced tiny file"0 ); |
262 | 1 | Ok(()) |
263 | 1 | } |
264 | | |
265 | | #[test] |
266 | 1 | fn test_finish_without_write_adds_header() -> Result<(), ArrowError> { |
267 | 1 | let buffer = Vec::<u8>::new(); |
268 | 1 | let mut writer = AvroWriter::new(buffer, make_schema())?0 ; |
269 | 1 | writer.finish()?0 ; |
270 | 1 | let out = writer.into_inner(); |
271 | 1 | assert_eq!(&out[..4], b"Obj\x01", "finish() should emit OCF header"0 ); |
272 | 1 | Ok(()) |
273 | 1 | } |
274 | | |
275 | | #[test] |
276 | 1 | fn test_write_long_encodes_zigzag_varint() -> Result<(), ArrowError> { |
277 | 1 | let mut buf = Vec::new(); |
278 | 1 | write_long(&mut buf, 0)?0 ; |
279 | 1 | write_long(&mut buf, -1)?0 ; |
280 | 1 | write_long(&mut buf, 1)?0 ; |
281 | 1 | write_long(&mut buf, -2)?0 ; |
282 | 1 | write_long(&mut buf, 2147483647)?0 ; |
283 | 1 | assert!( |
284 | 1 | buf.starts_with(&[0x00, 0x01, 0x02, 0x03]), |
285 | 0 | "zig‑zag varint encodings incorrect: {buf:?}" |
286 | | ); |
287 | 1 | Ok(()) |
288 | 1 | } |
289 | | |
290 | | #[test] |
291 | 1 | fn test_roundtrip_alltypes_roundtrip_writer() -> Result<(), ArrowError> { |
292 | 1 | let files = [ |
293 | 1 | "avro/alltypes_plain.avro", |
294 | 1 | "avro/alltypes_plain.snappy.avro", |
295 | 1 | "avro/alltypes_plain.zstandard.avro", |
296 | 1 | "avro/alltypes_plain.bzip2.avro", |
297 | 1 | "avro/alltypes_plain.xz.avro", |
298 | 1 | ]; |
299 | 6 | for rel5 in files { |
300 | 5 | let path = arrow_test_data(rel); |
301 | 5 | let rdr_file = File::open(&path).expect("open input avro"); |
302 | 5 | let mut reader = ReaderBuilder::new() |
303 | 5 | .build(BufReader::new(rdr_file)) |
304 | 5 | .expect("build reader"); |
305 | 5 | let schema = reader.schema(); |
306 | 5 | let input_batches = reader.collect::<Result<Vec<_>, _>>()?0 ; |
307 | 5 | let original = |
308 | 5 | arrow::compute::concat_batches(&schema, &input_batches).expect("concat input"); |
309 | 5 | let tmp = NamedTempFile::new().expect("create temp file"); |
310 | 5 | let out_path = tmp.into_temp_path(); |
311 | 5 | let out_file = File::create(&out_path).expect("create temp avro"); |
312 | 5 | let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?0 ; |
313 | 5 | if rel.contains(".snappy.") { |
314 | 1 | writer = writer.with_compression(Some(CompressionCodec::Snappy)); |
315 | 4 | } else if rel.contains(".zstandard.") { |
316 | 1 | writer = writer.with_compression(Some(CompressionCodec::ZStandard)); |
317 | 3 | } else if rel.contains(".bzip2.") { |
318 | 1 | writer = writer.with_compression(Some(CompressionCodec::Bzip2)); |
319 | 2 | } else if rel.contains(".xz.") { |
320 | 1 | writer = writer.with_compression(Some(CompressionCodec::Xz)); |
321 | 1 | } |
322 | 5 | writer.write(&original)?0 ; |
323 | 5 | writer.finish()?0 ; |
324 | 5 | drop(writer); |
325 | 5 | let rt_file = File::open(&out_path).expect("open roundtrip avro"); |
326 | 5 | let mut rt_reader = ReaderBuilder::new() |
327 | 5 | .build(BufReader::new(rt_file)) |
328 | 5 | .expect("build roundtrip reader"); |
329 | 5 | let rt_schema = rt_reader.schema(); |
330 | 5 | let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?0 ; |
331 | 5 | let roundtrip = |
332 | 5 | arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); |
333 | 5 | assert_eq!( |
334 | | roundtrip, original, |
335 | 0 | "Round-trip batch mismatch for file: {}", |
336 | | rel |
337 | | ); |
338 | | } |
339 | 1 | Ok(()) |
340 | 1 | } |
341 | | } |