diff --git a/include/mcl/bn.hpp b/include/mcl/bn.hpp index 3c5959b..d1cb9f7 100644 --- a/include/mcl/bn.hpp +++ b/include/mcl/bn.hpp @@ -561,30 +561,13 @@ struct MapTo { }; typedef mcl::FixedArray NafArray; -template -void addTbl(G& Q, const G *tbl, const NafArray& naf, size_t i) -{ - if (i >= naf.size()) return; - int n = naf[i]; - if (n > 0) { - Q += tbl[(n - 1) >> 1]; - } else if (n < 0) { - Q -= tbl[(-n - 1) >> 1]; - } -} /* Software implementation of Attribute-Based Encryption: Appendixes GLV for G1 on BN/BLS12 */ -template -struct GLV1T { - static F rw; // rw = 1 / w = (-1 - sqrt(-3)) / 2 - static size_t rBitSize; - static mpz_class v0, v1; - static mpz_class B[2][2]; - static mpz_class r; -private: + +struct GLV1 : mcl::GLV1T { static bool usePrecomputedTable(int curveType) { if (curveType < 0) return false; @@ -620,37 +603,21 @@ private: bool b; rw.setStr(&b, tbl[i].rw, 16); if (!b) continue; rBitSize = tbl[i].rBitSize; - mcl::gmp::setStr(&b, v0, tbl[i].v0, 16); if (!b) continue; - mcl::gmp::setStr(&b, v1, tbl[i].v1, 16); if (!b) continue; - mcl::gmp::setStr(&b, B[0][0], tbl[i].B[0][0], 16); if (!b) continue; - mcl::gmp::setStr(&b, B[0][1], tbl[i].B[0][1], 16); if (!b) continue; - mcl::gmp::setStr(&b, B[1][0], tbl[i].B[1][0], 16); if (!b) continue; - mcl::gmp::setStr(&b, B[1][1], tbl[i].B[1][1], 16); if (!b) continue; - mcl::gmp::setStr(&b, r, tbl[i].r, 16); if (!b) continue; + gmp::setStr(&b, v0, tbl[i].v0, 16); if (!b) continue; + gmp::setStr(&b, v1, tbl[i].v1, 16); if (!b) continue; + gmp::setStr(&b, B[0][0], tbl[i].B[0][0], 16); if (!b) continue; + gmp::setStr(&b, B[0][1], tbl[i].B[0][1], 16); if (!b) continue; + gmp::setStr(&b, B[1][0], tbl[i].B[1][0], 16); if (!b) continue; + gmp::setStr(&b, B[1][1], tbl[i].B[1][1], 16); if (!b) continue; + gmp::setStr(&b, r, tbl[i].r, 16); if (!b) continue; return true; } return false; } -public: -#ifndef CYBOZU_DONT_USE_STRING - static void dump(const mpz_class& x) - { - printf("\"%s\",\n", mcl::gmp::getStr(x, 16).c_str()); - } - static void dump() - { - printf("\"%s\",\n", rw.getStr(16).c_str()); - printf("%d,\n", (int)rBitSize); - dump(v0); - dump(v1); - dump(B[0][0]); dump(B[0][1]); dump(B[1][0]); dump(B[1][1]); - dump(r); - } -#endif - static void init(const mpz_class& _r, const mpz_class& z, bool isBLS12 = false, int curveType = -1) + static void initForBN(const mpz_class& _r, const mpz_class& z, bool isBLS12 = false, int curveType = -1) { if (usePrecomputedTable(curveType)) return; - bool b = F::squareRoot(rw, -3); + bool b = Fp::squareRoot(rw, -3); assert(b); (void)b; rw = -(rw + 1) / 2; @@ -684,90 +651,7 @@ public: v0 = ((-B[1][1]) << rBitSize) / r; v1 = ((B[1][0]) << rBitSize) / r; } - /* - L = lambda = p^4 - L (x, y) = (rw x, y) - */ - static void mulLambda(G& Q, const G& P) - { - F::mul(Q.x, P.x, rw); - Q.y = P.y; - Q.z = P.z; - } - /* - x = a + b * lambda mod r - */ - static void split(mpz_class& a, mpz_class& b, const mpz_class& x) - { - mpz_class t; - t = (x * v0) >> rBitSize; - b = (x * v1) >> rBitSize; - a = x - (t * B[0][0] + b * B[1][0]); - b = - (t * B[0][1] + b * B[1][1]); - } - static void mul(G& Q, const G& P, mpz_class x, bool constTime = false) - { - const int w = 5; - const size_t tblSize = 1 << (w - 2); - NafArray naf[2]; - mpz_class u[2]; - G tbl[2][tblSize]; - bool b; - - x %= r; - if (x == 0) { - Q.clear(); - if (!constTime) return; - } - if (x < 0) { - x += r; - } - split(u[0], u[1], x); - gmp::getNAFwidth(&b, naf[0], u[0], w); - assert(b); (void)b; - gmp::getNAFwidth(&b, naf[1], u[1], w); - assert(b); (void)b; - - tbl[0][0] = P; - mulLambda(tbl[1][0], tbl[0][0]); - { - G P2; - G::dbl(P2, P); - for (size_t i = 1; i < tblSize; i++) { - G::add(tbl[0][i], tbl[0][i - 1], P2); - mulLambda(tbl[1][i], tbl[0][i]); - } - } - const size_t maxBit = fp::max_(naf[0].size(), naf[1].size()); - Q.clear(); - for (size_t i = 0; i < maxBit; i++) { - G::dbl(Q, Q); - addTbl(Q, tbl[0], naf[0], maxBit - 1 - i); - addTbl(Q, tbl[1], naf[1], maxBit - 1 - i); - } - } - static void mulArray(G& z, const G& x, const mcl::fp::Unit *y, size_t yn, bool isNegative, bool constTime) - { - mpz_class s; - bool b; - mcl::gmp::setArray(&b, s, y, yn); - assert(b); - if (isNegative) s = -s; - mul(z, x, s, constTime); - } }; -template -F GLV1T::rw; // rw = 1 / w = (-1 - sqrt(-3)) / 2 -template -size_t GLV1T::rBitSize; -template -mpz_class GLV1T::v0; -template -mpz_class GLV1T::v1; -template -mpz_class GLV1T::B[2][2]; -template -mpz_class GLV1T::r; /* GLV method for G2 and GT on BN/BLS12 @@ -787,7 +671,7 @@ struct GLV2 { this->z = z; this->abs_z = z < 0 ? -z : z; this->isBLS12 = isBLS12; - rBitSize = mcl::gmp::getBitSize(r); + rBitSize = gmp::getBitSize(r); rBitSize = (rBitSize + mcl::fp::UnitBitSize - 1) & ~(mcl::fp::UnitBitSize - 1);// a little better size mpz_class z2p1 = z * 2 + 1; B[0][0] = z + 1; @@ -874,7 +758,6 @@ struct GLV2 { template void mul(T& Q, const T& P, mpz_class x, bool constTime = false) const { -#if 1 const int w = 5; const size_t tblSize = 1 << (w - 2); const size_t splitN = 4; @@ -917,120 +800,11 @@ struct GLV2 { Q.clear(); for (size_t i = 0; i < maxBit; i++) { T::dbl(Q, Q); - addTbl(Q, tbl[0], naf[0], maxBit - 1 - i); - addTbl(Q, tbl[1], naf[1], maxBit - 1 - i); - addTbl(Q, tbl[2], naf[2], maxBit - 1 - i); - addTbl(Q, tbl[3], naf[3], maxBit - 1 - i); - } -#else -#if 0 // #ifndef NDEBUG - { - T R; - T::mulGeneric(R, P, r); - assert(R.isZero()); - } -#endif - typedef mcl::fp::Unit Unit; - const size_t maxUnit = 512 / 2 / mcl::fp::UnitBitSize; - const int splitN = 4; - mpz_class u[splitN]; - T in[splitN]; - T tbl[16]; - int bitTbl[splitN]; // bit size of u[i] - Unit w[splitN][maxUnit]; // unit array of u[i] - int maxBit = 0; // max bit of u[i] - int maxN = 0; - int remainBit = 0; - - x %= r; - if (x == 0) { - Q.clear(); - if (constTime) goto DummyLoop; - return; + mcl::local::addTbl(Q, tbl[0], naf[0], maxBit - 1 - i); + mcl::local::addTbl(Q, tbl[1], naf[1], maxBit - 1 - i); + mcl::local::addTbl(Q, tbl[2], naf[2], maxBit - 1 - i); + mcl::local::addTbl(Q, tbl[3], naf[3], maxBit - 1 - i); } - if (x < 0) { - x += r; - } - split(u, x); - in[0] = P; - Frobenius(in[1], in[0]); - Frobenius(in[2], in[1]); - Frobenius(in[3], in[2]); - for (int i = 0; i < splitN; i++) { - if (u[i] < 0) { - u[i] = -u[i]; - T::neg(in[i], in[i]); - } -// in[i].normalize(); // slow - } -#if 0 - for (int i = 0; i < splitN; i++) { - T::mulGeneric(in[i], in[i], u[i]); - } - T::add(Q, in[0], in[1]); - Q += in[2]; - Q += in[3]; - return; -#else - tbl[0] = in[0]; - for (size_t i = 1; i < 16; i++) { - tbl[i].clear(); - if (i & 1) { - tbl[i] += in[0]; - } - if (i & 2) { - tbl[i] += in[1]; - } - if (i & 4) { - tbl[i] += in[2]; - } - if (i & 8) { - tbl[i] += in[3]; - } -// tbl[i].normalize(); - } - for (int i = 0; i < splitN; i++) { - bool b; - mcl::gmp::getArray(&b, w[i], maxUnit, u[i]); - assert(b); - bitTbl[i] = (int)mcl::gmp::getBitSize(u[i]); - maxBit = fp::max_(maxBit, bitTbl[i]); - } - maxBit--; - /* - maxBit = maxN * UnitBitSize + remainBit - 0 < remainBit <= UnitBitSize - */ - maxN = maxBit / mcl::fp::UnitBitSize; - remainBit = maxBit % mcl::fp::UnitBitSize; - remainBit++; - Q.clear(); - for (int i = maxN; i >= 0; i--) { - for (int j = remainBit - 1; j >= 0; j--) { - T::dbl(Q, Q); - uint32_t b0 = (w[0][i] >> j) & 1; - uint32_t b1 = (w[1][i] >> j) & 1; - uint32_t b2 = (w[2][i] >> j) & 1; - uint32_t b3 = (w[3][i] >> j) & 1; - uint32_t c = b3 * 8 + b2 * 4 + b1 * 2 + b0; - if (c == 0) { - if (constTime) tbl[0] += tbl[1]; - } else { - Q += tbl[c]; - } - } - remainBit = (int)mcl::fp::UnitBitSize; - } -#endif - DummyLoop: - if (!constTime) return; - const int limitBit = (int)rBitSize / splitN; - T D = tbl[0]; - for (int i = maxBit + 1; i < limitBit; i++) { - T::dbl(D, D); - D += tbl[0]; - } -#endif } void pow(Fp12& z, const Fp12& x, mpz_class y, bool constTime = false) const { @@ -1050,7 +824,6 @@ struct Param { mpz_class p; mpz_class r; local::MapTo mapTo; - typedef local::GLV1T GLV1; local::GLV2 glv2; // for G2 Frobenius Fp2 g2; @@ -1166,7 +939,7 @@ struct Param { } else { mapTo.init(2 * p - r, z, cp.curveType); } - GLV1::init(r, z, isBLS12, cp.curveType); + GLV1::initForBN(r, z, isBLS12, cp.curveType); glv2.init(r, z, isBLS12); basePoint.clear(); *pb = true; @@ -2233,7 +2006,7 @@ inline void init(bool *pb, const mcl::CurveParam& cp = mcl::BN254, fp::Mode mode { local::StaticVar<>::param.init(pb, cp, mode); if (!*pb) return; - G1::setMulArrayGLV(bn::local::Param::GLV1::mulArray); + G1::setMulArrayGLV(local::GLV1::mulArray); G2::setMulArrayGLV(local::mulArrayGLV2); Fp12::setPowArrayGLV(local::powArrayGLV2); G1::setCompressedExpression(); diff --git a/include/mcl/ec.hpp b/include/mcl/ec.hpp index ad6e6db..115a8de 100644 --- a/include/mcl/ec.hpp +++ b/include/mcl/ec.hpp @@ -10,6 +10,7 @@ #include #include #include +#include //#define MCL_EC_USE_AFFINE @@ -1068,6 +1069,130 @@ template void (*EcT::mulArrayGLV)(EcT& z, const EcT& x, const fp:: template int EcT::mode_; #endif +namespace local { + +template +void addTbl(G& Q, const G *tbl, const Vec& naf, size_t i) +{ + if (i >= naf.size()) return; + int n = naf[i]; + if (n > 0) { + Q += tbl[(n - 1) >> 1]; + } else if (n < 0) { + Q -= tbl[(-n - 1) >> 1]; + } +} + +} // mcl::local + +template +struct GLV1T { + static F rw; // rw = 1 / w = (-1 - sqrt(-3)) / 2 + static size_t rBitSize; + static mpz_class v0, v1; + static mpz_class B[2][2]; + static mpz_class r; +public: +#ifndef CYBOZU_DONT_USE_STRING + static void dump(const mpz_class& x) + { + printf("\"%s\",\n", mcl::gmp::getStr(x, 16).c_str()); + } + static void dump() + { + printf("\"%s\",\n", rw.getStr(16).c_str()); + printf("%d,\n", (int)rBitSize); + dump(v0); + dump(v1); + dump(B[0][0]); dump(B[0][1]); dump(B[1][0]); dump(B[1][1]); + dump(r); + } +#endif + /* + initGLV1() is defined in bn.hpp + */ + /* + L = lambda = p^4 + L (x, y) = (rw x, y) + */ + static void mulLambda(G& Q, const G& P) + { + F::mul(Q.x, P.x, rw); + Q.y = P.y; + Q.z = P.z; + } + /* + x = a + b * lambda mod r + */ + static void split(mpz_class& a, mpz_class& b, const mpz_class& x) + { + mpz_class t; + t = (x * v0) >> rBitSize; + b = (x * v1) >> rBitSize; + a = x - (t * B[0][0] + b * B[1][0]); + b = - (t * B[0][1] + b * B[1][1]); + } + static void mul(G& Q, const G& P, mpz_class x, bool constTime = false) + { + const int w = 5; + const size_t tblSize = 1 << (w - 2); + typedef mcl::FixedArray NafArray; + NafArray naf[2]; + mpz_class u[2]; + G tbl[2][tblSize]; + bool b; + + x %= r; + if (x == 0) { + Q.clear(); + if (!constTime) return; + } + if (x < 0) { + x += r; + } + split(u[0], u[1], x); + gmp::getNAFwidth(&b, naf[0], u[0], w); + assert(b); (void)b; + gmp::getNAFwidth(&b, naf[1], u[1], w); + assert(b); (void)b; + + tbl[0][0] = P; + mulLambda(tbl[1][0], tbl[0][0]); + { + G P2; + G::dbl(P2, P); + for (size_t i = 1; i < tblSize; i++) { + G::add(tbl[0][i], tbl[0][i - 1], P2); + mulLambda(tbl[1][i], tbl[0][i]); + } + } + const size_t maxBit = fp::max_(naf[0].size(), naf[1].size()); + Q.clear(); + for (size_t i = 0; i < maxBit; i++) { + G::dbl(Q, Q); + local::addTbl(Q, tbl[0], naf[0], maxBit - 1 - i); + local::addTbl(Q, tbl[1], naf[1], maxBit - 1 - i); + } + } + static void mulArray(G& z, const G& x, const mcl::fp::Unit *y, size_t yn, bool isNegative, bool constTime) + { + mpz_class s; + bool b; + mcl::gmp::setArray(&b, s, y, yn); + assert(b); + if (isNegative) s = -s; + mul(z, x, s, constTime); + } +}; + +// rw = 1 / w = (-1 - sqrt(-3)) / 2 +template F GLV1T::rw; +template size_t GLV1T::rBitSize; +template mpz_class GLV1T::v0; +template mpz_class GLV1T::v1; +template mpz_class GLV1T::B[2][2]; +template mpz_class GLV1T::r; + struct EcParam { const char *name; const char *p; diff --git a/test/glv_test.cpp b/test/glv_test.cpp index 79d378f..61f2062 100644 --- a/test/glv_test.cpp +++ b/test/glv_test.cpp @@ -88,7 +88,7 @@ void compareLength(const GLV2& lhs) for (int i = 1; i < 1000; i++) { r.setRand(rg); x = r.getMpz(); - GLV1::split(R0, R1, x); + mcl::bn::local::GLV1::split(R0, R1, x); lhs.split(L0, L1, x); size_t R0n = mcl::gmp::getBitSize(R0); @@ -121,8 +121,8 @@ void testGLV1() oldGlv.init(BN::param.r, BN::param.z); } - typedef mcl::bn::local::Param::GLV1 GLV1; - GLV1::init(BN::param.r, BN::param.z, BN::param.isBLS12); + typedef mcl::bn::local::GLV1 GLV1; + GLV1::initForBN(BN::param.r, BN::param.z, BN::param.isBLS12); if (!BN::param.isBLS12) { compareLength(oldGlv); }