diff --git a/src/fp_generator.hpp b/src/fp_generator.hpp index e4cbba0..5406abb 100644 --- a/src/fp_generator.hpp +++ b/src/fp_generator.hpp @@ -297,29 +297,15 @@ private: align(16); op.fp_negA_ = getCurr(); gen_fp_neg(); - if (op.N > 4) return; -// align(16); -// mulUnit_ = getCurr(); -// gen_mulUnit(); - align(16); - op.fp_mul = getCurr(); // used in toMont/fromMont - op.fp_mulA_ = getCurr(); - gen_mul(fp_mulL); - align(16); - op.fp_sqrA_ = getCurr(); - gen_sqr(); - if (op.primeMode != PM_NIST_P192 && op.N <= 4) { // support general op.N but not fast for op.N > 4 + // setup fp_tower + op.fp2_mulNF = 0; + if (pn_ <= 4 || (pn_ == 6 && !isFullBit_)) { align(16); - op.fp_preInv = getCurr(); - gen_preInv(); + op.fpDbl_addA_ = getCurr(); + gen_fpDbl_add(); } - // setup fp_tower if (op.N > 4) return; - op.fp2_mulNF = 0; - align(16); - op.fpDbl_addA_ = getCurr(); - gen_fpDbl_add(); align(16); op.fpDbl_subA_ = getCurr(); gen_fpDbl_sub(); @@ -375,6 +361,19 @@ private: op.fpDbl_sqrPreA_ = getCurr(); gen_fpDbl_sqrPre(op); } + align(16); + op.fp_mul = getCurr(); // used in toMont/fromMont + op.fp_mulA_ = getCurr(); + gen_mul(fp_mulL); +// if (op.N > 4) return; + align(16); + op.fp_sqrA_ = getCurr(); + gen_sqr(); + if (op.primeMode != PM_NIST_P192 && op.N <= 4) { // support general op.N but not fast for op.N > 4 + align(16); + op.fp_preInv = getCurr(); + gen_preInv(); + } if (op.N == 4 && !isFullBit_) { align(16); op.fp2_addA_ = getCurr(); @@ -639,21 +638,19 @@ private: const Reg64& py = sf.p[2]; gen_raw_fp_sub(pz, px, py, sf.t, false); } - void gen_fp_add6() + /* + add(pz + offset, px + offset, py + offset); + t.size() == 10 + destroy px, py, rax + */ + void gen_raw_fp_add6(const Reg64& pz, const Reg64& px, const Reg64& py, int offset, Pack t, bool withCarry) { - /* - cmov is faster than jmp - */ - StackFrame sf(this, 3, 10); - const Reg64& pz = sf.p[0]; - const Reg64& px = sf.p[1]; - const Reg64& py = sf.p[2]; - Pack t = sf.t.sub(0, 6); - Pack t2 = sf.t.sub(6); + Pack t2 = t.sub(6); + t = t.sub(0, 6); t2.append(rax); t2.append(px); - load_rm(t, px); - add_rm(t, py); + load_rm(t, px + offset); + add_rm(t, py + offset, withCarry); Label exit; if (isFullBit_) { jnc("@f"); @@ -662,14 +659,25 @@ private: jmp(exit); L("@@"); } - mov_rr(t2, t); + mov_rr(t2, t); // destroy px mov(py, (size_t)p_); sub_rm(t2, py); for (int i = 0; i < 6; i++) { cmovnc(t[i], t2[i]); } L(exit); - store_mr(pz, t); + store_mr(pz + offset, t); + } + void gen_fp_add6() + { + /* + cmov is faster than jmp + */ + StackFrame sf(this, 3, 10); + const Reg64& pz = sf.p[0]; + const Reg64& px = sf.p[1]; + const Reg64& py = sf.p[2]; + gen_raw_fp_add6(pz, px, py, 0, sf.t, false); } void gen_fp_add() { @@ -718,14 +726,25 @@ private: } void gen_fpDbl_add() { - assert(pn_ <= 4); - int tn = pn_ * 2 + (isFullBit_ ? 1 : 0); - StackFrame sf(this, 3, tn); - const Reg64& pz = sf.p[0]; - const Reg64& px = sf.p[1]; - const Reg64& py = sf.p[2]; - gen_raw_add(pz, px, py, rax, pn_); - gen_raw_fp_add(pz + 8 * pn_, px + 8 * pn_, py + 8 * pn_, sf.t, true); + if (pn_ <= 4) { + int tn = pn_ * 2 + (isFullBit_ ? 1 : 0); + StackFrame sf(this, 3, tn); + const Reg64& pz = sf.p[0]; + const Reg64& px = sf.p[1]; + const Reg64& py = sf.p[2]; + gen_raw_add(pz, px, py, rax, pn_); + gen_raw_fp_add(pz + 8 * pn_, px + 8 * pn_, py + 8 * pn_, sf.t, true); + } else if (pn_ == 6 && !isFullBit_) { + StackFrame sf(this, 3, 10); + const Reg64& pz = sf.p[0]; + const Reg64& px = sf.p[1]; + const Reg64& py = sf.p[2]; + gen_raw_add(pz, px, py, rax, pn_); + gen_raw_fp_add6(pz, px, py, pn_ * 8, sf.t, true); + } else { + assert(0); + exit(1); + } } void gen_fpDbl_sub() { @@ -814,6 +833,8 @@ private: gen_montMul3(p_, rp_); } else if (pn_ == 4) { gen_montMul4(fp_mulL, p_, rp_); +// } else if (pn_ == 6 && useAdx_) { +// gen_montMul6(fp_mulL, p_, rp_); } else if (pn_ <= 9) { gen_montMulN(p_, rp_, pn_); } else { @@ -1828,20 +1849,16 @@ private: mulPre3(sf.p[0], sf.p[1], sf.p[2], sf.t); return; } - assert(0); -#if 1 if (pn_ == 4) { StackFrame sf(this, 3, 10 | UseRDX); mulPre4(sf.p[0], sf.p[1], sf.p[2], sf.t); return; } -#endif -#if 0 // slow? + // 64clk -> 56clk if (pn_ == 6 && useAdx_) { StackFrame sf(this, 3, 7 | UseRDX); mulPre6(sf.p[0], sf.p[1], sf.p[2], sf.t); } -#endif } static inline void debug_put_inner(const uint64_t *ptr, int n) { diff --git a/test/bench.hpp b/test/bench.hpp index 19c04b6..8693a71 100644 --- a/test/bench.hpp +++ b/test/bench.hpp @@ -52,6 +52,7 @@ void testBench(const G1& P, const G2& Q) xx.b = 3; yy.a = y; yy.b = -5; +#if 1 CYBOZU_BENCH_C("Fp2::add ", C3, Fp2::add, xx, xx, yy); CYBOZU_BENCH_C("Fp2::sub ", C3, Fp2::sub, xx, xx, yy); CYBOZU_BENCH_C("Fp2::neg ", C3, Fp2::neg, xx, xx); @@ -77,6 +78,7 @@ void testBench(const G1& P, const G2& Q) CYBOZU_BENCH_C("GT::mul ", C2, GT::mul, e1, e1, e2); CYBOZU_BENCH_C("GT::sqr ", C2, GT::sqr, e1, e1); CYBOZU_BENCH_C("GT::inv ", C2, GT::inv, e1, e1); +#endif CYBOZU_BENCH_C("pairing ", C, pairing, e1, P, Q); CYBOZU_BENCH_C("millerLoop ", C, millerLoop, e1, P, Q); CYBOZU_BENCH_C("finalExp ", C, finalExp, e1, e1); diff --git a/test/bls12_test.cpp b/test/bls12_test.cpp index 2c4d98c..ea90813 100644 --- a/test/bls12_test.cpp +++ b/test/bls12_test.cpp @@ -654,9 +654,16 @@ CYBOZU_TEST_AUTO(multi) CYBOZU_TEST_AUTO(BLS12_G1mulCofactor) { if (BN::param.cp.curveType != MCL_BLS12_381) return; - } +typedef std::vector FpVec; + +void f(FpVec& zv, const FpVec& xv, const FpVec& yv) +{ + for (size_t i = 0; i < zv.size(); i++) { + Fp::mul(zv[i], xv[i], yv[i]); + } +} int main(int argc, char *argv[]) try { @@ -669,6 +676,22 @@ int main(int argc, char *argv[]) } g_mode = mcl::fp::StrToMode(mode); printf("JIT %d\n", mcl::fp::isEnableJIT()); +#if 0 + initPairing(mcl::BLS12_381); + cybozu::XorShift rg; + const int n = 1; + std::vector xv(n), yv(n), zv(n); + for (int i = 0; i < n; i++) { + xv[i].setByCSPRNG(rg); + yv[i].setByCSPRNG(rg); + } + FpDbl dx; + FpDbl::mulPre(dx, xv[0], xv[0]); + CYBOZU_BENCH_C("addDbl", 10000000, FpDbl::add, dx, dx, dx); +// CYBOZU_BENCH_C("mul", 10000000 / n, f, xv, yv, xv); +// CYBOZU_BENCH_C("mulPre", 10000000, FpDbl::mulPre, dx, xv[0], yv[0]); + return 0; +#endif return cybozu::test::autoRun.run(argc, argv); } catch (std::exception& e) { printf("ERR %s\n", e.what());