/Users/andrewlamb/Software/arrow-rs/arrow-buffer/src/util/bit_mask.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Utils for working with packed bit masks |
19 | | |
20 | | use crate::bit_util::ceil; |
21 | | |
22 | | /// Util function to set bits in a slice of bytes. |
23 | | /// |
24 | | /// This will sets all bits on `write_data` in the range `[offset_write..offset_write+len]` |
25 | | /// to be equal to the bits in `data` in the range `[offset_read..offset_read+len]` |
26 | | /// returns the number of `0` bits `data[offset_read..offset_read+len]` |
27 | | /// `offset_write`, `offset_read`, and `len` are in terms of bits |
28 | 57 | pub fn set_bits( |
29 | 57 | write_data: &mut [u8], |
30 | 57 | data: &[u8], |
31 | 57 | offset_write: usize, |
32 | 57 | offset_read: usize, |
33 | 57 | len: usize, |
34 | 57 | ) -> usize { |
35 | 57 | assert!(offset_write + len <= write_data.len() * 8); |
36 | 57 | assert!(offset_read + len <= data.len() * 8); |
37 | 57 | let mut null_count = 0; |
38 | 57 | let mut acc = 0; |
39 | 114 | while len > acc { |
40 | 57 | // SAFETY: the arguments to `set_upto_64bits` are within the valid range because |
41 | 57 | // (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8 |
42 | 57 | // (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8 |
43 | 57 | let (n, len_set) = unsafe { |
44 | 57 | set_upto_64bits( |
45 | 57 | write_data, |
46 | 57 | data, |
47 | 57 | offset_write + acc, |
48 | 57 | offset_read + acc, |
49 | 57 | len - acc, |
50 | 57 | ) |
51 | 57 | }; |
52 | 57 | null_count += n; |
53 | 57 | acc += len_set; |
54 | 57 | } |
55 | | |
56 | 57 | null_count |
57 | 57 | } |
58 | | |
59 | | /// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary. |
60 | | /// Returns a pair of the number of `0` bits and the number of bits set |
61 | | /// |
62 | | /// # Safety |
63 | | /// The caller must ensure all arguments are within the valid range. |
64 | | #[inline] |
65 | 57 | unsafe fn set_upto_64bits( |
66 | 57 | write_data: &mut [u8], |
67 | 57 | data: &[u8], |
68 | 57 | offset_write: usize, |
69 | 57 | offset_read: usize, |
70 | 57 | len: usize, |
71 | 57 | ) -> (usize, usize) { |
72 | 57 | let read_byte = offset_read / 8; |
73 | 57 | let read_shift = offset_read % 8; |
74 | 57 | let write_byte = offset_write / 8; |
75 | 57 | let write_shift = offset_write % 8; |
76 | | |
77 | 57 | if len >= 64 { |
78 | 0 | let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() }; |
79 | 0 | if read_shift == 0 { |
80 | 0 | if write_shift == 0 { |
81 | | // no shifting necessary |
82 | 0 | let len = 64; |
83 | 0 | let null_count = chunk.count_zeros() as usize; |
84 | 0 | unsafe { write_u64_bytes(write_data, write_byte, chunk) }; |
85 | 0 | (null_count, len) |
86 | | } else { |
87 | | // only write shifting necessary |
88 | 0 | let len = 64 - write_shift; |
89 | 0 | let chunk = chunk << write_shift; |
90 | 0 | let null_count = len - chunk.count_ones() as usize; |
91 | 0 | unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; |
92 | 0 | (null_count, len) |
93 | | } |
94 | 0 | } else if write_shift == 0 { |
95 | | // only read shifting necessary |
96 | 0 | let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0 |
97 | 0 | let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask |
98 | 0 | let null_count = len - chunk.count_ones() as usize; |
99 | 0 | unsafe { write_u64_bytes(write_data, write_byte, chunk) }; |
100 | 0 | (null_count, len) |
101 | | } else { |
102 | 0 | let len = 64 - std::cmp::max(read_shift, write_shift); |
103 | 0 | let chunk = (chunk >> read_shift) << write_shift; |
104 | 0 | let null_count = len - chunk.count_ones() as usize; |
105 | 0 | unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; |
106 | 0 | (null_count, len) |
107 | | } |
108 | 57 | } else if len == 1 { |
109 | 7 | let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1; |
110 | 7 | unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift }; |
111 | 7 | ((byte_chunk ^ 1) as usize, 1) |
112 | | } else { |
113 | 50 | let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift)); |
114 | 50 | let bytes = ceil(len + read_shift, 8); |
115 | | // SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len() |
116 | 50 | let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) }; |
117 | 50 | let mask = u64::MAX >> (64 - len); |
118 | 50 | let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only |
119 | 50 | let chunk = chunk << write_shift; // shifting back to align with `write_data` |
120 | 50 | let null_count = len - chunk.count_ones() as usize; |
121 | 50 | let bytes = ceil(len + write_shift, 8); |
122 | 57 | for (i, c) in chunk.to_le_bytes()50 .iter50 ().enumerate50 ().take50 (bytes50 ) { |
123 | 57 | unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c }; |
124 | 57 | } |
125 | 50 | (null_count, len) |
126 | | } |
127 | 57 | } |
128 | | |
129 | | /// # Safety |
130 | | /// The caller must ensure `data` has `offset..(offset + 8)` range, and `count <= 8`. |
131 | | #[inline] |
132 | 50 | unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 { |
133 | 50 | debug_assert!(count <= 8); |
134 | 50 | let mut tmp: u64 = 0; |
135 | 50 | let src = data.as_ptr().add(offset); |
136 | 50 | unsafe { |
137 | 50 | std::ptr::copy_nonoverlapping(src, &mut tmp as *mut _ as *mut u8, count); |
138 | 50 | } |
139 | 50 | tmp |
140 | 50 | } |
141 | | |
142 | | /// # Safety |
143 | | /// The caller must ensure `data` has `offset..(offset + 8)` range |
144 | | #[inline] |
145 | 0 | unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { |
146 | 0 | let ptr = data.as_mut_ptr().add(offset) as *mut u64; |
147 | 0 | ptr.write_unaligned(chunk); |
148 | 0 | } |
149 | | |
150 | | /// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk` |
151 | | /// instead of overwriting |
152 | | /// |
153 | | /// # Safety |
154 | | /// The caller must ensure `data` has `offset..(offset + 8)` range |
155 | | #[inline] |
156 | 0 | unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { |
157 | 0 | let ptr = data.as_mut_ptr().add(offset); |
158 | 0 | let chunk = chunk | (*ptr) as u64; |
159 | 0 | (ptr as *mut u64).write_unaligned(chunk); |
160 | 0 | } |
161 | | |
162 | | #[cfg(test)] |
163 | | mod tests { |
164 | | use super::*; |
165 | | use crate::bit_util::{get_bit, set_bit, unset_bit}; |
166 | | use rand::prelude::StdRng; |
167 | | use rand::{Rng, SeedableRng, TryRngCore}; |
168 | | use std::fmt::Display; |
169 | | |
170 | | #[test] |
171 | | fn test_set_bits_aligned() { |
172 | | SetBitsTest { |
173 | | write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
174 | | data: vec![ |
175 | | 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, |
176 | | 0b10100101, |
177 | | ], |
178 | | offset_write: 8, |
179 | | offset_read: 0, |
180 | | len: 64, |
181 | | expected_data: vec![ |
182 | | 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, |
183 | | 0b11100111, 0b10100101, 0, |
184 | | ], |
185 | | expected_null_count: 24, |
186 | | } |
187 | | .verify(); |
188 | | } |
189 | | |
190 | | #[test] |
191 | | fn test_set_bits_unaligned_destination_start() { |
192 | | SetBitsTest { |
193 | | write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
194 | | data: vec![ |
195 | | 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, |
196 | | 0b10100101, |
197 | | ], |
198 | | offset_write: 3, |
199 | | offset_read: 0, |
200 | | len: 64, |
201 | | expected_data: vec![ |
202 | | 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, 0b00111110, |
203 | | 0b00101111, 0b00000101, 0b00000000, |
204 | | ], |
205 | | expected_null_count: 24, |
206 | | } |
207 | | .verify(); |
208 | | } |
209 | | |
210 | | #[test] |
211 | | fn test_set_bits_unaligned_destination_end() { |
212 | | SetBitsTest { |
213 | | write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
214 | | data: vec![ |
215 | | 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, |
216 | | 0b10100101, |
217 | | ], |
218 | | offset_write: 8, |
219 | | offset_read: 0, |
220 | | len: 62, |
221 | | expected_data: vec![ |
222 | | 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, |
223 | | 0b11100111, 0b00100101, 0, |
224 | | ], |
225 | | expected_null_count: 23, |
226 | | } |
227 | | .verify(); |
228 | | } |
229 | | |
230 | | #[test] |
231 | | fn test_set_bits_unaligned() { |
232 | | SetBitsTest { |
233 | | write_data: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
234 | | data: vec![ |
235 | | 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, |
236 | | 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, 0b11100111, 0b10100101, |
237 | | 0b10011001, 0b11011011, 0b11101011, 0b11000011, |
238 | | ], |
239 | | offset_write: 3, |
240 | | offset_read: 5, |
241 | | len: 95, |
242 | | expected_data: vec![ |
243 | | 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b01111001, |
244 | | 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, 0b00000001, |
245 | | ], |
246 | | expected_null_count: 35, |
247 | | } |
248 | | .verify(); |
249 | | } |
250 | | |
251 | | #[test] |
252 | | fn set_bits_fuzz() { |
253 | | let mut rng = StdRng::seed_from_u64(42); |
254 | | let mut data = SetBitsTest::new(); |
255 | | for _ in 0..100 { |
256 | | data.regen(&mut rng); |
257 | | data.verify(); |
258 | | } |
259 | | } |
260 | | |
261 | | #[derive(Debug, Default)] |
262 | | struct SetBitsTest { |
263 | | /// target write data |
264 | | write_data: Vec<u8>, |
265 | | /// source data |
266 | | data: Vec<u8>, |
267 | | offset_write: usize, |
268 | | offset_read: usize, |
269 | | len: usize, |
270 | | /// the expected contents of write_data after the test |
271 | | expected_data: Vec<u8>, |
272 | | /// the expected number of nulls copied at the end of the test |
273 | | expected_null_count: usize, |
274 | | } |
275 | | |
276 | | /// prints a byte slice as a binary string like "01010101 10101010" |
277 | | struct BinaryFormatter<'a>(&'a [u8]); |
278 | | impl Display for BinaryFormatter<'_> { |
279 | | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
280 | | for byte in self.0 { |
281 | | write!(f, "{byte:08b} ")?; |
282 | | } |
283 | | write!(f, " ")?; |
284 | | Ok(()) |
285 | | } |
286 | | } |
287 | | |
288 | | impl Display for SetBitsTest { |
289 | | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
290 | | writeln!(f, "SetBitsTest {{")?; |
291 | | writeln!(f, " write_data: {}", BinaryFormatter(&self.write_data))?; |
292 | | writeln!(f, " data: {}", BinaryFormatter(&self.data))?; |
293 | | writeln!( |
294 | | f, |
295 | | " expected_data: {}", |
296 | | BinaryFormatter(&self.expected_data) |
297 | | )?; |
298 | | writeln!(f, " offset_write: {}", self.offset_write)?; |
299 | | writeln!(f, " offset_read: {}", self.offset_read)?; |
300 | | writeln!(f, " len: {}", self.len)?; |
301 | | writeln!(f, " expected_null_count: {}", self.expected_null_count)?; |
302 | | writeln!(f, "}}") |
303 | | } |
304 | | } |
305 | | |
306 | | impl SetBitsTest { |
307 | | /// create a new instance of FuzzData |
308 | | fn new() -> Self { |
309 | | Self::default() |
310 | | } |
311 | | |
312 | | /// Update this instance's fields with randomly selected values and expected data |
313 | | fn regen(&mut self, rng: &mut StdRng) { |
314 | | // (read) data |
315 | | // ------------------+-----------------+------- |
316 | | // .. offset_read .. | data | ... |
317 | | // ------------------+-----------------+------- |
318 | | |
319 | | // Write data |
320 | | // -------------------+-----------------+------- |
321 | | // .. offset_write .. | (data to write) | ... |
322 | | // -------------------+-----------------+------- |
323 | | |
324 | | // length of data to copy |
325 | | let len = rng.random_range(0..=200); |
326 | | |
327 | | // randomly pick where we will write to |
328 | | let offset_write_bits = rng.random_range(0..=200); |
329 | | let offset_write_bytes = if offset_write_bits % 8 == 0 { |
330 | | offset_write_bits / 8 |
331 | | } else { |
332 | | (offset_write_bits / 8) + 1 |
333 | | }; |
334 | | let extra_write_data_bytes = rng.random_range(0..=5); // ensure 0 shows up often |
335 | | |
336 | | // randomly decide where we will read from |
337 | | let extra_read_data_bytes = rng.random_range(0..=5); // make sure 0 shows up often |
338 | | let offset_read_bits = rng.random_range(0..=200); |
339 | | let offset_read_bytes = if offset_read_bits % 8 != 0 { |
340 | | (offset_read_bits / 8) + 1 |
341 | | } else { |
342 | | offset_read_bits / 8 |
343 | | }; |
344 | | |
345 | | // create space for writing |
346 | | self.write_data.clear(); |
347 | | self.write_data |
348 | | .resize(offset_write_bytes + len + extra_write_data_bytes, 0); |
349 | | |
350 | | // interestingly set_bits seems to assume the output is already zeroed |
351 | | // the fuzz tests fail when this is uncommented |
352 | | //self.write_data.try_fill(rng).unwrap(); |
353 | | self.offset_write = offset_write_bits; |
354 | | |
355 | | // make source data |
356 | | self.data |
357 | | .resize(offset_read_bytes + len + extra_read_data_bytes, 0); |
358 | | // fill source data with random bytes |
359 | | rng.try_fill_bytes(self.data.as_mut_slice()).unwrap(); |
360 | | self.offset_read = offset_read_bits; |
361 | | |
362 | | self.len = len; |
363 | | |
364 | | // generated expectated output (not efficient) |
365 | | self.expected_data.resize(self.write_data.len(), 0); |
366 | | self.expected_data.copy_from_slice(&self.write_data); |
367 | | |
368 | | self.expected_null_count = 0; |
369 | | for i in 0..self.len { |
370 | | let bit = get_bit(&self.data, self.offset_read + i); |
371 | | if bit { |
372 | | set_bit(&mut self.expected_data, self.offset_write + i); |
373 | | } else { |
374 | | unset_bit(&mut self.expected_data, self.offset_write + i); |
375 | | self.expected_null_count += 1; |
376 | | } |
377 | | } |
378 | | } |
379 | | |
380 | | /// call set_bits with the given parameters and compare with the expected output |
381 | | fn verify(&self) { |
382 | | // call set_bits and compare |
383 | | let mut actual = self.write_data.to_vec(); |
384 | | let null_count = set_bits( |
385 | | &mut actual, |
386 | | &self.data, |
387 | | self.offset_write, |
388 | | self.offset_read, |
389 | | self.len, |
390 | | ); |
391 | | |
392 | | assert_eq!(actual, self.expected_data, "self: {self}"); |
393 | | assert_eq!(null_count, self.expected_null_count, "self: {self}"); |
394 | | } |
395 | | } |
396 | | |
397 | | #[test] |
398 | | fn test_set_upto_64bits() { |
399 | | // len >= 64 |
400 | | let write_data: &mut [u8] = &mut [0; 9]; |
401 | | let data: &[u8] = &[ |
402 | | 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, |
403 | | 0b00000001, 0b00000001, |
404 | | ]; |
405 | | let offset_write = 1; |
406 | | let offset_read = 0; |
407 | | let len = 65; |
408 | | let (n, len_set) = |
409 | | unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; |
410 | | assert_eq!(n, 55); |
411 | | assert_eq!(len_set, 63); |
412 | | assert_eq!( |
413 | | write_data, |
414 | | &[ |
415 | | 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, |
416 | | 0b00000010, 0b00000000 |
417 | | ] |
418 | | ); |
419 | | |
420 | | // len = 1 |
421 | | let write_data: &mut [u8] = &mut [0b00000000]; |
422 | | let data: &[u8] = &[0b00000001]; |
423 | | let offset_write = 1; |
424 | | let offset_read = 0; |
425 | | let len = 1; |
426 | | let (n, len_set) = |
427 | | unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; |
428 | | assert_eq!(n, 0); |
429 | | assert_eq!(len_set, 1); |
430 | | assert_eq!(write_data, &[0b00000010]); |
431 | | } |
432 | | } |