diff --git a/include/mcl/fp_tower.hpp b/include/mcl/fp_tower.hpp index 4c25e87..d85e636 100644 --- a/include/mcl/fp_tower.hpp +++ b/include/mcl/fp_tower.hpp @@ -226,12 +226,14 @@ public: static void (*neg)(Fp2T& y, const Fp2T& x); static void (*mul)(Fp2T& z, const Fp2T& x, const Fp2T& y); static void (*sqr)(Fp2T& y, const Fp2T& x); + static void (*mul2)(Fp2T& y, const Fp2T& x); #else static void add(Fp2T& z, const Fp2T& x, const Fp2T& y) { addC(z, x, y); } static void sub(Fp2T& z, const Fp2T& x, const Fp2T& y) { subC(z, x, y); } static void neg(Fp2T& y, const Fp2T& x) { negC(y, x); } static void mul(Fp2T& z, const Fp2T& x, const Fp2T& y) { mulC(z, x, y); } static void sqr(Fp2T& y, const Fp2T& x) { sqrC(y, x); } + static void mul2(Fp2T& y, const Fp2T& x) { mul2C(y, x); } #endif static void (*mul_xi)(Fp2T& y, const Fp2T& x); static void addPre(Fp2T& z, const Fp2T& x, const Fp2T& y) { Fp::addPre(z.a, x.a, y.a); Fp::addPre(z.b, x.b, y.b); } @@ -386,6 +388,8 @@ public: if (mul == 0) mul = mulC; sqr = fp::func_ptr_cast(op.fp2_sqrA_); if (sqr == 0) sqr = sqrC; + mul2 = fp::func_ptr_cast(op.fp2_mul2A_); + if (mul2 == 0) mul2 = mul2C; mul_xi = fp::func_ptr_cast(op.fp2_mul_xiA_); #endif op.fp2_inv = fp2_invW; @@ -483,6 +487,11 @@ private: Fp::neg(y.a, x.a); Fp::neg(y.b, x.b); } + static void mul2C(Fp2T& y, const Fp2T& x) + { + Fp::mul2(y.a, x.a); + Fp::mul2(y.b, x.b); + } #if 0 /* x = a + bi, y = c + di, i^2 = -1 @@ -607,6 +616,7 @@ template void (*Fp2T::sub)(Fp2T& z, const Fp2T& x, const Fp2T& y template void (*Fp2T::neg)(Fp2T& y, const Fp2T& x); template void (*Fp2T::mul)(Fp2T& z, const Fp2T& x, const Fp2T& y); template void (*Fp2T::sqr)(Fp2T& y, const Fp2T& x); +template void (*Fp2T::mul2)(Fp2T& y, const Fp2T& x); #endif template void (*Fp2T::mul_xi)(Fp2T& y, const Fp2T& x); diff --git a/include/mcl/op.hpp b/include/mcl/op.hpp index ea5c379..fc61798 100644 --- a/include/mcl/op.hpp +++ b/include/mcl/op.hpp @@ -220,6 +220,7 @@ struct Op { void2u fp2_negA_; void3u fp2_mulA_; void2u fp2_sqrA_; + void2u fp2_mul2A_; void3u fpDbl_addA_; void3u fpDbl_subA_; void2u fpDbl_modA_; @@ -307,6 +308,7 @@ struct Op { fp2_negA_ = 0; fp2_mulA_ = 0; fp2_sqrA_ = 0; + fp2_mul2A_ = 0; fpDbl_addA_ = 0; fpDbl_subA_ = 0; fpDbl_modA_ = 0; diff --git a/src/fp_generator.hpp b/src/fp_generator.hpp index 684ecb3..64e6ebe 100644 --- a/src/fp_generator.hpp +++ b/src/fp_generator.hpp @@ -444,6 +444,10 @@ private: op.fp2_negA_ = gen_fp2_neg(); setFuncInfo(prof_, suf, "2_neg", op.fp2_negA_, getCurr()); + align(16); + op.fp2_mul2A_ = gen_fp2_mul2(); + setFuncInfo(prof_, suf, "2_mul2", op.fp2_mul2A_, getCurr()); + op.fp2_mulNF = 0; align(16); op.fp2Dbl_mulPreA_ = gen_fp2Dbl_mulPre(); @@ -919,33 +923,56 @@ private: mov(ptr [pz + (pn_ - 1) * 8], *t0); return func; } + // x = x << 1 + void shl1(const Pack& x) + { + for (int i = x.size() - 1; i > 0; i--) { + shld(x[i], x[i - 1], 1); + } + shl(x[0], 1); + } + /* + y = (x >= p[]) x - p[] : x + */ + void sub_mod(const Pack& y, const Pack& x, const RegExp& p) + { + mov_rr(y, x); + sub_rm(y, p); + cmovc_rr(y, x); + } void2u gen_mul2() { - if (isFullBit_) return 0; - if (!(pn_ == 4 || pn_ == 6)) return 0; + if (isFullBit_ || pn_ > 6) return 0; void2u func = getCurr(); - const int n = pn_ * 2 - 2; + const int n = pn_ * 2 - 1; StackFrame sf(this, 2, n); Pack x = sf.t.sub(0, pn_); load_rm(x, sf.p[1]); -#if 0 - add_rr(x, x); -#else - for (int i = pn_ - 1; i > 0; i--) { - shld(x[i], x[i - 1], 1); - } - shl(x[0], 1); -#endif + shl1(x); Pack t = sf.t.sub(pn_, n - pn_); t.append(sf.p[1]); - t.append(rax); // destroy last - mov_rr(t, x); lea(rax, ptr[rip + pL_]); - sub_rm(t, rax); - cmovc_rr(t, x); + sub_mod(t, x, rax); store_mr(sf.p[0], t); return func; } + void2u gen_fp2_mul2() + { + if (isFullBit_ || pn_ > 6) return 0; + void2u func = getCurr(); + const int n = pn_ * 2; + StackFrame sf(this, 2, n); + Pack x = sf.t.sub(0, pn_); + Pack t = sf.t.sub(pn_, pn_); + lea(rax, ptr[rip + pL_]); + for (int i = 0; i < 2; i++) { + load_rm(x, sf.p[1] + FpByte_ * i); + shl1(x); + sub_mod(t, x, rax); + store_mr(sf.p[0] + FpByte_ * i, t); + } + return func; + } void3u gen_mul() { void3u func = getCurr(); diff --git a/src/fp_static_code.hpp b/src/fp_static_code.hpp index 705e46e..d396526 100644 --- a/src/fp_static_code.hpp +++ b/src/fp_static_code.hpp @@ -36,6 +36,7 @@ void mclx_Fp2_sub(Unit*, const Unit*, const Unit*); void mclx_Fp2_neg(Unit*, const Unit*); void mclx_Fp2_mul(Unit*, const Unit*, const Unit*); void mclx_Fp2_sqr(Unit*, const Unit*); +void mclx_Fp2_mul2(Unit*, const Unit*); void mclx_Fp2_mul_xi(Unit*, const Unit*); Unit mclx_Fr_addPre(Unit*, const Unit*, const Unit*); @@ -76,6 +77,7 @@ void setStaticCode(mcl::fp::Op& op) op.fp2_mulNF = 0; op.fp2_mulA_ = mclx_Fp2_mul; op.fp2_sqrA_ = mclx_Fp2_sqr; + op.fp2_mul2A_ = mclx_Fp2_mul2; op.fp2_mul_xiA_ = mclx_Fp2_mul_xi; op.fp_preInv = mclx_Fp_preInv; } else { diff --git a/test/bench.hpp b/test/bench.hpp index 7359181..660aa13 100644 --- a/test/bench.hpp +++ b/test/bench.hpp @@ -135,6 +135,7 @@ void testBench(const G1& P, const G2& Q) CYBOZU_BENCH_C("Fp2::add ", C3, Fp2::add, xx, xx, yy); CYBOZU_BENCH_C("Fp2::sub ", C3, Fp2::sub, xx, xx, yy); CYBOZU_BENCH_C("Fp2::neg ", C3, Fp2::neg, xx, xx); + CYBOZU_BENCH_C("Fp2::mul2 ", C3, Fp2::mul2, xx, xx); CYBOZU_BENCH_C("Fp2::mul ", C3, Fp2::mul, xx, xx, yy); CYBOZU_BENCH_C("Fp2::mul_xi ", C3, Fp2::mul_xi, xx, xx); CYBOZU_BENCH_C("Fp2::sqr ", C3, Fp2::sqr, xx, xx); diff --git a/test/common_test.hpp b/test/common_test.hpp index 54d3bed..b35b9bc 100644 --- a/test/common_test.hpp +++ b/test/common_test.hpp @@ -103,8 +103,32 @@ void testMulCT(const G& P) } } +void testMul2() +{ + puts("testMul2"); + cybozu::XorShift rg; + Fp x1, x2; + x1.setByCSPRNG(rg); + x2 = x1; + for (int i = 0; i < 100; i++) { + Fp::mul2(x1, x1); + x2 += x2; + CYBOZU_TEST_EQUAL(x1, x2); + } + Fp2 y1; + y1.a = x1; + y1.b = -x1; + Fp2 y2 = y1; + for (int i = 0; i < 100; i++) { + Fp2::mul2(y1, y1); + y2 += y2; + CYBOZU_TEST_EQUAL(y1, y2); + } +} + void testCommon(const G1& P, const G2& Q) { + testMul2(); puts("G1"); testMulVec(P); puts("G2");