/Users/andrewlamb/Software/arrow-rs/arrow-buffer/src/bigint/div.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! N-digit division |
19 | | //! |
20 | | //! Implementation heavily inspired by [uint] |
21 | | //! |
22 | | //! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844 |
23 | | |
24 | | /// Unsigned, little-endian, n-digit division with remainder |
25 | | /// |
26 | | /// # Panics |
27 | | /// |
28 | | /// Panics if divisor is zero |
29 | 0 | pub fn div_rem<const N: usize>(numerator: &[u64; N], divisor: &[u64; N]) -> ([u64; N], [u64; N]) { |
30 | 0 | let numerator_bits = bits(numerator); |
31 | 0 | let divisor_bits = bits(divisor); |
32 | 0 | assert_ne!(divisor_bits, 0, "division by zero"); |
33 | | |
34 | 0 | if numerator_bits < divisor_bits { |
35 | 0 | return ([0; N], *numerator); |
36 | 0 | } |
37 | | |
38 | 0 | if divisor_bits <= 64 { |
39 | 0 | return div_rem_small(numerator, divisor[0]); |
40 | 0 | } |
41 | | |
42 | 0 | let numerator_words = numerator_bits.div_ceil(64); |
43 | 0 | let divisor_words = divisor_bits.div_ceil(64); |
44 | 0 | let n = divisor_words; |
45 | 0 | let m = numerator_words - divisor_words; |
46 | | |
47 | 0 | div_rem_knuth(numerator, divisor, n, m) |
48 | 0 | } |
49 | | |
50 | | /// Return the least number of bits needed to represent the number |
51 | 0 | fn bits(arr: &[u64]) -> usize { |
52 | 0 | for (idx, v) in arr.iter().enumerate().rev() { |
53 | 0 | if *v > 0 { |
54 | 0 | return 64 - v.leading_zeros() as usize + 64 * idx; |
55 | 0 | } |
56 | | } |
57 | 0 | 0 |
58 | 0 | } |
59 | | |
60 | | /// Division of numerator by a u64 divisor |
61 | 0 | fn div_rem_small<const N: usize>(numerator: &[u64; N], divisor: u64) -> ([u64; N], [u64; N]) { |
62 | 0 | let mut rem = 0u64; |
63 | 0 | let mut numerator = *numerator; |
64 | 0 | numerator.iter_mut().rev().for_each(|d| { |
65 | 0 | let (q, r) = div_rem_word(rem, *d, divisor); |
66 | 0 | *d = q; |
67 | 0 | rem = r; |
68 | 0 | }); |
69 | | |
70 | 0 | let mut rem_padded = [0; N]; |
71 | 0 | rem_padded[0] = rem; |
72 | 0 | (numerator, rem_padded) |
73 | 0 | } |
74 | | |
75 | | /// Use Knuth Algorithm D to compute `numerator / divisor` returning the |
76 | | /// quotient and remainder |
77 | | /// |
78 | | /// `n` is the number of non-zero 64-bit words in `divisor` |
79 | | /// `m` is the number of non-zero 64-bit words present in `numerator` beyond `divisor`, and |
80 | | /// therefore the number of words in the quotient |
81 | | /// |
82 | | /// A good explanation of the algorithm can be found [here](https://ridiculousfish.com/blog/posts/labor-of-division-episode-iv.html) |
83 | 0 | fn div_rem_knuth<const N: usize>( |
84 | 0 | numerator: &[u64; N], |
85 | 0 | divisor: &[u64; N], |
86 | 0 | n: usize, |
87 | 0 | m: usize, |
88 | 0 | ) -> ([u64; N], [u64; N]) { |
89 | 0 | assert!(n + m <= N); |
90 | | |
91 | | // The algorithm works by incrementally generating guesses `q_hat`, for the next digit |
92 | | // of the quotient, starting from the most significant digit. |
93 | | // |
94 | | // This relies on the property that for any `q_hat` where |
95 | | // |
96 | | // (q_hat << (j * 64)) * divisor <= numerator` |
97 | | // |
98 | | // We can set |
99 | | // |
100 | | // q += q_hat << (j * 64) |
101 | | // numerator -= (q_hat << (j * 64)) * divisor |
102 | | // |
103 | | // And then iterate until `numerator < divisor` |
104 | | |
105 | | // We normalize the divisor so that the highest bit in the highest digit of the |
106 | | // divisor is set, this ensures our initial guess of `q_hat` is at most 2 off from |
107 | | // the correct value for q[j] |
108 | 0 | let shift = divisor[n - 1].leading_zeros(); |
109 | | // As the shift is computed based on leading zeros, don't need to perform full_shl |
110 | 0 | let divisor = shl_word(divisor, shift); |
111 | | // numerator may have fewer leading zeros than divisor, so must add another digit |
112 | 0 | let mut numerator = full_shl(numerator, shift); |
113 | | |
114 | | // The two most significant digits of the divisor |
115 | 0 | let b0 = divisor[n - 1]; |
116 | 0 | let b1 = divisor[n - 2]; |
117 | | |
118 | 0 | let mut q = [0; N]; |
119 | | |
120 | 0 | for j in (0..=m).rev() { |
121 | 0 | let a0 = numerator[j + n]; |
122 | 0 | let a1 = numerator[j + n - 1]; |
123 | | |
124 | 0 | let mut q_hat = if a0 < b0 { |
125 | | // The first estimate is [a1, a0] / b0, it may be too large by at most 2 |
126 | 0 | let (mut q_hat, mut r_hat) = div_rem_word(a0, a1, b0); |
127 | | |
128 | | // r_hat = [a1, a0] - q_hat * b0 |
129 | | // |
130 | | // Now we want to compute a more precise estimate [a2,a1,a0] / [b1,b0] |
131 | | // which can only be less or equal to the current q_hat |
132 | | // |
133 | | // q_hat is too large if: |
134 | | // [a2,a1,a0] < q_hat * [b1,b0] |
135 | | // [a2,r_hat] < q_hat * b1 |
136 | 0 | let a2 = numerator[j + n - 2]; |
137 | | loop { |
138 | 0 | let r = u128::from(q_hat) * u128::from(b1); |
139 | 0 | let (lo, hi) = (r as u64, (r >> 64) as u64); |
140 | 0 | if (hi, lo) <= (r_hat, a2) { |
141 | 0 | break; |
142 | 0 | } |
143 | | |
144 | 0 | q_hat -= 1; |
145 | 0 | let (new_r_hat, overflow) = r_hat.overflowing_add(b0); |
146 | 0 | r_hat = new_r_hat; |
147 | | |
148 | 0 | if overflow { |
149 | 0 | break; |
150 | 0 | } |
151 | | } |
152 | 0 | q_hat |
153 | | } else { |
154 | 0 | u64::MAX |
155 | | }; |
156 | | |
157 | | // q_hat is now either the correct quotient digit, or in rare cases 1 too large |
158 | | |
159 | | // Compute numerator -= (q_hat * divisor) << (j * 64) |
160 | 0 | let q_hat_v = full_mul_u64(&divisor, q_hat); |
161 | 0 | let c = sub_assign(&mut numerator[j..], &q_hat_v[..n + 1]); |
162 | | |
163 | | // If underflow, q_hat was too large by 1 |
164 | 0 | if c { |
165 | 0 | // Reduce q_hat by 1 |
166 | 0 | q_hat -= 1; |
167 | 0 |
|
168 | 0 | // Add back one multiple of divisor |
169 | 0 | let c = add_assign(&mut numerator[j..], &divisor[..n]); |
170 | 0 | numerator[j + n] = numerator[j + n].wrapping_add(u64::from(c)); |
171 | 0 | } |
172 | | |
173 | | // q_hat is the correct value for q[j] |
174 | 0 | q[j] = q_hat; |
175 | | } |
176 | | |
177 | | // The remainder is what is left in numerator, with the initial normalization shl reversed |
178 | 0 | let remainder = full_shr(&numerator, shift); |
179 | 0 | (q, remainder) |
180 | 0 | } |
181 | | |
182 | | /// Perform narrowing division of a u128 by a u64 divisor, returning the quotient and remainder |
183 | | /// |
184 | | /// This method may trap or panic if hi >= divisor, i.e. the quotient would not fit |
185 | | /// into a 64-bit integer |
186 | 0 | fn div_rem_word(hi: u64, lo: u64, divisor: u64) -> (u64, u64) { |
187 | 0 | debug_assert!(hi < divisor); |
188 | 0 | debug_assert_ne!(divisor, 0); |
189 | | |
190 | | // LLVM fails to use the div instruction as it is not able to prove |
191 | | // that hi < divisor, and therefore the result will fit into 64-bits |
192 | | #[cfg(all(target_arch = "x86_64", not(miri)))] |
193 | | unsafe { |
194 | | let mut quot = lo; |
195 | | let mut rem = hi; |
196 | | std::arch::asm!( |
197 | | "div {divisor}", |
198 | | divisor = in(reg) divisor, |
199 | | inout("rax") quot, |
200 | | inout("rdx") rem, |
201 | | options(pure, nomem, nostack) |
202 | | ); |
203 | | (quot, rem) |
204 | | } |
205 | | #[cfg(any(not(target_arch = "x86_64"), miri))] |
206 | | { |
207 | 0 | let x = (u128::from(hi) << 64) + u128::from(lo); |
208 | 0 | let y = u128::from(divisor); |
209 | 0 | ((x / y) as u64, (x % y) as u64) |
210 | | } |
211 | 0 | } |
212 | | |
213 | | /// Perform `a += b` |
214 | 0 | fn add_assign(a: &mut [u64], b: &[u64]) -> bool { |
215 | 0 | binop_slice(a, b, u64::overflowing_add) |
216 | 0 | } |
217 | | |
218 | | /// Perform `a -= b` |
219 | 0 | fn sub_assign(a: &mut [u64], b: &[u64]) -> bool { |
220 | 0 | binop_slice(a, b, u64::overflowing_sub) |
221 | 0 | } |
222 | | |
223 | | /// Converts an overflowing binary operation on scalars to one on slices |
224 | 0 | fn binop_slice(a: &mut [u64], b: &[u64], binop: impl Fn(u64, u64) -> (u64, bool) + Copy) -> bool { |
225 | 0 | let mut c = false; |
226 | 0 | a.iter_mut().zip(b.iter()).for_each(|(x, y)| { |
227 | 0 | let (res1, overflow1) = y.overflowing_add(u64::from(c)); |
228 | 0 | let (res2, overflow2) = binop(*x, res1); |
229 | 0 | *x = res2; |
230 | 0 | c = overflow1 || overflow2; |
231 | 0 | }); |
232 | 0 | c |
233 | 0 | } |
234 | | |
235 | | /// Widening multiplication of an N-digit array with a u64 |
236 | 0 | fn full_mul_u64<const N: usize>(a: &[u64; N], b: u64) -> ArrayPlusOne<u64, N> { |
237 | 0 | let mut carry = 0; |
238 | 0 | let mut out = [0; N]; |
239 | 0 | out.iter_mut().zip(a).for_each(|(o, v)| { |
240 | 0 | let r = *v as u128 * b as u128 + carry as u128; |
241 | 0 | *o = r as u64; |
242 | 0 | carry = (r >> 64) as u64; |
243 | 0 | }); |
244 | 0 | ArrayPlusOne(out, carry) |
245 | 0 | } |
246 | | |
247 | | /// Left shift of an N-digit array by at most 63 bits |
248 | 0 | fn shl_word<const N: usize>(v: &[u64; N], shift: u32) -> [u64; N] { |
249 | 0 | full_shl(v, shift).0 |
250 | 0 | } |
251 | | |
252 | | /// Widening left shift of an N-digit array by at most 63 bits |
253 | 0 | fn full_shl<const N: usize>(v: &[u64; N], shift: u32) -> ArrayPlusOne<u64, N> { |
254 | 0 | debug_assert!(shift < 64); |
255 | 0 | if shift == 0 { |
256 | 0 | return ArrayPlusOne(*v, 0); |
257 | 0 | } |
258 | 0 | let mut out = [0u64; N]; |
259 | 0 | out[0] = v[0] << shift; |
260 | 0 | for i in 1..N { |
261 | 0 | out[i] = (v[i - 1] >> (64 - shift)) | (v[i] << shift) |
262 | | } |
263 | 0 | let carry = v[N - 1] >> (64 - shift); |
264 | 0 | ArrayPlusOne(out, carry) |
265 | 0 | } |
266 | | |
267 | | /// Narrowing right shift of an (N+1)-digit array by at most 63 bits |
268 | 0 | fn full_shr<const N: usize>(a: &ArrayPlusOne<u64, N>, shift: u32) -> [u64; N] { |
269 | 0 | debug_assert!(shift < 64); |
270 | 0 | if shift == 0 { |
271 | 0 | return a.0; |
272 | 0 | } |
273 | 0 | let mut out = [0; N]; |
274 | 0 | for i in 0..N - 1 { |
275 | 0 | out[i] = (a[i] >> shift) | (a[i + 1] << (64 - shift)) |
276 | | } |
277 | 0 | out[N - 1] = a[N - 1] >> shift; |
278 | 0 | out |
279 | 0 | } |
280 | | |
281 | | /// An array of N + 1 elements |
282 | | /// |
283 | | /// This is a hack around lack of support for const arithmetic |
284 | | #[repr(C)] |
285 | | struct ArrayPlusOne<T, const N: usize>([T; N], T); |
286 | | |
287 | | impl<T, const N: usize> std::ops::Deref for ArrayPlusOne<T, N> { |
288 | | type Target = [T]; |
289 | | |
290 | | #[inline] |
291 | 0 | fn deref(&self) -> &Self::Target { |
292 | 0 | let x = self as *const Self; |
293 | 0 | unsafe { std::slice::from_raw_parts(x as *const T, N + 1) } |
294 | 0 | } |
295 | | } |
296 | | |
297 | | impl<T, const N: usize> std::ops::DerefMut for ArrayPlusOne<T, N> { |
298 | 0 | fn deref_mut(&mut self) -> &mut Self::Target { |
299 | 0 | let x = self as *mut Self; |
300 | 0 | unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) } |
301 | 0 | } |
302 | | } |