diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index d49b6be..f53c7a6 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -509,6 +509,16 @@ public: #endif static inline void addPre(FpT& z, const FpT& x, const FpT& y) { op_.fp_addPre(z.v_, x.v_, y.v_); } static inline void subPre(FpT& z, const FpT& x, const FpT& y) { op_.fp_subPre(z.v_, x.v_, y.v_); } + static inline void mulSmall(FpT& z, const FpT& x, const uint32_t y) + { + assert(y <= op_.smallModp.maxMulN); + Unit xy[maxSize + 1]; + op_.fp_mulUnitPre(xy, x.v_, y); + int v = op_.smallModp.approxMul(xy); + const Unit *pv = op_.smallModp.getPmul(v); + op_.fp_subPre(z.v_, xy, pv); + op_.fp_sub(z.v_, z.v_, op_.p, op_.p); + } static inline void mulUnit(FpT& z, const FpT& x, const Unit y) { if (mulSmallUnit(z, x, y)) return; diff --git a/include/mcl/gmp_util.hpp b/include/mcl/gmp_util.hpp index ed0880b..9309bcd 100644 --- a/include/mcl/gmp_util.hpp +++ b/include/mcl/gmp_util.hpp @@ -942,6 +942,85 @@ public: #endif }; +/* + x mod p for a small value x < (pMulTblN * p). +*/ +struct SmallModp { + typedef mcl::fp::Unit Unit; + static const size_t unitBitSize = sizeof(Unit) * 8; + static const size_t maxTblSize = (MCL_MAX_BIT_SIZE + unitBitSize - 1) / unitBitSize + 1; + static const size_t maxMulN = 9; + static const size_t pMulTblN = maxMulN + 1; + int N_; + int shiftL_; + int shiftR_; + int maxIdx_; + // pMulTbl_[i] = (p * i) >> (pBitSize_ - 1) + Unit pMulTbl_[pMulTblN][maxTblSize]; + // idxTbl_[x] = (x << (pBitSize_ - 1)) / p + int8_t idxTbl_[pMulTblN * 2]; + // return x >> (pBitSize_ - 1) + SmallModp() + : N_(0) + , shiftL_(0) + , shiftR_(0) + , maxIdx_(0) + , pMulTbl_() + , idxTbl_() + { + } + // return argmax { i : x > i * p } + int approxMul(const Unit *x) const + { + int top = getTop(x); + assert(top <= maxIdx_); + return idxTbl_[top]; + } + const Unit *getPmul(size_t v) const + { + assert(v < pMulTblN); + return pMulTbl_[v]; + } + int getTop(const Unit *x) const + { + return (x[N_ - 1] >> shiftR_) | (x[N_] << shiftL_); + } + int cvtInt(const mpz_class& x) const + { + assert(mcl::gmp::getUnitSize(x) <= 1); + if (x == 0) { + return 0; + } else { + return int(mcl::gmp::getUnit(x)[0]); + } + } + void init(const mpz_class& p) + { + size_t pBitSize = mcl::gmp::getBitSize(p); + N_ = (pBitSize + unitBitSize - 1) / unitBitSize; + shiftR_ = (pBitSize - 1) % unitBitSize; + shiftL_ = unitBitSize - shiftR_; + mpz_class t = 0; + for (size_t i = 0; i < pMulTblN; i++) { + bool b; + mcl::gmp::getArray(&b, pMulTbl_[i], maxTblSize, t); + assert(b); + (void)b; + if (i == pMulTblN - 1) { + maxIdx_ = getTop(pMulTbl_[i]); + assert(maxIdx_ < CYBOZU_NUM_OF_ARRAY(idxTbl_)); + break; + } + t += p; + } + + for (int i = 0; i <= maxIdx_; i++) { + idxTbl_[i] = cvtInt((mpz_class(i) << (pBitSize - 1)) / p); + } + } +}; + + /* Barrett Reduction for non GMP version diff --git a/include/mcl/op.hpp b/include/mcl/op.hpp index 13fd0c5..2dd358c 100644 --- a/include/mcl/op.hpp +++ b/include/mcl/op.hpp @@ -191,6 +191,7 @@ struct Op { uint32_t pmod4; mcl::SquareRoot sq; mcl::Modp modp; + mcl::SmallModp smallModp; Unit half[maxUnitSize]; // (p + 1) / 2 Unit oneRep[maxUnitSize]; // 1(=inv R if Montgomery) /* diff --git a/test/bench.hpp b/test/bench.hpp index 9a28db7..8378674 100644 --- a/test/bench.hpp +++ b/test/bench.hpp @@ -116,6 +116,10 @@ void testBench(const G1& P, const G2& Q) CYBOZU_BENCH_C("Fp::sub ", C3, Fp::sub, x, x, y); CYBOZU_BENCH_C("Fp::add 2 ", C3, Fp::add, x, x, x); CYBOZU_BENCH_C("Fp::mul2 ", C3, Fp::mul2, x, x); + CYBOZU_BENCH_C("Fp::mulSmall8 ", C3, Fp::mulSmall, x, x, 8); + CYBOZU_BENCH_C("Fp::mulUnit8 ", C3, Fp::mulUnit, x, x, 8); + CYBOZU_BENCH_C("Fp::mulSmall9 ", C3, Fp::mulSmall, x, x, 9); + CYBOZU_BENCH_C("Fp::mulUnit9 ", C3, Fp::mulUnit, x, x, 9); CYBOZU_BENCH_C("Fp::neg ", C3, Fp::neg, x, x); CYBOZU_BENCH_C("Fp::mul ", C3, Fp::mul, x, x, y); CYBOZU_BENCH_C("Fp::sqr ", C3, Fp::sqr, x, x); diff --git a/test/common_test.hpp b/test/common_test.hpp index 338e7d3..74a745c 100644 --- a/test/common_test.hpp +++ b/test/common_test.hpp @@ -183,8 +183,24 @@ void testFp2Dbl_mul_xi1() } } +void testMulSmall() +{ + puts("testMulSmall"); + cybozu::XorShift rg; + for (int y = 0; y < 10; y++) { + for (int i = 0; i < 40; i++) { + Fp x, z1, z2; + x.setByCSPRNG(rg); + Fp::mulSmall(z1, x, y); + z2 = x * y; + CYBOZU_TEST_EQUAL(z1, z2); + } + } +} + void testCommon(const G1& P, const G2& Q) { + testMulSmall(); testFp2Dbl_mul_xi1(); testABCD(); testMul2();