Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-avro/src/writer/mod.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Avro writer implementation for the `arrow-avro` crate.
19
//!
20
//! # Overview
21
//!
22
//! *   Use **`AvroWriter`** (Object Container File) when you want a
23
//!     self‑contained Avro file with header, schema JSON, optional compression,
24
//!     blocks, and sync markers.
25
//! *   Use **`AvroStreamWriter`** (raw binary stream) when you already know the
26
//!     schema out‑of‑band (i.e., via a schema registry) and need a stream
27
//!     of Avro‑encoded records with minimal framing.
28
//!
29
30
/// Encodes `RecordBatch` into the Avro binary format.
31
pub mod encoder;
32
/// Logic for different Avro container file formats.
33
pub mod format;
34
35
use crate::compression::CompressionCodec;
36
use crate::schema::AvroSchema;
37
use crate::writer::encoder::{encode_record_batch, write_long};
38
use crate::writer::format::{AvroBinaryFormat, AvroFormat, AvroOcfFormat};
39
use arrow_array::RecordBatch;
40
use arrow_schema::{ArrowError, Schema};
41
use std::io::{self, Write};
42
use std::sync::Arc;
43
44
/// Builder to configure and create a `Writer`.
45
#[derive(Debug, Clone)]
46
pub struct WriterBuilder {
47
    schema: Schema,
48
    codec: Option<CompressionCodec>,
49
}
50
51
impl WriterBuilder {
52
    /// Create a new builder with default settings.
53
10
    pub fn new(schema: Schema) -> Self {
54
10
        Self {
55
10
            schema,
56
10
            codec: None,
57
10
        }
58
10
    }
59
60
    /// Change the compression codec.
61
0
    pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self {
62
0
        self.codec = codec;
63
0
        self
64
0
    }
65
66
    /// Create a new `Writer` with specified `AvroFormat` and builder options.
67
10
    pub fn build<W, F>(self, writer: W) -> Writer<W, F>
68
10
    where
69
10
        W: Write,
70
10
        F: AvroFormat,
71
    {
72
10
        Writer {
73
10
            writer,
74
10
            schema: Arc::from(self.schema),
75
10
            format: F::default(),
76
10
            compression: self.codec,
77
10
            started: false,
78
10
        }
79
10
    }
80
}
81
82
/// Generic Avro writer.
83
#[derive(Debug)]
84
pub struct Writer<W: Write, F: AvroFormat> {
85
    writer: W,
86
    schema: Arc<Schema>,
87
    format: F,
88
    compression: Option<CompressionCodec>,
89
    started: bool,
90
}
91
92
/// Alias for an Avro **Object Container File** writer.
93
pub type AvroWriter<W> = Writer<W, AvroOcfFormat>;
94
/// Alias for a raw Avro **binary stream** writer.
95
pub type AvroStreamWriter<W> = Writer<W, AvroBinaryFormat>;
96
97
impl<W: Write> Writer<W, AvroOcfFormat> {
98
    /// Convenience constructor – same as
99
10
    pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> {
100
10
        Ok(WriterBuilder::new(schema).build::<W, AvroOcfFormat>(writer))
101
10
    }
102
103
    /// Change the compression codec after construction.
104
4
    pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self {
105
4
        self.compression = codec;
106
4
        self
107
4
    }
108
109
    /// Return a reference to the 16‑byte sync marker generated for this file.
110
1
    pub fn sync_marker(&self) -> Option<&[u8; 16]> {
111
1
        self.format.sync_marker()
112
1
    }
113
}
114
115
impl<W: Write> Writer<W, AvroBinaryFormat> {
116
    /// Convenience constructor to create a new [`AvroStreamWriter`].
117
0
    pub fn new(writer: W, schema: Schema) -> Result<Self, ArrowError> {
118
0
        Ok(WriterBuilder::new(schema).build::<W, AvroBinaryFormat>(writer))
119
0
    }
120
}
121
122
impl<W: Write, F: AvroFormat> Writer<W, F> {
123
    /// Serialize one [`RecordBatch`] to the output.
124
9
    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
125
9
        if !self.started {
126
8
            self.format
127
8
                .start_stream(&mut self.writer, &self.schema, self.compression)
?0
;
128
8
            self.started = true;
129
1
        }
130
9
        if batch.schema() != self.schema {
131
1
            return Err(ArrowError::SchemaError(
132
1
                "Schema of RecordBatch differs from Writer schema".to_string(),
133
1
            ));
134
8
        }
135
8
        match self.format.sync_marker() {
136
8
            Some(&sync) => self.write_ocf_block(batch, &sync),
137
0
            None => self.write_stream(batch),
138
        }
139
9
    }
140
141
    /// A convenience method to write a slice of [`RecordBatch`].
142
    ///
143
    /// This is equivalent to calling `write` for each batch in the slice.
144
1
    pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> {
145
3
        for 
b2
in batches {
146
2
            self.write(b)
?0
;
147
        }
148
1
        Ok(())
149
1
    }
150
151
    /// Flush remaining buffered data and (for OCF) ensure the header is present.
152
8
    pub fn finish(&mut self) -> Result<(), ArrowError> {
153
8
        if !self.started {
154
1
            self.format
155
1
                .start_stream(&mut self.writer, &self.schema, self.compression)
?0
;
156
1
            self.started = true;
157
7
        }
158
8
        self.writer
159
8
            .flush()
160
8
            .map_err(|e| ArrowError::IoError(
format!0
(
"Error flushing writer: {e}"0
),
e0
))
161
8
    }
162
163
    /// Consume the writer, returning the underlying output object.
164
3
    pub fn into_inner(self) -> W {
165
3
        self.writer
166
3
    }
167
168
8
    fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> {
169
8
        let mut buf = Vec::<u8>::with_capacity(1024);
170
8
        encode_record_batch(batch, &mut buf)
?0
;
171
8
        let encoded = match self.compression {
172
4
            Some(codec) => codec.compress(&buf)
?0
,
173
4
            None => buf,
174
        };
175
8
        write_long(&mut self.writer, batch.num_rows() as i64)
?0
;
176
8
        write_long(&mut self.writer, encoded.len() as i64)
?0
;
177
8
        self.writer
178
8
            .write_all(&encoded)
179
8
            .map_err(|e| ArrowError::IoError(
format!0
(
"Error writing Avro block: {e}"0
),
e0
))
?0
;
180
8
        self.writer
181
8
            .write_all(sync)
182
8
            .map_err(|e| ArrowError::IoError(
format!0
(
"Error writing Avro sync: {e}"0
),
e0
))
?0
;
183
8
        Ok(())
184
8
    }
185
186
0
    fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
187
0
        encode_record_batch(batch, &mut self.writer)
188
0
    }
189
}
190
191
#[cfg(test)]
192
mod tests {
193
    use super::*;
194
    use crate::reader::ReaderBuilder;
195
    use crate::test_util::arrow_test_data;
196
    use arrow_array::{ArrayRef, BinaryArray, Int32Array, RecordBatch, StringArray};
197
    use arrow_schema::{DataType, Field, Schema};
198
    use std::fs::File;
199
    use std::io::BufReader;
200
    use std::sync::Arc;
201
    use tempfile::NamedTempFile;
202
203
8
    fn make_schema() -> Schema {
204
8
        Schema::new(vec![
205
8
            Field::new("id", DataType::Int32, false),
206
8
            Field::new("name", DataType::Binary, false),
207
        ])
208
8
    }
209
210
4
    fn make_batch() -> RecordBatch {
211
4
        let ids = Int32Array::from(vec![1, 2, 3]);
212
4
        let names = BinaryArray::from_vec(vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()]);
213
4
        RecordBatch::try_new(
214
4
            Arc::new(make_schema()),
215
4
            vec![Arc::new(ids) as ArrayRef, Arc::new(names) as ArrayRef],
216
        )
217
4
        .expect("failed to build test RecordBatch")
218
4
    }
219
220
0
    fn contains_ascii(haystack: &[u8], needle: &[u8]) -> bool {
221
0
        haystack.windows(needle.len()).any(|w| w == needle)
222
0
    }
223
224
    #[test]
225
1
    fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> {
226
1
        let batch = make_batch();
227
1
        let buffer: Vec<u8> = Vec::new();
228
1
        let mut writer = AvroWriter::new(buffer, make_schema())
?0
;
229
1
        writer.write(&batch)
?0
;
230
1
        writer.finish()
?0
;
231
1
        let out = writer.into_inner();
232
1
        assert_eq!(&out[..4], b"Obj\x01", 
"OCF magic bytes missing/incorrect"0
);
233
1
        let sync = AvroWriter::new(Vec::new(), make_schema())
?0
234
1
            .sync_marker()
235
1
            .cloned();
236
1
        let trailer = &out[out.len() - 16..];
237
1
        assert_eq!(trailer.len(), 16, 
"expected 16‑byte sync marker"0
);
238
1
        let _ = sync;
239
1
        Ok(())
240
1
    }
241
242
    #[test]
243
1
    fn test_schema_mismatch_yields_error() {
244
1
        let batch = make_batch();
245
1
        let alt_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]);
246
1
        let buffer = Vec::<u8>::new();
247
1
        let mut writer = AvroWriter::new(buffer, alt_schema).unwrap();
248
1
        let err = writer.write(&batch).unwrap_err();
249
1
        assert!(
matches!0
(err, ArrowError::SchemaError(_)));
250
1
    }
251
252
    #[test]
253
1
    fn test_write_batches_accumulates_multiple() -> Result<(), ArrowError> {
254
1
        let batch1 = make_batch();
255
1
        let batch2 = make_batch();
256
1
        let buffer = Vec::<u8>::new();
257
1
        let mut writer = AvroWriter::new(buffer, make_schema())
?0
;
258
1
        writer.write_batches(&[&batch1, &batch2])
?0
;
259
1
        writer.finish()
?0
;
260
1
        let out = writer.into_inner();
261
1
        assert!(out.len() > 4, 
"combined batches produced tiny file"0
);
262
1
        Ok(())
263
1
    }
264
265
    #[test]
266
1
    fn test_finish_without_write_adds_header() -> Result<(), ArrowError> {
267
1
        let buffer = Vec::<u8>::new();
268
1
        let mut writer = AvroWriter::new(buffer, make_schema())
?0
;
269
1
        writer.finish()
?0
;
270
1
        let out = writer.into_inner();
271
1
        assert_eq!(&out[..4], b"Obj\x01", 
"finish() should emit OCF header"0
);
272
1
        Ok(())
273
1
    }
274
275
    #[test]
276
1
    fn test_write_long_encodes_zigzag_varint() -> Result<(), ArrowError> {
277
1
        let mut buf = Vec::new();
278
1
        write_long(&mut buf, 0)
?0
;
279
1
        write_long(&mut buf, -1)
?0
;
280
1
        write_long(&mut buf, 1)
?0
;
281
1
        write_long(&mut buf, -2)
?0
;
282
1
        write_long(&mut buf, 2147483647)
?0
;
283
1
        assert!(
284
1
            buf.starts_with(&[0x00, 0x01, 0x02, 0x03]),
285
0
            "zig‑zag varint encodings incorrect: {buf:?}"
286
        );
287
1
        Ok(())
288
1
    }
289
290
    #[test]
291
1
    fn test_roundtrip_alltypes_roundtrip_writer() -> Result<(), ArrowError> {
292
1
        let files = [
293
1
            "avro/alltypes_plain.avro",
294
1
            "avro/alltypes_plain.snappy.avro",
295
1
            "avro/alltypes_plain.zstandard.avro",
296
1
            "avro/alltypes_plain.bzip2.avro",
297
1
            "avro/alltypes_plain.xz.avro",
298
1
        ];
299
6
        for 
rel5
in files {
300
5
            let path = arrow_test_data(rel);
301
5
            let rdr_file = File::open(&path).expect("open input avro");
302
5
            let mut reader = ReaderBuilder::new()
303
5
                .build(BufReader::new(rdr_file))
304
5
                .expect("build reader");
305
5
            let schema = reader.schema();
306
5
            let input_batches = reader.collect::<Result<Vec<_>, _>>()
?0
;
307
5
            let original =
308
5
                arrow::compute::concat_batches(&schema, &input_batches).expect("concat input");
309
5
            let tmp = NamedTempFile::new().expect("create temp file");
310
5
            let out_path = tmp.into_temp_path();
311
5
            let out_file = File::create(&out_path).expect("create temp avro");
312
5
            let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())
?0
;
313
5
            if rel.contains(".snappy.") {
314
1
                writer = writer.with_compression(Some(CompressionCodec::Snappy));
315
4
            } else if rel.contains(".zstandard.") {
316
1
                writer = writer.with_compression(Some(CompressionCodec::ZStandard));
317
3
            } else if rel.contains(".bzip2.") {
318
1
                writer = writer.with_compression(Some(CompressionCodec::Bzip2));
319
2
            } else if rel.contains(".xz.") {
320
1
                writer = writer.with_compression(Some(CompressionCodec::Xz));
321
1
            }
322
5
            writer.write(&original)
?0
;
323
5
            writer.finish()
?0
;
324
5
            drop(writer);
325
5
            let rt_file = File::open(&out_path).expect("open roundtrip avro");
326
5
            let mut rt_reader = ReaderBuilder::new()
327
5
                .build(BufReader::new(rt_file))
328
5
                .expect("build roundtrip reader");
329
5
            let rt_schema = rt_reader.schema();
330
5
            let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()
?0
;
331
5
            let roundtrip =
332
5
                arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
333
5
            assert_eq!(
334
                roundtrip, original,
335
0
                "Round-trip batch mismatch for file: {}",
336
                rel
337
            );
338
        }
339
1
        Ok(())
340
1
    }
341
}