diff --git a/include/mcl/gmp_util.hpp b/include/mcl/gmp_util.hpp index a64bed4..81e69e2 100644 --- a/include/mcl/gmp_util.hpp +++ b/include/mcl/gmp_util.hpp @@ -859,4 +859,55 @@ public: #endif }; +/* + Barrett Reduction +*/ +struct Modp { + static const size_t unitBitSize = sizeof(mcl::fp::Unit) * 8; + mpz_class p_; + mpz_class u_; + mpz_class a_; + size_t N_; + // x &= 1 << (unitBitSize * unitSize) + void shrinkSize(mpz_class &x, size_t unitSize) const + { + size_t u = mcl::gmp::getUnitSize(x); + if (u < unitSize) return; + bool b; + mcl::gmp::setArray(&b, x, mcl::gmp::getUnit(x), unitSize); + assert(b); + } + void init(const mpz_class& p, size_t unitSize) + { + p_ = p; + N_ = unitSize; + u_ = (mpz_class(1) << (unitBitSize * 2 * N_)) / p_; + a_ = mpz_class(1) << (unitBitSize * (N_ + 1)); + } + void modp(mpz_class& r, const mpz_class& t) const + { + assert(0 <= t && t < mpz_class(1) << (unitBitSize * 2 * N_)); + if (t < p_) { + r = t; + return; + } + mpz_class q; + q = t; + q >>= unitBitSize * (N_ - 1); + q *= u_; + q >>= unitBitSize * (N_ + 1); + q *= p_; + shrinkSize(q, N_ + 1); + r = t; + shrinkSize(r, N_ + 1); + r -= q; + if (r < 0) { + r += a_; + } + if (r >= p_) { + r -= p_; + } + } +}; + } // mcl diff --git a/test/fp_test.cpp b/test/fp_test.cpp index d8b4742..b7b1e23 100644 --- a/test/fp_test.cpp +++ b/test/fp_test.cpp @@ -777,6 +777,37 @@ void serializeTest() } } +void modpTest() +{ + const mpz_class& p = Fp::getOp().mp; + const mpz_class tbl[] = { + 0, + 1, + p - 1, + p, + p + 1, + p * 2 - 1, + p * 2, + p * 2 + 1, + p * (p - 1) - 1, + p * (p - 1), + p * (p - 1) + 1, + p * p - 1, + p * p, + p * p + 1, + (mpz_class(1) << Fp::getOp().N * mcl::fp::UnitBitSize * 2) - 1, + }; + mcl::Modp modp; + modp.init(p, Fp::getUnitSize()); + for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) { + const mpz_class& x = tbl[i]; + mpz_class r1, r2; + r1 = x % p; + modp.modp(r2, x); + CYBOZU_TEST_EQUAL(r1, r2); + } +} + #include #if (defined(MCL_USE_LLVM) || defined(MCL_USE_XBYAK)) && (MCL_MAX_BIT_SIZE >= 521) CYBOZU_TEST_AUTO(mod_NIST_P521) @@ -886,6 +917,7 @@ void sub(mcl::fp::Mode mode) getStrTest(); setHashOfTest(); serializeTest(); + modpTest(); } anotherFpTest(mode); setArrayTest2(mode);