Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 47 additions & 60 deletions src/biguint/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,23 @@ mod arith {

struct Arith<const P: u64> {}
impl<const P: u64> Arith<P> {
pub const FACTOR_TWO: u32 = (P-1).trailing_zeros();
pub const FACTOR_THREE: u32 = Self::factor_three();
pub const FACTOR_FIVE: u32 = Self::factor_five();
pub const MAX_NTT_LEN: u64 = Self::max_ntt_len();
pub const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P
pub const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P
pub const R4: u64 = Self::mpowmod(Self::R2, 3); // R^4 mod P
pub const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64
pub const ROOTR: u64 = Self::ntt_root_r(); // ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN)
const fn factor_three() -> u32 {
let (mut tmp, mut ans) = (P-1, 0);
while tmp % 3 == 0 { tmp /= 3; ans += 1; }
ans
}
const fn factor_five() -> u32 {
let (mut tmp, mut ans) = (P-1, 0);
while tmp % 5 == 0 { tmp /= 5; ans += 1; }
ans
}
const fn max_ntt_len() -> u64 {
let ans = 2u64.pow(Self::FACTOR_TWO) * 3u64.pow(Self::FACTOR_THREE) * 5u64.pow(Self::FACTOR_FIVE);
assert!(ans % 4050 == 0);
ans
}
const fn ntt_root_r() -> u64 {
const R: u64 = ((1u128 << 64) % P as u128) as u64; // 2^64 mod P
const R2: u64 = (Self::R as u128 * Self::R as u128 % P as u128) as u64; // R^2 mod P
const R4: u64 = Self::mpowmod(Self::R2, 3); // R^4 mod P
const PINV: u64 = arith::invmod(P, 0); // P^-1 mod 2^64
const MAX_NTT_LEN: u64 = 2u64.pow(Self::factors(2)) * 3u64.pow(Self::factors(3)) * 5u64.pow(Self::factors(5));
const ROOTR: u64 = {
// ROOT * R mod P (ROOT: MultiplicativeOrder[ROOT, P] == MAX_NTT_LEN)
assert!(Self::MAX_NTT_LEN % 4050 == 0);
let mut p = 2;
'outer: loop {
let root = Self::powmod_naive(p, P/Self::MAX_NTT_LEN);
let mut j = 0;
while j <= Self::FACTOR_TWO {
while j <= Self::factors(2) {
let mut k = 0;
while k <= Self::FACTOR_THREE {
while k <= Self::factors(3) {
let mut l = 0;
while l <= Self::FACTOR_FIVE {
while l <= Self::factors(5) {
let exponent = 2u64.pow(j) * 3u64.pow(k) * 5u64.pow(l);
if exponent < Self::MAX_NTT_LEN && Self::powmod_naive(root, exponent) == 1 {
p += 1;
Expand All @@ -82,59 +65,63 @@ impl<const P: u64> Arith<P> {
}
break Self::mmulmod(Self::R2, root)
}
};
// Counts the number of `divisor` factors in P-1.
const fn factors(divisor: u64) -> u32 {
let (mut tmp, mut ans) = (P-1, 0);
while tmp % divisor == 0 { tmp /= divisor; ans += 1; }
ans
}
// Computes base^exponent mod P
const fn powmod_naive(base: u64, exponent: u64) -> u64 {
const fn powmod_naive(base: u64, mut exponent: u64) -> u64 {
let mut cur = 1;
let mut pow = base as u128;
let mut p = exponent;
while p > 0 {
if p % 2 > 0 {
while exponent > 0 {
if exponent % 2 > 0 {
cur = (cur * pow) % P as u128;
}
p /= 2;
exponent /= 2;
pow = (pow * pow) % P as u128;
}
cur as u64
}
// Montgomery reduction:
// x * R^-1 mod P
pub const fn mreduce(x: u128) -> u64 {
const fn mreduce(x: u128) -> u64 {
let m = (x as u64).wrapping_mul(Self::PINV);
let y = ((m as u128 * P as u128) >> 64) as u64;
let y = (m as u128 * P as u128 >> 64) as u64;
let (out, overflow) = ((x >> 64) as u64).overflowing_sub(y);
if overflow { out.wrapping_add(P) } else { out }
}
// Multiplication with Montgomery reduction:
// a * b * R^-1 mod P
pub const fn mmulmod(a: u64, b: u64) -> u64 {
const fn mmulmod(a: u64, b: u64) -> u64 {
Self::mreduce(a as u128 * b as u128)
}
// Multiplication with Montgomery reduction:
// a * b * R^-1 mod P
// This function only applies the multiplication when INV && TWIDDLE,
// otherwise it just returns b.
pub const fn mmulmod_invtw<const INV: bool, const TWIDDLE: bool>(a: u64, b: u64) -> u64 {
const fn mmulmod_invtw<const INV: bool, const TWIDDLE: bool>(a: u64, b: u64) -> u64 {
if INV && TWIDDLE { Self::mmulmod(a, b) } else { b }
}
// Fused-multiply-sub with Montgomery reduction:
// a * b * R^-1 - c mod P
pub const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 {
const fn mmulsubmod(a: u64, b: u64, c: u64) -> u64 {
let x = a as u128 * b as u128;
let lo = x as u64;
let hi = Self::submod((x >> 64) as u64, c);
Self::mreduce(lo as u128 | ((hi as u128) << 64))
}
// Computes base^exponent mod P with Montgomery reduction
pub const fn mpowmod(base: u64, exponent: u64) -> u64 {
const fn mpowmod(base: u64, mut exponent: u64) -> u64 {
let mut cur = Self::R;
let mut pow = base;
let mut p = exponent;
while p > 0 {
if p % 2 > 0 {
while exponent > 0 {
if exponent % 2 > 0 {
cur = Self::mmulmod(cur, pow);
}
p /= 2;
exponent /= 2;
pow = Self::mmulmod(pow, pow);
}
cur
Expand All @@ -143,32 +130,32 @@ impl<const P: u64> Arith<P> {
// using d: u64 = mmulmod(P-1, c).
// It is caller's responsibility to ensure that d is correct.
// Note that d can be computed by calling mreducelo(c).
pub const fn mmulmod_noreduce(v: u128, c: u64, d: u64) -> u128 {
const fn mmulmod_noreduce(v: u128, c: u64, d: u64) -> u128 {
let a: u128 = c as u128 * (v >> 64);
let b: u128 = d as u128 * (v as u64 as u128);
let (w, overflow) = a.overflowing_sub(b);
if overflow { w.wrapping_add((P as u128) << 64) } else { w }
}
// Computes submod(0, mreduce(x as u128)) fast.
pub const fn mreducelo(x: u64) -> u64 {
const fn mreducelo(x: u64) -> u64 {
let m = x.wrapping_mul(Self::PINV);
((m as u128 * P as u128) >> 64) as u64
(m as u128 * P as u128 >> 64) as u64
}
// Computes a + b mod P, output range [0, P)
pub const fn addmod(a: u64, b: u64) -> u64 {
const fn addmod(a: u64, b: u64) -> u64 {
Self::submod(a, P.wrapping_sub(b))
}
// Computes a + b mod P, output range [0, 2^64)
pub const fn addmod64(a: u64, b: u64) -> u64 {
const fn addmod64(a: u64, b: u64) -> u64 {
let (out, overflow) = a.overflowing_add(b);
if overflow { out.wrapping_sub(P) } else { out }
}
// Computes a + b mod P, selects addmod64 or addmod depending on INV && TWIDDLE
pub const fn addmodopt_invtw<const INV: bool, const TWIDDLE: bool>(a: u64, b: u64) -> u64 {
const fn addmodopt_invtw<const INV: bool, const TWIDDLE: bool>(a: u64, b: u64) -> u64 {
if INV && TWIDDLE { Self::addmod64(a, b) } else { Self::addmod(a, b) }
}
// Computes a - b mod P, output range [0, P)
pub const fn submod(a: u64, b: u64) -> u64 {
const fn submod(a: u64, b: u64) -> u64 {
let (out, overflow) = a.overflowing_sub(b);
if overflow { out.wrapping_add(P) } else { out }
}
Expand All @@ -183,16 +170,16 @@ struct NttPlan {
pub s_list: Vec<(usize, usize)>,
}
impl NttPlan {
pub fn build<const P: u64>(min_len: usize) -> Self {
fn build<const P: u64>(min_len: usize) -> Self {
assert!(min_len as u64 <= Arith::<P>::MAX_NTT_LEN);
let (mut len_max, mut len_max_cost, mut g) = (usize::MAX, usize::MAX, 1);
for m7 in 0..=1 {
for m5 in 0..=Arith::<P>::FACTOR_FIVE {
for m3 in 0..=Arith::<P>::FACTOR_THREE {
for m5 in 0..=Arith::<P>::factors(5) {
for m3 in 0..=Arith::<P>::factors(3) {
let len = 7u64.pow(m7) * 5u64.pow(m5) * 3u64.pow(m3);
if len >= 2 * min_len as u64 { break; }
let (mut len, mut m2) = (len as usize, 0);
while len < min_len && m2 < Arith::<P>::FACTOR_TWO { len *= 2; m2 += 1; }
while len < min_len && m2 < Arith::<P>::factors(2) { len *= 2; m2 += 1; }
if len >= min_len && len < len_max_cost {
let (mut tmp, mut cost) = (len, 0);
let mut g_new = 1;
Expand Down Expand Up @@ -254,19 +241,19 @@ impl NttPlan {
fn conv_base<const P: u64>(n: usize, x: *mut u64, y: *mut u64, c: u64) {
unsafe {
let c2 = Arith::<P>::mreducelo(c);
let out = x.wrapping_sub(n);
let out = x.sub(n);
for i in 0..n {
let mut v: u128 = 0;
for j in i+1..n {
let (w, overflow) = v.overflowing_sub(*x.wrapping_add(j) as u128 * *y.wrapping_add(i+n-j) as u128);
let (w, overflow) = v.overflowing_sub(*x.add(j) as u128 * *y.add(i+n-j) as u128);
v = if overflow { w.wrapping_add((P as u128) << 64) } else { w };
}
v = Arith::<P>::mmulmod_noreduce(v, c, c2);
for j in 0..=i {
let (w, overflow) = v.overflowing_sub(*x.wrapping_add(j) as u128 * *y.wrapping_add(i-j) as u128);
let (w, overflow) = v.overflowing_sub(*x.add(j) as u128 * *y.add(i-j) as u128);
v = if overflow { w.wrapping_add((P as u128) << 64) } else { w };
}
*out.wrapping_add(i) = Arith::<P>::mreduce(v);
*out.add(i) = Arith::<P>::mreduce(v);
}
}
}
Expand Down Expand Up @@ -614,7 +601,7 @@ const P1INV_R_MOD_P2: u64 = Arith::<P2>::mmulmod(Arith::<P2>::R2, arith::invmod(
const P1P2INV_R_MOD_P3: u64 = Arith::<P3>::mmulmod(Arith::<P3>::R2, arith::invmod((P1 as u128 * P2 as u128 % P3 as u128) as u64, P3));
const P1_R_MOD_P3: u64 = Arith::<P3>::mmulmod(Arith::<P3>::R2, P1);
const P1P2_LO: u64 = (P1 as u128 * P2 as u128) as u64;
const P1P2_HI: u64 = ((P1 as u128 * P2 as u128) >> 64) as u64;
const P1P2_HI: u64 = (P1 as u128 * P2 as u128 >> 64) as u64;

// Propagates carry from the beginning to the end of acc,
// and returns the resulting carry if it is nonzero.
Expand Down Expand Up @@ -642,7 +629,7 @@ fn mac3_two_primes(acc: &mut [u64], b: &[u64], c: &[u64], bits: u64) {
if p + q >= bits {
unsafe { let out = x & mask; *pdst1 = out; *pdst2 = out; }
x = 0;
(pdst1, pdst2, k, p) = (pdst1.wrapping_add(1), pdst2.wrapping_add(1), k + bits - p, 0);
unsafe { (pdst1, pdst2, k, p) = (pdst1.add(1), pdst2.add(1), k + bits - p, 0); }
} else {
p += q;
break;
Expand Down