diff --git a/src/biguint/ntt.rs b/src/biguint/ntt.rs index 060704da..d96bee93 100644 --- a/src/biguint/ntt.rs +++ b/src/biguint/ntt.rs @@ -35,40 +35,23 @@ mod arith { struct Arith {} impl Arith

{ - pub const FACTOR_TWO: u32 = (P-1).trailing_zeros(); - pub const FACTOR_THREE: u32 = Self::factor_three(); - pub const FACTOR_FIVE: u32 = Self::factor_five(); - pub const MAX_NTT_LEN: u64 = Self::max_ntt_len(); - pub const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P - pub const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P - pub const R4: u64 = Self::mpowmod(Self::R2, 3); // R^4 mod P - pub const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 - pub const ROOTR: u64 = Self::ntt_root_r(); // ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN) - const fn factor_three() -> u32 { - let (mut tmp, mut ans) = (P-1, 0); - while tmp % 3 == 0 { tmp /= 3; ans += 1; } - ans - } - const fn factor_five() -> u32 { - let (mut tmp, mut ans) = (P-1, 0); - while tmp % 5 == 0 { tmp /= 5; ans += 1; } - ans - } - const fn max_ntt_len() -> u64 { - let ans = 2u64.pow(Self::FACTOR_TWO) * 3u64.pow(Self::FACTOR_THREE) * 5u64.pow(Self::FACTOR_FIVE); - assert!(ans % 4050 == 0); - ans - } - const fn ntt_root_r() -> u64 { + const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P + const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P + const R4: u64 = Self::mpowmod(Self::R2, 3); // R^4 mod P + const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64 + const MAX_NTT_LEN: u64 = 2u64.pow(Self::factors(2)) * 3u64.pow(Self::factors(3)) * 5u64.pow(Self::factors(5)); + const ROOTR: u64 = { + // ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN) + assert!(Self::MAX_NTT_LEN % 4050 == 0); let mut p = 2; 'outer: loop { let root = Self::powmod_naive(p, P/Self::MAX_NTT_LEN); let mut j = 0; - while j <= Self::FACTOR_TWO { + while j <= Self::factors(2) { let mut k = 0; - while k <= Self::FACTOR_THREE { + while k <= Self::factors(3) { let mut l = 0; - while l <= Self::FACTOR_FIVE { + while l <= Self::factors(5) { let exponent = 2u64.pow(j) * 3u64.pow(k) * 5u64.pow(l); if exponent < Self::MAX_NTT_LEN && Self::powmod_naive(root, exponent) == 1 { p += 1; @@ -82,59 +65,63 @@ impl Arith

{ } break Self::mmulmod(Self::R2, root) } + }; + // Counts the number of `divisor` factors in P-1. + const fn factors(divisor: u64) -> u32 { + let (mut tmp, mut ans) = (P-1, 0); + while tmp % divisor == 0 { tmp /= divisor; ans += 1; } + ans } // Computes base^exponent mod P - const fn powmod_naive(base: u64, exponent: u64) -> u64 { + const fn powmod_naive(base: u64, mut exponent: u64) -> u64 { let mut cur = 1; let mut pow = base as u128; - let mut p = exponent; - while p > 0 { - if p % 2 > 0 { + while exponent > 0 { + if exponent % 2 > 0 { cur = (cur * pow) % P as u128; } - p /= 2; + exponent /= 2; pow = (pow * pow) % P as u128; } cur as u64 } // Montgomery reduction: // x * R^-1 mod P - pub const fn mreduce(x: u128) -> u64 { + const fn mreduce(x: u128) -> u64 { let m = (x as u64).wrapping_mul(Self::PINV); - let y = ((m as u128 * P as u128) >> 64) as u64; + let y = (m as u128 * P as u128 >> 64) as u64; let (out, overflow) = ((x >> 64) as u64).overflowing_sub(y); if overflow { out.wrapping_add(P) } else { out } } // Multiplication with Montgomery reduction: // a * b * R^-1 mod P - pub const fn mmulmod(a: u64, b: u64) -> u64 { + const fn mmulmod(a: u64, b: u64) -> u64 { Self::mreduce(a as u128 * b as u128) } // Multiplication with Montgomery reduction: // a * b * R^-1 mod P // This function only applies the multiplication when INV && TWIDDLE, // otherwise it just returns b. - pub const fn mmulmod_invtw(a: u64, b: u64) -> u64 { + const fn mmulmod_invtw(a: u64, b: u64) -> u64 { if INV && TWIDDLE { Self::mmulmod(a, b) } else { b } } // Fused-multiply-sub with Montgomery reduction: // a * b * R^-1 - c mod P - pub const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 { + const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 { let x = a as u128 * b as u128; let lo = x as u64; let hi = Self::submod((x >> 64) as u64, c); Self::mreduce(lo as u128 | ((hi as u128) << 64)) } // Computes base^exponent mod P with Montgomery reduction - pub const fn mpowmod(base: u64, exponent: u64) -> u64 { + const fn mpowmod(base: u64, mut exponent: u64) -> u64 { let mut cur = Self::R; let mut pow = base; - let mut p = exponent; - while p > 0 { - if p % 2 > 0 { + while exponent > 0 { + if exponent % 2 > 0 { cur = Self::mmulmod(cur, pow); } - p /= 2; + exponent /= 2; pow = Self::mmulmod(pow, pow); } cur @@ -143,32 +130,32 @@ impl Arith

{ // using d: u64 = mmulmod(P-1, c). // It is caller's responsibility to ensure that d is correct. // Note that d can be computed by calling mreducelo(c). - pub const fn mmulmod_noreduce(v: u128, c: u64, d: u64) -> u128 { + const fn mmulmod_noreduce(v: u128, c: u64, d: u64) -> u128 { let a: u128 = c as u128 * (v >> 64); let b: u128 = d as u128 * (v as u64 as u128); let (w, overflow) = a.overflowing_sub(b); if overflow { w.wrapping_add((P as u128) << 64) } else { w } } // Computes submod(0, mreduce(x as u128)) fast. - pub const fn mreducelo(x: u64) -> u64 { + const fn mreducelo(x: u64) -> u64 { let m = x.wrapping_mul(Self::PINV); - ((m as u128 * P as u128) >> 64) as u64 + (m as u128 * P as u128 >> 64) as u64 } // Computes a + b mod P, output range [0, P) - pub const fn addmod(a: u64, b: u64) -> u64 { + const fn addmod(a: u64, b: u64) -> u64 { Self::submod(a, P.wrapping_sub(b)) } // Computes a + b mod P, output range [0, 2^64) - pub const fn addmod64(a: u64, b: u64) -> u64 { + const fn addmod64(a: u64, b: u64) -> u64 { let (out, overflow) = a.overflowing_add(b); if overflow { out.wrapping_sub(P) } else { out } } // Computes a + b mod P, selects addmod64 or addmod depending on INV && TWIDDLE - pub const fn addmodopt_invtw(a: u64, b: u64) -> u64 { + const fn addmodopt_invtw(a: u64, b: u64) -> u64 { if INV && TWIDDLE { Self::addmod64(a, b) } else { Self::addmod(a, b) } } // Computes a - b mod P, output range [0, P) - pub const fn submod(a: u64, b: u64) -> u64 { + const fn submod(a: u64, b: u64) -> u64 { let (out, overflow) = a.overflowing_sub(b); if overflow { out.wrapping_add(P) } else { out } } @@ -183,16 +170,16 @@ struct NttPlan { pub s_list: Vec<(usize, usize)>, } impl NttPlan { - pub fn build(min_len: usize) -> Self { + fn build(min_len: usize) -> Self { assert!(min_len as u64 <= Arith::

::MAX_NTT_LEN); let (mut len_max, mut len_max_cost, mut g) = (usize::MAX, usize::MAX, 1); for m7 in 0..=1 { - for m5 in 0..=Arith::

::FACTOR_FIVE { - for m3 in 0..=Arith::

::FACTOR_THREE { + for m5 in 0..=Arith::

::factors(5) { + for m3 in 0..=Arith::

::factors(3) { let len = 7u64.pow(m7) * 5u64.pow(m5) * 3u64.pow(m3); if len >= 2 * min_len as u64 { break; } let (mut len, mut m2) = (len as usize, 0); - while len < min_len && m2 < Arith::

::FACTOR_TWO { len *= 2; m2 += 1; } + while len < min_len && m2 < Arith::

::factors(2) { len *= 2; m2 += 1; } if len >= min_len && len < len_max_cost { let (mut tmp, mut cost) = (len, 0); let mut g_new = 1; @@ -254,19 +241,19 @@ impl NttPlan { fn conv_base(n: usize, x: *mut u64, y: *mut u64, c: u64) { unsafe { let c2 = Arith::

::mreducelo(c); - let out = x.wrapping_sub(n); + let out = x.sub(n); for i in 0..n { let mut v: u128 = 0; for j in i+1..n { - let (w, overflow) = v.overflowing_sub(*x.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128); + let (w, overflow) = v.overflowing_sub(*x.add(j) as u128 * *y.add(i+n-j) as u128); v = if overflow { w.wrapping_add((P as u128) << 64) } else { w }; } v = Arith::

::mmulmod_noreduce(v, c, c2); for j in 0..=i { - let (w, overflow) = v.overflowing_sub(*x.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128); + let (w, overflow) = v.overflowing_sub(*x.add(j) as u128 * *y.add(i-j) as u128); v = if overflow { w.wrapping_add((P as u128) << 64) } else { w }; } - *out.wrapping_add(i) = Arith::

::mreduce(v); + *out.add(i) = Arith::

::mreduce(v); } } } @@ -614,7 +601,7 @@ const P1INV_R_MOD_P2: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod( const P1P2INV_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, arith::invmod((P1 as u128 * P2 as u128 % P3 as u128) as u64, P3)); const P1_R_MOD_P3: u64 = Arith::::mmulmod(Arith::::R2, P1); const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64; -const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64; +const P1P2_HI: u64 = (P1 as u128 * P2 as u128 >> 64) as u64; // Propagates carry from the beginning to the end of acc, // and returns the resulting carry if it is nonzero. @@ -642,7 +629,7 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) { if p + q >= bits { unsafe { let out = x & mask; *pdst1 = out; *pdst2 = out; } x = 0; - (pdst1, pdst2, k, p) = (pdst1.wrapping_add(1), pdst2.wrapping_add(1), k + bits - p, 0); + unsafe { (pdst1, pdst2, k, p) = (pdst1.add(1), pdst2.add(1), k + bits - p, 0); } } else { p += q; break;