diff --git a/README.md b/README.md index fa11a9c3f..83f48ee64 100644 --- a/README.md +++ b/README.md @@ -239,7 +239,7 @@ $ java -cp classes com.williamfiset.algorithms.search.BinarySearch # Mathematics -- [[UNTESTED] Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java) +- [Chinese remainder theorem](src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java) - [Prime number sieve (sieve of Eratosthenes)](src/main/java/com/williamfiset/algorithms/math/SieveOfEratosthenes.java) **- O(nlog(log(n)))** - [Prime number sieve (sieve of Eratosthenes, compressed)](src/main/java/com/williamfiset/algorithms/math/CompressedPrimeSieve.java) **- O(nlog(log(n)))** - [Totient function (phi function, relatively prime number count)](src/main/java/com/williamfiset/algorithms/math/EulerTotientFunction.java) **- O(√n)** @@ -248,9 +248,9 @@ $ java -cp classes com.williamfiset.algorithms.search.BinarySearch - [Fast Fourier transform (quick polynomial multiplication)](src/main/java/com/williamfiset/algorithms/math/FastFourierTransform.java) **- O(nlog(n))** - [Fast Fourier transform (quick polynomial multiplication, complex numbers)](src/main/java/com/williamfiset/algorithms/math/FastFourierTransformComplexNumbers.java) **- O(nlog(n))** - [Primality check](src/main/java/com/williamfiset/algorithms/math/PrimalityCheck.java) **- O(√n)** -- [Primality check (Rabin-Miller)](src/main/java/com/williamfiset/algorithms/math/RabinMillerPrimalityTest.py) **- O(k)** - [Least Common Multiple (LCM)](src/main/java/com/williamfiset/algorithms/math/Lcm.java) **- ~O(log(a + b))** - [Modular inverse](src/main/java/com/williamfiset/algorithms/math/ModularInverse.java) **- ~O(log(a + b))** +- [Modular exponentiation](src/main/java/com/williamfiset/algorithms/math/ModPow.java) **- O(log(n))** - [Prime factorization (pollard rho)](src/main/java/com/williamfiset/algorithms/math/PrimeFactorization.java) **- O(n1/4)** - [Relatively prime check (coprimality check)](src/main/java/com/williamfiset/algorithms/math/RelativelyPrime.java) **- ~O(log(a + b))** diff --git a/src/main/java/com/williamfiset/algorithms/math/BUILD b/src/main/java/com/williamfiset/algorithms/math/BUILD index d99722864..d7ed1b927 100644 --- a/src/main/java/com/williamfiset/algorithms/math/BUILD +++ b/src/main/java/com/williamfiset/algorithms/math/BUILD @@ -63,13 +63,6 @@ java_binary( runtime_deps = [":math"], ) -# bazel run //src/main/java/com/williamfiset/algorithms/math:NChooseRModPrime -java_binary( - name = "NChooseRModPrime", - main_class = "com.williamfiset.algorithms.math.NChooseRModPrime", - runtime_deps = [":math"], -) - # bazel run //src/main/java/com/williamfiset/algorithms/math:PrimeFactorization java_binary( name = "PrimeFactorization", diff --git a/src/main/java/com/williamfiset/algorithms/math/BinomialCoefficientModPrime.java b/src/main/java/com/williamfiset/algorithms/math/BinomialCoefficientModPrime.java new file mode 100644 index 000000000..c1bd75e61 --- /dev/null +++ b/src/main/java/com/williamfiset/algorithms/math/BinomialCoefficientModPrime.java @@ -0,0 +1,48 @@ +/** + * Computes the binomial coefficient C(n, r) mod p using Fermat's Little Theorem. + * + * Given a prime p, the binomial coefficient C(n, r) = n! / (r! * (n-r)!) can be computed modulo p + * by precomputing factorials mod p and using modular inverses for the denominator. Fermat's Little + * Theorem gives a^(p-1) ≡ 1 (mod p) for prime p, so the modular inverse of x is x^(p-2) mod p. + * Here we use the extended Euclidean algorithm via ModularInverse instead. + * + * Requires p to be prime so that modular inverses exist for all non-zero values mod p, and n < p + * so that factorials are non-zero mod p. + * + * Time Complexity: O(n) for factorial precomputation, O(log(p)) for each modular inverse. + * + * @author Rohit Mazumder, mazumder.rohit7@gmail.com + */ +package com.williamfiset.algorithms.math; + +public class BinomialCoefficientModPrime { + + /** + * Computes C(n, r) mod p. + * + * @param n total items (must be >= 0 and < p). + * @param r items to choose (must be >= 0 and <= n). + * @param p a prime modulus. + * @return C(n, r) mod p. + * @throws IllegalArgumentException if parameters are out of range. + */ + public static long compute(int n, int r, int p) { + if (n < 0 || r < 0 || r > n) + throw new IllegalArgumentException("Requires 0 <= r <= n, got n=" + n + ", r=" + r); + if (p <= 1) + throw new IllegalArgumentException("Modulus p must be > 1, got p=" + p); + + if (r == 0 || r == n) + return 1; + + long[] factorial = new long[n + 1]; + factorial[0] = 1; + for (int i = 1; i <= n; i++) + factorial[i] = factorial[i - 1] * i % p; + + return factorial[n] + % p * ModularInverse.modInv(factorial[r], p) + % p * ModularInverse.modInv(factorial[n - r], p) + % p; + } +} diff --git a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java index a00d29538..a8e237386 100644 --- a/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java +++ b/src/main/java/com/williamfiset/algorithms/math/ChineseRemainderTheorem.java @@ -1,20 +1,29 @@ /** - * Use the chinese remainder theorem to solve a set of congruence equations. + * Solve a set of congruence equations using the Chinese Remainder Theorem (CRT). * - *
The first method (eliminateCoefficient) is used to reduce an equation of the form cx≡a(mod - * m)cx≡a(mod m) to the form x≡a_new(mod m_new)x≡anew(mod m_new), which gets rids of the - * coefficient. A value of null is returned if the coefficient cannot be eliminated. + * Given a system of simultaneous congruences: * - *
The second method (reduce) is used to reduce a set of equations so that the moduli become - * pairwise co-prime (which means that we can apply the Chinese Remainder Theorem). The input and - * output are of the form x≡a_0(mod m_0),...,x≡a_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡a_n−1(mod - * m_n−1). Note that the number of equations may change during this process. A value of null is - * returned if the set of equations cannot be reduced to co-prime moduli. + * x ≡ a_0 (mod m_0) + * x ≡ a_1 (mod m_1) + * ... + * x ≡ a_{n-1} (mod m_{n-1}) * - *
The third method (crt) is the actual Chinese Remainder Theorem. It assumes that all pairs of
- * moduli are co-prime to one another. This solves a set of equations of the form x≡a_0(mod
- * m_0),...,x≡v_n−1(mod m_n−1)x≡a_0(mod m_0),...,x≡v_n−1(mod m_n−1). It's output is of the form
- * x≡a_new(mod m_new)x≡a_new(mod m_new).
+ * where all moduli m_i are pairwise coprime (gcd(m_i, m_j) = 1 for i ≠ j), the CRT guarantees a
+ * unique solution x modulo M = m_0 * m_1 * ... * m_{n-1}.
+ *
+ * The solution is constructed as x = sum of a_i * M_i * y_i (mod M), where M_i = M / m_i and y_i
+ * is the modular inverse of M_i modulo m_i (found via the extended Euclidean algorithm). Each term
+ * contributes a_i for the i-th congruence and vanishes (mod m_j) for all j ≠ i, so the sum
+ * satisfies every equation simultaneously.
+ *
+ * When moduli are not pairwise coprime, the system must first be reduced. Each modulus is split
+ * into prime-power factors (e.g. 12 = 4 * 3), converting one equation into several with
+ * prime-power moduli. Redundant equations are removed and conflicting ones detected. After
+ * reduction, the moduli are pairwise coprime and the standard CRT applies.
+ *
+ * The eliminateCoefficient method handles equations of the form cx ≡ a (mod m) by dividing through
+ * by gcd(c, m) — which is only possible when gcd(c, m) divides a — and then multiplying by the
+ * modular inverse of the reduced coefficient.
*
* @author Micah Stairs
*/
@@ -24,12 +33,16 @@
public class ChineseRemainderTheorem {
- // eliminateCoefficient() takes cx≡a(mod m) and gives x≡a_new(mod m_new).
+ /**
+ * Reduces cx ≡ a (mod m) to x ≡ a' (mod m').
+ *
+ * @return {a', m'} or null if unsolvable.
+ */
public static long[] eliminateCoefficient(long c, long a, long m) {
-
long d = egcd(c, m)[0];
- if (a % d != 0) return null;
+ if (a % d != 0)
+ return null;
c /= d;
a /= d;
@@ -42,36 +55,35 @@ public static long[] eliminateCoefficient(long c, long a, long m) {
return new long[] {a, m};
}
- // reduce() takes a set of equations and reduces them to an equivalent
- // set with pairwise co-prime moduli (or null if not solvable).
+ /**
+ * Reduces a system x ≡ a[i] (mod m[i]) to an equivalent system with pairwise coprime moduli.
+ *
+ * @return {a[], m[]} with coprime moduli, or null if the system is inconsistent.
+ */
public static long[][] reduce(long[] a, long[] m) {
+ List An implementation of the modPow(a, n, mod) operation. This implementation is substantially
- * faster than Java's BigInteger class because it only uses primitive types.
+ * Supports negative exponents via modular inverse (requires gcd(a, m) = 1) and negative bases.
+ * Uses overflow-safe modular multiplication to handle the full range of long values.
*
- * Time Complexity O(lg(n))
+ * Time Complexity: O(log(n))
*
* @author William Fiset, william.alexandre.fiset@gmail.com
*/
package com.williamfiset.algorithms.math;
-import java.math.BigInteger;
-
public class ModPow {
- // The values placed into the modPow function cannot be greater
- // than MAX or less than MIN otherwise long overflow will
- // happen when the values get squared (they will exceed 2^63-1)
- private static final long MAX = (long) Math.sqrt(Long.MAX_VALUE);
- private static final long MIN = -MAX;
-
- // Computes the Greatest Common Divisor (GCD) of a & b
- private static long gcd(long a, long b) {
- return b == 0 ? (a < 0 ? -a : a) : gcd(b, a % b);
- }
-
- // This function performs the extended euclidean algorithm on two numbers a and b.
- // The function returns the gcd(a,b) as well as the numbers x and y such
- // that ax + by = gcd(a,b). This calculation is important in number theory
- // and can be used for several things such as finding modular inverses and
- // solutions to linear Diophantine equations.
- private static long[] egcd(long a, long b) {
- if (b == 0) return new long[] {a < 0 ? -a : a, 1L, 0L};
- long[] v = egcd(b, a % b);
- long tmp = v[1] - v[2] * (a / b);
- v[1] = v[2];
- v[2] = tmp;
- return v;
- }
-
- // Returns the modular inverse of 'a' mod 'm'
- // Make sure m > 0 and 'a' & 'm' are relatively prime.
- private static long modInv(long a, long m) {
-
- a = ((a % m) + m) % m;
-
- long[] v = egcd(a, m);
- long x = v[1];
-
- return ((x % m) + m) % m;
- }
-
- // Computes a^n modulo mod very efficiently in O(lg(n)) time.
- // This function supports negative exponent values and a negative
- // base, however the modulus must be positive.
+ /**
+ * Computes a^n mod m.
+ *
+ * @throws ArithmeticException if mod <= 0, or if n < 0 and gcd(a, mod) != 1.
+ */
public static long modPow(long a, long n, long mod) {
+ if (mod <= 0)
+ throw new ArithmeticException("mod must be > 0");
- if (mod <= 0) throw new ArithmeticException("mod must be > 0");
- if (a > MAX || mod > MAX)
- throw new IllegalArgumentException("Long overflow is upon you, mod or base is too high!");
- if (a < MIN || mod < MIN)
- throw new IllegalArgumentException("Long overflow is upon you, mod or base is too low!");
-
- // To handle negative exponents we can use the modular
- // inverse of a to our advantage since: a^-n mod m = (a^-1)^n mod m
+ // a^-n mod m = (a^-1)^n mod m
if (n < 0) {
if (gcd(a, mod) != 1)
throw new ArithmeticException("If n < 0 then must have gcd(a, mod) = 1");
return modPow(modInv(a, mod), -n, mod);
}
- if (n == 0L) return 1L;
- long p = a, r = 1L;
+ // Normalize base into [0, mod)
+ a = ((a % mod) + mod) % mod;
- for (long i = 0; n != 0; i++) {
- long mask = 1L << i;
- if ((n & mask) == mask) {
- r = (((r * p) % mod) + mod) % mod;
- n -= mask;
- }
- p = ((p * p) % mod + mod) % mod;
+ long result = 1;
+ while (n > 0) {
+ if ((n & 1) == 1)
+ result = mulMod(result, a, mod);
+ a = mulMod(a, a, mod);
+ n >>= 1;
}
-
- return ((r % mod) + mod) % mod;
+ return result;
}
- // Example usage
- public static void main(String[] args) {
-
- BigInteger A, N, M, r1;
- long a, n, m, r2;
-
- A = BigInteger.valueOf(3);
- N = BigInteger.valueOf(4);
- M = BigInteger.valueOf(1000000);
- a = A.longValue();
- n = N.longValue();
- m = M.longValue();
-
- // 3 ^ 4 mod 1000000
- r1 = A.modPow(N, M); // 81
- r2 = modPow(a, n, m); // 81
- System.out.println(r1 + " " + r2);
-
- A = BigInteger.valueOf(-45);
- N = BigInteger.valueOf(12345);
- M = BigInteger.valueOf(987654321);
- a = A.longValue();
- n = N.longValue();
- m = M.longValue();
-
- // Finds -45 ^ 12345 mod 987654321
- r1 = A.modPow(N, M); // 323182557
- r2 = modPow(a, n, m); // 323182557
- System.out.println(r1 + " " + r2);
-
- A = BigInteger.valueOf(6);
- N = BigInteger.valueOf(-66);
- M = BigInteger.valueOf(101);
- a = A.longValue();
- n = N.longValue();
- m = M.longValue();
-
- // Finds 6 ^ -66 mod 101
- r1 = A.modPow(N, M); // 84
- r2 = modPow(a, n, m); // 84
- System.out.println(r1 + " " + r2);
-
- A = BigInteger.valueOf(-5);
- N = BigInteger.valueOf(-7);
- M = BigInteger.valueOf(1009);
- a = A.longValue();
- n = N.longValue();
- m = M.longValue();
-
- // Finds -5 ^ -7 mod 1009
- r1 = A.modPow(N, M); // 675
- r2 = modPow(a, n, m); // 675
- System.out.println(r1 + " " + r2);
-
- for (int i = 0; i < 1000; i++) {
- A = BigInteger.valueOf(a);
- N = BigInteger.valueOf(n);
- M = BigInteger.valueOf(m);
- a = Math.random() < 0.5 ? randLong(MAX) : -randLong(MAX);
- n = randLong();
- m = randLong(MAX);
- try {
- r1 = A.modPow(N, M);
- r2 = modPow(a, n, m);
- if (r1.longValue() != r2)
- System.out.printf("Broke with: a = %d, n = %d, m = %d\n", a, n, m);
- } catch (ArithmeticException e) {
- }
- }
+ private static long modInv(long a, long m) {
+ a = ((a % m) + m) % m;
+ long x = egcd(a, m)[1];
+ return ((x % m) + m) % m;
}
- /* TESTING RELATED METHODS */
-
- static final java.util.Random RANDOM = new java.util.Random();
+ private static long[] egcd(long a, long b) {
+ if (b == 0)
+ return new long[] {a < 0 ? -a : a, 1L, 0L};
+ long[] v = egcd(b, a % b);
+ long tmp = v[1] - v[2] * (a / b);
+ v[1] = v[2];
+ v[2] = tmp;
+ return v;
+ }
- // Returns long between [1, bound]
- public static long randLong(long bound) {
- return java.util.concurrent.ThreadLocalRandom.current().nextLong(1, bound + 1);
+ private static long gcd(long a, long b) {
+ a = Math.abs(a);
+ b = Math.abs(b);
+ return b == 0 ? a : gcd(b, a % b);
}
- public static long randLong() {
- return RANDOM.nextLong();
+ /** Overflow-safe modular multiplication: (a * b) % mod. */
+ private static long mulMod(long a, long b, long mod) {
+ return java.math.BigInteger.valueOf(a)
+ .multiply(java.math.BigInteger.valueOf(b))
+ .mod(java.math.BigInteger.valueOf(mod))
+ .longValue();
}
}
diff --git a/src/main/java/com/williamfiset/algorithms/math/NChooseRModPrime.java b/src/main/java/com/williamfiset/algorithms/math/NChooseRModPrime.java
deleted file mode 100644
index a3574edac..000000000
--- a/src/main/java/com/williamfiset/algorithms/math/NChooseRModPrime.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/**
- * @author Rohit Mazumder, mazumder.rohit7@gmai.com
- */
-package com.williamfiset.algorithms.math;
-
-import java.math.BigInteger;
-
-public class NChooseRModPrime {
- /**
- * Calculate the value of C(N, R) % P using Fermat's Little Theorem.
- *
- * @param N
- * @param R
- * @param P
- * @return The value of N choose R Modulus P
- */
- public static long compute(int N, int R, int P) {
- if (R == 0) return 1;
-
- long[] factorial = new long[N + 1];
- factorial[0] = 1;
-
- for (int i = 1; i <= N; i++) {
- factorial[i] = factorial[i - 1] * i % P;
- }
-
- return (factorial[N]
- * ModularInverse.modInv(factorial[R], P)
- % P
- * ModularInverse.modInv(factorial[N - R], P)
- % P)
- % P;
- }
-
- // Method for testing output against the output generated by the compute(int,int,int) function
- private static String bigIntegerNChooseRModP(int N, int R, int P) {
- if (R == 0) return "1";
- BigInteger num = BigInteger.ONE;
- BigInteger den = BigInteger.ONE;
- while (R > 0) {
- num = num.multiply(BigInteger.valueOf(N));
- den = den.multiply(BigInteger.valueOf(R));
- BigInteger gcd = num.gcd(den);
- num = num.divide(gcd);
- den = den.divide(gcd);
- N--;
- R--;
- }
- num = num.divide(den);
- num = num.mod(BigInteger.valueOf(P));
- return num.toString();
- }
-
- public static void main(String args[]) {
- int N = 500;
- int R = 250;
- int P = 1000000007;
- int expected = Integer.parseInt(bigIntegerNChooseRModP(N, R, P));
- long actual = compute(N, R, P);
- System.out.println(expected); // 515561345
- System.out.println(actual); // 515561345
- }
-}
diff --git a/src/main/java/com/williamfiset/algorithms/math/RabinMillerPrimalityTest.py b/src/main/java/com/williamfiset/algorithms/math/RabinMillerPrimalityTest.py
deleted file mode 100644
index 3e0a82938..000000000
--- a/src/main/java/com/williamfiset/algorithms/math/RabinMillerPrimalityTest.py
+++ /dev/null
@@ -1,24 +0,0 @@
-
-import random
-# Rabin_Miller primality check. Tests whether or not a number
-# is prime with a failure rate of: (1/2)^certainty
-def isPrime(n, certainty = 12 ):
- if(n < 2): return False
- if(n != 2 and (n & 1) == 0): return False
- s = n-1
- while((s & 1) == 0): s >>= 1
- for _ in range(certainty):
- r = random.randrange(n-1) + 1
- tmp = s
- mod = pow(r,tmp,n)
- while(tmp != n-1 and mod != 1 and mod != n-1):
- mod = (mod*mod) % n
- tmp <<= 1
- if (mod != n-1 and (tmp & 1) == 0): return False
- return True
-
-print(isPrime(5))
-print(isPrime(1433))
-print(isPrime(567887653))
-print(isPrime(75611592179197710043))
-print(isPrime(205561530235962095930138512256047424384916810786171737181163))
diff --git a/src/test/java/com/williamfiset/algorithms/math/BUILD b/src/test/java/com/williamfiset/algorithms/math/BUILD
index 138ee1134..f51d8d3d8 100644
--- a/src/test/java/com/williamfiset/algorithms/math/BUILD
+++ b/src/test/java/com/williamfiset/algorithms/math/BUILD
@@ -47,3 +47,36 @@ java_test(
runtime_deps = JUNIT5_RUNTIME_DEPS,
deps = TEST_DEPS,
)
+
+# bazel test //src/test/java/com/williamfiset/algorithms/math:ChineseRemainderTheoremTest
+java_test(
+ name = "ChineseRemainderTheoremTest",
+ srcs = ["ChineseRemainderTheoremTest.java"],
+ main_class = "org.junit.platform.console.ConsoleLauncher",
+ use_testrunner = False,
+ args = ["--select-class=com.williamfiset.algorithms.math.ChineseRemainderTheoremTest"],
+ runtime_deps = JUNIT5_RUNTIME_DEPS,
+ deps = TEST_DEPS,
+)
+
+# bazel test //src/test/java/com/williamfiset/algorithms/math:ModPowTest
+java_test(
+ name = "ModPowTest",
+ srcs = ["ModPowTest.java"],
+ main_class = "org.junit.platform.console.ConsoleLauncher",
+ use_testrunner = False,
+ args = ["--select-class=com.williamfiset.algorithms.math.ModPowTest"],
+ runtime_deps = JUNIT5_RUNTIME_DEPS,
+ deps = TEST_DEPS,
+)
+
+# bazel test //src/test/java/com/williamfiset/algorithms/math:BinomialCoefficientModPrimeTest
+java_test(
+ name = "BinomialCoefficientModPrimeTest",
+ srcs = ["BinomialCoefficientModPrimeTest.java"],
+ main_class = "org.junit.platform.console.ConsoleLauncher",
+ use_testrunner = False,
+ args = ["--select-class=com.williamfiset.algorithms.math.BinomialCoefficientModPrimeTest"],
+ runtime_deps = JUNIT5_RUNTIME_DEPS,
+ deps = TEST_DEPS,
+)
diff --git a/src/test/java/com/williamfiset/algorithms/math/BinomialCoefficientModPrimeTest.java b/src/test/java/com/williamfiset/algorithms/math/BinomialCoefficientModPrimeTest.java
new file mode 100644
index 000000000..73893eced
--- /dev/null
+++ b/src/test/java/com/williamfiset/algorithms/math/BinomialCoefficientModPrimeTest.java
@@ -0,0 +1,88 @@
+package com.williamfiset.algorithms.math;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import java.math.BigInteger;
+import org.junit.jupiter.api.*;
+
+public class BinomialCoefficientModPrimeTest {
+
+ private static final int MOD = 1_000_000_007;
+
+ /** Computes C(n, r) mod p using BigInteger as a reference. */
+ private static long bigIntCnr(int n, int r, int p) {
+ BigInteger num = BigInteger.ONE;
+ BigInteger den = BigInteger.ONE;
+ for (int i = 0; i < r; i++) {
+ num = num.multiply(BigInteger.valueOf(n - i));
+ den = den.multiply(BigInteger.valueOf(i + 1));
+ }
+ return num.divide(den).mod(BigInteger.valueOf(p)).longValue();
+ }
+
+ @Test
+ public void rEqualsZero() {
+ assertThat(BinomialCoefficientModPrime.compute(10, 0, MOD)).isEqualTo(1);
+ }
+
+ @Test
+ public void rEqualsN() {
+ assertThat(BinomialCoefficientModPrime.compute(10, 10, MOD)).isEqualTo(1);
+ }
+
+ @Test
+ public void smallValues() {
+ assertThat(BinomialCoefficientModPrime.compute(5, 2, MOD)).isEqualTo(10);
+ assertThat(BinomialCoefficientModPrime.compute(6, 3, MOD)).isEqualTo(20);
+ assertThat(BinomialCoefficientModPrime.compute(10, 4, MOD)).isEqualTo(210);
+ }
+
+ @Test
+ public void knownLargeValue() {
+ // C(500, 250) mod 10^9+7 = 515561345
+ assertThat(BinomialCoefficientModPrime.compute(500, 250, MOD)).isEqualTo(515561345L);
+ }
+
+ @Test
+ public void symmetry() {
+ // C(n, r) == C(n, n-r)
+ assertThat(BinomialCoefficientModPrime.compute(100, 30, MOD))
+ .isEqualTo(BinomialCoefficientModPrime.compute(100, 70, MOD));
+ }
+
+ @Test
+ public void smallPrime() {
+ // C(6, 2) = 15, mod 7 = 1
+ assertThat(BinomialCoefficientModPrime.compute(6, 2, 7)).isEqualTo(1);
+ }
+
+ @Test
+ public void matchesBigInteger() {
+ int[] ns = {20, 50, 100, 200, 500};
+ for (int n : ns) {
+ for (int r = 0; r <= n; r += Math.max(1, n / 10)) {
+ long expected = bigIntCnr(n, r, MOD);
+ assertThat(BinomialCoefficientModPrime.compute(n, r, MOD)).isEqualTo(expected);
+ }
+ }
+ }
+
+ @Test
+ public void negativeNThrows() {
+ assertThrows(IllegalArgumentException.class,
+ () -> BinomialCoefficientModPrime.compute(-1, 0, MOD));
+ }
+
+ @Test
+ public void rGreaterThanNThrows() {
+ assertThrows(IllegalArgumentException.class,
+ () -> BinomialCoefficientModPrime.compute(5, 6, MOD));
+ }
+
+ @Test
+ public void invalidModulusThrows() {
+ assertThrows(IllegalArgumentException.class,
+ () -> BinomialCoefficientModPrime.compute(5, 2, 1));
+ }
+}
diff --git a/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java
new file mode 100644
index 000000000..ed312dcb3
--- /dev/null
+++ b/src/test/java/com/williamfiset/algorithms/math/ChineseRemainderTheoremTest.java
@@ -0,0 +1,167 @@
+package com.williamfiset.algorithms.math;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.jupiter.api.Assertions.assertNull;
+
+import org.junit.jupiter.api.*;
+
+public class ChineseRemainderTheoremTest {
+
+ // --- eliminateCoefficient tests ---
+
+ @Test
+ public void eliminateCoefficient_simple() {
+ // 3x ≡ 6 (mod 9) → x ≡ 2 (mod 3)
+ long[] result = ChineseRemainderTheorem.eliminateCoefficient(3, 6, 9);
+ assertThat(result).isNotNull();
+ assertThat(result[0]).isEqualTo(2);
+ assertThat(result[1]).isEqualTo(3);
+ }
+
+ @Test
+ public void eliminateCoefficient_coefficientOne() {
+ // 1x ≡ 5 (mod 7) → x ≡ 5 (mod 7)
+ long[] result = ChineseRemainderTheorem.eliminateCoefficient(1, 5, 7);
+ assertThat(result).isNotNull();
+ assertThat(result[0]).isEqualTo(5);
+ assertThat(result[1]).isEqualTo(7);
+ }
+
+ @Test
+ public void eliminateCoefficient_unsolvable() {
+ // 2x ≡ 3 (mod 4) — no solution since gcd(2,4)=2 does not divide 3
+ assertNull(ChineseRemainderTheorem.eliminateCoefficient(2, 3, 4));
+ }
+
+ @Test
+ public void eliminateCoefficient_coprime() {
+ // 3x ≡ 1 (mod 7) → x ≡ 5 (mod 7) since 3*5=15≡1 (mod 7)
+ long[] result = ChineseRemainderTheorem.eliminateCoefficient(3, 1, 7);
+ assertThat(result).isNotNull();
+ assertThat(result[0]).isEqualTo(5);
+ assertThat(result[1]).isEqualTo(7);
+ }
+
+ // --- crt tests ---
+
+ @Test
+ public void crt_classicExample() {
+ // x ≡ 2 (mod 3), x ≡ 3 (mod 5), x ≡ 2 (mod 7) → x ≡ 23 (mod 105)
+ long[] result = ChineseRemainderTheorem.crt(new long[] {2, 3, 2}, new long[] {3, 5, 7});
+ assertThat(result[0]).isEqualTo(23);
+ assertThat(result[1]).isEqualTo(105);
+ }
+
+ @Test
+ public void crt_twoEquations() {
+ // x ≡ 1 (mod 2), x ≡ 2 (mod 3) → x ≡ 5 (mod 6)
+ long[] result = ChineseRemainderTheorem.crt(new long[] {1, 2}, new long[] {2, 3});
+ assertThat(result[0]).isEqualTo(5);
+ assertThat(result[1]).isEqualTo(6);
+ }
+
+ @Test
+ public void crt_singleEquation() {
+ long[] result = ChineseRemainderTheorem.crt(new long[] {3}, new long[] {7});
+ assertThat(result[0]).isEqualTo(3);
+ assertThat(result[1]).isEqualTo(7);
+ }
+
+ @Test
+ public void crt_resultSatisfiesAllCongruences() {
+ long[] a = {1, 2, 3};
+ long[] m = {5, 7, 11};
+ long[] result = ChineseRemainderTheorem.crt(a, m);
+ for (int i = 0; i < a.length; i++)
+ assertThat(result[0] % m[i]).isEqualTo(a[i]);
+ }
+
+ @Test
+ public void crt_zeroRemainders() {
+ // x ≡ 0 (mod 3), x ≡ 0 (mod 5) → x ≡ 0 (mod 15)
+ long[] result = ChineseRemainderTheorem.crt(new long[] {0, 0}, new long[] {3, 5});
+ assertThat(result[0]).isEqualTo(0);
+ assertThat(result[1]).isEqualTo(15);
+ }
+
+ // --- reduce tests ---
+
+ @Test
+ public void reduce_alreadyCoprime() {
+ long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 2}, new long[] {3, 5});
+ assertThat(result).isNotNull();
+ // Should preserve the equations since 3 and 5 are already coprime
+ assertThat(result[0]).asList().containsExactly(1L, 2L);
+ assertThat(result[1]).asList().containsExactly(3L, 5L);
+ }
+
+ @Test
+ public void reduce_sharedPrimeFactor() {
+ // x ≡ 1 (mod 6), x ≡ 3 (mod 10) — share factor 2
+ // 6 = 2·3, 10 = 2·5 → split to mod {2,3,2,5}
+ long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 3}, new long[] {6, 10});
+ assertThat(result).isNotNull();
+ // The reduced system should be solvable via CRT
+ long[] crtResult = ChineseRemainderTheorem.crt(result[0], result[1]);
+ // Verify solution satisfies original congruences
+ assertThat(crtResult[0] % 6).isEqualTo(1);
+ assertThat(crtResult[0] % 10).isEqualTo(3);
+ }
+
+ @Test
+ public void reduce_inconsistent() {
+ // x ≡ 1 (mod 4), x ≡ 2 (mod 8) — inconsistent since 2 mod 4 = 2 ≠ 1
+ assertNull(ChineseRemainderTheorem.reduce(new long[] {1, 2}, new long[] {4, 8}));
+ }
+
+ @Test
+ public void reduce_redundantEquation() {
+ // x ≡ 1 (mod 2), x ≡ 1 (mod 4) — second subsumes first
+ long[][] result = ChineseRemainderTheorem.reduce(new long[] {1, 1}, new long[] {2, 4});
+ assertThat(result).isNotNull();
+ long[] crtResult = ChineseRemainderTheorem.crt(result[0], result[1]);
+ assertThat(crtResult[0] % 4).isEqualTo(1);
+ }
+
+ // --- egcd tests ---
+
+ @Test
+ public void egcd_basicProperties() {
+ long[] result = ChineseRemainderTheorem.egcd(35, 15);
+ assertThat(result[0]).isEqualTo(5); // gcd(35,15) = 5
+ // Verify Bezout's identity: 35*x + 15*y = 5
+ assertThat(35 * result[1] + 15 * result[2]).isEqualTo(5);
+ }
+
+ @Test
+ public void egcd_coprime() {
+ long[] result = ChineseRemainderTheorem.egcd(7, 11);
+ assertThat(result[0]).isEqualTo(1);
+ assertThat(7 * result[1] + 11 * result[2]).isEqualTo(1);
+ }
+
+ // --- Integration: reduce + crt ---
+
+ @Test
+ public void reduceAndCrt_fullPipeline() {
+ // Solve x ≡ 2 (mod 12), x ≡ 8 (mod 10)
+ // 12 = 4·3, 10 = 2·5 — share factor 2, consistent since 2 ≡ 0 (mod 2) and 8 ≡ 0 (mod 2)
+ long[][] reduced = ChineseRemainderTheorem.reduce(new long[] {2, 8}, new long[] {12, 10});
+ assertThat(reduced).isNotNull();
+ long[] result = ChineseRemainderTheorem.crt(reduced[0], reduced[1]);
+ assertThat(result[0] % 12).isEqualTo(2);
+ assertThat(result[0] % 10).isEqualTo(8);
+ }
+
+ @Test
+ public void reduceAndCrt_threeCoprime() {
+ // Already coprime — reduce should pass through, CRT solves directly
+ long[] a = {2, 3, 2};
+ long[] m = {3, 5, 7};
+ long[][] reduced = ChineseRemainderTheorem.reduce(a, m);
+ assertThat(reduced).isNotNull();
+ long[] result = ChineseRemainderTheorem.crt(reduced[0], reduced[1]);
+ assertThat(result[0]).isEqualTo(23);
+ assertThat(result[1]).isEqualTo(105);
+ }
+}
diff --git a/src/test/java/com/williamfiset/algorithms/math/ModPowTest.java b/src/test/java/com/williamfiset/algorithms/math/ModPowTest.java
new file mode 100644
index 000000000..d69e9ad3d
--- /dev/null
+++ b/src/test/java/com/williamfiset/algorithms/math/ModPowTest.java
@@ -0,0 +1,100 @@
+package com.williamfiset.algorithms.math;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import java.math.BigInteger;
+import java.util.concurrent.ThreadLocalRandom;
+import org.junit.jupiter.api.*;
+
+public class ModPowTest {
+
+ @Test
+ public void basicPositiveExponent() {
+ // 3^4 mod 1000000 = 81
+ assertThat(ModPow.modPow(3, 4, 1000000)).isEqualTo(81);
+ }
+
+ @Test
+ public void negativeBase() {
+ // (-45)^12345 mod 987654321
+ long expected =
+ BigInteger.valueOf(-45)
+ .modPow(BigInteger.valueOf(12345), BigInteger.valueOf(987654321))
+ .longValue();
+ assertThat(ModPow.modPow(-45, 12345, 987654321)).isEqualTo(expected);
+ }
+
+ @Test
+ public void negativeExponent() {
+ // 6^-66 mod 101 = 84
+ long expected =
+ BigInteger.valueOf(6)
+ .modPow(BigInteger.valueOf(-66), BigInteger.valueOf(101))
+ .longValue();
+ assertThat(ModPow.modPow(6, -66, 101)).isEqualTo(expected);
+ }
+
+ @Test
+ public void negativeBaseAndExponent() {
+ // (-5)^-7 mod 1009
+ long expected =
+ BigInteger.valueOf(-5)
+ .modPow(BigInteger.valueOf(-7), BigInteger.valueOf(1009))
+ .longValue();
+ assertThat(ModPow.modPow(-5, -7, 1009)).isEqualTo(expected);
+ }
+
+ @Test
+ public void exponentZero() {
+ assertThat(ModPow.modPow(123, 0, 7)).isEqualTo(1);
+ assertThat(ModPow.modPow(0, 0, 5)).isEqualTo(1);
+ }
+
+ @Test
+ public void baseZero() {
+ assertThat(ModPow.modPow(0, 10, 7)).isEqualTo(0);
+ }
+
+ @Test
+ public void modOne() {
+ // Anything mod 1 = 0
+ assertThat(ModPow.modPow(999, 999, 1)).isEqualTo(0);
+ }
+
+ @Test
+ public void largeValues() {
+ // Test with values that would overflow without safe multiplication
+ long a = 1_000_000_000L;
+ long n = 1_000_000_000L;
+ long mod = 999_999_937L;
+ long expected =
+ BigInteger.valueOf(a).modPow(BigInteger.valueOf(n), BigInteger.valueOf(mod)).longValue();
+ assertThat(ModPow.modPow(a, n, mod)).isEqualTo(expected);
+ }
+
+ @Test
+ public void modNonPositiveThrows() {
+ assertThrows(ArithmeticException.class, () -> ModPow.modPow(2, 3, 0));
+ assertThrows(ArithmeticException.class, () -> ModPow.modPow(2, 3, -5));
+ }
+
+ @Test
+ public void negativeExponentNotCoprime() {
+ // gcd(4, 8) = 4 ≠ 1, so no modular inverse
+ assertThrows(ArithmeticException.class, () -> ModPow.modPow(4, -1, 8));
+ }
+
+ @Test
+ public void matchesBigIntegerRandomized() {
+ ThreadLocalRandom rng = ThreadLocalRandom.current();
+ for (int i = 0; i < 500; i++) {
+ long a = rng.nextLong(-1_000_000_000L, 1_000_000_000L);
+ long n = rng.nextLong(0, 1_000_000_000L);
+ long mod = rng.nextLong(1, 1_000_000_000L);
+ long expected =
+ BigInteger.valueOf(a).modPow(BigInteger.valueOf(n), BigInteger.valueOf(mod)).longValue();
+ assertThat(ModPow.modPow(a, n, mod)).isEqualTo(expected);
+ }
+ }
+}