From 67672f4050347afc834f966633daa3134e60c926 Mon Sep 17 00:00:00 2001 From: MITSUNARI Shigeo Date: Tue, 2 Jun 2020 16:07:05 +0900 Subject: [PATCH] refactor mul --- include/mcl/bn.hpp | 73 +++++++++++++++++++++++++--------------------- include/mcl/ec.hpp | 26 ++++++++++++----- test/bn_test.cpp | 2 +- test/glv_test.cpp | 19 +++++++----- 4 files changed, 71 insertions(+), 49 deletions(-) diff --git a/include/mcl/bn.hpp b/include/mcl/bn.hpp index 2927ec2..2154051 100644 --- a/include/mcl/bn.hpp +++ b/include/mcl/bn.hpp @@ -701,22 +701,22 @@ struct GLV1 : mcl::GLV1T { GLV method for G2 and GT on BN/BLS12 */ template -struct GLV2 { +struct GLV2T { + typedef GLV2T<_Fr> GLV2; typedef _Fr Fr; typedef mcl::FixedArray NafArray; - size_t rBitSize; - mpz_class B[4][4]; - mpz_class v[4]; - mpz_class z; - mpz_class abs_z; - bool isBLS12; - GLV2() : rBitSize(0), isBLS12(false) {} - void init(const mpz_class& z, bool isBLS12 = false) + static size_t rBitSize; + static mpz_class B[4][4]; + static mpz_class v[4]; + static mpz_class z; + static mpz_class abs_z; + static bool isBLS12; + static void init(const mpz_class& z, bool isBLS12 = false) { const mpz_class& r = Fr::getOp().mp; - this->z = z; - this->abs_z = z < 0 ? -z : z; - this->isBLS12 = isBLS12; + GLV2::z = z; + GLV2::abs_z = z < 0 ? -z : z; + GLV2::isBLS12 = isBLS12; rBitSize = Fr::getOp().bitSize; rBitSize = (rBitSize + mcl::fp::UnitBitSize - 1) & ~(mcl::fp::UnitBitSize - 1);// a little better size mpz_class z2p1 = z * 2 + 1; @@ -767,7 +767,7 @@ struct GLV2 { /* u[] = [x, 0, 0, 0] - v[] * x * B */ - void split(mpz_class u[4], const mpz_class& x) const + static void split(mpz_class u[4], const mpz_class& x) { if (isBLS12) { /* @@ -802,7 +802,7 @@ struct GLV2 { } } template - void mul(T& Q, const T& P, const mpz_class& x, bool constTime = false) const + static void mul(T& Q, const T& P, const mpz_class& x, bool constTime = false) { mulVecNGLV(Q, &P, &x, 1, constTime); } @@ -812,10 +812,10 @@ struct GLV2 { Frobenius(Q, P); } template - size_t mulVecNGLV(T& z, const T *xVec, const mpz_class *yVec, size_t n, bool constTime) const + static size_t mulVecNGLV(T& z, const T *xVec, const mpz_class *yVec, size_t n, bool constTime) { if (n == 1 && constTime) { - ec::local::mul1CT, T, Fr, 4, 4>(*this, z, *xVec, *yVec); + ec::local::mul1CT(z, *xVec, *yVec); return 1; } const mpz_class& r = Fr::getOp().mp; @@ -847,14 +847,14 @@ struct GLV2 { T P2; T::dbl(P2, xVec[i]); tbl[i][0][0] = xVec[i]; - Frobenius(tbl[i][1][0], tbl[i][0][0]); - Frobenius(tbl[i][2][0], tbl[i][1][0]); - Frobenius(tbl[i][3][0], tbl[i][2][0]); + for (int k = 1; k < w; k++) { + mulLambda(tbl[i][k][0], tbl[i][k - 1][0]); + } for (size_t j = 1; j < tblSize; j++) { T::add(tbl[i][0][j], tbl[i][0][j - 1], P2); - Frobenius(tbl[i][1][j], tbl[i][0][j]); - Frobenius(tbl[i][2][j], tbl[i][1][j]); - Frobenius(tbl[i][3][j], tbl[i][2][j]); + for (int k = 1; k < w; k++) { + mulLambda(tbl[i][k][j], tbl[i][k - 1][j]); + } } } z.clear(); @@ -862,16 +862,15 @@ struct GLV2 { const size_t bit = maxBit - 1 - i; T::dbl(z, z); for (size_t j = 0; j < n; j++) { - mcl::local::addTbl(z, tbl[j][0], naf[j][0], bit); - mcl::local::addTbl(z, tbl[j][1], naf[j][1], bit); - mcl::local::addTbl(z, tbl[j][2], naf[j][2], bit); - mcl::local::addTbl(z, tbl[j][3], naf[j][3], bit); + for (int k = 0; k < w; k++) { + mcl::local::addTbl(z, tbl[j][k], naf[j][k], bit); + } } } return n; } - void pow(Fp12& z, const Fp12& x, const mpz_class& y, bool constTime = false) const + static void pow(Fp12& z, const Fp12& x, const mpz_class& y, bool constTime = false) { typedef GroupMtoA AG; // as additive group AG& _z = static_cast(z); @@ -880,6 +879,13 @@ struct GLV2 { } }; +template size_t GLV2T::rBitSize = 0; +template mpz_class GLV2T::B[4][4]; +template mpz_class GLV2T::v[4]; +template mpz_class GLV2T::z; +template mpz_class GLV2T::abs_z; +template bool GLV2T::isBLS12 = false; + struct Param { CurveParam cp; mpz_class z; @@ -889,7 +895,6 @@ struct Param { mpz_class p; mpz_class r; local::MapTo mapTo; - local::GLV2 glv2; // for G2 Frobenius Fp2 g2; Fp2 g3; @@ -1001,7 +1006,7 @@ struct Param { mapTo.init(2 * p - r, z, cp.curveType); } GLV1::initForBN(z, isBLS12, cp.curveType); - glv2.init(z, isBLS12); + GLV2T::init(z, isBLS12); basePoint.clear(); *pb = true; } @@ -1049,6 +1054,8 @@ static local::Param& nonConstParam = local::StaticVar<>::param; namespace local { +typedef GLV2T GLV2; + inline void mulArrayGLV2(G2& z, const G2& x, const mcl::fp::Unit *y, size_t yn, bool isNegative, bool constTime) { mpz_class s; @@ -1056,7 +1063,7 @@ inline void mulArrayGLV2(G2& z, const G2& x, const mcl::fp::Unit *y, size_t yn, mcl::gmp::setArray(&b, s, y, yn); assert(b); if (isNegative) s = -s; - BN::param.glv2.mul(z, x, s, constTime); + GLV2::mul(z, x, s, constTime); } inline void powArrayGLV2(Fp12& z, const Fp12& x, const mcl::fp::Unit *y, size_t yn, bool isNegative, bool constTime) { @@ -1065,12 +1072,12 @@ inline void powArrayGLV2(Fp12& z, const Fp12& x, const mcl::fp::Unit *y, size_t mcl::gmp::setArray(&b, s, y, yn); assert(b); if (isNegative) s = -s; - BN::param.glv2.pow(z, x, s, constTime); + GLV2::pow(z, x, s, constTime); } inline size_t mulVecNGLV2(G2& z, const G2 *xVec, const mpz_class *yVec, size_t n, bool constTime) { - return BN::param.glv2.mulVecNGLV(z, xVec, yVec, n, constTime); + return GLV2::mulVecNGLV(z, xVec, yVec, n, constTime); } inline size_t powVecNGLV2(Fp12& z, const Fp12 *xVec, const mpz_class *yVec, size_t n, bool constTime) @@ -1078,7 +1085,7 @@ inline size_t powVecNGLV2(Fp12& z, const Fp12 *xVec, const mpz_class *yVec, size typedef GroupMtoA AG; // as additive group AG& _z = static_cast(z); const AG *_xVec = static_cast(xVec); - return BN::param.glv2.mulVecNGLV(_z, _xVec, yVec, n, constTime); + return GLV2::mulVecNGLV(_z, _xVec, yVec, n, constTime); } /* diff --git a/include/mcl/ec.hpp b/include/mcl/ec.hpp index 3678993..33e02df 100644 --- a/include/mcl/ec.hpp +++ b/include/mcl/ec.hpp @@ -72,19 +72,25 @@ bool get_a_flag(const mcl::Fp2T& x) return get_a_flag(x.b); // x = a + bi } +/* + Q = x P + splitN = 2(G1) or 4(G2) + w : window size +*/ template -void mul1CT(const GLV& glv, G& Q, const G& P, const mpz_class& x) +void mul1CT(G& Q, const G& P, const mpz_class& x) { const mpz_class& r = F::getOp().mp; const size_t tblSize = 1 << w; G tbl[splitN][tblSize]; bool negTbl[splitN]; mpz_class u[splitN]; - mpz_class y = x % r; + mpz_class y; + F::getOp().modp.modp(y, x); if (y < 0) { y += r; } - glv.split(u, y); + GLV::split(u, y); for (int i = 0; i < splitN; i++) { if (u[i] < 0) { gmp::neg(u[i], u[i]); @@ -1532,10 +1538,12 @@ public: Q.z = P.z; } /* - x = a + b * lambda mod r + x = u[0] + u[1] * lambda mod r */ - static void split(mpz_class& a, mpz_class& b, const mpz_class& x) + static void split(mpz_class u[2], const mpz_class& x) { + mpz_class& a = u[0]; + mpz_class& b = u[1]; mpz_class t; t = (x * v0) >> rBitSize; b = (x * v1) >> rBitSize; @@ -1546,8 +1554,12 @@ public: { mulVecNGLV(Q, &P, &x, 1, constTime); } - static inline size_t mulVecNGLV(Ec& z, const Ec *xVec, const mpz_class *yVec, size_t n, bool /*constTime*/ = false) + static inline size_t mulVecNGLV(Ec& z, const Ec *xVec, const mpz_class *yVec, size_t n, bool constTime) { + if (n == 1 && constTime) { + ec::local::mul1CT, Ec, _Fr, 2, 4>(z, *xVec, *yVec); + return 1; + } const size_t N = mcl::fp::maxMulVecNGLV; if (n > N) n = N; const int w = 5; @@ -1565,7 +1577,7 @@ public: if (y < 0) { y += r; } - split(u[0], u[1], y); + split(u, y); for (int j = 0; j < 2; j++) { gmp::getNAFwidth(&b, naf[i][j], u[j], w); diff --git a/test/bn_test.cpp b/test/bn_test.cpp index e6139b3..1a503c5 100644 --- a/test/bn_test.cpp +++ b/test/bn_test.cpp @@ -211,7 +211,7 @@ void testFp12pow(const G1& P, const G2& Q) x.setRand(rg); mpz_class xm = x.getMpz(); Fp12::pow(e1, e, xm); - BN::param.glv2.pow(e2, e, xm); + local::GLV2::pow(e2, e, xm); CYBOZU_TEST_EQUAL(e1, e2); } } diff --git a/test/glv_test.cpp b/test/glv_test.cpp index 78bb821..59bdcdd 100644 --- a/test/glv_test.cpp +++ b/test/glv_test.cpp @@ -83,12 +83,15 @@ void compareLength(const GLV2& lhs) int lt = 0; int eq = 0; int gt = 0; - mpz_class R0, R1, L0, L1, x; + mpz_class R[2]; + mpz_class L0, L1, x; + mpz_class& R0 = R[0]; + mpz_class& R1 = R[1]; Fr r; for (int i = 1; i < 1000; i++) { r.setRand(rg); x = r.getMpz(); - mcl::bn::local::GLV1::split(R0, R1, x); + mcl::bn::local::GLV1::split(R,x); lhs.split(L0, L1, x); size_t R0n = mcl::gmp::getBitSize(R0); @@ -162,33 +165,33 @@ void testGLV1() */ void testGLV2() { + typedef local::GLV2 GLV2; G2 Q0, Q1, Q2; mpz_class z = BN::param.z; mpz_class r = BN::param.r; - mcl::bn::local::GLV2 glv2; - glv2.init(z, BN::param.isBLS12); + GLV2::init(z, BN::param.isBLS12); mpz_class n; cybozu::XorShift rg; mapToG2(Q0, 1); for (int i = -10; i < 10; i++) { n = i; G2::mulGeneric(Q1, Q0, n); - glv2.mul(Q2, Q0, n); + GLV2::mul(Q2, Q0, n); CYBOZU_TEST_EQUAL(Q1, Q2); } for (int i = 1; i < 100; i++) { - mcl::gmp::getRand(n, glv2.rBitSize, rg); + mcl::gmp::getRand(n, GLV2::rBitSize, rg); n %= r; n -= r/2; mapToG2(Q0, i); G2::mulGeneric(Q1, Q0, n); - glv2.mul(Q2, Q0, n); + GLV2::mul(Q2, Q0, n); CYBOZU_TEST_EQUAL(Q1, Q2); } Fr s; mapToG2(Q0, 123); CYBOZU_BENCH_C("G2::mul", 1000, Q2 = Q0; s.setRand(rg); G2::mulGeneric, Q2, Q1, s.getMpz()); - CYBOZU_BENCH_C("G2::glv", 1000, Q1 = Q0; s.setRand(rg); glv2.mul, Q2, Q1, s.getMpz()); + CYBOZU_BENCH_C("G2::glv", 1000, Q1 = Q0; s.setRand(rg); GLV2::mul, Q2, Q1, s.getMpz()); } void testGT()