diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index 54807e0..f10fe14 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include namespace mcl { @@ -32,6 +32,10 @@ namespace fp { struct TagDefault; +void arrayToStr(std::string& str, const Unit *x, size_t n, int base, bool withPrefix); + +void strToGmp(mpz_class& x, bool *isMinus, const std::string& str, int base); + } // mcl::fp template @@ -58,7 +62,7 @@ public: { assert(maxBitN <= MCL_MAX_OP_BIT_N); bool isMinus; - inFromStr(op_.mp, &isMinus, mstr, base); + fp::strToGmp(op_.mp, &isMinus, mstr, base); if (isMinus) throw cybozu::Exception("mcl:FpT:setModulo:mstr is not minus") << mstr; const size_t bitLen = Gmp::getBitLen(op_.mp); if (bitLen > maxBitN) throw cybozu::Exception("mcl:FpT:setModulo:too large bitLen") << bitLen << maxBitN; @@ -140,7 +144,7 @@ public: { bool isMinus; mpz_class x; - inFromStr(x, &isMinus, str, base); + fp::strToGmp(x, &isMinus, str, base); if (x >= op_.mp) throw cybozu::Exception("fp:FpT:fromStr:large str") << str << op_.mp; fp::toArray(v_, op_.N, x.get_mpz_t()); if (isMinus) { @@ -192,31 +196,11 @@ public: fp::getRandVal(v_, rg, op_.p, op_.bitLen); fromMont(*this, *this); } - static inline void toStr(std::string& str, const Unit *x, size_t n, int base = 10, bool withPrefix = false) - { - switch (base) { - case 10: - { - mpz_class t; - Gmp::setRaw(t, x, n); - Gmp::toStr(str, t, 10); - } - return; - case 16: - mcl::fp::toStr16(str, x, n, withPrefix); - return; - case 2: - mcl::fp::toStr2(str, x, n, withPrefix); - return; - default: - throw cybozu::Exception("fp:FpT:toStr:bad base") << base; - } - } void toStr(std::string& str, int base = 10, bool withPrefix = false) const { fp::Block b; getBlock(b); - toStr(str, b.p, b.n, base, withPrefix); + fp::arrayToStr(str, b.p, b.n, base, withPrefix); } std::string toStr(int base = 10, bool withPrefix = false) const { @@ -369,13 +353,6 @@ public: op_.negP(y, x, op_.p); } private: - static inline void inFromStr(mpz_class& x, bool *isMinus, const std::string& str, int base) - { - const char *p = fp::verifyStr(isMinus, &base, str); - if (!Gmp::fromStr(x, p, base)) { - throw cybozu::Exception("fp:FpT:inFromStr") << str; - } - } }; template fp::Op FpT::op_; diff --git a/include/mcl/fp_util.hpp b/include/mcl/fp_util.hpp index 22748d0..7712e9c 100644 --- a/include/mcl/fp_util.hpp +++ b/include/mcl/fp_util.hpp @@ -2,6 +2,7 @@ #include #include #include +#include /** @file @brief utility of Fp @@ -12,32 +13,6 @@ namespace mcl { namespace fp { -#if defined(CYBOZU_OS_BIT) && (CYBOZU_OS_BIT == 32) - typedef uint32_t BlockType; -#else - typedef uint64_t BlockType; -#endif - -template -void setBlockBit(S *buf, size_t bitLen, bool b) -{ - const size_t unitSize = sizeof(S) * 8; - const size_t q = bitLen / unitSize; - const size_t r = bitLen % unitSize; - if (b) { - buf[q] |= S(1) << r; - } else { - buf[q] &= ~(S(1) << r); - } -} -template -bool getBlockBit(const S *buf, size_t bitLen) -{ - const size_t unitSize = sizeof(S) * 8; - const size_t q = bitLen / unitSize; - const size_t r = bitLen % unitSize; - return (buf[q] & (S(1) << r)) != 0; -} /* convert x[0..n) to hex string start "0x" if withPrefix @@ -126,80 +101,5 @@ void fromStr16(T *x, size_t xn, const char *str, size_t strLen) for (size_t i = requireSize; i < xn; i++) x[i] = 0; } -/* - @param base [inout] -*/ -inline const char *verifyStr(bool *isMinus, int *base, const std::string& str) -{ - const char *p = str.c_str(); - if (*p == '-') { - *isMinus = true; - p++; - } else { - *isMinus = false; - } - if (p[0] == '0') { - if (p[1] == 'x') { - if (*base != 0 && *base != 16) { - throw cybozu::Exception("fp:verifyStr:bad base") << *base << str; - } - *base = 16; - p += 2; - } else if (p[1] == 'b') { - if (*base != 0 && *base != 2) { - throw cybozu::Exception("fp:verifyStr:bad base") << *base << str; - } - *base = 2; - p += 2; - } - } - if (*base == 0) *base = 10; - if (*p == '\0') throw cybozu::Exception("fp:verifyStr:str is empty"); - return p; -} - -template -size_t getRoundNum(size_t x) -{ - const size_t size = sizeof(S) * 8; - return (x + size - 1) / size; -} - -/* - compare x[0, n) with y[0, n) -*/ -template -int compareArray(const S* x, const S* y, size_t n) -{ - for (size_t i = 0; i < n; i++) { - const S a = x[n - 1 - i]; - const S b = y[n - 1 - i]; - if (a > b) return 1; - if (a < b) return -1; - } - return 0; -} - -/* - get random value less than in[] - n = (bitLen + sizeof(S) * 8) / (sizeof(S) * 8) - input in[0..n) - output out[n..n) - 0 <= out < in -*/ -template -inline void getRandVal(S *out, RG& rg, const S *in, size_t bitLen) -{ - const size_t unitBitSize = sizeof(S) * 8; - const size_t n = getRoundNum(bitLen); - const size_t rem = bitLen & (unitBitSize - 1); - for (;;) { - rg.read(out, n); - if (rem > 0) out[n - 1] &= (S(1) << rem) - 1; - if (compareArray(out, in, n) < 0) return; - } -} - -} // mcl::fp -} // fp +} } // mcl::fp diff --git a/include/mcl/gmp_util.hpp b/include/mcl/gmp_util.hpp index a038c75..c08b789 100644 --- a/include/mcl/gmp_util.hpp +++ b/include/mcl/gmp_util.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable : 4616) @@ -51,7 +52,6 @@ #endif #endif #endif -#include namespace mcl { diff --git a/include/mcl/op.hpp b/include/mcl/op.hpp index 5d4ab12..e1b92a3 100644 --- a/include/mcl/op.hpp +++ b/include/mcl/op.hpp @@ -7,21 +7,24 @@ http://opensource.org/licenses/BSD-3-Clause */ #include -#include -#include +#include #ifndef MCL_MAX_OP_BIT_N #define MCL_MAX_OP_BIT_N 521 #endif -namespace mcl { +namespace mcl { namespace fp { -namespace fp { - -struct FpGenerator; +#if defined(CYBOZU_OS_BIT) && (CYBOZU_OS_BIT == 32) +typedef uint32_t Unit; +#else +typedef uint64_t Unit; +#endif +const size_t UnitBitN = sizeof(Unit) * 8; const size_t maxOpUnitN = (MCL_MAX_OP_BIT_N + UnitBitN - 1) / UnitBitN; +struct FpGenerator; struct Op; typedef void (*void1u)(Unit*); diff --git a/include/mcl/unit.hpp b/include/mcl/unit.hpp index 19c4b1b..4acb6d1 100644 --- a/include/mcl/unit.hpp +++ b/include/mcl/unit.hpp @@ -6,29 +6,22 @@ @license modified new BSD license http://opensource.org/licenses/BSD-3-Clause */ -#include #include namespace mcl { namespace fp { -#if defined(CYBOZU_OS_BIT) && (CYBOZU_OS_BIT == 32) -typedef uint32_t Unit; -#else -typedef uint64_t Unit; -#endif -const size_t UnitBitN = sizeof(Unit) * 8; - /* get pp such that p * pp = -1 mod M, where p is prime and M = 1 << 64(or 32). @param pLow [in] p mod M */ -inline Unit getMontgomeryCoeff(Unit pLow) +template +T getMontgomeryCoeff(T pLow) { - Unit ret = 0; - Unit t = 0; - Unit x = 1; - for (size_t i = 0; i < UnitBitN; i++) { + T ret = 0; + T t = 0; + T x = 1; + for (size_t i = 0; i < sizeof(T) * 8; i++) { if ((t & 1) == 0) { t += pLow; ret += x; @@ -39,7 +32,8 @@ inline Unit getMontgomeryCoeff(Unit pLow) return ret; } -inline int compareArray(const Unit* x, const Unit* y, size_t n) +template +int compareArray(const T* x, const T* y, size_t n) { for (size_t i = n - 1; i != size_t(-1); i--) { if (x[i] < y[i]) return -1; @@ -48,7 +42,8 @@ inline int compareArray(const Unit* x, const Unit* y, size_t n) return 0; } -inline bool isEqualArray(const Unit* x, const Unit* y, size_t n) +template +bool isEqualArray(const T* x, const T* y, size_t n) { for (size_t i = 0; i < n; i++) { if (x[i] != y[i]) return false; @@ -56,7 +51,8 @@ inline bool isEqualArray(const Unit* x, const Unit* y, size_t n) return true; } -inline bool isZeroArray(const Unit *x, size_t n) +template +bool isZeroArray(const T *x, size_t n) { for (size_t i = 0; i < n; i++) { if (x[i]) return false; @@ -64,25 +60,48 @@ inline bool isZeroArray(const Unit *x, size_t n) return true; } -inline void clearArray(Unit *x, size_t begin, size_t end) +template +void clearArray(T *x, size_t begin, size_t end) { for (size_t i = begin; i < end; i++) x[i] = 0; } -inline void copyArray(Unit *y, const Unit *x, size_t n) +template +void copyArray(T *y, const T *x, size_t n) { for (size_t i = 0; i < n; i++) y[i] = x[i]; } -inline void toArray(Unit *y, size_t yn, const mpz_srcptr x) +template +void toArray(T *y, size_t yn, const mpz_srcptr x) { const int xn = x->_mp_size; assert(xn >= 0); - const Unit* xp = (const Unit*)x->_mp_d; + const T* xp = (const T*)x->_mp_d; assert(xn <= (int)yn); copyArray(y, xp, xn); clearArray(y, xn, yn); } +/* + get random value less than in[] + n = (bitLen + sizeof(T) * 8) / (sizeof(T) * 8) + input in[0..n) + output out[n..n) + 0 <= out < in +*/ +template +void getRandVal(T *out, RG& rg, const T *in, size_t bitLen) +{ + const size_t TBitN = sizeof(T) * 8; + const size_t n = (bitLen + TBitN - 1) / TBitN; + const size_t rem = bitLen & (TBitN - 1); + for (;;) { + rg.read(out, n); + if (rem > 0) out[n - 1] &= (T(1) << rem) - 1; + if (compareArray(out, in, n) < 0) return; + } +} + } } // mcl::fp diff --git a/src/fp.cpp b/src/fp.cpp index dd86cba..77af38f 100644 --- a/src/fp.cpp +++ b/src/fp.cpp @@ -1,4 +1,6 @@ #include +#include +#include #ifdef USE_MONT_FP #include #endif @@ -249,5 +251,63 @@ void Op::init(const Unit* p, size_t bitLen) #endif } +void arrayToStr(std::string& str, const Unit *x, size_t n, int base, bool withPrefix) +{ + switch (base) { + case 10: + { + mpz_class t; + Gmp::setRaw(t, x, n); + Gmp::toStr(str, t, 10); + } + return; + case 16: + mcl::fp::toStr16(str, x, n, withPrefix); + return; + case 2: + mcl::fp::toStr2(str, x, n, withPrefix); + return; + default: + throw cybozu::Exception("fp:arrayToStr:bad base") << base; + } +} + +inline const char *verifyStr(bool *isMinus, int *base, const std::string& str) +{ + const char *p = str.c_str(); + if (*p == '-') { + *isMinus = true; + p++; + } else { + *isMinus = false; + } + if (p[0] == '0') { + if (p[1] == 'x') { + if (*base != 0 && *base != 16) { + throw cybozu::Exception("fp:verifyStr:bad base") << *base << str; + } + *base = 16; + p += 2; + } else if (p[1] == 'b') { + if (*base != 0 && *base != 2) { + throw cybozu::Exception("fp:verifyStr:bad base") << *base << str; + } + *base = 2; + p += 2; + } + } + if (*base == 0) *base = 10; + if (*p == '\0') throw cybozu::Exception("fp:verifyStr:str is empty"); + return p; +} + +void strToGmp(mpz_class& x, bool *isMinus, const std::string& str, int base) +{ + const char *p = fp::verifyStr(isMinus, &base, str); + if (!Gmp::fromStr(x, p, base)) { + throw cybozu::Exception("fp:FpT:inFromStr") << str; + } +} + } } // mcl::fp diff --git a/test/fp_generator_test.cpp b/test/fp_generator_test.cpp index 28ad51c..9848047 100644 --- a/test/fp_generator_test.cpp +++ b/test/fp_generator_test.cpp @@ -6,8 +6,8 @@ #include #include #include -#include #include +#include #include #include #include