diff --git a/include/mcl/fp_tower.hpp b/include/mcl/fp_tower.hpp index c311aae..f79cba7 100644 --- a/include/mcl/fp_tower.hpp +++ b/include/mcl/fp_tower.hpp @@ -16,6 +16,7 @@ class FpDblT : public fp::Serializable > { Unit v_[Fp::maxSize * 2]; public: static size_t getUnitSize() { return Fp::op_.N * 2; } + const fp::Unit *getUnit() const { return v_; } void dump() const { const size_t n = getUnitSize(); @@ -662,25 +663,26 @@ struct Fp2DblT { FpDbl::neg(y.a, x.a); FpDbl::neg(y.b, x.b); } - static void mul_xi(Fp2DblT& y, const Fp2DblT& x) + static void mul_xi_1C(Fp2DblT& y, const Fp2DblT& x) + { + FpDbl t; + FpDbl::add(t, x.a, x.b); + FpDbl::sub(y.a, x.a, x.b); + y.b = t; + } + static void mul_xi_genericC(Fp2DblT& y, const Fp2DblT& x) { const uint32_t xi_a = Fp2::get_xi_a(); - if (xi_a == 1) { - FpDbl t; - FpDbl::add(t, x.a, x.b); - FpDbl::sub(y.a, x.a, x.b); - y.b = t; - } else { - FpDbl t; - FpDbl::mulUnit(t, x.a, xi_a); - FpDbl::sub(t, t, x.b); - FpDbl::mulUnit(y.b, x.b, xi_a); - FpDbl::add(y.b, y.b, x.a); - y.a = t; - } + FpDbl t; + FpDbl::mulUnit(t, x.a, xi_a); + FpDbl::sub(t, t, x.b); + FpDbl::mulUnit(y.b, x.b, xi_a); + FpDbl::add(y.b, y.b, x.a); + y.a = t; } static void (*mulPre)(Fp2DblT&, const Fp2&, const Fp2&); static void (*sqrPre)(Fp2DblT&, const Fp2&); + static void (*mul_xi)(Fp2DblT&, const Fp2DblT&); static void mod(Fp2& y, const Fp2DblT& x) { FpDbl::mod(y.a, x.a); @@ -716,6 +718,18 @@ struct Fp2DblT { sqrPre = fp2Dbl_sqrPreW; } } + const uint32_t xi_a = Fp2::get_xi_a(); + switch (xi_a) { + case 1: + mul_xi = mul_xi_1C; + if (op.fp2Dbl_mul_xiA_) { + mul_xi = fp::func_ptr_cast(op.fp2Dbl_mul_xiA_); + } + break; + default: + mul_xi = mul_xi_genericC; + break; + } } /* Fp2Dbl::mulPre by FpDblT @@ -770,6 +784,7 @@ struct Fp2DblT { template void (*Fp2DblT::mulPre)(Fp2DblT&, const Fp2T&, const Fp2T&); template void (*Fp2DblT::sqrPre)(Fp2DblT&, const Fp2T&); +template void (*Fp2DblT::mul_xi)(Fp2DblT&, const Fp2DblT&); template Fp2T Fp2T::g[Fp2T::gN]; template Fp2T Fp2T::g2[Fp2T::gN]; diff --git a/include/mcl/op.hpp b/include/mcl/op.hpp index 7e314ae..76cf7f0 100644 --- a/include/mcl/op.hpp +++ b/include/mcl/op.hpp @@ -226,6 +226,7 @@ struct Op { void2u fpDbl_modA_; void3u fp2Dbl_mulPreA_; void2u fp2Dbl_sqrPreA_; + void2u fp2Dbl_mul_xiA_; size_t maxN; size_t N; size_t bitSize; @@ -314,6 +315,7 @@ struct Op { fpDbl_modA_ = 0; fp2Dbl_mulPreA_ = 0; fp2Dbl_sqrPreA_ = 0; + fp2Dbl_mul_xiA_ = 0; maxN = 0; N = 0; bitSize = 0; diff --git a/src/fp_generator.hpp b/src/fp_generator.hpp index f5e34cf..cfbb53f 100644 --- a/src/fp_generator.hpp +++ b/src/fp_generator.hpp @@ -457,6 +457,10 @@ private: op.fp2Dbl_sqrPreA_ = gen_fp2Dbl_sqrPre(); if (op.fp2Dbl_sqrPreA_) setFuncInfo(prof_, suf, "2Dbl_sqrPre", op.fp2Dbl_sqrPreA_, getCurr()); + align(16); + op.fp2Dbl_mul_xiA_ = gen_fp2Dbl_mul_xi(); + if (op.fp2Dbl_mul_xiA_) setFuncInfo(prof_, suf, "2Dbl_mul_xi", op.fp2Dbl_mul_xiA_, getCurr()); + align(16); op.fp2_mulA_ = gen_fp2_mul(); setFuncInfo(prof_, suf, "2_mul", op.fp2_mulA_, getCurr()); @@ -3173,6 +3177,13 @@ private: } } } + // y[i] &= t + void andPack(const Pack& y, const Reg64& t) + { + for (int i = 0; i < (int)y.size(); i++) { + and_(y[i], t); + } + } /* [rdx:x:t0] <- py[1:0] * x destroy x, t @@ -3647,6 +3658,55 @@ private: call(mulPreL); return func; } + void2u gen_fp2Dbl_mul_xi() + { + if (isFullBit_) return 0; + if (op_->xi_a != 1) return 0; + void2u func = getCurr(); + // y = (x.a - x.b, x.a + x.b) + StackFrame sf(this, 2, pn_ * 2, FpByte_ * 2); + Pack t1 = sf.t.sub(0, pn_); + Pack t2 = sf.t.sub(pn_, pn_); + const RegExp& ya = sf.p[0]; + const RegExp& yb = sf.p[0] + FpByte_ * 2; + const RegExp& xa = sf.p[1]; + const RegExp& xb = sf.p[1] + FpByte_ * 2; + // [rsp] = x.a + x.b + for (int i = 0; i < pn_ * 2; i++) { + mov(rax, ptr[xa + i * 8]); + if (i == 0) { + add(rax, ptr[xb + i * 8]); + } else { + adc(rax, ptr[xb + i * 8]); + } + mov(ptr[rsp + i * 8], rax); + } + // low : x.a = x.a - x.b + load_rm(t1, xa); + sub_rm(t1, xb); + store_mr(ya, t1); + // high : x.a = (x.a - x.b) % p + load_rm(t1, xa + FpByte_); + sub_rm(t1, xb + FpByte_, true); + lea(rax, ptr[rip + pL_]); + load_rm(t2, rax); // t2 = p + sbb(rax, rax); + andPack(t2, rax); + add_rr(t1, t2); // mod p + store_mr(ya + FpByte_, t1); + + // low : y.b = [rsp] + for (int i = 0; i < pn_; i++) { + mov(rax, ptr[rsp + i * 8]); + mov(ptr[yb + i * 8], rax); + } + // high : y.b = (x.a + x.b) % p + load_rm(t1, rsp + FpByte_); + lea(rax, ptr[rip + pL_]); + sub_p_mod(t2, t1, rax); + store_mr(yb + FpByte_, t2); + return func; + } void gen_fp2_add4() { assert(!isFullBit_); diff --git a/src/fp_static_code.hpp b/src/fp_static_code.hpp index d396526..ef27f1b 100644 --- a/src/fp_static_code.hpp +++ b/src/fp_static_code.hpp @@ -38,6 +38,7 @@ void mclx_Fp2_mul(Unit*, const Unit*, const Unit*); void mclx_Fp2_sqr(Unit*, const Unit*); void mclx_Fp2_mul2(Unit*, const Unit*); void mclx_Fp2_mul_xi(Unit*, const Unit*); +void mclx_Fp2Dbl_mul_xi(Unit*, const Unit*); Unit mclx_Fr_addPre(Unit*, const Unit*, const Unit*); Unit mclx_Fr_subPre(Unit*, const Unit*, const Unit*); @@ -79,6 +80,7 @@ void setStaticCode(mcl::fp::Op& op) op.fp2_sqrA_ = mclx_Fp2_sqr; op.fp2_mul2A_ = mclx_Fp2_mul2; op.fp2_mul_xiA_ = mclx_Fp2_mul_xi; + op.fp2Dbl_mul_xiA_ = mclx_Fp2Dbl_mul_xi; op.fp_preInv = mclx_Fp_preInv; } else { // Fr, sizeof(Fr) = 32 diff --git a/test/common_test.hpp b/test/common_test.hpp index afe23ee..0781ebb 100644 --- a/test/common_test.hpp +++ b/test/common_test.hpp @@ -161,8 +161,31 @@ void testABCD() } } +void testFp2Dbl_mul_xi1() +{ + if (Fp2::get_xi_a() != 1) return; + puts("testFp2Dbl_mul_xi1"); + cybozu::XorShift rg; + for (int i = 0; i < 100; i++) { + Fp a1, a2; + a1.setByCSPRNG(rg); + a2.setByCSPRNG(rg); + Fp2Dbl x; + FpDbl::mulPre(x.a, a1, a2); + a1.setByCSPRNG(rg); + a2.setByCSPRNG(rg); + FpDbl::mulPre(x.b, a1, a2); + Fp2Dbl ok; + Fp2Dbl::mul_xi_1C(x, x); + Fp2Dbl::mul_xi(x, x); + CYBOZU_TEST_EQUAL_ARRAY(ok.a.getUnit(), x.a.getUnit(), ok.a.getUnitSize()); + CYBOZU_TEST_EQUAL_ARRAY(ok.b.getUnit(), x.b.getUnit(), ok.b.getUnitSize()); + } +} + void testCommon(const G1& P, const G2& Q) { + testFp2Dbl_mul_xi1(); testABCD(); testMul2(); puts("G1");