diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index ffb4699..2e69729 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -302,10 +302,13 @@ public: } cybozu::write(pb, os, buf + sizeof(buf) - len, len); } + /* + mode = Mod : set x mod p if sizeof(S) * n <= 64 else error + */ template - void setArray(bool *pb, const S *x, size_t n) + void setArray(bool *pb, const S *x, size_t n, mcl::fp::MaskMode mode = fp::NoMask) { - *pb = fp::copyAndMask(v_, x, sizeof(S) * n, op_, fp::NoMask); + *pb = fp::copyAndMask(v_, x, sizeof(S) * n, op_, mode); toMont(); } /* diff --git a/include/mcl/op.hpp b/include/mcl/op.hpp index 347b02f..155d168 100644 --- a/include/mcl/op.hpp +++ b/include/mcl/op.hpp @@ -161,7 +161,8 @@ enum PrimeMode { enum MaskMode { NoMask = 0, // throw if greater or equal SmallMask = 1, // 1-bit smaller mask if greater or equal - MaskAndMod = 2 // mask and substract if greater or equal + MaskAndMod = 2, // mask and substract if greater or equal + Mod = 3 // mod p }; struct Op { @@ -174,6 +175,7 @@ struct Op { mpz_class mp; uint32_t pmod4; mcl::SquareRoot sq; + mcl::Modp modp; Unit half[maxUnitSize]; // (p + 1) / 2 Unit oneRep[maxUnitSize]; // 1(=inv R if Montgomery) /* diff --git a/src/fp.cpp b/src/fp.cpp index ebd9477..df72d6d 100644 --- a/src/fp.cpp +++ b/src/fp.cpp @@ -476,6 +476,7 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, int _xi_a, Mode mode, size sq.set(&b, mp); if (!b) return false; } + modp.init(mp); return fp::initForMont(*this, p, mode); } @@ -528,6 +529,27 @@ int detectIoMode(int ioMode, const std::ios_base& ios) bool copyAndMask(Unit *y, const void *x, size_t xByteSize, const Op& op, MaskMode maskMode) { const size_t fpByteSize = sizeof(Unit) * op.N; + if (maskMode == Mod) { + if (xByteSize > fpByteSize * 2) return false; + mpz_class mx; + bool b; + gmp::setArray(&b, mx, (const char*)x, xByteSize); + if (!b) return false; +#ifdef MCL_USE_VINT + op.modp.modp(mx, mx); +#else + mx %= op.mp; +#endif + const Unit *pmx = gmp::getUnit(mx); + size_t i = 0; + for (const size_t n = gmp::getUnitSize(mx); i < n; i++) { + y[i] = pmx[i]; + } + for (; i < op.N; i++) { + y[i] = 0; + } + return true; + } if (xByteSize > fpByteSize) { if (maskMode == NoMask) return false; xByteSize = fpByteSize; diff --git a/test/fp_test.cpp b/test/fp_test.cpp index 26f97b1..7e78773 100644 --- a/test/fp_test.cpp +++ b/test/fp_test.cpp @@ -565,6 +565,43 @@ void setArrayMaskTest2(mcl::fp::Mode mode) } } +void setArrayModTest() +{ + const mpz_class& p = Fp::getOp().mp; + const mpz_class tbl[] = { + 0, + 1, + p - 1, + p, + p + 1, + p * 2 - 1, + p * 2, + p * 2 + 1, + p * (p - 1) - 1, + p * (p - 1), + p * (p - 1) + 1, + p * p - 1, + p * p, + p * p + 1, + (mpz_class(1) << Fp::getOp().N * mcl::fp::UnitBitSize * 2) - 1, + }; + const size_t unitByteSize = sizeof(mcl::fp::Unit); + for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) { + const mpz_class& x = tbl[i]; + const mcl::fp::Unit *px = mcl::gmp::getUnit(x); + const size_t xn = mcl::gmp::getUnitSize(x); + const size_t xByteSize = xn * unitByteSize; + const size_t fpByteSize = unitByteSize * Fp::getOp().N; + Fp y; + bool b; + y.setArray(&b, px, xn, mcl::fp::Mod); + bool expected = xByteSize <= fpByteSize * 2; + CYBOZU_TEST_EQUAL(b, expected); + if (!b) continue; + CYBOZU_TEST_EQUAL(y.getMpz(), x % p); + } +} + CYBOZU_TEST_AUTO(set64bit) { Fp::init("0x1000000000000000000f"); @@ -911,6 +948,7 @@ void sub(mcl::fp::Mode mode) powGmp(); setArrayTest1(); setArrayMaskTest1(); + setArrayModTest(); getUint64Test(); getInt64Test(); divBy2Test();