optimize Fp2Dbl::mul_xi

update-fork
MITSUNARI Shigeo 4 years ago
parent b46aa28d8e
commit 371399aa14
  1. 25
      include/mcl/fp_tower.hpp
  2. 2
      include/mcl/op.hpp
  3. 60
      src/fp_generator.hpp
  4. 2
      src/fp_static_code.hpp
  5. 23
      test/common_test.hpp

@ -16,6 +16,7 @@ class FpDblT : public fp::Serializable<FpDblT<Fp> > {
Unit v_[Fp::maxSize * 2];
public:
static size_t getUnitSize() { return Fp::op_.N * 2; }
const fp::Unit *getUnit() const { return v_; }
void dump() const
{
const size_t n = getUnitSize();
@ -662,15 +663,16 @@ struct Fp2DblT {
FpDbl::neg(y.a, x.a);
FpDbl::neg(y.b, x.b);
}
static void mul_xi(Fp2DblT& y, const Fp2DblT& x)
static void mul_xi_1C(Fp2DblT& y, const Fp2DblT& x)
{
const uint32_t xi_a = Fp2::get_xi_a();
if (xi_a == 1) {
FpDbl t;
FpDbl::add(t, x.a, x.b);
FpDbl::sub(y.a, x.a, x.b);
y.b = t;
} else {
}
static void mul_xi_genericC(Fp2DblT& y, const Fp2DblT& x)
{
const uint32_t xi_a = Fp2::get_xi_a();
FpDbl t;
FpDbl::mulUnit(t, x.a, xi_a);
FpDbl::sub(t, t, x.b);
@ -678,9 +680,9 @@ struct Fp2DblT {
FpDbl::add(y.b, y.b, x.a);
y.a = t;
}
}
static void (*mulPre)(Fp2DblT&, const Fp2&, const Fp2&);
static void (*sqrPre)(Fp2DblT&, const Fp2&);
static void (*mul_xi)(Fp2DblT&, const Fp2DblT&);
static void mod(Fp2& y, const Fp2DblT& x)
{
FpDbl::mod(y.a, x.a);
@ -716,6 +718,18 @@ struct Fp2DblT {
sqrPre = fp2Dbl_sqrPreW<false>;
}
}
const uint32_t xi_a = Fp2::get_xi_a();
switch (xi_a) {
case 1:
mul_xi = mul_xi_1C;
if (op.fp2Dbl_mul_xiA_) {
mul_xi = fp::func_ptr_cast<void (*)(Fp2DblT&, const Fp2DblT&)>(op.fp2Dbl_mul_xiA_);
}
break;
default:
mul_xi = mul_xi_genericC;
break;
}
}
/*
Fp2Dbl::mulPre by FpDblT
@ -770,6 +784,7 @@ struct Fp2DblT {
template<class Fp> void (*Fp2DblT<Fp>::mulPre)(Fp2DblT&, const Fp2T<Fp>&, const Fp2T<Fp>&);
template<class Fp> void (*Fp2DblT<Fp>::sqrPre)(Fp2DblT&, const Fp2T<Fp>&);
template<class Fp> void (*Fp2DblT<Fp>::mul_xi)(Fp2DblT<Fp>&, const Fp2DblT<Fp>&);
template<class Fp> Fp2T<Fp> Fp2T<Fp>::g[Fp2T<Fp>::gN];
template<class Fp> Fp2T<Fp> Fp2T<Fp>::g2[Fp2T<Fp>::gN];

@ -226,6 +226,7 @@ struct Op {
void2u fpDbl_modA_;
void3u fp2Dbl_mulPreA_;
void2u fp2Dbl_sqrPreA_;
void2u fp2Dbl_mul_xiA_;
size_t maxN;
size_t N;
size_t bitSize;
@ -314,6 +315,7 @@ struct Op {
fpDbl_modA_ = 0;
fp2Dbl_mulPreA_ = 0;
fp2Dbl_sqrPreA_ = 0;
fp2Dbl_mul_xiA_ = 0;
maxN = 0;
N = 0;
bitSize = 0;

@ -457,6 +457,10 @@ private:
op.fp2Dbl_sqrPreA_ = gen_fp2Dbl_sqrPre();
if (op.fp2Dbl_sqrPreA_) setFuncInfo(prof_, suf, "2Dbl_sqrPre", op.fp2Dbl_sqrPreA_, getCurr());
align(16);
op.fp2Dbl_mul_xiA_ = gen_fp2Dbl_mul_xi();
if (op.fp2Dbl_mul_xiA_) setFuncInfo(prof_, suf, "2Dbl_mul_xi", op.fp2Dbl_mul_xiA_, getCurr());
align(16);
op.fp2_mulA_ = gen_fp2_mul();
setFuncInfo(prof_, suf, "2_mul", op.fp2_mulA_, getCurr());
@ -3173,6 +3177,13 @@ private:
}
}
}
// y[i] &= t
void andPack(const Pack& y, const Reg64& t)
{
for (int i = 0; i < (int)y.size(); i++) {
and_(y[i], t);
}
}
/*
[rdx:x:t0] <- py[1:0] * x
destroy x, t
@ -3647,6 +3658,55 @@ private:
call(mulPreL);
return func;
}
void2u gen_fp2Dbl_mul_xi()
{
if (isFullBit_) return 0;
if (op_->xi_a != 1) return 0;
void2u func = getCurr<void2u>();
// y = (x.a - x.b, x.a + x.b)
StackFrame sf(this, 2, pn_ * 2, FpByte_ * 2);
Pack t1 = sf.t.sub(0, pn_);
Pack t2 = sf.t.sub(pn_, pn_);
const RegExp& ya = sf.p[0];
const RegExp& yb = sf.p[0] + FpByte_ * 2;
const RegExp& xa = sf.p[1];
const RegExp& xb = sf.p[1] + FpByte_ * 2;
// [rsp] = x.a + x.b
for (int i = 0; i < pn_ * 2; i++) {
mov(rax, ptr[xa + i * 8]);
if (i == 0) {
add(rax, ptr[xb + i * 8]);
} else {
adc(rax, ptr[xb + i * 8]);
}
mov(ptr[rsp + i * 8], rax);
}
// low : x.a = x.a - x.b
load_rm(t1, xa);
sub_rm(t1, xb);
store_mr(ya, t1);
// high : x.a = (x.a - x.b) % p
load_rm(t1, xa + FpByte_);
sub_rm(t1, xb + FpByte_, true);
lea(rax, ptr[rip + pL_]);
load_rm(t2, rax); // t2 = p
sbb(rax, rax);
andPack(t2, rax);
add_rr(t1, t2); // mod p
store_mr(ya + FpByte_, t1);
// low : y.b = [rsp]
for (int i = 0; i < pn_; i++) {
mov(rax, ptr[rsp + i * 8]);
mov(ptr[yb + i * 8], rax);
}
// high : y.b = (x.a + x.b) % p
load_rm(t1, rsp + FpByte_);
lea(rax, ptr[rip + pL_]);
sub_p_mod(t2, t1, rax);
store_mr(yb + FpByte_, t2);
return func;
}
void gen_fp2_add4()
{
assert(!isFullBit_);

@ -38,6 +38,7 @@ void mclx_Fp2_mul(Unit*, const Unit*, const Unit*);
void mclx_Fp2_sqr(Unit*, const Unit*);
void mclx_Fp2_mul2(Unit*, const Unit*);
void mclx_Fp2_mul_xi(Unit*, const Unit*);
void mclx_Fp2Dbl_mul_xi(Unit*, const Unit*);
Unit mclx_Fr_addPre(Unit*, const Unit*, const Unit*);
Unit mclx_Fr_subPre(Unit*, const Unit*, const Unit*);
@ -79,6 +80,7 @@ void setStaticCode(mcl::fp::Op& op)
op.fp2_sqrA_ = mclx_Fp2_sqr;
op.fp2_mul2A_ = mclx_Fp2_mul2;
op.fp2_mul_xiA_ = mclx_Fp2_mul_xi;
op.fp2Dbl_mul_xiA_ = mclx_Fp2Dbl_mul_xi;
op.fp_preInv = mclx_Fp_preInv;
} else {
// Fr, sizeof(Fr) = 32

@ -161,8 +161,31 @@ void testABCD()
}
}
void testFp2Dbl_mul_xi1()
{
if (Fp2::get_xi_a() != 1) return;
puts("testFp2Dbl_mul_xi1");
cybozu::XorShift rg;
for (int i = 0; i < 100; i++) {
Fp a1, a2;
a1.setByCSPRNG(rg);
a2.setByCSPRNG(rg);
Fp2Dbl x;
FpDbl::mulPre(x.a, a1, a2);
a1.setByCSPRNG(rg);
a2.setByCSPRNG(rg);
FpDbl::mulPre(x.b, a1, a2);
Fp2Dbl ok;
Fp2Dbl::mul_xi_1C(x, x);
Fp2Dbl::mul_xi(x, x);
CYBOZU_TEST_EQUAL_ARRAY(ok.a.getUnit(), x.a.getUnit(), ok.a.getUnitSize());
CYBOZU_TEST_EQUAL_ARRAY(ok.b.getUnit(), x.b.getUnit(), ok.b.getUnitSize());
}
}
void testCommon(const G1& P, const G2& Q)
{
testFp2Dbl_mul_xi1();
testABCD();
testMul2();
puts("G1");

Loading…
Cancel
Save