add fp2Dbl_mulPre

dev
MITSUNARI Shigeo 6 years ago
parent 602a2df220
commit c57760ea54
  1. 154
      include/mcl/fp_tower.hpp
  2. 2
      include/mcl/op.hpp
  3. 303
      src/fp_generator.hpp
  4. 2
      test/bench.hpp

@ -134,6 +134,55 @@ public:
}
void operator+=(const FpDblT& x) { add(*this, *this, x); }
void operator-=(const FpDblT& x) { sub(*this, *this, x); }
/*
Fp2Dbl::mulPre by FpDblT
@note mod of NIST_P192 is fast
*/
static void fp2Dbl_mulPreW(Unit *z, const Unit *x, const Unit *y)
{
const Fp *px = reinterpret_cast<const Fp*>(x);
const Fp *py = reinterpret_cast<const Fp*>(y);
const Fp& a = px[0];
const Fp& b = px[1];
const Fp& c = py[0];
const Fp& d = py[1];
FpDblT& d0 = reinterpret_cast<FpDblT*>(z)[0];
FpDblT& d1 = reinterpret_cast<FpDblT*>(z)[1];
FpDblT d2;
Fp s, t;
Fp::add(s, a, b);
Fp::add(t, c, d);
FpDblT::mulPre(d1, s, t); // (a + b)(c + d)
FpDblT::mulPre(d0, a, c);
FpDblT::mulPre(d2, b, d);
FpDblT::sub(d1, d1, d0); // (a + b)(c + d) - ac
FpDblT::sub(d1, d1, d2); // (a + b)(c + d) - ac - bd
FpDblT::sub(d0, d0, d2); // ac - bd
}
/*
Fp2Dbl::mulPre by FpDblT with No Carry
*/
static void fp2Dbl_mulPreNoCarryW(Unit *z, const Unit *x, const Unit *y)
{
const Fp *px = reinterpret_cast<const Fp*>(x);
const Fp *py = reinterpret_cast<const Fp*>(y);
const Fp& a = px[0];
const Fp& b = px[1];
const Fp& c = py[0];
const Fp& d = py[1];
FpDblT& d0 = reinterpret_cast<FpDblT*>(z)[0];
FpDblT& d1 = reinterpret_cast<FpDblT*>(z)[1];
FpDblT d2;
Fp s, t;
Fp::addPre(s, a, b);
Fp::addPre(t, c, d);
FpDblT::mulPre(d1, s, t); // (a + b)(c + d)
FpDblT::mulPre(d0, a, c);
FpDblT::mulPre(d2, b, d);
FpDblT::subPre(d1, d1, d0); // (a + b)(c + d) - ac
FpDblT::subPre(d1, d1, d2); // (a + b)(c + d) - ac - bd
FpDblT::sub(d0, d0, d2); // ac - bd
}
};
template<class Fp> struct Fp12T;
@ -328,16 +377,25 @@ public:
mcl::fp::Op& op = Fp::op_;
op.fp2_add = fp2_addW;
op.fp2_sub = fp2_subW;
if (op.isFastMod) {
op.fp2_mul = fp2_mulW;
} else if (!op.isFullBit) {
if (0 && sizeof(Fp) * 8 == op.N * fp::UnitBitSize && op.fp2_mulNF) {
op.fp2_mul = fp2_mulNFW;
if (op.fp2Dbl_mulPre == 0) {
if (op.isFullBit) {
op.fp2Dbl_mulPre = FpDblT<Fp>::fp2Dbl_mulPreW;
} else {
op.fp2_mul = fp2_mulUseDblUseNCW;
op.fp2Dbl_mulPre = FpDblT<Fp>::fp2Dbl_mulPreNoCarryW;
}
}
if (op.fp2_mul == 0) {
if (op.isFastMod) {
op.fp2_mul = fp2_mulW;
} else if (!op.isFullBit) {
if (0 && sizeof(Fp) * 8 == op.N * fp::UnitBitSize && op.fp2_mulNF) {
op.fp2_mul = fp2_mulNFW;
} else {
op.fp2_mul = fp2_mulW;
}
} else {
op.fp2_mul = fp2_mulW;
}
} else {
op.fp2_mul = fp2_mulUseDblW;
}
op.fp2_neg = fp2_negW;
op.fp2_inv = fp2_invW;
@ -439,6 +497,7 @@ private:
Fp::neg(py[0], px[0]);
Fp::neg(py[1], px[1]);
}
#if 0
/*
x = a + bi, y = c + di, i^2 = -1
z = xy = (a + bi)(c + di) = (ac - bd) + (ad + bc)i
@ -464,58 +523,19 @@ private:
Fp::sub(pz[1], t1, ac);
pz[1] -= bd;
}
/*
# of mod = 2
@note mod of NIST_P192 is fast
*/
static void fp2_mulUseDblW(Unit *z, const Unit *x, const Unit *y)
{
const Fp *px = reinterpret_cast<const Fp*>(x);
const Fp *py = reinterpret_cast<const Fp*>(y);
const Fp& a = px[0];
const Fp& b = px[1];
const Fp& c = py[0];
const Fp& d = py[1];
FpDbl d0, d1, d2;
Fp s, t;
Fp::add(s, a, b);
Fp::add(t, c, d);
FpDbl::mulPre(d0, s, t); // (a + b)(c + d)
FpDbl::mulPre(d1, a, c);
FpDbl::mulPre(d2, b, d);
FpDbl::sub(d0, d0, d1); // (a + b)(c + d) - ac
FpDbl::sub(d0, d0, d2); // (a + b)(c + d) - ac - bd
Fp *pz = reinterpret_cast<Fp*>(z);
FpDbl::mod(pz[1], d0);
FpDbl::sub(d1, d1, d2); // ac - bd
FpDbl::mod(pz[0], d1); // set z0
}
#endif
static void fp2_mulNFW(Unit *z, const Unit *x, const Unit *y)
{
const fp::Op& op = Fp::op_;
op.fp2_mulNF(z, x, y, op.p);
}
static void fp2_mulUseDblUseNCW(Unit *z, const Unit *x, const Unit *y)
static void fp2_mulW(Unit *z, const Unit *x, const Unit *y)
{
const Fp *px = reinterpret_cast<const Fp*>(x);
const Fp *py = reinterpret_cast<const Fp*>(y);
const Fp& a = px[0];
const Fp& b = px[1];
const Fp& c = py[0];
const Fp& d = py[1];
FpDbl d0, d1, d2;
Fp s, t;
Fp::addPre(s, a, b);
Fp::addPre(t, c, d);
FpDbl::mulPre(d0, s, t); // (a + b)(c + d)
FpDbl::mulPre(d1, a, c);
FpDbl::mulPre(d2, b, d);
FpDbl::subPre(d0, d0, d1); // (a + b)(c + d) - ac
FpDbl::subPre(d0, d0, d2); // (a + b)(c + d) - ac - bd
FpDbl d[2];
Fp::getOp().fp2Dbl_mulPre(reinterpret_cast<Unit*>(d), x, y);
Fp *pz = reinterpret_cast<Fp*>(z);
FpDbl::mod(pz[1], d0);
FpDbl::sub(d1, d1, d2); // ac - bd
FpDbl::mod(pz[0], d1); // set z0
FpDbl::mod(pz[0], d[0]);
FpDbl::mod(pz[1], d[1]);
}
/*
x = a + bi, i^2 = -1
@ -665,33 +685,7 @@ struct Fp2DblT {
}
static void mulPre(Fp2DblT& z, const Fp2& x, const Fp2& y)
{
const Fp& a = x.a;
const Fp& b = x.b;
const Fp& c = y.a;
const Fp& d = y.b;
if (Fp::isFullBit()) {
FpDbl BD;
Fp s, t;
Fp::add(s, a, b); // s = a + b
Fp::add(t, c, d); // t = c + d
FpDbl::mulPre(BD, b, d); // BD = bd
FpDbl::mulPre(z.a, a, c); // z.a = ac
FpDbl::mulPre(z.b, s, t); // z.b = st
FpDbl::sub(z.b, z.b, z.a); // z.b = st - ac
FpDbl::sub(z.b, z.b, BD); // z.b = st - ac - bd = ad + bc
FpDbl::sub(z.a, z.a, BD); // ac - bd
} else {
FpDbl BD;
Fp s, t;
Fp::addPre(s, a, b); // s = a + b
Fp::addPre(t, c, d); // t = c + d
FpDbl::mulPre(BD, b, d); // BD = bd
FpDbl::mulPre(z.a, a, c); // z.a = ac
FpDbl::mulPre(z.b, s, t); // z.b = st
FpDbl::subPre(z.b, z.b, z.a); // z.b = st - ac
FpDbl::subPre(z.b, z.b, BD); // z.b = st - ac - bd = ad + bc
FpDbl::sub(z.a, z.a, BD); // ac - bd
}
Fp::getOp().fp2Dbl_mulPre((fp::Unit*)&z, (const fp::Unit*)&x, (const fp::Unit*)&y);
}
static void mod(Fp2& y, const Fp2DblT& x)
{

@ -208,6 +208,7 @@ struct Op {
u3u fp_subPre; // without modulo p
u3u fpDbl_addPre;
u3u fpDbl_subPre;
void3u fp2Dbl_mulPre;
/*
for Fp2 = F[u] / (u^2 + 1)
x = a + bu
@ -284,6 +285,7 @@ struct Op {
fp_subPre = 0;
fpDbl_addPre = 0;
fpDbl_subPre = 0;
fp2Dbl_mulPre = 0;
xi_a = 0;
fp2_add = 0;

@ -130,19 +130,71 @@ struct FpGenerator : Xbyak::CodeGenerator {
typedef Xbyak::Reg64 Reg64;
typedef Xbyak::Xmm Xmm;
typedef Xbyak::Operand Operand;
typedef Xbyak::Label Label;
typedef Xbyak::util::StackFrame StackFrame;
typedef Xbyak::util::Pack Pack;
typedef fp_gen_local::MixPack MixPack;
typedef fp_gen_local::MemReg MemReg;
static const int UseRDX = Xbyak::util::UseRDX;
static const int UseRCX = Xbyak::util::UseRCX;
/*
classes to calculate offset and size
*/
struct Ext1 {
Ext1(size_t FpByte, const Reg64& r, int n = 0)
: r_(r)
, n_(n)
, next(FpByte + n)
{
}
operator RegExp() const { return r_ + n_; }
const Reg64& r_;
const int n_;
const int next;
private:
Ext1(const Ext1&);
void operator=(const Ext1&);
};
struct Ext2 {
Ext2(size_t FpByte, const Reg64& r, int n = 0)
: r_(r)
, n_(n)
, next(FpByte * 2 + n)
, a(FpByte, r, n)
, b(FpByte, r, n + FpByte)
{
}
operator RegExp() const { return r_ + n_; }
const Reg64& r_;
const int n_;
const int next;
Ext1 a;
Ext1 b;
private:
Ext2(const Ext2&);
void operator=(const Ext2&);
};
Xbyak::util::Cpu cpu_;
bool useMulx_;
bool useAdx_;
const Reg64& gp0;
const Reg64& gp1;
const Reg64& gp2;
const Reg64& gt0;
const Reg64& gt1;
const Reg64& gt2;
const Reg64& gt3;
const Reg64& gt4;
const Reg64& gt5;
const Reg64& gt6;
const Reg64& gt7;
const Reg64& gt8;
const Reg64& gt9;
const mcl::fp::Op *op_;
const uint64_t *p_;
uint64_t rp_;
int pn_;
int FpByte_;
bool isFullBit_;
// add/sub without carry. return true if overflow
typedef bool (*bool3op)(uint64_t*, const uint64_t*, const uint64_t*);
@ -160,13 +212,40 @@ struct FpGenerator : Xbyak::CodeGenerator {
typedef int (*int2op)(uint64_t*, const uint64_t*);
void4u mul_;
uint3opI mulUnit_;
// the following labels assume sf(this, 3, 10 | UseRDX)
Label mulPreL_;
Label fpDbl_modL_;
FpGenerator()
: CodeGenerator(4096 * 8)
#ifdef _MSC_VER
, gp0(rcx)
, gp1(r11)
, gp2(r8)
, gt0(r9)
, gt1(r10)
, gt2(rdi)
, gt3(rsi)
#else
, gp0(rdi)
, gp1(rsi)
, gp2(r11)
, gt0(rcx)
, gt1(r8)
, gt2(r9)
, gt3(r10)
#endif
, gt4(rbx)
, gt5(rbp)
, gt6(r12)
, gt7(r13)
, gt8(r14)
, gt9(r15)
, op_(0)
, p_(0)
, rp_(0)
, pn_(0)
, isFullBit_(0)
, FpByte_(0)
, mul_(0)
, mulUnit_(0)
{
@ -182,7 +261,8 @@ struct FpGenerator : Xbyak::CodeGenerator {
p_ = op.p;
rp_ = fp::getMontgomeryCoeff(p_[0]);
pn_ = (int)op.N;
isFullBit_ = (p_[pn_ - 1] >> 63) != 0;
FpByte_ = int(op.maxN * sizeof(uint64_t));
isFullBit_ = op.isFullBit;
// printf("p=%p, pn_=%d, isFullBit_=%d\n", p_, pn_, isFullBit_);
setSize(0); // reset code
@ -245,18 +325,45 @@ struct FpGenerator : Xbyak::CodeGenerator {
if (op.N == 2 || op.N == 3 || op.N == 4) {
align(16);
op.fpDbl_mod = getCurr<void3u>();
gen_fpDbl_mod(op);
if (op.N == 4) {
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(fpDbl_modL_);
sf.close();
L(fpDbl_modL_);
gen_fpDbl_mod4(gp0, gp1, sf.t, gp2);
ret();
} else {
gen_fpDbl_mod(op);
}
}
if ((useMulx_ && op.N == 2) || op.N == 3 || op.N == 4 || (useAdx_ && op.N == 6)) {
align(16);
op.fpDbl_mulPre = getCurr<void3u>();
gen_fpDbl_mulPre();
if (op.N == 4) {
/*
fpDbl_mulPre is available as C function
this function calls mulPreL_ directly.
*/
StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(mulPreL_);
sf.close(); // make epilog
L(mulPreL_); // called only from asm code
mulPre4(gp0, gp1, gp2, sf.t);
ret();
} else {
gen_fpDbl_mulPre();
}
}
if ((useMulx_ && op.N == 2) || op.N == 3 || op.N == 4) {
align(16);
op.fpDbl_sqrPre = getCurr<void2u>();
gen_fpDbl_sqrPre(op);
}
if (op.N == 4 && !isFullBit_) {
align(16);
op.fp2Dbl_mulPre = getCurr<void3u>();
gen_fp2Dbl_mulPre();
}
}
void gen_addSubPre(bool isAdd, int n)
{
@ -767,23 +874,18 @@ struct FpGenerator : Xbyak::CodeGenerator {
@note destroy rax, rdx, t0, ..., t10, xm0, xm1
xm2 if isFullBit_
*/
void gen_fpDbl_mod4()
void gen_fpDbl_mod4(const Reg64& z, const Reg64& xy, const Pack& t, const Reg64& t10)
{
StackFrame sf(this, 3, 10 | UseRDX);
const Reg64& z = sf.p[0];
const Reg64& xy = sf.p[1];
const Reg64& t0 = sf.t[0];
const Reg64& t1 = sf.t[1];
const Reg64& t2 = sf.t[2];
const Reg64& t3 = sf.t[3];
const Reg64& t4 = sf.t[4];
const Reg64& t5 = sf.t[5];
const Reg64& t6 = sf.t[6];
const Reg64& t7 = sf.t[7];
const Reg64& t8 = sf.t[8];
const Reg64& t9 = sf.t[9];
const Reg64& t10 = sf.p[2];
const Reg64& t0 = t[0];
const Reg64& t1 = t[1];
const Reg64& t2 = t[2];
const Reg64& t3 = t[3];
const Reg64& t4 = t[4];
const Reg64& t5 = t[5];
const Reg64& t6 = t[6];
const Reg64& t7 = t[7];
const Reg64& t8 = t[8];
const Reg64& t9 = t[9];
const Reg64& a = rax;
const Reg64& d = rdx;
@ -903,9 +1005,14 @@ struct FpGenerator : Xbyak::CodeGenerator {
case 3:
gen_fpDbl_mod3();
break;
#if 0
case 4:
gen_fpDbl_mod4();
{
StackFrame sf(this, 3, 10 | UseRDX);
gen_fpDbl_mod4(gp0, gp1, sf.t, gp2);
}
break;
#endif
default:
throw cybozu::Exception("gen_fpDbl_mod:not support") << pn_;
}
@ -1608,18 +1715,27 @@ struct FpGenerator : Xbyak::CodeGenerator {
if (useMulx_ && pn_ == 2) {
StackFrame sf(this, 3, 5 | UseRDX);
mulPre2(sf.p[0], sf.p[1], sf.p[2], sf.t);
} else if (pn_ == 3) {
return;
}
if (pn_ == 3) {
StackFrame sf(this, 3, 10 | UseRDX);
mulPre3(sf.p[0], sf.p[1], sf.p[2], sf.t);
} else if (pn_ == 4) {
return;
}
assert(0);
#if 1
if (pn_ == 4) {
StackFrame sf(this, 3, 10 | UseRDX);
mulPre4(sf.p[0], sf.p[1], sf.p[2], sf.t);
return;
}
#endif
#if 0 // slow?
} else if (pn_ == 6 && useAdx_) {
if (pn_ == 6 && useAdx_) {
StackFrame sf(this, 3, 7 | UseRDX);
mulPre6(sf.p[0], sf.p[1], sf.p[2], sf.t);
#endif
}
#endif
}
static inline void debug_put_inner(const uint64_t *ptr, int n)
{
@ -2655,6 +2771,143 @@ private:
}
}
}
void gen_fp2Dbl_mulPre()
{
#if 1
assert(!isFullBit_);
const RegExp z = rsp + 0 * 8;
const RegExp x = rsp + 1 * 8;
const RegExp y = rsp + 2 * 8;
const Ext1 s(FpByte_, rsp, 3 * 8);
const Ext1 t(FpByte_, rsp, s.next);
const Ext1 d2(FpByte_ * 2, rsp, t.next);
const int SS = d2.next;
StackFrame sf(this, 3, 10 | UseRDX, SS);
mov(ptr [z], gp0);
mov(ptr [x], gp1);
mov(ptr [y], gp2);
// s = a + b
gen_raw_add(s, gp1, gp1 + FpByte_, rax, 4);
// t = c + d
gen_raw_add(t, gp2, gp2 + FpByte_, rax, 4);
// d1 = (a + b)(c + d)
mov(gp0, ptr [z]);
add(gp0, FpByte_ * 2); // d1
lea(gp1, ptr [s]);
lea(gp2, ptr [t]);
call(mulPreL_);
// d0 = a c
mov(gp0, ptr [z]);
mov(gp1, ptr [x]);
mov(gp2, ptr [y]);
call(mulPreL_);
// d2 = b d
lea(gp0, ptr [d2]);
mov(gp1, ptr [x]);
add(gp1, FpByte_);
mov(gp2, ptr [y]);
add(gp2, FpByte_);
call(mulPreL_);
mov(gp0, ptr [z]);
add(gp0, FpByte_ * 2); // d1
mov(gp1, gp0);
mov(gp2, ptr [z]);
gen_raw_sub(gp0, gp1, gp2, rax, 8);
lea(gp2, ptr [d2]);
gen_raw_sub(gp0, gp1, gp2, rax, 8);
// pz[1]
mov(gp0, ptr [z]);
mov(gp1, gp0);
lea(gp2, ptr [d2]);
gen_raw_sub(gp0, gp1, gp2, rax, 4);
gen_raw_fp_sub(gp0 + 8 * 4, gp1 + 8 * 4, gp2 + 8 * 4, Pack(gt0, gt1, gt2, gt3, gt4, gt5, gt6, gt7), true);
#else
assert(!isFullBit_);
/*
x = a + bi, y = c + di
xy = (ac - bd) + (ad + bc)i
= (ac - bd) + ((a + b)(c + d) - ac - bd)i
*/
const RegExp z = rsp + 0 * 8;
const RegExp x = rsp + 1 * 8;
const RegExp y = rsp + 2 * 8;
const Ext1 s(FpByte_, rsp, 3 * 8);
const Ext1 t(FpByte_, rsp, s.next);
const Ext1 d0(FpByte_ * 2, rsp, t.next);
const Ext1 d1(FpByte_ * 2, rsp, d0.next);
const Ext1 d2(FpByte_ * 2, rsp, d1.next);
const int SS = d2.next;
StackFrame sf(this, 3, 10 | UseRDX, SS);
mov(ptr[z], gp0);
mov(ptr[x], gp1);
mov(ptr[y], gp2);
/*
FpDbl d0, d1, d2;
Fp s, t;
Fp::addPre(s, a, b);
Fp::addPre(t, c, d);
FpDbl::mulPre(d0, s, t); // (a + b)(c + d)
FpDbl::mulPre(d1, a, c);
FpDbl::mulPre(d2, b, d);
FpDbl::subPre(d0, d0, d1); // (a + b)(c + d) - ac
FpDbl::subPre(d0, d0, d2); // (a + b)(c + d) - ac - bd
Fp *pz = reinterpret_cast<Fp*>(z);
FpDbl::mod(pz[1], d0);
FpDbl::sub(d1, d1, d2); // ac - bd
FpDbl::mod(pz[0], d1); // set z0
*/
// s = a + b
gen_raw_add(s, x, x + FpByte_, rax, 4);
// t = c + d
gen_raw_add(t, y, y + FpByte_, rax, 4);
// d0 = (a + b)(c + d)
lea(gp0, ptr [d0]);
lea(gp1, ptr [s]);
lea(gp2, ptr [t]);
call(mulPreL_);
// d1 = a c
lea(gp0, ptr [d1]);
mov(gp1, ptr [x]);
mov(gp2, ptr [y]);
call(mulPreL_);
// d2 = b d
lea(gp0, ptr [d2]);
mov(gp1, ptr [x]);
add(gp1, FpByte_);
mov(gp2, ptr [y]);
add(gp2, FpByte_);
call(mulPreL_);
lea(gp0, ptr [d0]);
mov(gp1, gp0);
lea(gp2, ptr [d1]);
gen_raw_sub(gp0, gp1, gp2, rax, 8);
lea(gp2, ptr [d2]);
gen_raw_sub(gp0, gp1, gp2, rax, 8);
// pz[1]
mov(gp0, ptr [z]);
add(gp0, FpByte_);
lea(gp1, ptr[d0]);
call(fpDbl_modL_);
lea(gp0, ptr [d1]);
mov(gp1, gp0);
lea(gp2, ptr [d2]);
gen_raw_sub(gp0, gp1, gp2, rax, 4);
gen_raw_fp_sub(gp0 + 8 * 4, gp1 + 8 * 4, gp2 + 8 * 4, Pack(gt0, gt1, gt2, gt3, gt4, gt5, gt6, gt7), true);
mov(gp0, ptr [z]);
lea(gp1, ptr[d1]);
call(fpDbl_modL_);
#endif
}
};
} } // mcl::fp

@ -11,6 +11,7 @@ void testBench(const G1& P, const G2& Q)
Fp x, y;
x.setHashOf("abc");
y.setHashOf("xyz");
#if 1
mpz_class z = 3;
mpz_class a = x.getMpz();
CYBOZU_BENCH_C("G1::mulCT ", C, G1::mulCT, Pa, P, a);
@ -46,6 +47,7 @@ void testBench(const G1& P, const G2& Q)
CYBOZU_BENCH_C("Fp::mul ", C3, Fp::mul, x, x, y);
CYBOZU_BENCH_C("Fp::sqr ", C3, Fp::sqr, x, x);
CYBOZU_BENCH_C("Fp::inv ", C3, Fp::inv, x, x);
#endif
Fp2 xx, yy;
xx.a = x;
xx.b = 3;

Loading…
Cancel
Save