Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-buffer/src/bigint/div.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! N-digit division
19
//!
20
//! Implementation heavily inspired by [uint]
21
//!
22
//! [uint]: https://github.com/paritytech/parity-common/blob/d3a9327124a66e52ca1114bb8640c02c18c134b8/uint/src/uint.rs#L844
23
24
/// Unsigned, little-endian, n-digit division with remainder
25
///
26
/// # Panics
27
///
28
/// Panics if divisor is zero
29
0
pub fn div_rem<const N: usize>(numerator: &[u64; N], divisor: &[u64; N]) -> ([u64; N], [u64; N]) {
30
0
    let numerator_bits = bits(numerator);
31
0
    let divisor_bits = bits(divisor);
32
0
    assert_ne!(divisor_bits, 0, "division by zero");
33
34
0
    if numerator_bits < divisor_bits {
35
0
        return ([0; N], *numerator);
36
0
    }
37
38
0
    if divisor_bits <= 64 {
39
0
        return div_rem_small(numerator, divisor[0]);
40
0
    }
41
42
0
    let numerator_words = numerator_bits.div_ceil(64);
43
0
    let divisor_words = divisor_bits.div_ceil(64);
44
0
    let n = divisor_words;
45
0
    let m = numerator_words - divisor_words;
46
47
0
    div_rem_knuth(numerator, divisor, n, m)
48
0
}
49
50
/// Return the least number of bits needed to represent the number
51
0
fn bits(arr: &[u64]) -> usize {
52
0
    for (idx, v) in arr.iter().enumerate().rev() {
53
0
        if *v > 0 {
54
0
            return 64 - v.leading_zeros() as usize + 64 * idx;
55
0
        }
56
    }
57
0
    0
58
0
}
59
60
/// Division of numerator by a u64 divisor
61
0
fn div_rem_small<const N: usize>(numerator: &[u64; N], divisor: u64) -> ([u64; N], [u64; N]) {
62
0
    let mut rem = 0u64;
63
0
    let mut numerator = *numerator;
64
0
    numerator.iter_mut().rev().for_each(|d| {
65
0
        let (q, r) = div_rem_word(rem, *d, divisor);
66
0
        *d = q;
67
0
        rem = r;
68
0
    });
69
70
0
    let mut rem_padded = [0; N];
71
0
    rem_padded[0] = rem;
72
0
    (numerator, rem_padded)
73
0
}
74
75
/// Use Knuth Algorithm D to compute `numerator / divisor` returning the
76
/// quotient and remainder
77
///
78
/// `n` is the number of non-zero 64-bit words in `divisor`
79
/// `m` is the number of non-zero 64-bit words present in `numerator` beyond `divisor`, and
80
/// therefore the number of words in the quotient
81
///
82
/// A good explanation of the algorithm can be found [here](https://ridiculousfish.com/blog/posts/labor-of-division-episode-iv.html)
83
0
fn div_rem_knuth<const N: usize>(
84
0
    numerator: &[u64; N],
85
0
    divisor: &[u64; N],
86
0
    n: usize,
87
0
    m: usize,
88
0
) -> ([u64; N], [u64; N]) {
89
0
    assert!(n + m <= N);
90
91
    // The algorithm works by incrementally generating guesses `q_hat`, for the next digit
92
    // of the quotient, starting from the most significant digit.
93
    //
94
    // This relies on the property that for any `q_hat` where
95
    //
96
    //      (q_hat << (j * 64)) * divisor <= numerator`
97
    //
98
    // We can set
99
    //
100
    //      q += q_hat << (j * 64)
101
    //      numerator -= (q_hat << (j * 64)) * divisor
102
    //
103
    // And then iterate until `numerator < divisor`
104
105
    // We normalize the divisor so that the highest bit in the highest digit of the
106
    // divisor is set, this ensures our initial guess of `q_hat` is at most 2 off from
107
    // the correct value for q[j]
108
0
    let shift = divisor[n - 1].leading_zeros();
109
    // As the shift is computed based on leading zeros, don't need to perform full_shl
110
0
    let divisor = shl_word(divisor, shift);
111
    // numerator may have fewer leading zeros than divisor, so must add another digit
112
0
    let mut numerator = full_shl(numerator, shift);
113
114
    // The two most significant digits of the divisor
115
0
    let b0 = divisor[n - 1];
116
0
    let b1 = divisor[n - 2];
117
118
0
    let mut q = [0; N];
119
120
0
    for j in (0..=m).rev() {
121
0
        let a0 = numerator[j + n];
122
0
        let a1 = numerator[j + n - 1];
123
124
0
        let mut q_hat = if a0 < b0 {
125
            // The first estimate is [a1, a0] / b0, it may be too large by at most 2
126
0
            let (mut q_hat, mut r_hat) = div_rem_word(a0, a1, b0);
127
128
            // r_hat = [a1, a0] - q_hat * b0
129
            //
130
            // Now we want to compute a more precise estimate [a2,a1,a0] / [b1,b0]
131
            // which can only be less or equal to the current q_hat
132
            //
133
            // q_hat is too large if:
134
            // [a2,a1,a0] < q_hat * [b1,b0]
135
            // [a2,r_hat] < q_hat * b1
136
0
            let a2 = numerator[j + n - 2];
137
            loop {
138
0
                let r = u128::from(q_hat) * u128::from(b1);
139
0
                let (lo, hi) = (r as u64, (r >> 64) as u64);
140
0
                if (hi, lo) <= (r_hat, a2) {
141
0
                    break;
142
0
                }
143
144
0
                q_hat -= 1;
145
0
                let (new_r_hat, overflow) = r_hat.overflowing_add(b0);
146
0
                r_hat = new_r_hat;
147
148
0
                if overflow {
149
0
                    break;
150
0
                }
151
            }
152
0
            q_hat
153
        } else {
154
0
            u64::MAX
155
        };
156
157
        // q_hat is now either the correct quotient digit, or in rare cases 1 too large
158
159
        // Compute numerator -= (q_hat * divisor) << (j * 64)
160
0
        let q_hat_v = full_mul_u64(&divisor, q_hat);
161
0
        let c = sub_assign(&mut numerator[j..], &q_hat_v[..n + 1]);
162
163
        // If underflow, q_hat was too large by 1
164
0
        if c {
165
0
            // Reduce q_hat by 1
166
0
            q_hat -= 1;
167
0
168
0
            // Add back one multiple of divisor
169
0
            let c = add_assign(&mut numerator[j..], &divisor[..n]);
170
0
            numerator[j + n] = numerator[j + n].wrapping_add(u64::from(c));
171
0
        }
172
173
        // q_hat is the correct value for q[j]
174
0
        q[j] = q_hat;
175
    }
176
177
    // The remainder is what is left in numerator, with the initial normalization shl reversed
178
0
    let remainder = full_shr(&numerator, shift);
179
0
    (q, remainder)
180
0
}
181
182
/// Perform narrowing division of a u128 by a u64 divisor, returning the quotient and remainder
183
///
184
/// This method may trap or panic if hi >= divisor, i.e. the quotient would not fit
185
/// into a 64-bit integer
186
0
fn div_rem_word(hi: u64, lo: u64, divisor: u64) -> (u64, u64) {
187
0
    debug_assert!(hi < divisor);
188
0
    debug_assert_ne!(divisor, 0);
189
190
    // LLVM fails to use the div instruction as it is not able to prove
191
    // that hi < divisor, and therefore the result will fit into 64-bits
192
    #[cfg(all(target_arch = "x86_64", not(miri)))]
193
    unsafe {
194
        let mut quot = lo;
195
        let mut rem = hi;
196
        std::arch::asm!(
197
            "div {divisor}",
198
            divisor = in(reg) divisor,
199
            inout("rax") quot,
200
            inout("rdx") rem,
201
            options(pure, nomem, nostack)
202
        );
203
        (quot, rem)
204
    }
205
    #[cfg(any(not(target_arch = "x86_64"), miri))]
206
    {
207
0
        let x = (u128::from(hi) << 64) + u128::from(lo);
208
0
        let y = u128::from(divisor);
209
0
        ((x / y) as u64, (x % y) as u64)
210
    }
211
0
}
212
213
/// Perform `a += b`
214
0
fn add_assign(a: &mut [u64], b: &[u64]) -> bool {
215
0
    binop_slice(a, b, u64::overflowing_add)
216
0
}
217
218
/// Perform `a -= b`
219
0
fn sub_assign(a: &mut [u64], b: &[u64]) -> bool {
220
0
    binop_slice(a, b, u64::overflowing_sub)
221
0
}
222
223
/// Converts an overflowing binary operation on scalars to one on slices
224
0
fn binop_slice(a: &mut [u64], b: &[u64], binop: impl Fn(u64, u64) -> (u64, bool) + Copy) -> bool {
225
0
    let mut c = false;
226
0
    a.iter_mut().zip(b.iter()).for_each(|(x, y)| {
227
0
        let (res1, overflow1) = y.overflowing_add(u64::from(c));
228
0
        let (res2, overflow2) = binop(*x, res1);
229
0
        *x = res2;
230
0
        c = overflow1 || overflow2;
231
0
    });
232
0
    c
233
0
}
234
235
/// Widening multiplication of an N-digit array with a u64
236
0
fn full_mul_u64<const N: usize>(a: &[u64; N], b: u64) -> ArrayPlusOne<u64, N> {
237
0
    let mut carry = 0;
238
0
    let mut out = [0; N];
239
0
    out.iter_mut().zip(a).for_each(|(o, v)| {
240
0
        let r = *v as u128 * b as u128 + carry as u128;
241
0
        *o = r as u64;
242
0
        carry = (r >> 64) as u64;
243
0
    });
244
0
    ArrayPlusOne(out, carry)
245
0
}
246
247
/// Left shift of an N-digit array by at most 63 bits
248
0
fn shl_word<const N: usize>(v: &[u64; N], shift: u32) -> [u64; N] {
249
0
    full_shl(v, shift).0
250
0
}
251
252
/// Widening left shift of an N-digit array by at most 63 bits
253
0
fn full_shl<const N: usize>(v: &[u64; N], shift: u32) -> ArrayPlusOne<u64, N> {
254
0
    debug_assert!(shift < 64);
255
0
    if shift == 0 {
256
0
        return ArrayPlusOne(*v, 0);
257
0
    }
258
0
    let mut out = [0u64; N];
259
0
    out[0] = v[0] << shift;
260
0
    for i in 1..N {
261
0
        out[i] = (v[i - 1] >> (64 - shift)) | (v[i] << shift)
262
    }
263
0
    let carry = v[N - 1] >> (64 - shift);
264
0
    ArrayPlusOne(out, carry)
265
0
}
266
267
/// Narrowing right shift of an (N+1)-digit array by at most 63 bits
268
0
fn full_shr<const N: usize>(a: &ArrayPlusOne<u64, N>, shift: u32) -> [u64; N] {
269
0
    debug_assert!(shift < 64);
270
0
    if shift == 0 {
271
0
        return a.0;
272
0
    }
273
0
    let mut out = [0; N];
274
0
    for i in 0..N - 1 {
275
0
        out[i] = (a[i] >> shift) | (a[i + 1] << (64 - shift))
276
    }
277
0
    out[N - 1] = a[N - 1] >> shift;
278
0
    out
279
0
}
280
281
/// An array of N + 1 elements
282
///
283
/// This is a hack around lack of support for const arithmetic
284
#[repr(C)]
285
struct ArrayPlusOne<T, const N: usize>([T; N], T);
286
287
impl<T, const N: usize> std::ops::Deref for ArrayPlusOne<T, N> {
288
    type Target = [T];
289
290
    #[inline]
291
0
    fn deref(&self) -> &Self::Target {
292
0
        let x = self as *const Self;
293
0
        unsafe { std::slice::from_raw_parts(x as *const T, N + 1) }
294
0
    }
295
}
296
297
impl<T, const N: usize> std::ops::DerefMut for ArrayPlusOne<T, N> {
298
0
    fn deref_mut(&mut self) -> &mut Self::Target {
299
0
        let x = self as *mut Self;
300
0
        unsafe { std::slice::from_raw_parts_mut(x as *mut T, N + 1) }
301
0
    }
302
}