diff --git a/include/mcl/fp_tower.hpp b/include/mcl/fp_tower.hpp index 63738a3..8267caa 100644 --- a/include/mcl/fp_tower.hpp +++ b/include/mcl/fp_tower.hpp @@ -121,20 +121,22 @@ public: static void (*add)(FpDblT& z, const FpDblT& x, const FpDblT& y); static void (*sub)(FpDblT& z, const FpDblT& x, const FpDblT& y); static void (*mod)(Fp& z, const FpDblT& xy); + static void (*addPre)(FpDblT& z, const FpDblT& x, const FpDblT& y); + static void (*subPre)(FpDblT& z, const FpDblT& x, const FpDblT& y); static void addC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_add(z.v_, x.v_, y.v_, Fp::op_.p); } static void subC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_sub(z.v_, x.v_, y.v_, Fp::op_.p); } static void modC(Fp& z, const FpDblT& xy) { Fp::op_.fpDbl_mod(z.v_, xy.v_, Fp::op_.p); } + static void addPreC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_addPre(z.v_, x.v_, y.v_); } + static void subPreC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_subPre(z.v_, x.v_, y.v_); } #else static void add(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_add(z.v_, x.v_, y.v_, Fp::op_.p); } static void sub(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_sub(z.v_, x.v_, y.v_, Fp::op_.p); } static void mod(Fp& z, const FpDblT& xy) { Fp::op_.fpDbl_mod(z.v_, xy.v_, Fp::op_.p); } + static void addPre(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_addPre(z.v_, x.v_, y.v_); } + static void subPre(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_subPre(z.v_, x.v_, y.v_); } #endif - static void addPreC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_addPre(z.v_, x.v_, y.v_); } - static void subPreC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_subPre(z.v_, x.v_, y.v_); } static void mulPreC(FpDblT& xy, const Fp& x, const Fp& y) { Fp::op_.fpDbl_mulPre(xy.v_, x.v_, y.v_); } static void sqrPreC(FpDblT& xx, const Fp& x) { Fp::op_.fpDbl_sqrPre(xx.v_, x.v_); } - static void (*addPre)(FpDblT& z, const FpDblT& x, const FpDblT& y); - static void (*subPre)(FpDblT& z, const FpDblT& x, const FpDblT& y); /* mul(z, x, y) = mulPre(xy, x, y) + mod(z, xy) */ @@ -155,17 +157,11 @@ public: if (sub == 0) sub = subC; mod = (void (*)(Fp&, const FpDblT&))op.fpDbl_modA_; if (mod == 0) mod = modC; + addPre = (void (*)(FpDblT&, const FpDblT&, const FpDblT&))op.fpDbl_addPre; + if (addPre == 0) addPre = addPreC; + subPre = (void (*)(FpDblT&, const FpDblT&, const FpDblT&))op.fpDbl_subPre; + if (subPre == 0) subPre = subPreC; #endif - if (op.fpDbl_addPreA_) { - addPre = (void (*)(FpDblT&, const FpDblT&, const FpDblT&))op.fpDbl_addPreA_; - } else { - addPre = addPreC; - } - if (op.fpDbl_subPreA_) { - subPre = (void (*)(FpDblT&, const FpDblT&, const FpDblT&))op.fpDbl_subPreA_; - } else { - subPre = subPreC; - } if (op.fpDbl_mulPreA_) { mulPre = (void (*)(FpDblT&, const Fp&, const Fp&))op.fpDbl_mulPreA_; } else { @@ -185,9 +181,9 @@ public: template void (*FpDblT::add)(FpDblT&, const FpDblT&, const FpDblT&); template void (*FpDblT::sub)(FpDblT&, const FpDblT&, const FpDblT&); template void (*FpDblT::mod)(Fp&, const FpDblT&); -#endif template void (*FpDblT::addPre)(FpDblT&, const FpDblT&, const FpDblT&); template void (*FpDblT::subPre)(FpDblT&, const FpDblT&, const FpDblT&); +#endif template void (*FpDblT::mulPre)(FpDblT&, const Fp&, const Fp&); template void (*FpDblT::sqrPre)(FpDblT&, const Fp&); diff --git a/include/mcl/op.hpp b/include/mcl/op.hpp index d108d1a..afbc8a4 100644 --- a/include/mcl/op.hpp +++ b/include/mcl/op.hpp @@ -193,8 +193,6 @@ struct Op { void2u fp2_sqrA_; void3u fpDbl_addA_; void3u fpDbl_subA_; - void3u fpDbl_addPreA_; - void3u fpDbl_subPreA_; void3u fpDbl_mulPreA_; void2u fpDbl_sqrPreA_; void2u fpDbl_modA_; @@ -282,8 +280,6 @@ struct Op { fp2_sqrA_ = 0; fpDbl_addA_ = 0; fpDbl_subA_ = 0; - fpDbl_addPreA_ = 0; - fpDbl_subPreA_ = 0; fpDbl_mulPreA_ = 0; fpDbl_sqrPreA_ = 0; fpDbl_modA_ = 0; diff --git a/src/fp_generator.hpp b/src/fp_generator.hpp index 80410e0..df058fa 100644 --- a/src/fp_generator.hpp +++ b/src/fp_generator.hpp @@ -287,97 +287,31 @@ private: // code from here setSize(4096); assert((getCurr() & 4095) == 0); - op.fp_addPre = getCurr(); - gen_addSubPre(true, pn_); - align(16); - op.fp_subPre = getCurr(); - gen_addSubPre(false, pn_); - align(16); - op.fp_sub = getCurr(); - op.fp_subA_ = getCurr(); - gen_fp_sub(); - align(16); - op.fp_add = getCurr(); - op.fp_addA_ = getCurr(); - gen_fp_add(); + op.fp_addPre = gen_addSubPre(true, pn_); + op.fp_subPre = gen_addSubPre(false, pn_); + op.fp_subA_ = gen_fp_sub(); + op.fp_addA_ = gen_fp_add(); - align(16); - op.fp_shr1 = getCurr(); - gen_shr1(); + op.fp_shr1 = gen_shr1(); - align(16); - op.fp_negA_ = getCurr(); - gen_fp_neg(); + op.fp_negA_ = gen_fp_neg(); + const void* func = 0; // setup fp_tower op.fp2_mulNF = 0; - if (pn_ <= 4 || (pn_ == 6 && !isFullBit_)) { - align(16); - op.fpDbl_addA_ = getCurr(); - gen_fpDbl_add(); - align(16); - op.fpDbl_subA_ = getCurr(); - gen_fpDbl_sub(); - } - if (op.isFullBit) { - op.fpDbl_addPre = 0; - op.fpDbl_subPre = 0; - } else { - align(16); - op.fpDbl_addPreA_ = getCurr(); - gen_addSubPre(true, pn_ * 2); - align(16); - op.fpDbl_subPreA_ = getCurr(); - gen_addSubPre(false, pn_ * 2); - } - if ((useMulx_ && op.N == 2) || op.N == 3 || op.N == 4 || (useAdx_ && op.N == 6)) { - align(16); - op.fpDbl_mulPreA_ = getCurr(); - if (op.N == 4) { - /* - fpDbl_mulPre is available as C function - this function calls mulPreL directly. - */ - StackFrame sf(this, 3, 10 | UseRDX, 0, false); - mulPre4(gp0, gp1, gp2, sf.t); - sf.close(); // make epilog - L(mulPreL); // called only from asm code - mulPre4(gp0, gp1, gp2, sf.t); - ret(); - } else if (op.N == 6 && useAdx_) { - StackFrame sf(this, 3, 10 | UseRDX, 0, false); - call(mulPreL); - sf.close(); // make epilog - L(mulPreL); // called only from asm code - mulPre6(sf.t); - ret(); - } else { - gen_fpDbl_mulPre(); - } - } - if (op.N == 2 || op.N == 3 || op.N == 4 || (op.N == 6 && !isFullBit_ && useAdx_)) { - align(16); - op.fpDbl_modA_ = getCurr(); - if (op.N == 4) { - StackFrame sf(this, 3, 10 | UseRDX, 0, false); - call(fpDbl_modL); - sf.close(); - L(fpDbl_modL); - gen_fpDbl_mod4(gp0, gp1, sf.t, gp2); - ret(); - } else if (op.N == 6 && !isFullBit_ && useAdx_) { - StackFrame sf(this, 3, 10 | UseRDX, 0, false); - call(fpDbl_modL); - sf.close(); - L(fpDbl_modL); - Pack t = sf.t; - t.append(gp2); - gen_fpDbl_mod6(gp0, gp1, t); - ret(); - } else { - gen_fpDbl_mod(op); - } - } + func = gen_fpDbl_add(); + if (func) op.fpDbl_addA_ = reinterpret_cast(func); + func = gen_fpDbl_sub(); + if (func) op.fpDbl_subA_ = reinterpret_cast(func); + op.fpDbl_addPre = gen_addSubPre(true, pn_ * 2); + op.fpDbl_subPre = gen_addSubPre(false, pn_ * 2); + + func = gen_fpDbl_mulPre(); + if (func) op.fpDbl_mulPreA_ = reinterpret_cast(func); + + func = gen_fpDbl_mod(op); + if (func) op.fpDbl_modA_ = reinterpret_cast(func); + if (op.N > 4) return; align(16); op.fp_mul = getCurr(); // used in toMont/fromMont @@ -389,7 +323,7 @@ private: op.fpDbl_sqrPreA_ = getCurr(); gen_fpDbl_sqrPre(op); } -// if (op.N > 4) return; + if (op.N > 4) return; align(16); op.fp_sqrA_ = getCurr(); gen_sqr(); @@ -425,14 +359,18 @@ private: gen_fp2_mul_xi4(); } } - void gen_addSubPre(bool isAdd, int n) + u3u gen_addSubPre(bool isAdd, int n) { +// if (isFullBit_) return 0; + align(16); + u3u func = getCurr(); StackFrame sf(this, 3); if (isAdd) { gen_raw_add(sf.p[0], sf.p[1], sf.p[2], rax, n); } else { gen_raw_sub(sf.p[0], sf.p[1], sf.p[2], rax, n); } + return func; } /* pz[] = px[] + py[] @@ -702,15 +640,17 @@ private: t2.append(px); // destory after used gen_raw_fp_add6(pz, px, py, 0, t1, t2, false); } - void gen_fp_add() + void3u gen_fp_add() { + align(16); + void3u func = getCurr(); if (pn_ <= 4) { gen_fp_add_le4(); - return; + return func; } if (pn_ == 6) { gen_fp_add6(); - return; + return func; } StackFrame sf(this, 3, 0, pn_ * 8); const Reg64& pz = sf.p[0]; @@ -746,9 +686,12 @@ private: L(".exit"); #endif outLocalLabel(); + return func; } - void gen_fpDbl_add() + const void* gen_fpDbl_add() { + align(16); + const void* func = getCurr(); if (pn_ <= 4) { int tn = pn_ * 2 + (isFullBit_ ? 1 : 0); StackFrame sf(this, 3, tn); @@ -757,6 +700,7 @@ private: const Reg64& py = sf.p[2]; gen_raw_add(pz, px, py, rax, pn_); gen_raw_fp_add(pz + 8 * pn_, px + 8 * pn_, py + 8 * pn_, sf.t, true); + return func; } else if (pn_ == 6 && !isFullBit_) { StackFrame sf(this, 3, 10); const Reg64& pz = sf.p[0]; @@ -768,13 +712,14 @@ private: t2.append(rax); t2.append(py); gen_raw_fp_add6(pz, px, py, pn_ * 8, t1, t2, true); - } else { - assert(0); - exit(1); + return func; } + return 0; } - void gen_fpDbl_sub() + const void* gen_fpDbl_sub() { + align(16); + const void* func = getCurr(); if (pn_ <= 4) { int tn = pn_ * 2; StackFrame sf(this, 3, tn); @@ -783,6 +728,7 @@ private: const Reg64& py = sf.p[2]; gen_raw_sub(pz, px, py, rax, pn_); gen_raw_fp_sub(pz + 8 * pn_, px + 8 * pn_, py + 8 * pn_, sf.t, true); + return func; } else if (pn_ == 6) { StackFrame sf(this, 3, 4); const Reg64& pz = sf.p[0]; @@ -793,10 +739,9 @@ private: t.append(rax); t.append(px); gen_raw_fp_sub6(pz, px, py, pn_ * 8, t, true); - } else { - assert(0); - exit(1); + return func; } + return 0; } void gen_raw_fp_sub6(const Reg64& pz, const Reg64& px, const Reg64& py, int offset, const Pack& t, bool withCarry) { @@ -821,15 +766,17 @@ private: t.append(px); // |t| = 6 gen_raw_fp_sub6(pz, px, py, 0, t, false); } - void gen_fp_sub() + void3u gen_fp_sub() { + align(16); + void3u func = getCurr(); if (pn_ <= 4) { gen_fp_sub_le4(); - return; + return func; } if (pn_ == 6) { gen_fp_sub6(); - return; + return func; } StackFrame sf(this, 3); const Reg64& pz = sf.p[0]; @@ -842,14 +789,20 @@ private: mov(px, (size_t)p_); gen_raw_add(pz, pz, px, rax, pn_); L(exit); + return func; } - void gen_fp_neg() + void2u gen_fp_neg() { + align(16); + void2u func = getCurr(); StackFrame sf(this, 2, UseRDX | pn_); gen_raw_neg(sf.p[0], sf.p[1], sf.t); + return func; } - void gen_shr1() + void2u gen_shr1() { + align(16); + void2u func = getCurr(); const int c = 1; StackFrame sf(this, 2, 1); const Reg64 *t0 = &rax; @@ -865,6 +818,7 @@ private: } shr(*t0, c); mov(ptr [pz + (pn_ - 1) * 8], *t0); + return func; } void gen_mul() { @@ -874,10 +828,10 @@ private: fpDbl_mod_NIST_P192(sf.p[0], rsp, sf.t); } if (pn_ == 3) { - gen_montMul3(p_, rp_); + gen_montMul3(); } else if (pn_ == 4) { - gen_montMul4(p_, rp_); -#if 0 + gen_montMul4(); +#if 1 } else if (pn_ == 6 && useAdx_) { // gen_montMul6(p_, rp_); StackFrame sf(this, 3, 10 | UseRDX, (1 + 12) * 8); @@ -1160,38 +1114,51 @@ private: movq(z, xm0); store_mr(z, Pack(t10, t9, t8, t4)); } - void gen_fpDbl_mod(const mcl::fp::Op& op) + const void* gen_fpDbl_mod(const mcl::fp::Op& op) { + align(16); + const void* func = getCurr(); if (op.primeMode == PM_NIST_P192) { StackFrame sf(this, 2, 6 | UseRDX); fpDbl_mod_NIST_P192(sf.p[0], sf.p[1], sf.t); - return; + return func; } #if 0 if (op.primeMode == PM_NIST_P521) { StackFrame sf(this, 2, 8 | UseRDX); fpDbl_mod_NIST_P521(sf.p[0], sf.p[1], sf.t); - return; + return func; } #endif - switch (pn_) { - case 2: + if (pn_ == 2) { gen_fpDbl_mod2(); - break; - case 3: + return func; + } + if (pn_ == 3) { gen_fpDbl_mod3(); - break; -#if 0 - case 4: - { - StackFrame sf(this, 3, 10 | UseRDX); - gen_fpDbl_mod4(gp0, gp1, sf.t, gp2); - } - break; -#endif - default: - throw cybozu::Exception("gen_fpDbl_mod:not support") << pn_; + return func; } + if (pn_ == 4) { + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + call(fpDbl_modL); + sf.close(); + L(fpDbl_modL); + gen_fpDbl_mod4(gp0, gp1, sf.t, gp2); + ret(); + return func; + } + if (pn_ == 6 && !isFullBit_ && useAdx_) { + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + call(fpDbl_modL); + sf.close(); + L(fpDbl_modL); + Pack t = sf.t; + t.append(gp2); + gen_fpDbl_mod6(gp0, gp1, t); + ret(); + return func; + } + return 0; } void gen_sqr() { @@ -1255,7 +1222,7 @@ private: z[0..3] <- montgomery(x[0..3], y[0..3]) destroy gt0, ..., gt9, xm0, xm1, p2 */ - void gen_montMul4(const uint64_t *p, uint64_t pp) + void gen_montMul4() { StackFrame sf(this, 3, 10 | UseRDX, 0, false); call(fp_mulL); @@ -1277,22 +1244,22 @@ private: L(fp_mulL); movq(xm0, p0); // save p0 - mov(p0, (uint64_t)p); + mov(p0, pL_); movq(xm1, p2); mov(p2, ptr [p2]); - montgomery4_1(pp, t0, t7, t3, t2, t1, p1, p2, p0, t4, t5, t6, t8, t9, true, xm2); + montgomery4_1(rp_, t0, t7, t3, t2, t1, p1, p2, p0, t4, t5, t6, t8, t9, true, xm2); movq(p2, xm1); mov(p2, ptr [p2 + 8]); - montgomery4_1(pp, t1, t0, t7, t3, t2, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); + montgomery4_1(rp_, t1, t0, t7, t3, t2, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); movq(p2, xm1); mov(p2, ptr [p2 + 16]); - montgomery4_1(pp, t2, t1, t0, t7, t3, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); + montgomery4_1(rp_, t2, t1, t0, t7, t3, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); movq(p2, xm1); mov(p2, ptr [p2 + 24]); - montgomery4_1(pp, t3, t2, t1, t0, t7, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); + montgomery4_1(rp_, t3, t2, t1, t0, t7, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2); // [t7:t3:t2:t1:t0] mov(t4, t0); @@ -1315,7 +1282,7 @@ private: z[0..2] <- montgomery(x[0..2], y[0..2]) destroy gt0, ..., gt9, xm0, xm1, p2 */ - void gen_montMul3(const uint64_t *p, uint64_t pp) + void gen_montMul3() { StackFrame sf(this, 3, 10 | UseRDX); const Reg64& p0 = sf.p[0]; @@ -1334,15 +1301,15 @@ private: const Reg64& t9 = sf.t[9]; movq(xm0, p0); // save p0 - mov(t7, (uint64_t)p); + mov(t7, pL_); mov(t9, ptr [p2]); // c3, c2, c1, c0, px, y, p, - montgomery3_1(pp, t0, t3, t2, t1, p1, t9, t7, t4, t5, t6, t8, p0, true); + montgomery3_1(rp_, t0, t3, t2, t1, p1, t9, t7, t4, t5, t6, t8, p0, true); mov(t9, ptr [p2 + 8]); - montgomery3_1(pp, t1, t0, t3, t2, p1, t9, t7, t4, t5, t6, t8, p0, false); + montgomery3_1(rp_, t1, t0, t3, t2, p1, t9, t7, t4, t5, t6, t8, p0, false); mov(t9, ptr [p2 + 16]); - montgomery3_1(pp, t2, t1, t0, t3, p1, t9, t7, t4, t5, t6, t8, p0, false); + montgomery3_1(rp_, t2, t1, t0, t3, p1, t9, t7, t4, t5, t6, t8, p0, false); // [(t3):t2:t1:t0] mov(t4, t0); @@ -1607,6 +1574,7 @@ private: if (useMulx_) { mulPack(pz, px, py, Pack(t2, t1, t0)); +#if 0 // a little slow if (useAdx_) { // [t2:t1:t0] mulPackAdd(pz + 8 * 1, px + 8 * 1, py, t3, Pack(t2, t1, t0)); @@ -1616,6 +1584,7 @@ private: store_mr(pz + 8 * 3, Pack(t4, t3, t2)); return; } +#endif } else { mov(t5, ptr [px]); mov(a, ptr [py + 8 * 0]); @@ -2122,20 +2091,43 @@ private: #endif jmp((void*)op.fpDbl_mulPreA_); } - void gen_fpDbl_mulPre() + const void* gen_fpDbl_mulPre() { - if (useMulx_ && pn_ == 2) { + align(16); + const void* func = getCurr(); + if (pn_ == 2 && useMulx_) { StackFrame sf(this, 3, 5 | UseRDX); mulPre2(sf.p[0], sf.p[1], sf.p[2], sf.t); - return; + return func; } if (pn_ == 3) { StackFrame sf(this, 3, 10 | UseRDX); mulPre3(sf.p[0], sf.p[1], sf.p[2], sf.t); - return; + return func; } - assert(0); - exit(1); + if (pn_ == 4) { + /* + fpDbl_mulPre is available as C function + this function calls mulPreL directly. + */ + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + mulPre4(gp0, gp1, gp2, sf.t); + sf.close(); // make epilog + L(mulPreL); // called only from asm code + mulPre4(gp0, gp1, gp2, sf.t); + ret(); + return func; + } + if (pn_ == 6 && useAdx_) { + StackFrame sf(this, 3, 10 | UseRDX, 0, false); + call(mulPreL); + sf.close(); // make epilog + L(mulPreL); // called only from asm code + mulPre6(sf.t); + ret(); + return func; + } + return 0; } static inline void debug_put_inner(const uint64_t *ptr, int n) { diff --git a/test/bench.hpp b/test/bench.hpp index 65850fa..1ca9e5c 100644 --- a/test/bench.hpp +++ b/test/bench.hpp @@ -8,7 +8,7 @@ void testBench(const G1& P, const G2& Q) pairing(e1, P, Q); Fp12::pow(e2, e1, 12345); const int C = 500; - const int C3 = 10000; + const int C3 = 3000; Fp x, y; x.setHashOf("abc"); y.setHashOf("xyz"); diff --git a/test/bls12_test.cpp b/test/bls12_test.cpp index 501603a..c7e7615 100644 --- a/test/bls12_test.cpp +++ b/test/bls12_test.cpp @@ -697,7 +697,7 @@ if(0){ } // CYBOZU_BENCH_C("subDbl", 10000000, FpDbl::sub, dx, dx, dx); CYBOZU_BENCH_C("mul", 10000000 / n, f, xv, yv, xv); - CYBOZU_BENCH_C("mulPre", 10000000, FpDbl::mulPre, dx, xv[0], yv[0]); + CYBOZU_BENCH_C("mulPre", 100000000, FpDbl::mulPre, dx, xv[0], yv[0]); return 0; #endif return cybozu::test::autoRun.run(argc, argv);