diff --git a/include/mcl/fp_generator.hpp b/include/mcl/fp_generator.hpp index d82f911..019b282 100644 --- a/include/mcl/fp_generator.hpp +++ b/include/mcl/fp_generator.hpp @@ -258,10 +258,10 @@ struct FpGenerator : Xbyak::CodeGenerator { op.fpDbl_mulPre = getCurr(); gen_fpDbl_mulPre(); } - if (op.N == 3) { + if (op.N == 3 || op.N == 4) { align(16); op.fpDbl_sqrPre = getCurr(); - gen_fpDbl_sqrPre(); + gen_fpDbl_sqrPre(op); } } void gen_addSubNC(bool isAdd, int n) @@ -1056,7 +1056,7 @@ struct FpGenerator : Xbyak::CodeGenerator { /* py[5..0] <- px[2..0]^2 */ - void sqr3x3(const RegExp& py, const RegExp& px, const Pack& t) + void sqrPre3(const RegExp& py, const RegExp& px, const Pack& t) { const Reg64& a = rax; const Reg64& d = rdx; @@ -1072,38 +1072,58 @@ struct FpGenerator : Xbyak::CodeGenerator { const Reg64& t8 = t[8]; const Reg64& t9 = t[9]; - mov(t9, ptr [px + 8 * 0]); - mov(a, t9); - mul(t9); - mov(ptr [py + 8 * 0], a); - mov(t0, d); - mov(a, ptr [px + 8 * 1]); - mul(t9); - mov(t1, a); - mov(t2, d); - mov(a, ptr [px + 8 * 2]); - mul(t9); - mov(t3, a); - mov(t4, d); + if (useMulx_) { + mov(d, ptr [px + 8 * 0]); + mulx(t0, a, d); + mov(ptr [py + 8 * 0], a); - mov(t5, t2); - mov(t6, t4); + mov(t7, ptr [px + 8 * 1]); + mov(t9, ptr [px + 8 * 2]); + mulx(t2, t1, t7); + mulx(t4, t3, t9); - add(t0, t1); - adc(t5, t3); - adc(t6, 0); // [t6:t5:t0] + mov(t5, t2); + mov(t6, t4); + add(t0, t1); + adc(t5, t3); + adc(t6, 0); // [t6:t5:t0] - mov(t9, ptr [px + 8 * 1]); - mov(a, t9); - mul(t9); - mov(t7, a); - mov(t8, d); - mov(a, ptr [px + 8 * 2]); - mul(t9); - mov(t9, a); - mov(c, d); + mov(d, t7); + mulx(t8, t7, d); + mulx(c, t9, t9); + } else { + mov(t9, ptr [px + 8 * 0]); + mov(a, t9); + mul(t9); + mov(ptr [py + 8 * 0], a); + mov(t0, d); + mov(a, ptr [px + 8 * 1]); + mul(t9); + mov(t1, a); + mov(t2, d); + mov(a, ptr [px + 8 * 2]); + mul(t9); + mov(t3, a); + mov(t4, d); + mov(t5, t2); + mov(t6, t4); + + add(t0, t1); + adc(t5, t3); + adc(t6, 0); // [t6:t5:t0] + + mov(t9, ptr [px + 8 * 1]); + mov(a, t9); + mul(t9); + mov(t7, a); + mov(t8, d); + mov(a, ptr [px + 8 * 2]); + mul(t9); + mov(t9, a); + mov(c, d); + } add(t2, t7); adc(t8, t9); mov(t7, c); @@ -1130,7 +1150,7 @@ struct FpGenerator : Xbyak::CodeGenerator { /* pz[5..0] <- px[2..0] * py[2..0] */ - void mul3x3(const RegExp& pz, const RegExp& px, const RegExp& py, const Pack& t) + void mulPre3(const RegExp& pz, const RegExp& px, const RegExp& py, const Pack& t) { const Reg64& a = rax; const Reg64& d = rdx; @@ -1197,7 +1217,7 @@ struct FpGenerator : Xbyak::CodeGenerator { /* pz[7..0] <- px[3..0] * py[3..0] */ - void mul4x4(const RegExp& pz, const RegExp& px, const RegExp& py, const Pack& t) + void mulPre4(const RegExp& pz, const RegExp& px, const RegExp& py, const Pack& t) { const Reg64& a = rax; const Reg64& d = rdx; @@ -1276,21 +1296,28 @@ struct FpGenerator : Xbyak::CodeGenerator { store_mr(pz + 8 * 3, Pack(t7, t8, t3, t2)); mov(ptr [pz + 8 * 7], d); } - void gen_fpDbl_sqrPre() + void gen_fpDbl_sqrPre(mcl::fp::Op& op) { if (pn_ == 3) { StackFrame sf(this, 2, 10 | UseRDX | UseRCX); - sqr3x3(sf.p[0], sf.p[1], sf.t); + sqrPre3(sf.p[0], sf.p[1], sf.t); + return; } +#ifdef XBYAK64_WIN + mov(r8, rdx); +#else + mov(rdx, rsi); +#endif + jmp((void*)op.fpDbl_mulPre); } void gen_fpDbl_mulPre() { if (pn_ == 3) { StackFrame sf(this, 3, 10 | UseRDX); - mul3x3(sf.p[0], sf.p[1], sf.p[2], sf.t); + mulPre3(sf.p[0], sf.p[1], sf.p[2], sf.t); } else if (pn_ == 4) { StackFrame sf(this, 3, 10 | UseRDX); - mul4x4(sf.p[0], sf.p[1], sf.p[2], sf.t); + mulPre4(sf.p[0], sf.p[1], sf.p[2], sf.t); } } static inline void debug_put_inner(const uint64_t *ptr, int n) diff --git a/sample/rawbench.cpp b/sample/rawbench.cpp index a5b3958..7870bf2 100644 --- a/sample/rawbench.cpp +++ b/sample/rawbench.cpp @@ -29,21 +29,20 @@ void benchRaw(const char *p, mcl::fp::Mode mode) typedef mcl::fp::Unit Unit; const size_t maxN = sizeof(Fp) / sizeof(Unit); const mcl::fp::Op& op = Fp::getOp(); - Fp fx = -1, fy; - mpz_class mp(p); - fy.setMpz(mp / 2); + cybozu::XorShift rg; + Fp fx, fy; + fx.setRand(rg); + fy.setRand(rg); Unit ux[maxN * 2] = {}; Unit uy[maxN * 2] = {}; memcpy(ux, fx.getUnit(), sizeof(Unit) * op.N); + memcpy(ux + op.N, fx.getUnit(), sizeof(Unit) * op.N); memcpy(uy, fy.getUnit(), sizeof(Unit) * op.N); - fy.setMpz(mp - 1); - memcpy(uy + op.N, fy.getUnit(), sizeof(Unit) * op.N); + memcpy(ux + op.N, fx.getUnit(), sizeof(Unit) * op.N); double fp_sqrT, fp_addT, fp_subT, fp_mulT; double fpDbl_addT, fpDbl_subT; double fpDbl_sqrPreT, fpDbl_mulPreT, fpDbl_modT; double fp2_sqrT, fp2_mulT, fp2_mul2T; -// double fp2_mulT, fp2_sqrT; -// double fp_addNCT, fp_subNCT, fpDbl_addNCT,fpDbl_subNCT; CYBOZU_BENCH_T(fp_sqrT, op.fp_sqr, ux, ux); CYBOZU_BENCH_T(fp_addT, op.fp_add, ux, ux, ux); CYBOZU_BENCH_T(fp_subT, op.fp_sub, ux, uy, ux);