/Users/andrewlamb/Software/arrow-rs/arrow-buffer/src/util/bit_mask.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Utils for working with packed bit masks |
19 | | |
20 | | use crate::bit_util::ceil; |
21 | | |
22 | | /// Util function to set bits in a slice of bytes. |
23 | | /// |
24 | | /// This will sets all bits on `write_data` in the range `[offset_write..offset_write+len]` |
25 | | /// to be equal to the bits in `data` in the range `[offset_read..offset_read+len]` |
26 | | /// returns the number of `0` bits `data[offset_read..offset_read+len]` |
27 | | /// `offset_write`, `offset_read`, and `len` are in terms of bits |
28 | 73 | pub fn set_bits( |
29 | 73 | write_data: &mut [u8], |
30 | 73 | data: &[u8], |
31 | 73 | offset_write: usize, |
32 | 73 | offset_read: usize, |
33 | 73 | len: usize, |
34 | 73 | ) -> usize { |
35 | 73 | assert!(offset_write + len <= write_data.len() * 8); |
36 | 73 | assert!(offset_read + len <= data.len() * 8); |
37 | 73 | let mut null_count = 0; |
38 | 73 | let mut acc = 0; |
39 | 145 | while len > acc { |
40 | 72 | // SAFETY: the arguments to `set_upto_64bits` are within the valid range because |
41 | 72 | // (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8 |
42 | 72 | // (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8 |
43 | 72 | let (n, len_set) = unsafe { |
44 | 72 | set_upto_64bits( |
45 | 72 | write_data, |
46 | 72 | data, |
47 | 72 | offset_write + acc, |
48 | 72 | offset_read + acc, |
49 | 72 | len - acc, |
50 | 72 | ) |
51 | 72 | }; |
52 | 72 | null_count += n; |
53 | 72 | acc += len_set; |
54 | 72 | } |
55 | | |
56 | 73 | null_count |
57 | 73 | } |
58 | | |
59 | | /// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary. |
60 | | /// Returns a pair of the number of `0` bits and the number of bits set |
61 | | /// |
62 | | /// # Safety |
63 | | /// The caller must ensure all arguments are within the valid range. |
64 | | #[inline] |
65 | 72 | unsafe fn set_upto_64bits( |
66 | 72 | write_data: &mut [u8], |
67 | 72 | data: &[u8], |
68 | 72 | offset_write: usize, |
69 | 72 | offset_read: usize, |
70 | 72 | len: usize, |
71 | 72 | ) -> (usize, usize) { |
72 | 72 | let read_byte = offset_read / 8; |
73 | 72 | let read_shift = offset_read % 8; |
74 | 72 | let write_byte = offset_write / 8; |
75 | 72 | let write_shift = offset_write % 8; |
76 | | |
77 | 72 | if len >= 64 { |
78 | 0 | let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() }; |
79 | 0 | if read_shift == 0 { |
80 | 0 | if write_shift == 0 { |
81 | | // no shifting necessary |
82 | 0 | let len = 64; |
83 | 0 | let null_count = chunk.count_zeros() as usize; |
84 | 0 | unsafe { write_u64_bytes(write_data, write_byte, chunk) }; |
85 | 0 | (null_count, len) |
86 | | } else { |
87 | | // only write shifting necessary |
88 | 0 | let len = 64 - write_shift; |
89 | 0 | let chunk = chunk << write_shift; |
90 | 0 | let null_count = len - chunk.count_ones() as usize; |
91 | 0 | unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; |
92 | 0 | (null_count, len) |
93 | | } |
94 | 0 | } else if write_shift == 0 { |
95 | | // only read shifting necessary |
96 | 0 | let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0 |
97 | 0 | let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask |
98 | 0 | let null_count = len - chunk.count_ones() as usize; |
99 | 0 | unsafe { write_u64_bytes(write_data, write_byte, chunk) }; |
100 | 0 | (null_count, len) |
101 | | } else { |
102 | 0 | let len = 64 - std::cmp::max(read_shift, write_shift); |
103 | 0 | let chunk = (chunk >> read_shift) << write_shift; |
104 | 0 | let null_count = len - chunk.count_ones() as usize; |
105 | 0 | unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; |
106 | 0 | (null_count, len) |
107 | | } |
108 | 72 | } else if len == 1 { |
109 | 46 | let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1; |
110 | 46 | unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift }; |
111 | 46 | ((byte_chunk ^ 1) as usize, 1) |
112 | | } else { |
113 | 26 | let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift)); |
114 | 26 | let bytes = ceil(len + read_shift, 8); |
115 | | // SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len() |
116 | 26 | let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) }; |
117 | 26 | let mask = u64::MAX >> (64 - len); |
118 | 26 | let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only |
119 | 26 | let chunk = chunk << write_shift; // shifting back to align with `write_data` |
120 | 26 | let null_count = len - chunk.count_ones() as usize; |
121 | 26 | let bytes = ceil(len + write_shift, 8); |
122 | 29 | for (i, c) in chunk.to_le_bytes()26 .iter26 ().enumerate26 ().take26 (bytes26 ) { |
123 | 29 | unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c }; |
124 | 29 | } |
125 | 26 | (null_count, len) |
126 | | } |
127 | 72 | } |
128 | | |
129 | | /// # Safety |
130 | | /// The caller must ensure `data` has `offset..(offset + 8)` range, and `count <= 8`. |
131 | | #[inline] |
132 | 26 | unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 { |
133 | 26 | debug_assert!(count <= 8); |
134 | 26 | let mut tmp: u64 = 0; |
135 | 26 | let src = unsafe { data.as_ptr().add(offset) }; |
136 | 26 | unsafe { std::ptr::copy_nonoverlapping(src, &mut tmp as *mut _ as *mut u8, count) }; |
137 | 26 | tmp |
138 | 26 | } |
139 | | |
140 | | /// # Safety |
141 | | /// The caller must ensure `data` has `offset..(offset + 8)` range |
142 | | #[inline] |
143 | 0 | unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { |
144 | 0 | let ptr = unsafe { data.as_mut_ptr().add(offset) } as *mut u64; |
145 | 0 | unsafe { ptr.write_unaligned(chunk) }; |
146 | 0 | } |
147 | | |
148 | | /// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk` |
149 | | /// instead of overwriting |
150 | | /// |
151 | | /// # Safety |
152 | | /// The caller must ensure `data` has `offset..(offset + 8)` range |
153 | | #[inline] |
154 | 0 | unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { |
155 | 0 | let ptr = unsafe { data.as_mut_ptr().add(offset) }; |
156 | 0 | let chunk = chunk | (unsafe { *ptr }) as u64; |
157 | 0 | unsafe { (ptr as *mut u64).write_unaligned(chunk) }; |
158 | 0 | } |
159 | | |
160 | | #[cfg(test)] |
161 | | mod tests { |
162 | | use super::*; |
163 | | use crate::bit_util::{get_bit, set_bit, unset_bit}; |
164 | | use rand::prelude::StdRng; |
165 | | use rand::{Rng, SeedableRng, TryRngCore}; |
166 | | use std::fmt::Display; |
167 | | |
168 | | #[test] |
169 | | fn test_set_bits_aligned() { |
170 | | SetBitsTest { |
171 | | write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
172 | | data: vec![ |
173 | | 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, |
174 | | 0b10100101, |
175 | | ], |
176 | | offset_write: 8, |
177 | | offset_read: 0, |
178 | | len: 64, |
179 | | expected_data: vec![ |
180 | | 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, |
181 | | 0b11100111, 0b10100101, 0, |
182 | | ], |
183 | | expected_null_count: 24, |
184 | | } |
185 | | .verify(); |
186 | | } |
187 | | |
188 | | #[test] |
189 | | fn test_set_bits_unaligned_destination_start() { |
190 | | SetBitsTest { |
191 | | write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
192 | | data: vec![ |
193 | | 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, |
194 | | 0b10100101, |
195 | | ], |
196 | | offset_write: 3, |
197 | | offset_read: 0, |
198 | | len: 64, |
199 | | expected_data: vec![ |
200 | | 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110, |
201 | | 0b00101111, 0b00000101, 0b00000000, |
202 | | ], |
203 | | expected_null_count: 24, |
204 | | } |
205 | | .verify(); |
206 | | } |
207 | | |
208 | | #[test] |
209 | | fn test_set_bits_unaligned_destination_end() { |
210 | | SetBitsTest { |
211 | | write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
212 | | data: vec![ |
213 | | 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, |
214 | | 0b10100101, |
215 | | ], |
216 | | offset_write: 8, |
217 | | offset_read: 0, |
218 | | len: 62, |
219 | | expected_data: vec![ |
220 | | 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, |
221 | | 0b11100111, 0b00100101, 0, |
222 | | ], |
223 | | expected_null_count: 23, |
224 | | } |
225 | | .verify(); |
226 | | } |
227 | | |
228 | | #[test] |
229 | | fn test_set_bits_unaligned() { |
230 | | SetBitsTest { |
231 | | write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
232 | | data: vec![ |
233 | | 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, |
234 | | 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101, |
235 | | 0b10011001, 0b11011011, 0b11101011, 0b11000011, |
236 | | ], |
237 | | offset_write: 3, |
238 | | offset_read: 5, |
239 | | len: 95, |
240 | | expected_data: vec![ |
241 | | 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001, |
242 | | 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001, |
243 | | ], |
244 | | expected_null_count: 35, |
245 | | } |
246 | | .verify(); |
247 | | } |
248 | | |
249 | | #[test] |
250 | | fn set_bits_fuzz() { |
251 | | let mut rng = StdRng::seed_from_u64(42); |
252 | | let mut data = SetBitsTest::new(); |
253 | | for _ in 0..100 { |
254 | | data.regen(&mut rng); |
255 | | data.verify(); |
256 | | } |
257 | | } |
258 | | |
259 | | #[derive(Debug, Default)] |
260 | | struct SetBitsTest { |
261 | | /// target write data |
262 | | write_data: Vec<u8>, |
263 | | /// source data |
264 | | data: Vec<u8>, |
265 | | offset_write: usize, |
266 | | offset_read: usize, |
267 | | len: usize, |
268 | | /// the expected contents of write_data after the test |
269 | | expected_data: Vec<u8>, |
270 | | /// the expected number of nulls copied at the end of the test |
271 | | expected_null_count: usize, |
272 | | } |
273 | | |
274 | | /// prints a byte slice as a binary string like "01010101 10101010" |
275 | | struct BinaryFormatter<'a>(&'a [u8]); |
276 | | impl Display for BinaryFormatter<'_> { |
277 | | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
278 | | for byte in self.0 { |
279 | | write!(f, "{byte:08b} ")?; |
280 | | } |
281 | | write!(f, " ")?; |
282 | | Ok(()) |
283 | | } |
284 | | } |
285 | | |
286 | | impl Display for SetBitsTest { |
287 | | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
288 | | writeln!(f, "SetBitsTest {{")?; |
289 | | writeln!(f, " write_data: {}", BinaryFormatter(&self.write_data))?; |
290 | | writeln!(f, " data: {}", BinaryFormatter(&self.data))?; |
291 | | writeln!( |
292 | | f, |
293 | | " expected_data: {}", |
294 | | BinaryFormatter(&self.expected_data) |
295 | | )?; |
296 | | writeln!(f, " offset_write: {}", self.offset_write)?; |
297 | | writeln!(f, " offset_read: {}", self.offset_read)?; |
298 | | writeln!(f, " len: {}", self.len)?; |
299 | | writeln!(f, " expected_null_count: {}", self.expected_null_count)?; |
300 | | writeln!(f, "}}") |
301 | | } |
302 | | } |
303 | | |
304 | | impl SetBitsTest { |
305 | | /// create a new instance of FuzzData |
306 | | fn new() -> Self { |
307 | | Self::default() |
308 | | } |
309 | | |
310 | | /// Update this instance's fields with randomly selected values and expected data |
311 | | fn regen(&mut self, rng: &mut StdRng) { |
312 | | // (read) data |
313 | | // ------------------+-----------------+------- |
314 | | // .. offset_read .. | data | ... |
315 | | // ------------------+-----------------+------- |
316 | | |
317 | | // Write data |
318 | | // -------------------+-----------------+------- |
319 | | // .. offset_write .. | (data to write) | ... |
320 | | // -------------------+-----------------+------- |
321 | | |
322 | | // length of data to copy |
323 | | let len = rng.random_range(0..=200); |
324 | | |
325 | | // randomly pick where we will write to |
326 | | let offset_write_bits = rng.random_range(0..=200); |
327 | | let offset_write_bytes = if offset_write_bits % 8 == 0 { |
328 | | offset_write_bits / 8 |
329 | | } else { |
330 | | (offset_write_bits / 8) + 1 |
331 | | }; |
332 | | let extra_write_data_bytes = rng.random_range(0..=5); // ensure 0 shows up often |
333 | | |
334 | | // randomly decide where we will read from |
335 | | let extra_read_data_bytes = rng.random_range(0..=5); // make sure 0 shows up often |
336 | | let offset_read_bits = rng.random_range(0..=200); |
337 | | let offset_read_bytes = if offset_read_bits % 8 != 0 { |
338 | | (offset_read_bits / 8) + 1 |
339 | | } else { |
340 | | offset_read_bits / 8 |
341 | | }; |
342 | | |
343 | | // create space for writing |
344 | | self.write_data.clear(); |
345 | | self.write_data |
346 | | .resize(offset_write_bytes + len + extra_write_data_bytes, 0); |
347 | | |
348 | | // interestingly set_bits seems to assume the output is already zeroed |
349 | | // the fuzz tests fail when this is uncommented |
350 | | //self.write_data.try_fill(rng).unwrap(); |
351 | | self.offset_write = offset_write_bits; |
352 | | |
353 | | // make source data |
354 | | self.data |
355 | | .resize(offset_read_bytes + len + extra_read_data_bytes, 0); |
356 | | // fill source data with random bytes |
357 | | rng.try_fill_bytes(self.data.as_mut_slice()).unwrap(); |
358 | | self.offset_read = offset_read_bits; |
359 | | |
360 | | self.len = len; |
361 | | |
362 | | // generated expectated output (not efficient) |
363 | | self.expected_data.resize(self.write_data.len(), 0); |
364 | | self.expected_data.copy_from_slice(&self.write_data); |
365 | | |
366 | | self.expected_null_count = 0; |
367 | | for i in 0..self.len { |
368 | | let bit = get_bit(&self.data, self.offset_read + i); |
369 | | if bit { |
370 | | set_bit(&mut self.expected_data, self.offset_write + i); |
371 | | } else { |
372 | | unset_bit(&mut self.expected_data, self.offset_write + i); |
373 | | self.expected_null_count += 1; |
374 | | } |
375 | | } |
376 | | } |
377 | | |
378 | | /// call set_bits with the given parameters and compare with the expected output |
379 | | fn verify(&self) { |
380 | | // call set_bits and compare |
381 | | let mut actual = self.write_data.to_vec(); |
382 | | let null_count = set_bits( |
383 | | &mut actual, |
384 | | &self.data, |
385 | | self.offset_write, |
386 | | self.offset_read, |
387 | | self.len, |
388 | | ); |
389 | | |
390 | | assert_eq!(actual, self.expected_data, "self: {self}"); |
391 | | assert_eq!(null_count, self.expected_null_count, "self: {self}"); |
392 | | } |
393 | | } |
394 | | |
395 | | #[test] |
396 | | fn test_set_upto_64bits() { |
397 | | // len >= 64 |
398 | | let write_data: &mut [u8] = &mut [0; 9]; |
399 | | let data: &[u8] = &[ |
400 | | 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, |
401 | | 0b00000001, 0b00000001, |
402 | | ]; |
403 | | let offset_write = 1; |
404 | | let offset_read = 0; |
405 | | let len = 65; |
406 | | let (n, len_set) = |
407 | | unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; |
408 | | assert_eq!(n, 55); |
409 | | assert_eq!(len_set, 63); |
410 | | assert_eq!( |
411 | | write_data, |
412 | | &[ |
413 | | 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, |
414 | | 0b00000010, 0b00000000 |
415 | | ] |
416 | | ); |
417 | | |
418 | | // len = 1 |
419 | | let write_data: &mut [u8] = &mut [0b00000000]; |
420 | | let data: &[u8] = &[0b00000001]; |
421 | | let offset_write = 1; |
422 | | let offset_read = 0; |
423 | | let len = 1; |
424 | | let (n, len_set) = |
425 | | unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; |
426 | | assert_eq!(n, 0); |
427 | | assert_eq!(len_set, 1); |
428 | | assert_eq!(write_data, &[0b00000010]); |
429 | | } |
430 | | } |