refactor fp_generator.hpp

dev
MITSUNARI Shigeo 6 years ago
parent 87afbcc756
commit 194bdc0b47
  1. 26
      include/mcl/fp_tower.hpp
  2. 4
      include/mcl/op.hpp
  3. 272
      src/fp_generator.hpp
  4. 2
      test/bench.hpp
  5. 2
      test/bls12_test.cpp

@ -121,20 +121,22 @@ public:
static void (*add)(FpDblT& z, const FpDblT& x, const FpDblT& y);
static void (*sub)(FpDblT& z, const FpDblT& x, const FpDblT& y);
static void (*mod)(Fp& z, const FpDblT& xy);
static void (*addPre)(FpDblT& z, const FpDblT& x, const FpDblT& y);
static void (*subPre)(FpDblT& z, const FpDblT& x, const FpDblT& y);
static void addC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_add(z.v_, x.v_, y.v_, Fp::op_.p); }
static void subC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_sub(z.v_, x.v_, y.v_, Fp::op_.p); }
static void modC(Fp& z, const FpDblT& xy) { Fp::op_.fpDbl_mod(z.v_, xy.v_, Fp::op_.p); }
static void addPreC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_addPre(z.v_, x.v_, y.v_); }
static void subPreC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_subPre(z.v_, x.v_, y.v_); }
#else
static void add(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_add(z.v_, x.v_, y.v_, Fp::op_.p); }
static void sub(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_sub(z.v_, x.v_, y.v_, Fp::op_.p); }
static void mod(Fp& z, const FpDblT& xy) { Fp::op_.fpDbl_mod(z.v_, xy.v_, Fp::op_.p); }
static void addPre(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_addPre(z.v_, x.v_, y.v_); }
static void subPre(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_subPre(z.v_, x.v_, y.v_); }
#endif
static void addPreC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_addPre(z.v_, x.v_, y.v_); }
static void subPreC(FpDblT& z, const FpDblT& x, const FpDblT& y) { Fp::op_.fpDbl_subPre(z.v_, x.v_, y.v_); }
static void mulPreC(FpDblT& xy, const Fp& x, const Fp& y) { Fp::op_.fpDbl_mulPre(xy.v_, x.v_, y.v_); }
static void sqrPreC(FpDblT& xx, const Fp& x) { Fp::op_.fpDbl_sqrPre(xx.v_, x.v_); }
static void (*addPre)(FpDblT& z, const FpDblT& x, const FpDblT& y);
static void (*subPre)(FpDblT& z, const FpDblT& x, const FpDblT& y);
/*
mul(z, x, y) = mulPre(xy, x, y) + mod(z, xy)
*/
@ -155,17 +157,11 @@ public:
if (sub == 0) sub = subC;
mod = (void (*)(Fp&, const FpDblT&))op.fpDbl_modA_;
if (mod == 0) mod = modC;
addPre = (void (*)(FpDblT&, const FpDblT&, const FpDblT&))op.fpDbl_addPre;
if (addPre == 0) addPre = addPreC;
subPre = (void (*)(FpDblT&, const FpDblT&, const FpDblT&))op.fpDbl_subPre;
if (subPre == 0) subPre = subPreC;
#endif
if (op.fpDbl_addPreA_) {
addPre = (void (*)(FpDblT&, const FpDblT&, const FpDblT&))op.fpDbl_addPreA_;
} else {
addPre = addPreC;
}
if (op.fpDbl_subPreA_) {
subPre = (void (*)(FpDblT&, const FpDblT&, const FpDblT&))op.fpDbl_subPreA_;
} else {
subPre = subPreC;
}
if (op.fpDbl_mulPreA_) {
mulPre = (void (*)(FpDblT&, const Fp&, const Fp&))op.fpDbl_mulPreA_;
} else {
@ -185,9 +181,9 @@ public:
template<class Fp> void (*FpDblT<Fp>::add)(FpDblT&, const FpDblT&, const FpDblT&);
template<class Fp> void (*FpDblT<Fp>::sub)(FpDblT&, const FpDblT&, const FpDblT&);
template<class Fp> void (*FpDblT<Fp>::mod)(Fp&, const FpDblT&);
#endif
template<class Fp> void (*FpDblT<Fp>::addPre)(FpDblT&, const FpDblT&, const FpDblT&);
template<class Fp> void (*FpDblT<Fp>::subPre)(FpDblT&, const FpDblT&, const FpDblT&);
#endif
template<class Fp> void (*FpDblT<Fp>::mulPre)(FpDblT&, const Fp&, const Fp&);
template<class Fp> void (*FpDblT<Fp>::sqrPre)(FpDblT&, const Fp&);

@ -193,8 +193,6 @@ struct Op {
void2u fp2_sqrA_;
void3u fpDbl_addA_;
void3u fpDbl_subA_;
void3u fpDbl_addPreA_;
void3u fpDbl_subPreA_;
void3u fpDbl_mulPreA_;
void2u fpDbl_sqrPreA_;
void2u fpDbl_modA_;
@ -282,8 +280,6 @@ struct Op {
fp2_sqrA_ = 0;
fpDbl_addA_ = 0;
fpDbl_subA_ = 0;
fpDbl_addPreA_ = 0;
fpDbl_subPreA_ = 0;
fpDbl_mulPreA_ = 0;
fpDbl_sqrPreA_ = 0;
fpDbl_modA_ = 0;

@ -287,97 +287,31 @@ private:
// code from here
setSize(4096);
assert((getCurr<size_t>() & 4095) == 0);
op.fp_addPre = getCurr<u3u>();
gen_addSubPre(true, pn_);
align(16);
op.fp_subPre = getCurr<u3u>();
gen_addSubPre(false, pn_);
align(16);
op.fp_sub = getCurr<void4u>();
op.fp_subA_ = getCurr<void3u>();
gen_fp_sub();
align(16);
op.fp_add = getCurr<void4u>();
op.fp_addA_ = getCurr<void3u>();
gen_fp_add();
op.fp_addPre = gen_addSubPre(true, pn_);
op.fp_subPre = gen_addSubPre(false, pn_);
op.fp_subA_ = gen_fp_sub();
op.fp_addA_ = gen_fp_add();
align(16);
op.fp_shr1 = getCurr<void2u>();
gen_shr1();
op.fp_shr1 = gen_shr1();
align(16);
op.fp_negA_ = getCurr<void2u>();
gen_fp_neg();
op.fp_negA_ = gen_fp_neg();
const void* func = 0;
// setup fp_tower
op.fp2_mulNF = 0;
if (pn_ <= 4 || (pn_ == 6 && !isFullBit_)) {
align(16);
op.fpDbl_addA_ = getCurr<void3u>();
gen_fpDbl_add();
align(16);
op.fpDbl_subA_ = getCurr<void3u>();
gen_fpDbl_sub();
}
if (op.isFullBit) {
op.fpDbl_addPre = 0;
op.fpDbl_subPre = 0;
} else {
align(16);
op.fpDbl_addPreA_ = getCurr<void3u>();
gen_addSubPre(true, pn_ * 2);
align(16);
op.fpDbl_subPreA_ = getCurr<void3u>();
gen_addSubPre(false, pn_ * 2);
}
if ((useMulx_ && op.N == 2) || op.N == 3 || op.N == 4 || (useAdx_ && op.N == 6)) {
align(16);
op.fpDbl_mulPreA_ = getCurr<void3u>();
if (op.N == 4) {
/*
fpDbl_mulPre is available as C function
this function calls mulPreL directly.
*/
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
mulPre4(gp0, gp1, gp2, sf.t);
sf.close(); // make epilog
L(mulPreL); // called only from asm code
mulPre4(gp0, gp1, gp2, sf.t);
ret();
} else if (op.N == 6 && useAdx_) {
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(mulPreL);
sf.close(); // make epilog
L(mulPreL); // called only from asm code
mulPre6(sf.t);
ret();
} else {
gen_fpDbl_mulPre();
}
}
if (op.N == 2 || op.N == 3 || op.N == 4 || (op.N == 6 && !isFullBit_ && useAdx_)) {
align(16);
op.fpDbl_modA_ = getCurr<void2u>();
if (op.N == 4) {
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(fpDbl_modL);
sf.close();
L(fpDbl_modL);
gen_fpDbl_mod4(gp0, gp1, sf.t, gp2);
ret();
} else if (op.N == 6 && !isFullBit_ && useAdx_) {
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(fpDbl_modL);
sf.close();
L(fpDbl_modL);
Pack t = sf.t;
t.append(gp2);
gen_fpDbl_mod6(gp0, gp1, t);
ret();
} else {
gen_fpDbl_mod(op);
}
}
func = gen_fpDbl_add();
if (func) op.fpDbl_addA_ = reinterpret_cast<void3u>(func);
func = gen_fpDbl_sub();
if (func) op.fpDbl_subA_ = reinterpret_cast<void3u>(func);
op.fpDbl_addPre = gen_addSubPre(true, pn_ * 2);
op.fpDbl_subPre = gen_addSubPre(false, pn_ * 2);
func = gen_fpDbl_mulPre();
if (func) op.fpDbl_mulPreA_ = reinterpret_cast<void3u>(func);
func = gen_fpDbl_mod(op);
if (func) op.fpDbl_modA_ = reinterpret_cast<void2u>(func);
if (op.N > 4) return;
align(16);
op.fp_mul = getCurr<void4u>(); // used in toMont/fromMont
@ -389,7 +323,7 @@ private:
op.fpDbl_sqrPreA_ = getCurr<void2u>();
gen_fpDbl_sqrPre(op);
}
// if (op.N > 4) return;
if (op.N > 4) return;
align(16);
op.fp_sqrA_ = getCurr<void2u>();
gen_sqr();
@ -425,14 +359,18 @@ private:
gen_fp2_mul_xi4();
}
}
void gen_addSubPre(bool isAdd, int n)
u3u gen_addSubPre(bool isAdd, int n)
{
// if (isFullBit_) return 0;
align(16);
u3u func = getCurr<u3u>();
StackFrame sf(this, 3);
if (isAdd) {
gen_raw_add(sf.p[0], sf.p[1], sf.p[2], rax, n);
} else {
gen_raw_sub(sf.p[0], sf.p[1], sf.p[2], rax, n);
}
return func;
}
/*
pz[] = px[] + py[]
@ -702,15 +640,17 @@ private:
t2.append(px); // destory after used
gen_raw_fp_add6(pz, px, py, 0, t1, t2, false);
}
void gen_fp_add()
void3u gen_fp_add()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ <= 4) {
gen_fp_add_le4();
return;
return func;
}
if (pn_ == 6) {
gen_fp_add6();
return;
return func;
}
StackFrame sf(this, 3, 0, pn_ * 8);
const Reg64& pz = sf.p[0];
@ -746,9 +686,12 @@ private:
L(".exit");
#endif
outLocalLabel();
return func;
}
void gen_fpDbl_add()
const void* gen_fpDbl_add()
{
align(16);
const void* func = getCurr<void*>();
if (pn_ <= 4) {
int tn = pn_ * 2 + (isFullBit_ ? 1 : 0);
StackFrame sf(this, 3, tn);
@ -757,6 +700,7 @@ private:
const Reg64& py = sf.p[2];
gen_raw_add(pz, px, py, rax, pn_);
gen_raw_fp_add(pz + 8 * pn_, px + 8 * pn_, py + 8 * pn_, sf.t, true);
return func;
} else if (pn_ == 6 && !isFullBit_) {
StackFrame sf(this, 3, 10);
const Reg64& pz = sf.p[0];
@ -768,13 +712,14 @@ private:
t2.append(rax);
t2.append(py);
gen_raw_fp_add6(pz, px, py, pn_ * 8, t1, t2, true);
} else {
assert(0);
exit(1);
return func;
}
return 0;
}
void gen_fpDbl_sub()
const void* gen_fpDbl_sub()
{
align(16);
const void* func = getCurr<void*>();
if (pn_ <= 4) {
int tn = pn_ * 2;
StackFrame sf(this, 3, tn);
@ -783,6 +728,7 @@ private:
const Reg64& py = sf.p[2];
gen_raw_sub(pz, px, py, rax, pn_);
gen_raw_fp_sub(pz + 8 * pn_, px + 8 * pn_, py + 8 * pn_, sf.t, true);
return func;
} else if (pn_ == 6) {
StackFrame sf(this, 3, 4);
const Reg64& pz = sf.p[0];
@ -793,10 +739,9 @@ private:
t.append(rax);
t.append(px);
gen_raw_fp_sub6(pz, px, py, pn_ * 8, t, true);
} else {
assert(0);
exit(1);
return func;
}
return 0;
}
void gen_raw_fp_sub6(const Reg64& pz, const Reg64& px, const Reg64& py, int offset, const Pack& t, bool withCarry)
{
@ -821,15 +766,17 @@ private:
t.append(px); // |t| = 6
gen_raw_fp_sub6(pz, px, py, 0, t, false);
}
void gen_fp_sub()
void3u gen_fp_sub()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ <= 4) {
gen_fp_sub_le4();
return;
return func;
}
if (pn_ == 6) {
gen_fp_sub6();
return;
return func;
}
StackFrame sf(this, 3);
const Reg64& pz = sf.p[0];
@ -842,14 +789,20 @@ private:
mov(px, (size_t)p_);
gen_raw_add(pz, pz, px, rax, pn_);
L(exit);
return func;
}
void gen_fp_neg()
void2u gen_fp_neg()
{
align(16);
void2u func = getCurr<void2u>();
StackFrame sf(this, 2, UseRDX | pn_);
gen_raw_neg(sf.p[0], sf.p[1], sf.t);
return func;
}
void gen_shr1()
void2u gen_shr1()
{
align(16);
void2u func = getCurr<void2u>();
const int c = 1;
StackFrame sf(this, 2, 1);
const Reg64 *t0 = &rax;
@ -865,6 +818,7 @@ private:
}
shr(*t0, c);
mov(ptr [pz + (pn_ - 1) * 8], *t0);
return func;
}
void gen_mul()
{
@ -874,10 +828,10 @@ private:
fpDbl_mod_NIST_P192(sf.p[0], rsp, sf.t);
}
if (pn_ == 3) {
gen_montMul3(p_, rp_);
gen_montMul3();
} else if (pn_ == 4) {
gen_montMul4(p_, rp_);
#if 0
gen_montMul4();
#if 1
} else if (pn_ == 6 && useAdx_) {
// gen_montMul6(p_, rp_);
StackFrame sf(this, 3, 10 | UseRDX, (1 + 12) * 8);
@ -1160,38 +1114,51 @@ private:
movq(z, xm0);
store_mr(z, Pack(t10, t9, t8, t4));
}
void gen_fpDbl_mod(const mcl::fp::Op& op)
const void* gen_fpDbl_mod(const mcl::fp::Op& op)
{
align(16);
const void* func = getCurr<void*>();
if (op.primeMode == PM_NIST_P192) {
StackFrame sf(this, 2, 6 | UseRDX);
fpDbl_mod_NIST_P192(sf.p[0], sf.p[1], sf.t);
return;
return func;
}
#if 0
if (op.primeMode == PM_NIST_P521) {
StackFrame sf(this, 2, 8 | UseRDX);
fpDbl_mod_NIST_P521(sf.p[0], sf.p[1], sf.t);
return;
return func;
}
#endif
switch (pn_) {
case 2:
if (pn_ == 2) {
gen_fpDbl_mod2();
break;
case 3:
return func;
}
if (pn_ == 3) {
gen_fpDbl_mod3();
break;
#if 0
case 4:
{
StackFrame sf(this, 3, 10 | UseRDX);
return func;
}
if (pn_ == 4) {
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(fpDbl_modL);
sf.close();
L(fpDbl_modL);
gen_fpDbl_mod4(gp0, gp1, sf.t, gp2);
ret();
return func;
}
break;
#endif
default:
throw cybozu::Exception("gen_fpDbl_mod:not support") << pn_;
if (pn_ == 6 && !isFullBit_ && useAdx_) {
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(fpDbl_modL);
sf.close();
L(fpDbl_modL);
Pack t = sf.t;
t.append(gp2);
gen_fpDbl_mod6(gp0, gp1, t);
ret();
return func;
}
return 0;
}
void gen_sqr()
{
@ -1255,7 +1222,7 @@ private:
z[0..3] <- montgomery(x[0..3], y[0..3])
destroy gt0, ..., gt9, xm0, xm1, p2
*/
void gen_montMul4(const uint64_t *p, uint64_t pp)
void gen_montMul4()
{
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(fp_mulL);
@ -1277,22 +1244,22 @@ private:
L(fp_mulL);
movq(xm0, p0); // save p0
mov(p0, (uint64_t)p);
mov(p0, pL_);
movq(xm1, p2);
mov(p2, ptr [p2]);
montgomery4_1(pp, t0, t7, t3, t2, t1, p1, p2, p0, t4, t5, t6, t8, t9, true, xm2);
montgomery4_1(rp_, t0, t7, t3, t2, t1, p1, p2, p0, t4, t5, t6, t8, t9, true, xm2);
movq(p2, xm1);
mov(p2, ptr [p2 + 8]);
montgomery4_1(pp, t1, t0, t7, t3, t2, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2);
montgomery4_1(rp_, t1, t0, t7, t3, t2, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2);
movq(p2, xm1);
mov(p2, ptr [p2 + 16]);
montgomery4_1(pp, t2, t1, t0, t7, t3, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2);
montgomery4_1(rp_, t2, t1, t0, t7, t3, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2);
movq(p2, xm1);
mov(p2, ptr [p2 + 24]);
montgomery4_1(pp, t3, t2, t1, t0, t7, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2);
montgomery4_1(rp_, t3, t2, t1, t0, t7, p1, p2, p0, t4, t5, t6, t8, t9, false, xm2);
// [t7:t3:t2:t1:t0]
mov(t4, t0);
@ -1315,7 +1282,7 @@ private:
z[0..2] <- montgomery(x[0..2], y[0..2])
destroy gt0, ..., gt9, xm0, xm1, p2
*/
void gen_montMul3(const uint64_t *p, uint64_t pp)
void gen_montMul3()
{
StackFrame sf(this, 3, 10 | UseRDX);
const Reg64& p0 = sf.p[0];
@ -1334,15 +1301,15 @@ private:
const Reg64& t9 = sf.t[9];
movq(xm0, p0); // save p0
mov(t7, (uint64_t)p);
mov(t7, pL_);
mov(t9, ptr [p2]);
// c3, c2, c1, c0, px, y, p,
montgomery3_1(pp, t0, t3, t2, t1, p1, t9, t7, t4, t5, t6, t8, p0, true);
montgomery3_1(rp_, t0, t3, t2, t1, p1, t9, t7, t4, t5, t6, t8, p0, true);
mov(t9, ptr [p2 + 8]);
montgomery3_1(pp, t1, t0, t3, t2, p1, t9, t7, t4, t5, t6, t8, p0, false);
montgomery3_1(rp_, t1, t0, t3, t2, p1, t9, t7, t4, t5, t6, t8, p0, false);
mov(t9, ptr [p2 + 16]);
montgomery3_1(pp, t2, t1, t0, t3, p1, t9, t7, t4, t5, t6, t8, p0, false);
montgomery3_1(rp_, t2, t1, t0, t3, p1, t9, t7, t4, t5, t6, t8, p0, false);
// [(t3):t2:t1:t0]
mov(t4, t0);
@ -1607,6 +1574,7 @@ private:
if (useMulx_) {
mulPack(pz, px, py, Pack(t2, t1, t0));
#if 0 // a little slow
if (useAdx_) {
// [t2:t1:t0]
mulPackAdd(pz + 8 * 1, px + 8 * 1, py, t3, Pack(t2, t1, t0));
@ -1616,6 +1584,7 @@ private:
store_mr(pz + 8 * 3, Pack(t4, t3, t2));
return;
}
#endif
} else {
mov(t5, ptr [px]);
mov(a, ptr [py + 8 * 0]);
@ -2122,20 +2091,43 @@ private:
#endif
jmp((void*)op.fpDbl_mulPreA_);
}
void gen_fpDbl_mulPre()
const void* gen_fpDbl_mulPre()
{
if (useMulx_ && pn_ == 2) {
align(16);
const void* func = getCurr<void*>();
if (pn_ == 2 && useMulx_) {
StackFrame sf(this, 3, 5 | UseRDX);
mulPre2(sf.p[0], sf.p[1], sf.p[2], sf.t);
return;
return func;
}
if (pn_ == 3) {
StackFrame sf(this, 3, 10 | UseRDX);
mulPre3(sf.p[0], sf.p[1], sf.p[2], sf.t);
return;
return func;
}
if (pn_ == 4) {
/*
fpDbl_mulPre is available as C function
this function calls mulPreL directly.
*/
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
mulPre4(gp0, gp1, gp2, sf.t);
sf.close(); // make epilog
L(mulPreL); // called only from asm code
mulPre4(gp0, gp1, gp2, sf.t);
ret();
return func;
}
if (pn_ == 6 && useAdx_) {
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(mulPreL);
sf.close(); // make epilog
L(mulPreL); // called only from asm code
mulPre6(sf.t);
ret();
return func;
}
assert(0);
exit(1);
return 0;
}
static inline void debug_put_inner(const uint64_t *ptr, int n)
{

@ -8,7 +8,7 @@ void testBench(const G1& P, const G2& Q)
pairing(e1, P, Q);
Fp12::pow(e2, e1, 12345);
const int C = 500;
const int C3 = 10000;
const int C3 = 3000;
Fp x, y;
x.setHashOf("abc");
y.setHashOf("xyz");

@ -697,7 +697,7 @@ if(0){
}
// CYBOZU_BENCH_C("subDbl", 10000000, FpDbl::sub, dx, dx, dx);
CYBOZU_BENCH_C("mul", 10000000 / n, f, xv, yv, xv);
CYBOZU_BENCH_C("mulPre", 10000000, FpDbl::mulPre, dx, xv[0], yv[0]);
CYBOZU_BENCH_C("mulPre", 100000000, FpDbl::mulPre, dx, xv[0], yv[0]);
return 0;
#endif
return cybozu::test::autoRun.run(argc, argv);

Loading…
Cancel
Save