add Fp::sqrPre for bls12

dev
MITSUNARI Shigeo 6 years ago
parent 832c615796
commit e5894642b1
  1. 104
      src/fp_generator.hpp
  2. 1
      test/bls12_test.cpp

@ -312,18 +312,15 @@ private:
func = gen_fpDbl_mod(op); func = gen_fpDbl_mod(op);
if (func) op.fpDbl_modA_ = reinterpret_cast<void2u>(func); if (func) op.fpDbl_modA_ = reinterpret_cast<void2u>(func);
func = gen_fpDbl_sqrPre(op);
if (func) op.fpDbl_sqrPreA_ = reinterpret_cast<void2u>(func);
if (op.N > 4) return; if (op.N > 4) return;
align(16); align(16);
op.fp_mul = getCurr<void4u>(); // used in toMont/fromMont op.fp_mul = getCurr<void4u>(); // used in toMont/fromMont
op.fp_mulA_ = getCurr<void3u>(); op.fp_mulA_ = getCurr<void3u>();
gen_mul(); gen_mul();
if (op.N > 4) return; if (op.N > 4) return;
if ((useMulx_ && op.N == 2) || op.N == 3 || op.N == 4) {
align(16);
op.fpDbl_sqrPreA_ = getCurr<void2u>();
gen_fpDbl_sqrPre(op);
}
if (op.N > 4) return;
align(16); align(16);
op.fp_sqrA_ = getCurr<void2u>(); op.fp_sqrA_ = getCurr<void2u>();
gen_sqr(); gen_sqr();
@ -1114,7 +1111,7 @@ private:
movq(z, xm0); movq(z, xm0);
store_mr(z, Pack(t10, t9, t8, t4)); store_mr(z, Pack(t10, t9, t8, t4));
} }
const void* gen_fpDbl_mod(const mcl::fp::Op& op) const void* gen_fpDbl_mod(const fp::Op& op)
{ {
align(16); align(16);
const void* func = getCurr<void*>(); const void* func = getCurr<void*>();
@ -1163,12 +1160,14 @@ private:
void gen_sqr() void gen_sqr()
{ {
if (op_->primeMode == PM_NIST_P192) { if (op_->primeMode == PM_NIST_P192) {
StackFrame sf(this, 2, 10 | UseRDX | UseRCX, 8 * 6); StackFrame sf(this, 3, 10 | UseRDX, 8 * 6);
Pack t = sf.t;
t.append(sf.p[2]);
sqrPre3(rsp, sf.p[1], sf.t); sqrPre3(rsp, sf.p[1], sf.t);
fpDbl_mod_NIST_P192(sf.p[0], rsp, sf.t); fpDbl_mod_NIST_P192(sf.p[0], rsp, sf.t);
} }
if (pn_ == 3) { if (pn_ == 3) {
gen_montSqr3(p_, rp_); gen_montSqr3();
return; return;
} }
// sqr(y, x) = mul(y, x, x) // sqr(y, x) = mul(y, x, x)
@ -1328,7 +1327,7 @@ private:
z[0..2] <- montgomery(px[0..2], px[0..2]) z[0..2] <- montgomery(px[0..2], px[0..2])
destroy gt0, ..., gt9, xm0, xm1, p2 destroy gt0, ..., gt9, xm0, xm1, p2
*/ */
void gen_montSqr3(const uint64_t *p, uint64_t pp) void gen_montSqr3()
{ {
StackFrame sf(this, 3, 10 | UseRDX, 16 * 3); StackFrame sf(this, 3, 10 | UseRDX, 16 * 3);
const Reg64& pz = sf.p[0]; const Reg64& pz = sf.p[0];
@ -1347,23 +1346,23 @@ private:
const Reg64& t9 = sf.t[9]; const Reg64& t9 = sf.t[9];
movq(xm0, pz); // save pz movq(xm0, pz); // save pz
mov(t7, (uint64_t)p); mov(t7, pL_);
mov(t9, ptr [px]); mov(t9, ptr [px]);
mul3x1_sqr1(px, t9, t3, t2, t1, t0); mul3x1_sqr1(px, t9, t3, t2, t1, t0);
mov(t0, rdx); mov(t0, rdx);
montgomery3_sub(pp, t0, t9, t2, t1, px, t3, t7, t4, t5, t6, t8, pz, true); montgomery3_sub(rp_, t0, t9, t2, t1, px, t3, t7, t4, t5, t6, t8, pz, true);
mov(t3, ptr [px + 8]); mov(t3, ptr [px + 8]);
mul3x1_sqr2(px, t3, t6, t5, t4); mul3x1_sqr2(px, t3, t6, t5, t4);
add_rr(Pack(t1, t0, t9, t2), Pack(rdx, rax, t5, t4)); add_rr(Pack(t1, t0, t9, t2), Pack(rdx, rax, t5, t4));
if (isFullBit_) setc(pz.cvt8()); if (isFullBit_) setc(pz.cvt8());
montgomery3_sub(pp, t1, t3, t9, t2, px, t0, t7, t4, t5, t6, t8, pz, false); montgomery3_sub(rp_, t1, t3, t9, t2, px, t0, t7, t4, t5, t6, t8, pz, false);
mov(t0, ptr [px + 16]); mov(t0, ptr [px + 16]);
mul3x1_sqr3(t0, t5, t4); mul3x1_sqr3(t0, t5, t4);
add_rr(Pack(t2, t1, t3, t9), Pack(rdx, rax, t5, t4)); add_rr(Pack(t2, t1, t3, t9), Pack(rdx, rax, t5, t4));
if (isFullBit_) setc(pz.cvt8()); if (isFullBit_) setc(pz.cvt8());
montgomery3_sub(pp, t2, t0, t3, t9, px, t1, t7, t4, t5, t6, t8, pz, false); montgomery3_sub(rp_, t2, t0, t3, t9, px, t1, t7, t4, t5, t6, t8, pz, false);
// [t9:t2:t0:t3] // [t9:t2:t0:t3]
mov(t4, t3); mov(t4, t3);
@ -1379,13 +1378,12 @@ private:
} }
/* /*
py[5..0] <- px[2..0]^2 py[5..0] <- px[2..0]^2
@note use rax, rdx, rcx! @note use rax, rdx
*/ */
void sqrPre3(const RegExp& py, const RegExp& px, const Pack& t) void sqrPre3(const RegExp& py, const RegExp& px, const Pack& t)
{ {
const Reg64& a = rax; const Reg64& a = rax;
const Reg64& d = rdx; const Reg64& d = rdx;
const Reg64& c = rcx;
const Reg64& t0 = t[0]; const Reg64& t0 = t[0];
const Reg64& t1 = t[1]; const Reg64& t1 = t[1];
const Reg64& t2 = t[2]; const Reg64& t2 = t[2];
@ -1396,6 +1394,7 @@ private:
const Reg64& t7 = t[7]; const Reg64& t7 = t[7];
const Reg64& t8 = t[8]; const Reg64& t8 = t[8];
const Reg64& t9 = t[9]; const Reg64& t9 = t[9];
const Reg64& t10 = t[10];
if (useMulx_) { if (useMulx_) {
mov(d, ptr [px + 8 * 0]); mov(d, ptr [px + 8 * 0]);
@ -1416,7 +1415,7 @@ private:
mov(d, t7); mov(d, t7);
mulx(t8, t7, d); mulx(t8, t7, d);
mulx(c, t9, t9); mulx(t10, t9, t9);
} else { } else {
mov(t9, ptr [px + 8 * 0]); mov(t9, ptr [px + 8 * 0]);
mov(a, t9); mov(a, t9);
@ -1447,11 +1446,11 @@ private:
mov(a, ptr [px + 8 * 2]); mov(a, ptr [px + 8 * 2]);
mul(t9); mul(t9);
mov(t9, a); mov(t9, a);
mov(c, d); mov(t10, d);
} }
add(t2, t7); add(t2, t7);
adc(t8, t9); adc(t8, t9);
mov(t7, c); mov(t7, t10);
adc(t7, 0); // [t7:t8:t2:t1] adc(t7, 0); // [t7:t8:t2:t1]
add(t0, t1); add(t0, t1);
@ -1463,7 +1462,7 @@ private:
mov(a, ptr [px + 8 * 2]); mov(a, ptr [px + 8 * 2]);
mul(a); mul(a);
add(t4, t9); add(t4, t9);
adc(a, c); adc(a, t10);
adc(d, 0); // [d:a:t4:t3] adc(d, 0); // [d:a:t4:t3]
add(t2, t3); add(t2, t3);
@ -1752,6 +1751,38 @@ private:
adc(t5, 0); adc(t5, 0);
store_mr(py + 8 * 2, Pack(t5, t4, t3, t2, t1, t0)); store_mr(py + 8 * 2, Pack(t5, t4, t3, t2, t1, t0));
} }
/*
py[11..0] = px[5..0] ^ 2
use stack[6 * 8]
*/
void sqrPre6(const RegExp& py, const RegExp& px, const Pack& t)
{
const Reg64& t0 = t[0];
const Reg64& t1 = t[1];
const Reg64& t2 = t[2];
/*
(aN + b)^2 = a^2 N^2 + 2ab N + b^2
*/
sqrPre3(py, px, t); // [py] <- b^2
sqrPre3(py + 6 * 8, px + 3 * 8, t); // [py + 6 * 8] <- a^2
mulPre3(rsp, px, px + 3 * 8, t); // ab
Pack ab = t.sub(0, 6);
load_rm(ab, py + 3 * 8);
for (int i = 0; i < 6; i++) {
if (i == 0) {
add(ab[i], ab[i]);
} else {
adc(ab[i], ab[i]);
}
}
add_rm(ab, rsp);
store_mr(py + 3 * 8, ab);
load_rm(Pack(t2, t1, t0), py + 9 * 8);
adc(t0, 0);
adc(t1, 0);
adc(t2, 0);
store_mr(py + 9 * 8, Pack(t2, t1, t0));
}
/* /*
pz[7..0] <- px[3..0] * py[3..0] pz[7..0] <- px[3..0] * py[3..0]
*/ */
@ -2067,29 +2098,44 @@ private:
movq(z, xm0); movq(z, xm0);
store_mr(z, zp); store_mr(z, zp);
} }
void gen_fpDbl_sqrPre(mcl::fp::Op& op) const void* gen_fpDbl_sqrPre(const fp::Op&/* op */)
{ {
if (useMulx_ && pn_ == 2) { align(16);
const void* func = getCurr<void*>();
if (pn_ == 2 && useMulx_) {
StackFrame sf(this, 2, 7 | UseRDX); StackFrame sf(this, 2, 7 | UseRDX);
sqrPre2(sf.p[0], sf.p[1], sf.t); sqrPre2(sf.p[0], sf.p[1], sf.t);
return; return func;
} }
if (pn_ == 3) { if (pn_ == 3) {
StackFrame sf(this, 2, 10 | UseRDX | UseRCX); StackFrame sf(this, 3, 10 | UseRDX);
sqrPre3(sf.p[0], sf.p[1], sf.t); Pack t = sf.t;
return; t.append(sf.p[2]);
sqrPre3(sf.p[0], sf.p[1], t);
return func;
} }
if (useMulx_ && pn_ == 4) { if (pn_ == 4 && useMulx_) {
StackFrame sf(this, 2, 10 | UseRDX | UseRCX); StackFrame sf(this, 2, 10 | UseRDX);
sqrPre4(sf.p[0], sf.p[1], sf.t); sqrPre4(sf.p[0], sf.p[1], sf.t);
return; return func;
}
if (pn_ == 6 && useMulx_ && useAdx_) {
StackFrame sf(this, 3, 10 | UseRDX, 6 * 8);
Pack t = sf.t;
t.append(sf.p[2]);
sqrPre6(sf.p[0], sf.p[1], t);
return func;
} }
return 0;
#if 0
#ifdef XBYAK64_WIN #ifdef XBYAK64_WIN
mov(r8, rdx); mov(r8, rdx);
#else #else
mov(rdx, rsi); mov(rdx, rsi);
#endif #endif
jmp((void*)op.fpDbl_mulPreA_); jmp((void*)op.fpDbl_mulPreA_);
return func;
#endif
} }
const void* gen_fpDbl_mulPre() const void* gen_fpDbl_mulPre()
{ {

@ -698,6 +698,7 @@ if(0){
// CYBOZU_BENCH_C("subDbl", 10000000, FpDbl::sub, dx, dx, dx); // CYBOZU_BENCH_C("subDbl", 10000000, FpDbl::sub, dx, dx, dx);
CYBOZU_BENCH_C("mul", 10000000 / n, f, xv, yv, xv); CYBOZU_BENCH_C("mul", 10000000 / n, f, xv, yv, xv);
CYBOZU_BENCH_C("mulPre", 100000000, FpDbl::mulPre, dx, xv[0], yv[0]); CYBOZU_BENCH_C("mulPre", 100000000, FpDbl::mulPre, dx, xv[0], yv[0]);
CYBOZU_BENCH_C("sqrPre", 100000000, FpDbl::sqrPre, dx, xv[0]);
return 0; return 0;
#endif #endif
return cybozu::test::autoRun.run(argc, argv); return cybozu::test::autoRun.run(argc, argv);

Loading…
Cancel
Save