diff --git a/CHANGELOG.md b/CHANGELOG.md index 73985060..17f44655 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,12 +19,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - MSRV bumped to 1.85 ([#503]) - Made `*next_power_of_two` and `*next_multiple_of` `const` ([#533]) - Reimplemented `TryFrom` for `Uint` to speed it up, fixing edge cases and removing `std` requirements ([#524]) +- Make `mul*` functions `const` ([#449]) [#503]: https://github.com/recmo/uint/pull/503 [#516]: https://github.com/recmo/uint/pull/516 [#526]: https://github.com/recmo/uint/pull/526 [#533]: https://github.com/recmo/uint/pull/533 [#524]: https://github.com/recmo/uint/pull/524 +[#449]: https://github.com/recmo/uint/pull/449 ## [1.16.0] - 2025-08-04 diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 90c353b7..84619989 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -34,7 +34,7 @@ pub use self::{ }; pub(crate) trait DoubleWord: Sized + Copy { - /// `high << 64 | low` + /// `high << T::BITS | low` fn join(high: T, low: T) -> Self; /// `(low, high)` fn split(self) -> (T, T); @@ -42,11 +42,6 @@ pub(crate) trait DoubleWord: Sized + Copy { /// `a * b + c + d` fn muladd2(a: T, b: T, c: T, d: T) -> Self; - /// `a + b` - #[inline(always)] - fn add(a: T, b: T) -> Self { - Self::muladd2(T::default(), T::default(), a, b) - } /// `a * b` #[inline(always)] fn mul(a: T, b: T) -> Self { @@ -94,6 +89,64 @@ impl DoubleWord for u128 { } } +#[derive(Clone, Copy)] +struct ConstDoubleWord(T); + +impl ConstDoubleWord { + #[inline(always)] + const fn ext(a: u64) -> u128 { + a as u128 + } + + #[inline(always)] + #[allow(dead_code)] + const fn join(high: u64, low: u64) -> Self { + Self(Self::ext(high) << 64 | Self::ext(low)) + } + + #[inline(always)] + const fn split(self) -> (u64, u64) { + (self.low(), self.high()) + } + + /// `a + b` + #[inline(always)] + const fn add(a: u64, b: u64) -> Self { + Self(Self::ext(a) + Self::ext(b)) + } + + /// `a * b + c` + #[inline(always)] + const fn carrying_mul(a: u64, b: u64, c: u64) -> Self { + Self::carrying_mul_add(a, b, c, 0) + } + + /// `a * b + c + d` + #[inline(always)] + const fn carrying_mul_add(a: u64, b: u64, c: u64, d: u64) -> Self { + #[cfg(feature = "nightly")] + { + let (low, high) = u64::carrying_mul_add(a, b, c, d); + Self::join(high, low) + } + #[cfg(not(feature = "nightly"))] + { + Self(Self::ext(a) * Self::ext(b) + Self::ext(c) + Self::ext(d)) + } + } + + #[inline(always)] + const fn high(self) -> u64 { + (self.0 >> 64) as u64 + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + const fn low(self) -> u64 { + self.0 as u64 + } +} + /// ⚠️ Compare two limb slices in reverse order. #[doc = crate::algorithms::unstable_warning!()] /// Assumes that if the slices are of different length, the longer slice is diff --git a/src/algorithms/mul.rs b/src/algorithms/mul.rs index 47a014f2..d94367be 100644 --- a/src/algorithms/mul.rs +++ b/src/algorithms/mul.rs @@ -1,6 +1,6 @@ #![allow(clippy::module_name_repetitions)] -use crate::algorithms::{borrowing_sub, DoubleWord}; +use crate::algorithms::{borrowing_sub, ConstDoubleWord as DW}; /// ⚠️ Computes `result += a * b` and checks for overflow. #[doc = crate::algorithms::unstable_warning!()] @@ -21,7 +21,7 @@ use crate::algorithms::{borrowing_sub, DoubleWord}; /// assert_eq!(result, [12]); /// ``` #[inline(always)] -pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool { +pub const fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool { // Trim zeros from `a` while let [0, rest @ ..] = a { a = rest; @@ -53,8 +53,10 @@ pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool { let (a, b) = if b.len() > a.len() { (b, a) } else { (a, b) }; // Iterate over limbs of `b` and add partial products to `lhs`. + let mut i = 0; let mut overflow = false; - for &b in b { + while i < b.len() { + let b = b[i]; if lhs.len() >= a.len() { let (target, rest) = lhs.split_at_mut(a.len()); let carry = addmul_nx1(target, a, b); @@ -65,9 +67,10 @@ pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool { if lhs.is_empty() { break; } - addmul_nx1(lhs, &a[..lhs.len()], b); + addmul_nx1(lhs, a.split_at(lhs.len()).0, b); } - lhs = &mut lhs[1..]; + lhs = lhs.split_at_mut(1).1; + i += 1; } overflow } @@ -79,7 +82,7 @@ const ADDMUL_N_SMALL_LIMIT: usize = 8; #[doc = crate::algorithms::unstable_warning!()] /// See [`addmul`] for more details. #[inline(always)] -pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) { +pub const fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) { let n = lhs.len(); if n <= ADDMUL_N_SMALL_LIMIT && a.len() == n && b.len() == n { addmul_n_small(lhs, a, b); @@ -89,32 +92,38 @@ pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) { } #[inline(always)] -fn addmul_n_small(lhs: &mut [u64], a: &[u64], b: &[u64]) { +const fn addmul_n_small(lhs: &mut [u64], a: &[u64], b: &[u64]) { let n = lhs.len(); assume!(n <= ADDMUL_N_SMALL_LIMIT); assume!(a.len() == n); assume!(b.len() == n); - for j in 0..n { + let mut j = 0; + while j < n { let mut carry = 0; - for i in 0..(n - j) { - (lhs[j + i], carry) = u128::muladd2(a[i], b[j], carry, lhs[j + i]).split(); + let mut i = 0; + while i < (n - j) { + (lhs[j + i], carry) = DW::carrying_mul_add(a[i], b[j], carry, lhs[j + i]).split(); + i += 1; } + j += 1; } } /// ⚠️ Computes `lhs += a` and returns the carry. #[doc = crate::algorithms::unstable_warning!()] #[inline(always)] -pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 { +pub const fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 { if a == 0 { return 0; } - for lhs in lhs { - (*lhs, a) = u128::add(*lhs, a).split(); + let mut i = 0; + while i < lhs.len() { + (lhs[i], a) = DW::add(lhs[i], a).split(); if a == 0 { return 0; } + i += 1; } a } @@ -122,10 +131,12 @@ pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 { /// ⚠️ Computes `lhs *= a` and returns the carry. #[doc = crate::algorithms::unstable_warning!()] #[inline(always)] -pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { +pub const fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { let mut carry = 0; - for lhs in lhs { - (*lhs, carry) = u128::muladd(*lhs, a, carry).split(); + let mut i = 0; + while i < lhs.len() { + (lhs[i], carry) = DW::carrying_mul(lhs[i], a, carry).split(); + i += 1; } carry } @@ -141,11 +152,13 @@ pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { /// }{2^{64⋅N}}} \end{aligned} /// $$ #[inline(always)] -pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { +pub const fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { assume!(lhs.len() == a.len()); let mut carry = 0; - for i in 0..a.len() { - (lhs[i], carry) = u128::muladd2(a[i], b, carry, lhs[i]).split(); + let mut i = 0; + while i < a.len() { + (lhs[i], carry) = DW::carrying_mul_add(a[i], b, carry, lhs[i]).split(); + i += 1; } carry } @@ -162,17 +175,19 @@ pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { /// $$ // OPT: `carry` and `borrow` can probably be merged into a single var. #[inline(always)] -pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { +pub const fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { assume!(lhs.len() == a.len()); let mut carry = 0; let mut borrow = false; - for i in 0..a.len() { + let mut i = 0; + while i < a.len() { // Compute product limbs let limb; - (limb, carry) = u128::muladd(a[i], b, carry).split(); + (limb, carry) = DW::carrying_mul(a[i], b, carry).split(); // Subtract (lhs[i], borrow) = borrowing_sub(lhs[i], limb, borrow); + i += 1; } borrow as u64 + carry } diff --git a/src/div.rs b/src/div.rs index daa5d683..8f2b0d25 100644 --- a/src/div.rs +++ b/src/div.rs @@ -5,7 +5,6 @@ impl Uint { /// Computes `self / rhs`, returning [`None`] if `rhs == 0`. #[inline] #[must_use] - #[allow(clippy::missing_const_for_fn)] // False positive pub fn checked_div(self, rhs: Self) -> Option { if rhs.is_zero() { return None; @@ -16,7 +15,6 @@ impl Uint { /// Computes `self % rhs`, returning [`None`] if `rhs == 0`. #[inline] #[must_use] - #[allow(clippy::missing_const_for_fn)] // False positive pub fn checked_rem(self, rhs: Self) -> Option { if rhs.is_zero() { return None; diff --git a/src/macros.rs b/src/macros.rs index a8ff8fb1..7b27ad7e 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -92,7 +92,7 @@ macro_rules! assume { macro_rules! debug_unreachable { ($($t:tt)*) => { if cfg!(debug_assertions) { - unreachable!($($t)*); + panic!($($t)*); } else { unsafe { core::hint::unreachable_unchecked() }; } diff --git a/src/mul.rs b/src/mul.rs index 661fb548..fd322f32 100644 --- a/src/mul.rs +++ b/src/mul.rs @@ -9,7 +9,7 @@ impl Uint { /// Computes `self * rhs`, returning [`None`] if overflow occurred. #[inline(always)] #[must_use] - pub fn checked_mul(self, rhs: Self) -> Option { + pub const fn checked_mul(self, rhs: Self) -> Option { match self.overflowing_mul(rhs) { (value, false) => Some(value), _ => None, @@ -36,7 +36,7 @@ impl Uint { /// ``` #[inline] #[must_use] - pub fn overflowing_mul(self, rhs: Self) -> (Self, bool) { + pub const fn overflowing_mul(self, rhs: Self) -> (Self, bool) { let mut result = Self::ZERO; let mut overflow = algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs()); if Self::SHOULD_MASK { @@ -50,7 +50,7 @@ impl Uint { /// overflowing. #[inline(always)] #[must_use] - pub fn saturating_mul(self, rhs: Self) -> Self { + pub const fn saturating_mul(self, rhs: Self) -> Self { match self.overflowing_mul(rhs) { (value, false) => value, _ => Self::MAX, @@ -60,7 +60,7 @@ impl Uint { /// Computes `self * rhs`, wrapping around at the boundary of the type. #[inline(always)] #[must_use] - pub fn wrapping_mul(self, rhs: Self) -> Self { + pub const fn wrapping_mul(self, rhs: Self) -> Self { let mut result = Self::ZERO; algorithms::addmul_n(&mut result.limbs, self.as_limbs(), rhs.as_limbs()); result.apply_mask(); @@ -125,8 +125,7 @@ impl Uint { /// ``` #[inline] #[must_use] - #[allow(clippy::similar_names)] // Don't confuse `res` and `rhs`. - pub fn widening_mul< + pub const fn widening_mul< const BITS_RHS: usize, const LIMBS_RHS: usize, const BITS_RES: usize, @@ -135,14 +134,13 @@ impl Uint { self, rhs: Uint, ) -> Uint { - assert_eq!(BITS_RES, BITS + BITS_RHS); - assert_eq!(LIMBS_RES, nlimbs(BITS_RES)); + assert!(BITS_RES == BITS + BITS_RHS); + assert!(LIMBS_RES == nlimbs(BITS_RES)); let mut result = Uint::::ZERO; algorithms::addmul(&mut result.limbs, self.as_limbs(), rhs.as_limbs()); if LIMBS_RES > 0 { debug_assert!(result.limbs[LIMBS_RES - 1] <= Uint::::MASK); } - result } }