Skip to content

Commit 02ecc33

Browse files
committed
Use const generics to remove BitTree heap allocations
1 parent a010cc0 commit 02ecc33

File tree

5 files changed

+117
-97
lines changed

5 files changed

+117
-97
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ env_logger = { version = "^0.8.3", optional = true }
1919

2020
[dev-dependencies]
2121
rust-lzma = "0.5"
22+
seq-macro = "0.3"
2223

2324
[features]
2425
enable_logging = ["env_logger", "log"]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[![Documentation](https://docs.rs/lzma-rs/badge.svg)](https://docs.rs/lzma-rs)
55
[![Safety Dance](https://img.shields.io/badge/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/)
66
![Build Status](https://github.com/gendx/lzma-rs/workflows/Build%20and%20run%20tests/badge.svg)
7-
[![Minimum rust 1.50](https://img.shields.io/badge/rust-1.50%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1500-2021-02-11)
7+
[![Minimum rust 1.51](https://img.shields.io/badge/rust-1.51%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1510-2021-03-25)
88

99
This project is a decoder for LZMA and its variants written in pure Rust, with focus on clarity.
1010
It already supports LZMA, LZMA2 and a subset of the `.xz` file format.

src/decode/lzma.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::decode::lzbuffer::{LzBuffer, LzCircularBuffer};
2-
use crate::decode::rangecoder::{BitTree, LenDecoder, RangeDecoder};
2+
use crate::decode::rangecoder::{bittree_probs_len, BitTree, LenDecoder, RangeDecoder};
33
use crate::decompress::{Options, UnpackedSize};
44
use crate::error;
55
use crate::util::vec2d::Vec2D;
@@ -167,8 +167,8 @@ pub(crate) struct DecoderState {
167167
pub(crate) lzma_props: LzmaProperties,
168168
unpacked_size: Option<u64>,
169169
literal_probs: Vec2D<u16>,
170-
pos_slot_decoder: [BitTree; 4],
171-
align_decoder: BitTree,
170+
pos_slot_decoder: [BitTree<6, { bittree_probs_len::<6>() }>; 4],
171+
align_decoder: BitTree<4, { bittree_probs_len::<4>() }>,
172172
pos_decoders: [u16; 115],
173173
is_match: [u16; 192], // true = LZ, false = literal
174174
is_rep: [u16; 12],
@@ -191,12 +191,12 @@ impl DecoderState {
191191
unpacked_size,
192192
literal_probs: Vec2D::init(0x400, (1 << (lzma_props.lc + lzma_props.lp), 0x300)),
193193
pos_slot_decoder: [
194-
BitTree::new(6),
195-
BitTree::new(6),
196-
BitTree::new(6),
197-
BitTree::new(6),
194+
BitTree::new(),
195+
BitTree::new(),
196+
BitTree::new(),
197+
BitTree::new(),
198198
],
199-
align_decoder: BitTree::new(4),
199+
align_decoder: BitTree::new(),
200200
pos_decoders: [0x400; 115],
201201
is_match: [0x400; 192],
202202
is_rep: [0x400; 12],
@@ -222,8 +222,13 @@ impl DecoderState {
222222
}
223223

224224
self.lzma_props = new_props;
225-
self.pos_slot_decoder.iter_mut().for_each(|t| t.reset());
226-
self.align_decoder.reset();
225+
self.pos_slot_decoder = [
226+
BitTree::new(),
227+
BitTree::new(),
228+
BitTree::new(),
229+
BitTree::new(),
230+
];
231+
self.align_decoder = BitTree::new();
227232
self.pos_decoders = [0x400; 115];
228233
self.is_match = [0x400; 192];
229234
self.is_rep = [0x400; 12];
@@ -233,8 +238,8 @@ impl DecoderState {
233238
self.is_rep_0long = [0x400; 192];
234239
self.state = 0;
235240
self.rep = [0; 4];
236-
self.len_decoder.reset();
237-
self.rep_len_decoder.reset();
241+
self.len_decoder = LenDecoder::new();
242+
self.rep_len_decoder = LenDecoder::new();
238243
}
239244

240245
pub fn set_unpacked_size(&mut self, unpacked_size: Option<u64>) {

src/decode/rangecoder.rs

Lines changed: 63 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,35 @@ where
150150
}
151151
}
152152

153-
// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this
153+
/// macro for compile-time const assertions
154+
macro_rules! const_assert {
155+
($($list:ident : $ty:ty),* => $expr:expr) => {{
156+
struct Assert<$(const $list: $ty,)*>;
157+
impl<$(const $list: $ty,)*> Assert<$($list,)*> {
158+
const OK: u8 = 0 - !($expr) as u8;
159+
}
160+
Assert::<$($list,)*>::OK
161+
}};
162+
($expr:expr) => {
163+
const OK: u8 = 0 - !($expr) as u8;
164+
};
165+
}
166+
167+
// const fn helper to parameterize the length of the bittree probability array.
168+
pub const fn bittree_probs_len<const NUM_BITS: usize>() -> usize {
169+
1 << NUM_BITS
170+
}
171+
154172
#[derive(Debug, Clone)]
155-
pub struct BitTree {
156-
num_bits: usize,
157-
probs: Vec<u16>,
173+
pub struct BitTree<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> {
174+
probs: [u16; PROBS_ARRAY_LEN],
158175
}
159176

160-
impl BitTree {
161-
pub fn new(num_bits: usize) -> Self {
177+
impl<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> BitTree<NUM_BITS, PROBS_ARRAY_LEN> {
178+
pub fn new() -> Self {
179+
const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == bittree_probs_len::<NUM_BITS>());
162180
BitTree {
163-
num_bits,
164-
probs: vec![0x400; 1 << num_bits],
181+
probs: [0x400; PROBS_ARRAY_LEN],
165182
}
166183
}
167184

@@ -170,29 +187,25 @@ impl BitTree {
170187
rangecoder: &mut RangeDecoder<R>,
171188
update: bool,
172189
) -> io::Result<u32> {
173-
rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice(), update)
190+
rangecoder.parse_bit_tree(NUM_BITS, &mut self.probs, update)
174191
}
175192

176193
pub fn parse_reverse<R: io::BufRead>(
177194
&mut self,
178195
rangecoder: &mut RangeDecoder<R>,
179196
update: bool,
180197
) -> io::Result<u32> {
181-
rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, update)
182-
}
183-
184-
pub fn reset(&mut self) {
185-
self.probs.fill(0x400);
198+
rangecoder.parse_reverse_bit_tree(NUM_BITS, &mut self.probs, 0, update)
186199
}
187200
}
188201

189202
#[derive(Debug)]
190203
pub struct LenDecoder {
191204
choice: u16,
192205
choice2: u16,
193-
low_coder: [BitTree; 16],
194-
mid_coder: [BitTree; 16],
195-
high_coder: BitTree,
206+
low_coder: [BitTree<3, { bittree_probs_len::<3>() }>; 16],
207+
mid_coder: [BitTree<3, { bittree_probs_len::<3>() }>; 16],
208+
high_coder: BitTree<8, { bittree_probs_len::<8>() }>,
196209
}
197210

198211
impl LenDecoder {
@@ -201,42 +214,42 @@ impl LenDecoder {
201214
choice: 0x400,
202215
choice2: 0x400,
203216
low_coder: [
204-
BitTree::new(3),
205-
BitTree::new(3),
206-
BitTree::new(3),
207-
BitTree::new(3),
208-
BitTree::new(3),
209-
BitTree::new(3),
210-
BitTree::new(3),
211-
BitTree::new(3),
212-
BitTree::new(3),
213-
BitTree::new(3),
214-
BitTree::new(3),
215-
BitTree::new(3),
216-
BitTree::new(3),
217-
BitTree::new(3),
218-
BitTree::new(3),
219-
BitTree::new(3),
217+
BitTree::new(),
218+
BitTree::new(),
219+
BitTree::new(),
220+
BitTree::new(),
221+
BitTree::new(),
222+
BitTree::new(),
223+
BitTree::new(),
224+
BitTree::new(),
225+
BitTree::new(),
226+
BitTree::new(),
227+
BitTree::new(),
228+
BitTree::new(),
229+
BitTree::new(),
230+
BitTree::new(),
231+
BitTree::new(),
232+
BitTree::new(),
220233
],
221234
mid_coder: [
222-
BitTree::new(3),
223-
BitTree::new(3),
224-
BitTree::new(3),
225-
BitTree::new(3),
226-
BitTree::new(3),
227-
BitTree::new(3),
228-
BitTree::new(3),
229-
BitTree::new(3),
230-
BitTree::new(3),
231-
BitTree::new(3),
232-
BitTree::new(3),
233-
BitTree::new(3),
234-
BitTree::new(3),
235-
BitTree::new(3),
236-
BitTree::new(3),
237-
BitTree::new(3),
235+
BitTree::new(),
236+
BitTree::new(),
237+
BitTree::new(),
238+
BitTree::new(),
239+
BitTree::new(),
240+
BitTree::new(),
241+
BitTree::new(),
242+
BitTree::new(),
243+
BitTree::new(),
244+
BitTree::new(),
245+
BitTree::new(),
246+
BitTree::new(),
247+
BitTree::new(),
248+
BitTree::new(),
249+
BitTree::new(),
250+
BitTree::new(),
238251
],
239-
high_coder: BitTree::new(8),
252+
high_coder: BitTree::new(),
240253
}
241254
}
242255

@@ -254,12 +267,4 @@ impl LenDecoder {
254267
Ok(self.high_coder.parse(rangecoder, update)? as usize + 16)
255268
}
256269
}
257-
258-
pub fn reset(&mut self) {
259-
self.choice = 0x400;
260-
self.choice2 = 0x400;
261-
self.low_coder.iter_mut().for_each(|t| t.reset());
262-
self.mid_coder.iter_mut().for_each(|t| t.reset());
263-
self.high_coder.reset();
264-
}
265270
}

src/encode/rangecoder.rs

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ mod test {
222222
use super::*;
223223
use crate::decode::rangecoder::{LenDecoder, RangeDecoder};
224224
use crate::{decode, encode};
225+
use seq_macro::seq;
225226
use std::io::BufReader;
226227

227228
fn encode_decode(prob_init: u16, bits: &[bool]) {
@@ -253,19 +254,19 @@ mod test {
253254
encode_decode(0x400, &[true; 10000]);
254255
}
255256

256-
fn encode_decode_bittree(num_bits: usize, values: &[u32]) {
257+
fn encode_decode_bittree<const NUM_BITS: usize, const PROBS_LEN: usize>(values: &[u32]) {
257258
let mut buf: Vec<u8> = Vec::new();
258259

259260
let mut encoder = RangeEncoder::new(&mut buf);
260-
let mut tree = encode::rangecoder::BitTree::new(num_bits);
261+
let mut tree = encode::rangecoder::BitTree::new(NUM_BITS);
261262
for &v in values {
262263
tree.encode(&mut encoder, v).unwrap();
263264
}
264265
encoder.finish().unwrap();
265266

266267
let mut bufread = BufReader::new(buf.as_slice());
267268
let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
268-
let mut tree = decode::rangecoder::BitTree::new(num_bits);
269+
let mut tree = decode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
269270
for &v in values {
270271
assert_eq!(tree.parse(&mut decoder, true).unwrap(), v);
271272
}
@@ -274,40 +275,45 @@ mod test {
274275

275276
#[test]
276277
fn test_encode_decode_bittree_zeros() {
277-
for num_bits in 0..16 {
278-
encode_decode_bittree(num_bits, &[0; 10000]);
279-
}
278+
seq!(NUM_BITS in 0..16 {
279+
encode_decode_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
280+
(&[0; 10000]);
281+
});
280282
}
281283

282284
#[test]
283285
fn test_encode_decode_bittree_ones() {
284-
for num_bits in 0..16 {
285-
encode_decode_bittree(num_bits, &[(1 << num_bits) - 1; 10000]);
286-
}
286+
seq!(NUM_BITS in 0..16 {
287+
encode_decode_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
288+
(&[(1 << NUM_BITS) - 1; 10000]);
289+
});
287290
}
288291

289292
#[test]
290293
fn test_encode_decode_bittree_all() {
291-
for num_bits in 0..16 {
292-
let max = 1 << num_bits;
294+
seq!(NUM_BITS in 0..16 {
295+
let max = 1 << NUM_BITS;
293296
let values: Vec<u32> = (0..max).collect();
294-
encode_decode_bittree(num_bits, &values);
295-
}
297+
encode_decode_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
298+
(&values);
299+
});
296300
}
297301

298-
fn encode_decode_reverse_bittree(num_bits: usize, values: &[u32]) {
302+
fn encode_decode_reverse_bittree<const NUM_BITS: usize, const PROBS_LEN: usize>(
303+
values: &[u32],
304+
) {
299305
let mut buf: Vec<u8> = Vec::new();
300306

301307
let mut encoder = RangeEncoder::new(&mut buf);
302-
let mut tree = encode::rangecoder::BitTree::new(num_bits);
308+
let mut tree = encode::rangecoder::BitTree::new(NUM_BITS);
303309
for &v in values {
304310
tree.encode_reverse(&mut encoder, v).unwrap();
305311
}
306312
encoder.finish().unwrap();
307313

308314
let mut bufread = BufReader::new(buf.as_slice());
309315
let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
310-
let mut tree = decode::rangecoder::BitTree::new(num_bits);
316+
let mut tree = decode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
311317
for &v in values {
312318
assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v);
313319
}
@@ -316,25 +322,28 @@ mod test {
316322

317323
#[test]
318324
fn test_encode_decode_reverse_bittree_zeros() {
319-
for num_bits in 0..16 {
320-
encode_decode_reverse_bittree(num_bits, &[0; 10000]);
321-
}
325+
seq!(NUM_BITS in 0..16 {
326+
encode_decode_reverse_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
327+
(&[0; 10000]);
328+
});
322329
}
323330

324331
#[test]
325332
fn test_encode_decode_reverse_bittree_ones() {
326-
for num_bits in 0..16 {
327-
encode_decode_reverse_bittree(num_bits, &[(1 << num_bits) - 1; 10000]);
328-
}
333+
seq!(NUM_BITS in 0..16 {
334+
encode_decode_reverse_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
335+
(&[(1 << NUM_BITS) - 1; 10000]);
336+
});
329337
}
330338

331339
#[test]
332340
fn test_encode_decode_reverse_bittree_all() {
333-
for num_bits in 0..16 {
334-
let max = 1 << num_bits;
341+
seq!(NUM_BITS in 0..16 {
342+
let max = 1 << NUM_BITS;
335343
let values: Vec<u32> = (0..max).collect();
336-
encode_decode_reverse_bittree(num_bits, &values);
337-
}
344+
encode_decode_reverse_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
345+
(&values);
346+
});
338347
}
339348

340349
fn encode_decode_length(pos_state: usize, values: &[u32]) {

0 commit comments

Comments
 (0)