diff --git a/Makefile b/Makefile index c4c5eba..7cb811a 100644 --- a/Makefile +++ b/Makefile @@ -261,7 +261,7 @@ endif emcc -o $@ src/fp.cpp src/she_c384.cpp $(EMCC_OPT) -DMCL_MAX_BIT_SIZE=384 -s TOTAL_MEMORY=67108864 -s DISABLE_EXCEPTION_CATCHING=0 ../mcl-wasm/mcl_c.js: src/bn_c256.cpp $(MCL_C_DEP) - emcc -o $@ src/fp.cpp src/bn_c256.cpp $(EMCC_OPT) -DMCL_MAX_BIT_SIZE=256 -DMCL_USE_WEB_CRYPTO_API -s DISABLE_EXCEPTION_CATCHING=1 + emcc -o $@ src/fp.cpp src/bn_c256.cpp $(EMCC_OPT) -DMCL_MAX_BIT_SIZE=256 -DMCL_USE_WEB_CRYPTO_API -s DISABLE_EXCEPTION_CATCHING=1 #-DCYBOZU_DONT_USE_EXCEPTION -DCYBOZU_DONT_USE_STRING ../mcl-wasm/mcl_c512.js: src/bn_c512.cpp $(MCL_C_DEP) emcc -o $@ src/fp.cpp src/bn_c512.cpp $(EMCC_OPT) -DMCL_MAX_BIT_SIZE=512 -DMCL_USE_WEB_CRYPTO_API -s DISABLE_EXCEPTION_CATCHING=1 diff --git a/include/mcl/bn.hpp b/include/mcl/bn.hpp index a686999..98020ac 100644 --- a/include/mcl/bn.hpp +++ b/include/mcl/bn.hpp @@ -890,7 +890,8 @@ struct Param { { this->cp = cp; isBLS12 = cp.curveType == MCL_BLS12_381; - z = mpz_class(cp.z); + gmp::setStr(pb, z, cp.z); + if (!*pb) return; isNegative = z < 0; if (isNegative) { abs_z = -z; @@ -970,12 +971,14 @@ struct Param { glv2.init(r, z, isBLS12); *pb = true; } +#ifndef CYBOZU_DONT_EXCEPTION void init(const mcl::CurveParam& cp, fp::Mode mode) { bool b; init(&b, cp, mode); if (!b) throw cybozu::Exception("Param:init"); } +#endif }; template @@ -1828,6 +1831,7 @@ inline void precomputedMillerLoop2(Fp12& f, const G1& P1, const mcl::Array& } inline void mapToG1(bool *pb, G1& P, const Fp& x) { *pb = BN::param.mapTo.calcG1(P, x); } inline void mapToG2(bool *pb, G2& P, const Fp2& x) { *pb = BN::param.mapTo.calcG2(P, x); } +#ifndef CYBOZU_DONT_EXCEPTION inline void mapToG1(G1& P, const Fp& x) { bool b; @@ -1840,6 +1844,7 @@ inline void mapToG2(G2& P, const Fp2& x) mapToG2(&b, P, x); if (!b) throw cybozu::Exception("mapToG2:bad value") << x; } +#endif inline void hashAndMapToG1(G1& P, const void *buf, size_t bufSize) { Fp t; @@ -1861,6 +1866,7 @@ inline void hashAndMapToG2(G2& P, const void *buf, size_t bufSize) assert(b); (void)b; } +#ifndef CYBOZU_DONT_USE_STRING inline void hashAndMapToG1(G1& P, const std::string& str) { hashAndMapToG1(P, str.c_str(), str.size()); @@ -1869,6 +1875,7 @@ inline void hashAndMapToG2(G2& P, const std::string& str) { hashAndMapToG2(P, str.c_str(), str.size()); } +#endif inline void verifyOrderG1(bool doVerify) { if (BN::param.isBLS12) { @@ -1936,12 +1943,14 @@ inline void init(bool *pb, const mcl::CurveParam& cp = mcl::BN254, fp::Mode mode *pb = true; } +#ifndef CYBOZU_DONT_EXCEPTION inline void init(const mcl::CurveParam& cp = mcl::BN254, fp::Mode mode = fp::FP_AUTO) { bool b; init(&b, cp, mode); if (!b) throw cybozu::Exception("BN:init"); } +#endif } // mcl::bn::BN @@ -1950,12 +1959,14 @@ inline void initPairing(bool *pb, const mcl::CurveParam& cp = mcl::BN254, fp::Mo BN::init(pb, cp, mode); } +#ifndef CYBOZU_DONT_EXCEPTION inline void initPairing(const mcl::CurveParam& cp = mcl::BN254, fp::Mode mode = fp::FP_AUTO) { bool b; BN::init(&b, cp, mode); if (!b) throw cybozu::Exception("bn:initPairing"); } +#endif } } // mcl::bn diff --git a/include/mcl/conversion.hpp b/include/mcl/conversion.hpp index 1000257..b5faa50 100644 --- a/include/mcl/conversion.hpp +++ b/include/mcl/conversion.hpp @@ -26,6 +26,7 @@ bool skipSpace(char *c, InputStream& is) } } +#ifndef CYBOZU_DONT_USE_STRING template void loadWord(std::string& s, InputStream& is) { @@ -39,6 +40,7 @@ void loadWord(std::string& s, InputStream& is) s += c; } } +#endif template size_t loadWord(char *buf, size_t bufSize, InputStream& is) diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index b114a1f..caa67c9 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -48,7 +48,14 @@ int64_t getInt64(bool *pb, fp::Block& b, const fp::Op& op); const char *ModeToStr(Mode mode); -Mode StrToMode(const std::string& s); +Mode StrToMode(const char *s); + +#ifndef CYBOZU_DONT_USE_STRING +inline Mode StrToMode(const std::string& s) +{ + return StrToMode(s.c_str()); +} +#endif inline void dumpUnit(Unit x) { @@ -124,36 +131,14 @@ public: static inline void init(bool *pb, const char *mstr, fp::Mode mode = fp::FP_AUTO) { mpz_class p; - gmp::setStr(pb, p, mstr, strlen(mstr)); + gmp::setStr(pb, p, mstr); if (!*pb) return; init(pb, p, mode); } - static inline void init(const mpz_class& _p, fp::Mode mode = fp::FP_AUTO) - { - bool b; - init(&b, _p, mode); - if (!b) throw cybozu::Exception("Fp:init"); - } - static inline void init(const std::string& mstr, fp::Mode mode = fp::FP_AUTO) - { - bool b; - init(&b, mstr.c_str(), mode); - if (!b) throw cybozu::Exception("Fp:init"); - } static inline size_t getModulo(char *buf, size_t bufSize) { return gmp::getStr(buf, bufSize, op_.mp); } - static inline void getModulo(std::string& pstr) - { - gmp::getStr(pstr, op_.mp); - } - static std::string getModulo() - { - std::string s; - getModulo(s); - return s; - } static inline bool isFullBit() { return op_.isFullBit; } /* binary patter of p @@ -176,8 +161,8 @@ public: x.getMpz(mx); bool b = op_.sq.get(my, mx); if (!b) return false; - y.setMpz(my); - return true; + y.setMpz(&b, my); + return b; } FpT() {} FpT(const FpT& x) @@ -194,10 +179,6 @@ public: op_.fp_clear(v_); } FpT(int64_t x) { operator=(x); } - explicit FpT(const std::string& str, int base = 0) - { - Serializer::setStr(str, base); - } FpT& operator=(int64_t x) { if (x == 1) { @@ -290,36 +271,12 @@ public: } cybozu::write(pb, os, buf + sizeof(buf) - len, len); } - template - void save(OutputStream& os, int ioMode = IoSerialize) const - { - bool b; - save(&b, os, ioMode); - if (!b) throw cybozu::Exception("fp:save") << ioMode; - } - template - void load(InputStream& is, int ioMode = IoSerialize) - { - bool b; - load(&b, is, ioMode); - if (!b) throw cybozu::Exception("fp:load") << ioMode; - } template void setArray(bool *pb, const S *x, size_t n) { *pb = fp::copyAndMask(v_, x, sizeof(S) * n, op_, fp::NoMask); toMont(); } - /* - throw exception if x >= p - */ - template - void setArray(const S *x, size_t n) - { - bool b; - setArray(&b, x, n); - if (!b) throw cybozu::Exception("Fp:setArray"); - } /* mask x with (1 << bitLen) and subtract p if x >= p */ @@ -368,10 +325,6 @@ public: uint32_t size = op_.hash(buf, static_cast(sizeof(buf)), msg, static_cast(msgSize)); setArrayMask(buf, size); } - void setHashOf(const std::string& msg) - { - setHashOf(msg.data(), msg.size()); - } void getMpz(mpz_class& x) const { fp::Block b; @@ -392,12 +345,6 @@ public: } setArray(pb, gmp::getUnit(x), gmp::getUnitSize(x)); } - void setMpz(const mpz_class& x) - { - bool b; - setMpz(&b, x); - if (!b) throw cybozu::Exception("Fp:setMpz:neg"); - } static inline void add(FpT& z, const FpT& x, const FpT& y) { op_.fp_add(z.v_, x.v_, y.v_, op_.p); } static inline void sub(FpT& z, const FpT& x, const FpT& y) { op_.fp_sub(z.v_, x.v_, y.v_, op_.p); } static inline void addPre(FpT& z, const FpT& x, const FpT& y) { op_.fp_addPre(z.v_, x.v_, y.v_); } @@ -458,20 +405,6 @@ public: getBlock(b); return fp::getInt64(pb, b, op_); } - uint64_t getUint64() const - { - bool b; - uint64_t v = getUint64(&b); - if (!b) throw cybozu::Exception("Fp:getUint64:large value"); - return v; - } - int64_t getInt64() const - { - bool b; - int64_t v = getInt64(&b); - if (!b) throw cybozu::Exception("Fp:getInt64:large value"); - return v; - } bool operator==(const FpT& rhs) const { return fp::isEqualArray(v_, rhs.v_, op_.N); } bool operator!=(const FpT& rhs) const { return !operator==(rhs); } friend inline std::ostream& operator<<(std::ostream& os, const FpT& self) @@ -526,16 +459,94 @@ public: ioMode_ = ioMode; } static inline int getIoMode() { return ioMode_; } + static inline size_t getModBitLen() { return getBitSize(); } + static inline void setHashFunc(uint32_t hash(void *out, uint32_t maxOutSize, const void *msg, uint32_t msgSize)) + { + op_.hash = hash; + } +#ifndef CYBOZU_DONT_USE_STRING + explicit FpT(const std::string& str, int base = 0) + { + Serializer::setStr(str, base); + } + static inline void getModulo(std::string& pstr) + { + gmp::getStr(pstr, op_.mp); + } + static std::string getModulo() + { + std::string s; + getModulo(s); + return s; + } + void setHashOf(const std::string& msg) + { + setHashOf(msg.data(), msg.size()); + } // backward compatibility static inline void setModulo(const std::string& mstr, fp::Mode mode = fp::FP_AUTO) { init(mstr, mode); } - static inline size_t getModBitLen() { return getBitSize(); } - static inline void setHashFunc(uint32_t hash(void *out, uint32_t maxOutSize, const void *msg, uint32_t msgSize)) +#endif +#ifndef CYBOZU_DONT_USE_EXCEPTION + static inline void init(const mpz_class& _p, fp::Mode mode = fp::FP_AUTO) { - op_.hash = hash; + bool b; + init(&b, _p, mode); + if (!b) throw cybozu::Exception("Fp:init"); + } + static inline void init(const std::string& mstr, fp::Mode mode = fp::FP_AUTO) + { + bool b; + init(&b, mstr.c_str(), mode); + if (!b) throw cybozu::Exception("Fp:init"); + } + template + void save(OutputStream& os, int ioMode = IoSerialize) const + { + bool b; + save(&b, os, ioMode); + if (!b) throw cybozu::Exception("fp:save") << ioMode; + } + template + void load(InputStream& is, int ioMode = IoSerialize) + { + bool b; + load(&b, is, ioMode); + if (!b) throw cybozu::Exception("fp:load") << ioMode; + } + /* + throw exception if x >= p + */ + template + void setArray(const S *x, size_t n) + { + bool b; + setArray(&b, x, n); + if (!b) throw cybozu::Exception("Fp:setArray"); + } + void setMpz(const mpz_class& x) + { + bool b; + setMpz(&b, x); + if (!b) throw cybozu::Exception("Fp:setMpz:neg"); + } + uint64_t getUint64() const + { + bool b; + uint64_t v = getUint64(&b); + if (!b) throw cybozu::Exception("Fp:getUint64:large value"); + return v; + } + int64_t getInt64() const + { + bool b; + int64_t v = getInt64(&b); + if (!b) throw cybozu::Exception("Fp:getInt64:large value"); + return v; } +#endif }; template fp::Op FpT::op_; diff --git a/include/mcl/gmp_util.hpp b/include/mcl/gmp_util.hpp index 399041c..bb461af 100644 --- a/include/mcl/gmp_util.hpp +++ b/include/mcl/gmp_util.hpp @@ -64,12 +64,13 @@ typedef mpz_class ImplType; // z = [buf[n-1]:..:buf[1]:buf[0]] // eg. buf[] = {0x12345678, 0xaabbccdd}; => z = 0xaabbccdd12345678; template -void setArray(mpz_class& z, const T *buf, size_t n) +void setArray(bool *pb, mpz_class& z, const T *buf, size_t n) { #ifdef MCL_USE_VINT - z.setArray(buf, n); + z.setArray(pb, buf, n); #else mpz_import(z.get_mpz_t(), n, -1, sizeof(*buf), 0, 0, buf); + *pb = true; #endif } /* @@ -78,44 +79,43 @@ void setArray(mpz_class& z, const T *buf, size_t n) */ #ifndef MCL_USE_VINT template -void getArray(T *buf, size_t maxSize, const mpz_srcptr x) +bool getArray_(T *buf, size_t maxSize, const mpz_srcptr x) { const size_t bufByteSize = sizeof(T) * maxSize; const int xn = x->_mp_size; - if (xn < 0) throw cybozu::Exception("gmp:getArray:x is negative"); + if (xn < 0) return false; size_t xByteSize = sizeof(*x->_mp_d) * xn; - if (xByteSize > bufByteSize) throw cybozu::Exception("gmp:getArray:too small") << xn << maxSize; + if (xByteSize > bufByteSize) return false; memcpy(buf, x->_mp_d, xByteSize); memset((char*)buf + xByteSize, 0, bufByteSize - xByteSize); + return true; } #endif template -void getArray(T *buf, size_t maxSize, const mpz_class& x) +void getArray(bool *pb, T *buf, size_t maxSize, const mpz_class& x) { #ifdef MCL_USE_VINT - x.getArray(buf, maxSize); + x.getArray(pb, buf, maxSize); #else - getArray(buf, maxSize, x.get_mpz_t()); + *pb = getArray_(buf, maxSize, x.get_mpz_t()); #endif } inline void set(mpz_class& z, uint64_t x) { - setArray(z, &x, 1); + bool b; + setArray(&b, z, &x, 1); + assert(b); + (void)b; } -inline void setStr(bool *pb, mpz_class& z, const char *str, size_t strSize, int base = 0) +inline void setStr(bool *pb, mpz_class& z, const char *str, int base = 0) { #ifdef MCL_USE_VINT - z.setStr(pb, str, strSize, base); + z.setStr(pb, str, base); #else - *pb = z.set_str(std::string(str, strSize), base) == 0; + *pb = z.set_str(str, base) == 0; #endif } -inline void setStr(mpz_class& z, const std::string& str, int base = 0) -{ - bool b; - setStr(&b, z, str.c_str(), str.size(), base); - if (!b) throw cybozu::Exception("gmp:setStr"); -} + /* set buf with string terminated by '\0' return strlen(buf) if success else 0 @@ -125,18 +125,19 @@ inline size_t getStr(char *buf, size_t bufSize, const mpz_class& z, int base = 1 #ifdef MCL_USE_VINT return z.getStr(buf, bufSize, base); #else - std::string str = z.get_str(base); - if (str.size() < bufSize) { - memcpy(buf, str.c_str(), str.size() + 1); - return str.size(); - } - return 0; + __gmp_alloc_cstring tmp(mpz_get_str(0, base, z.get_mpz_t())); + size_t n = strlen(tmp.str); + if (n + 1 > bufSize) return 0; + memcpy(buf, tmp.str, n + 1); + return n; #endif } + +#ifndef CYBOZU_DONT_USE_STRING inline void getStr(std::string& str, const mpz_class& z, int base = 10) { #ifdef MCL_USE_VINT - str = z.getStr(base); + z.getStr(str, base); #else str = z.get_str(base); #endif @@ -144,9 +145,11 @@ inline void getStr(std::string& str, const mpz_class& z, int base = 10) inline std::string getStr(const mpz_class& z, int base = 10) { std::string s; - getStr(s, z, base); + gmp::getStr(s, z, base); return s; } +#endif + inline void add(mpz_class& z, const mpz_class& x, const mpz_class& y) { #ifdef MCL_USE_VINT @@ -365,11 +368,12 @@ inline int legendre(const mpz_class& a, const mpz_class& p) return mpz_legendre(a.get_mpz_t(), p.get_mpz_t()); #endif } -inline bool isPrime(const mpz_class& x) +inline bool isPrime(bool *pb, const mpz_class& x) { #ifdef MCL_USE_VINT - return x.isPrime(32); + return x.isPrime(pb, 32); #else + *pb = true; return mpz_probab_prime_p(x.get_mpz_t(), 32) != 0; #endif } @@ -438,7 +442,7 @@ inline mpz_class abs(const mpz_class& x) #endif } -inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()) +inline void getRand(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()) { if (rg.isZero()) rg = fp::RandGen::get(); assert(bitSize > 1); @@ -447,7 +451,7 @@ inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen() uint32_t buf[128]; assert(n <= CYBOZU_NUM_OF_ARRAY(buf)); if (n > CYBOZU_NUM_OF_ARRAY(buf)) { - z = 0; + *pb = false; return; } rg.read(buf, n * sizeof(buf[0])); @@ -459,22 +463,26 @@ inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen() v |= 1U << (rem - 1); } buf[n - 1] = v; - setArray(z, buf, n); + setArray(pb, z, buf, n); } -inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false) +inline void getRandPrime(bool *pb, mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false) { if (rg.isZero()) rg = fp::RandGen::get(); assert(bitSize > 2); - do { - getRand(z, bitSize, rg); + for (;;) { + getRand(pb, z, bitSize, rg); + if (!*pb) return; if (setSecondBit) { z |= mpz_class(1) << (bitSize - 2); } if (mustBe3mod4) { z |= 3; } - } while (!(isPrime(z))); + bool ret = isPrime(pb, z); + if (!*pb) return; + if (ret) return; + } } inline mpz_class getQuadraticNonResidue(const mpz_class& p) { @@ -566,6 +574,49 @@ bool getNAF(Vec& v, const mpz_class& x) } } +#ifndef CYBOZU_DONT_USE_EXCEPTION +inline void setStr(mpz_class& z, const std::string& str, int base = 0) +{ + bool b; + setStr(&b, z, str.c_str(), base); + if (!b) throw cybozu::Exception("gmp:setStr"); +} +template +void setArray(mpz_class& z, const T *buf, size_t n) +{ + bool b; + setArray(&b, z, buf, n); + if (!b) throw cybozu::Exception("gmp:setArray"); +} +template +void getArray(T *buf, size_t maxSize, const mpz_class& x) +{ + bool b; + getArray(&b, buf, maxSize, x); + if (!b) throw cybozu::Exception("gmp:getArray"); +} +inline bool isPrime(const mpz_class& x) +{ + bool b; + bool ret = isPrime(&b, x); + if (!b) throw cybozu::Exception("gmp:isPrime"); + return ret; +} +inline void getRand(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen()) +{ + bool b; + getRand(&b, z, bitSize, rg); + if (!b) throw cybozu::Exception("gmp:getRand"); +} +inline void getRandPrime(mpz_class& z, size_t bitSize, fp::RandGen rg = fp::RandGen(), bool setSecondBit = false, bool mustBe3mod4 = false) +{ + bool b; + getRandPrime(&b, z, bitSize, rg, setSecondBit, mustBe3mod4); + if (!b) throw cybozu::Exception("gmp:getRandPrime"); +} +#endif + + } // mcl::gmp /* @@ -591,12 +642,19 @@ public: s = 0; q_add_1_div_2 = 0; } - void set(const mpz_class& _p) + void set(bool *pb, const mpz_class& _p) { p = _p; - if (p <= 2) throw cybozu::Exception("SquareRoot:bad p") << p; - isPrime = gmp::isPrime(p); - if (!isPrime) return; // don't throw until get() is called + if (p <= 2) { + *pb = false; + return; + } + isPrime = gmp::isPrime(pb, p); + if (!*pb) return; + if (!isPrime) { + *pb = false; + return; + } g = gmp::getQuadraticNonResidue(p); // p - 1 = 2^r q, q is odd r = 0; @@ -607,13 +665,18 @@ public: } gmp::powMod(s, g, q, p); q_add_1_div_2 = (q + 1) / 2; + *pb = true; } /* solve x^2 = a mod p */ - bool get(mpz_class& x, const mpz_class& a) const + bool get(bool *pb, mpz_class& x, const mpz_class& a) const { - if (!isPrime) throw cybozu::Exception("SquareRoot:get:not prime") << p; + if (!isPrime) { + *pb = false; + return false; + } + *pb = true; if (a == 0) { x = 0; return true; @@ -653,7 +716,7 @@ public: template bool get(Fp& x, const Fp& a) const { - if (Fp::getOp().mp != p) throw cybozu::Exception("bad Fp") << Fp::getOp().mp << p; + assert(Fp::getOp().mp == p); if (a == 0) { x = 0; return true; @@ -691,6 +754,21 @@ public: } return true; } +#ifndef CYBOZU_DONT_USE_EXCEPTION + void set(const mpz_class& _p) + { + bool b; + set(&b, _p); + if (!b) throw cybozu::Exception("gmp:SquareRoot:set"); + } + bool get(mpz_class& x, const mpz_class& a) const + { + bool b; + bool ret = get(&b, x, a); + if (!b) throw cybozu::Exception("gmp:SquareRoot:get:not prime"); + return ret; + } +#endif }; } // mcl diff --git a/include/mcl/operator.hpp b/include/mcl/operator.hpp index 7198929..29a66f5 100644 --- a/include/mcl/operator.hpp +++ b/include/mcl/operator.hpp @@ -136,6 +136,7 @@ struct Serializable : public E { buf[n] = '\0'; return n; } +#ifndef CYBOZU_DONT_USE_STRING void setStr(const std::string& str, int ioMode = 0) { cybozu::StringInputStream is(str); @@ -153,6 +154,7 @@ struct Serializable : public E { getStr(str, ioMode); return str; } +#endif // return written bytes size_t serialize(void *buf, size_t maxBufSize, int ioMode = IoSerialize) const { diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp index a7e9f7a..8fa24b9 100644 --- a/include/mcl/vint.hpp +++ b/include/mcl/vint.hpp @@ -618,6 +618,7 @@ void divNM(T *q, size_t qn, T *r, const T *x, size_t xn, const T *y, size_t yn) } } +#ifndef MCL_VINT_FIXED_BUFFER template class Buffer { size_t allocSize_; @@ -651,21 +652,6 @@ public: std::swap(allocSize_, rhs.allocSize_); std::swap(ptr_, rhs.ptr_); } -#if 0 -#if CYBOZU_CPP_VERSION >= CYBOZU_CPP_VERSION_CPP11 - Buffer(Buffer&& rhs) noexcept - : allocSize_(0) - , ptr_(0) - { - swap(rhs); - } - Buffer& operator=(Buffer&& rhs) noexcept - { - swap(rhs); - return *this; - } -#endif -#endif void clear() { allocSize_ = 0; @@ -676,17 +662,29 @@ public: /* @note extended buffer may be not cleared */ - void alloc(size_t n) + void alloc(bool *pb, size_t n) { if (n > allocSize_) { T *p = (T*)malloc(n * sizeof(T)); - if (p == 0) throw cybozu::Exception("Buffer:alloc:malloc:") << n; + if (p == 0) { + *pb = false; + return; + } copyN(p, ptr_, allocSize_); free(ptr_); ptr_ = p; allocSize_ = n; } + *pb = true; } +#ifndef CYBOZU_DONT_USE_EXCEPTION + void alloc(size_t n) + { + bool b; + alloc(&b, n); + if (!b) throw cybozu::Exception("Buffer:alloc"); + } +#endif /* *this = rhs rhs may be destroyed @@ -694,6 +692,7 @@ public: const T& operator[](size_t n) const { return ptr_[n]; } T& operator[](size_t n) { return ptr_[n]; } }; +#endif template class FixedBuffer { @@ -721,11 +720,23 @@ public: return *this; } void clear() { size_ = 0; } - void alloc(size_t n) + void alloc(bool *pb, size_t n) { - verify(n); + if (n > N) { + *pb = false; + return; + } size_ = n; + *pb = true; } +#ifndef CYBOZU_DONT_USE_EXCEPTION + void alloc(size_t n) + { + bool b; + alloc(&b, n); + if (!b) throw cybozu::Exception("FixedBuffer:alloc"); + } +#endif void swap(FixedBuffer& rhs) { FixedBuffer *p1 = this; @@ -745,9 +756,8 @@ public: // to avoid warning of gcc void verify(size_t n) const { - if (n > N) { - throw cybozu::Exception("verify:too large size") << n << (int)N; - } + assert(n <= N); + (void)n; } const T& operator[](size_t n) const { verify(n); return v_[n]; } T& operator[](size_t n) { verify(n); return v_[n]; } @@ -946,6 +956,19 @@ private: } r.trim(yn); } + /* + @param x [inout] x <- d + @retval s for x = 2^s d where d is odd + */ + static uint32_t countTrailingZero(VintT& x) + { + uint32_t s = 0; + while (x.isEven()) { + x >>= 1; + s++; + } + return s; + } struct MulMod { const VintT *pm; void operator()(VintT& z, const VintT& x, const VintT& y) const @@ -973,11 +996,6 @@ public: { *this = x; } - explicit VintT(const std::string& str) - : size_(0) - { - setStr(str); - } VintT(const VintT& rhs) : buf_(rhs.buf_) , size_(rhs.size_) @@ -1057,12 +1075,13 @@ public: /* set [0, max) randomly */ - void setRand(const VintT& max) + void setRand(bool *pb, const VintT& max) { + assert(max > 0); fp::RandGen& rg = fp::RandGen::get(); - if (max <= 0) throw cybozu::Exception("Vint:setRand:bad value") << max; size_t n = max.size(); - buf_.alloc(n); + buf_.alloc(pb, n); + if (!*pb) return; rg.read(&buf_[0], n * sizeof(buf_[0])); trim(n); *this %= max; @@ -1073,12 +1092,16 @@ public: buf_[size, maxSize) with zero @note assume little endian system */ - void getArray(Unit *x, size_t maxSize) const + void getArray(bool *pb, Unit *x, size_t maxSize) const { size_t n = size(); - if (n > maxSize) throw cybozu::Exception("Vint:getArray:small maxSize") << maxSize << n; + if (n > maxSize) { + *pb = false; + return; + } vint::copyN(x, &buf_[0], n); vint::clearN(x + n, maxSize - n); + *pb = true; } void clear() { *this = 0; } template @@ -1093,13 +1116,6 @@ public: } cybozu::write(pb, os, buf + sizeof(buf) - n, n); } - template - void save(OutputStream& os, int base = 10) const - { - bool b; - save(&b, os, base); - if (!b) throw cybozu::Exception("Vint:save"); - } /* set buf with string terminated by '\0' return strlen(buf) if success else 0 @@ -1114,13 +1130,6 @@ public: buf[n] = '\0'; return n; } - std::string getStr(int base = 10) const - { - std::string s; - cybozu::StringOutputStream os(s); - save(os, base); - return s; - } /* return bitSize(abs(*this)) @note return 1 if zero @@ -1138,7 +1147,7 @@ public: { size_t q = i / unitBitSize; size_t r = i % unitBitSize; - if (q > size()) throw cybozu::Exception("Vint:testBit:large i") << q << size(); + assert(q <= size()); Unit mask = Unit(1) << r; return (buf_[q] & mask) != 0; } @@ -1146,7 +1155,7 @@ public: { size_t q = i / unitBitSize; size_t r = i % unitBitSize; - if (q > size()) throw cybozu::Exception("Vint:setBit:large i") << q << size(); + assert(q <= size()); buf_.alloc(q + 1); Unit mask = Unit(1) << r; if (v) { @@ -1162,23 +1171,19 @@ public: "0b..." => base = 2 otherwise => base = 10 */ - void setStr(bool *pb, const char *str, size_t strSize, int base = 0) + void setStr(bool *pb, const char *str, int base = 0) { const size_t maxN = MCL_MAX_BIT_SIZE / (sizeof(MCL_SIZEOF_UNIT) * 8); - buf_.alloc(maxN); + buf_.alloc(pb, maxN); + if (!*pb) return; *pb = false; isNeg_ = false; - size_t n = fp::strToArray(&isNeg_, &buf_[0], maxN, str, strSize, base); + size_t len = strlen(str); + size_t n = fp::strToArray(&isNeg_, &buf_[0], maxN, str, len, base); if (n == 0) return; trim(n); *pb = true; } - void setStr(std::string str, int base = 0) - { - bool b; - setStr(&b, str.c_str(), str.size(), base); - if (!b) throw cybozu::Exception("Vint:setStr") << str; - } static int compare(const VintT& x, const VintT& y) { if (x.isNeg_ ^ y.isNeg_) { @@ -1372,10 +1377,6 @@ public: usub(r, yy.buf_, yy.size(), r.buf_, r.size()); } } - inline friend std::ostream& operator<<(std::ostream& os, const VintT& x) - { - return os << x.getStr(os.flags() & std::ios_base::hex ? 16 : 10); - } template void load(bool *pb, InputStream& is, int ioMode) { @@ -1391,18 +1392,6 @@ public: trim(n); *pb = true; } - template - void load(InputStream& is, int ioMode = 0) - { - bool b; - load(&b, is, ioMode); - if (!b) throw cybozu::Exception("Vint:load"); - } - inline friend std::istream& operator>>(std::istream& is, VintT& x) - { - x.load(is); - return is; - } // logical left shift (copy sign) static void shl(VintT& y, const VintT& x, size_t shiftBit) { @@ -1575,26 +1564,12 @@ public: b -= a * q; } } -private: - /* - @param x [inout] x <- d - @retval s for x = 2^s d where d is odd - */ - static uint32_t countTrailingZero(VintT& x) - { - uint32_t s = 0; - while (x.isEven()) { - x >>= 1; - s++; - } - return s; - } -public: /* Miller-Rabin */ - static bool isPrime(const VintT& n, int tryNum = 32) + static bool isPrime(bool *pb, const VintT& n, int tryNum = 32) { + *pb = true; if (n <= 1) return false; if (n == 2 || n == 3) return true; if (n.isEven()) return false; @@ -1604,7 +1579,8 @@ public: // n - 1 = 2^r d VintT a, x; for (int i = 0; i < tryNum; i++) { - a.setRand(n - 3); + a.setRand(pb, n - 3); + if (!*pb) return false; a += 2; // a in [2, n - 2] powMod(x, a, d, n); if (x == 1 || x == nm1) { @@ -1621,9 +1597,9 @@ public: } return true; } - bool isPrime(int tryNum = 32) const + bool isPrime(bool *pb, int tryNum = 32) const { - return isPrime(*this, tryNum); + return isPrime(pb, *this, tryNum); } static void gcd(VintT& z, VintT x, VintT y) { @@ -1665,7 +1641,7 @@ public: */ static int jacobi(VintT m, VintT n) { - if (n.isEven()) throw cybozu::Exception(); + assert(n.isOdd()); if (n == 1) return 1; if (m < 0 || m > n) { quotRem(0, m, m, n); // m = m mod n @@ -1693,6 +1669,81 @@ public: } return j; } +#ifndef CYBOZU_DONT_USE_STRING + explicit VintT(const std::string& str) + : size_(0) + { + setStr(str); + } + void getStr(std::string& s, int base = 10) const + { + cybozu::StringOutputStream os(s); + save(os, base); + } + std::string getStr(int base = 10) const + { + std::string s; + getStr(s, base); + return s; + } + inline friend std::ostream& operator<<(std::ostream& os, const VintT& x) + { + return os << x.getStr(os.flags() & std::ios_base::hex ? 16 : 10); + } + inline friend std::istream& operator>>(std::istream& is, VintT& x) + { + x.load(is); + return is; + } +#endif +#ifndef CYBOZU_DONT_USE_EXCEPTION + void setStr(const std::string& str, int base = 0) + { + bool b; + setStr(&b, str.c_str(), base); + if (!b) throw cybozu::Exception("Vint:setStr") << str; + } + void setRand(const VintT& max) + { + bool b; + setRand(&b, max); + if (!b) throw cybozu::Exception("Vint:setRand"); + } + void getArray(Unit *x, size_t maxSize) const + { + bool b; + getArray(&b, x, maxSize); + if (!b) throw cybozu::Exception("Vint:getArray"); + } + template + void load(InputStream& is, int ioMode = 0) + { + bool b; + load(&b, is, ioMode); + if (!b) throw cybozu::Exception("Vint:load"); + } + template + void save(OutputStream& os, int base = 10) const + { + bool b; + save(&b, os, base); + if (!b) throw cybozu::Exception("Vint:save"); + } + static bool isPrime(const VintT& n, int tryNum = 32) + { + bool b; + bool ret = isPrime(&b, n, tryNum); + if (!b) throw cybozu::Exception("Vint:isPrime"); + return ret; + } + bool isPrime(int tryNum = 32) const + { + bool b; + bool ret = isPrime(&b, *this, tryNum); + if (!b) throw cybozu::Exception("Vint:isPrime"); + return ret; + } +#endif VintT& operator++() { adds1(*this, *this, 1); return *this; } VintT& operator--() { subs1(*this, *this, 1); return *this; } VintT operator++(int) { VintT c = *this; adds1(*this, *this, 1); return c; } diff --git a/src/fp.cpp b/src/fp.cpp index e45217b..ba9e484 100644 --- a/src/fp.cpp +++ b/src/fp.cpp @@ -77,7 +77,7 @@ const char *ModeToStr(Mode mode) } } -Mode StrToMode(const std::string& s) +Mode StrToMode(const char *s) { static const struct { const char *s; @@ -91,7 +91,7 @@ Mode StrToMode(const std::string& s) { "xbyak", FP_XBYAK }, }; for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) { - if (s == tbl[i].s) return tbl[i].mode; + if (strcmp(s, tbl[i].s) == 0) return tbl[i].mode; } return FP_AUTO; } @@ -176,19 +176,24 @@ static inline void set_mpz_t(mpz_t& z, const Unit* p, int n) static inline void fp_invOpC(Unit *y, const Unit *x, const Op& op) { const int N = (int)op.N; + bool b; #ifdef MCL_USE_VINT Vint vx, vy, vp; - vx.setArray(x, N); - vp.setArray(op.p, N); + vx.setArray(&b, x, N); + assert(b); + vp.setArray(&b, op.p, N); + assert(b); Vint::invMod(vy, vx, vp); - vy.getArray(y, N); + vy.getArray(&b, y, N); + assert(b); #else mpz_class my; mpz_t mx, mp; set_mpz_t(mx, x, N); set_mpz_t(mp, op.p, N); mpz_invert(my.get_mpz_t(), mx, mp); - gmp::getArray(y, N, my); + gmp::getArray(&b, y, N, my); + assert(b); #endif } @@ -323,20 +328,24 @@ static void initInvTbl(Op& op) } #endif -static void initForMont(Op& op, const Unit *p, Mode mode) +static bool initForMont(Op& op, const Unit *p, Mode mode) { const size_t N = op.N; + bool b; { mpz_class t = 1, R; - gmp::getArray(op.one, N, t); + gmp::getArray(&b, op.one, N, t); + if (!b) return false; R = (t << (N * UnitBitSize)) % op.mp; t = (R * R) % op.mp; - gmp::getArray(op.R2, N, t); + gmp::getArray(&b, op.R2, N, t); + if (!b) return false; t = (t * R) % op.mp; - gmp::getArray(op.R3, N, t); + gmp::getArray(&b, op.R3, N, t); + if (!b) return false; } op.rp = getMontgomeryCoeff(p[0]); - if (mode != FP_XBYAK) return; + if (mode != FP_XBYAK) return true; #ifdef MCL_USE_XBYAK if (op.fg == 0) op.fg = Op::createFpGenerator(); op.fg->init(op); @@ -346,6 +355,7 @@ static void initForMont(Op& op, const Unit *p, Mode mode) initInvTbl(op); } #endif + return true; } bool Op::init(const mpz_class& _p, size_t maxBitSize, Mode mode, size_t mclMaxBitSize) @@ -359,11 +369,13 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, Mode mode, size_t mclMaxBi if (maxBitSize > MCL_MAX_BIT_SIZE) return false; if (_p <= 0) return false; clear(); + bool b; { const size_t maxN = (maxBitSize + fp::UnitBitSize - 1) / fp::UnitBitSize; N = gmp::getUnitSize(_p); if (N > maxN) return false; - gmp::getArray(p, N, _p); + gmp::getArray(&b, p, N, _p); + if (!b) return false; mp = _p; } bitSize = gmp::getBitSize(mp); @@ -417,10 +429,16 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, Mode mode, size_t mclMaxBi } #endif #if defined(MCL_USE_VINT) && MCL_SIZEOF_UNIT == 8 - if (mp == mpz_class("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f")) { - primeMode = PM_SECP256K1; - isMont = false; - isFastMod = true; + { + const char *secp256k1Str = "0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f"; + bool b; + mpz_class secp256k1; + gmp::setStr(&b, secp256k1, secp256k1Str); + if (b && mp == secp256k1) { + primeMode = PM_SECP256K1; + isMont = false; + isFastMod = true; + } } #endif switch (N) { @@ -477,8 +495,9 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, Mode mode, size_t mclMaxBi fpDbl_mod = &mcl::vint::mcl_fpDbl_mod_SECP256K1; } #endif - fp::initForMont(*this, p, mode); - sq.set(mp); + if (!fp::initForMont(*this, p, mode)) return false; + sq.set(&b, mp); + if (!b) return false; if (N * UnitBitSize <= 256) { hash = sha256; } else { diff --git a/test/fp_test.cpp b/test/fp_test.cpp index f81e0c8..71d6986 100644 --- a/test/fp_test.cpp +++ b/test/fp_test.cpp @@ -351,6 +351,7 @@ void compareTest() void moduloTest(const char *pStr) { +std::cout << std::hex; std::string str; Fp::getModulo(str); CYBOZU_TEST_EQUAL(str, mcl::gmp::getStr(mpz_class(pStr))); diff --git a/test/gmp_test.cpp b/test/gmp_test.cpp index 22c80dd..1fe9d4e 100644 --- a/test/gmp_test.cpp +++ b/test/gmp_test.cpp @@ -21,6 +21,45 @@ CYBOZU_TEST_AUTO(testBit) } } +CYBOZU_TEST_AUTO(getStr) +{ + const struct { + int x; + const char *dec; + const char *hex; + } tbl[] = { + { 0, "0", "0" }, + { 1, "1", "1" }, + { 10, "10", "a" }, + { 16, "16", "10" }, + { 123456789, "123456789", "75bcd15" }, + { -1, "-1", "-1" }, + { -10, "-10", "-a" }, + { -16, "-16", "-10" }, + { -100000000, "-100000000", "-5f5e100" }, + { -987654321, "-987654321", "-3ade68b1" }, + { -2147483647, "-2147483647", "-7fffffff" }, + }; + for (size_t i = 0; i < CYBOZU_NUM_OF_ARRAY(tbl); i++) { + mpz_class x = tbl[i].x; + char buf[32]; + size_t n, len; + len = strlen(tbl[i].dec); + n = mcl::gmp::getStr(buf, len, x, 10); + CYBOZU_TEST_EQUAL(n, 0); + n = mcl::gmp::getStr(buf, len + 1, x, 10); + CYBOZU_TEST_EQUAL(n, len); + CYBOZU_TEST_EQUAL_ARRAY(buf, tbl[i].dec, n); + + len = strlen(tbl[i].hex); + n = mcl::gmp::getStr(buf, len, x, 16); + CYBOZU_TEST_EQUAL(n, 0); + n = mcl::gmp::getStr(buf, len + 1, x, 16); + CYBOZU_TEST_EQUAL(n, len); + CYBOZU_TEST_EQUAL_ARRAY(buf, tbl[i].hex, n); + } +} + CYBOZU_TEST_AUTO(getRandPrime) { for (int i = 0; i < 10; i++) {