add Fp2::sqr for bls12

dev
MITSUNARI Shigeo 6 years ago
parent a2bdc82c97
commit 2f68f703e4
  1. 59
      src/fp_generator.hpp
  2. 9
      test/bls12_test.cpp

@ -581,7 +581,7 @@ private:
gen_raw_fp_sub(pz, px, py, sf.t, false); gen_raw_fp_sub(pz, px, py, sf.t, false);
} }
/* /*
add(pz + offset, px + offset, py + offset); add(pz, px, py);
size of t1, t2 == 6 size of t1, t2 == 6
destroy t0, t1 destroy t0, t1
*/ */
@ -723,7 +723,7 @@ private:
} }
return 0; return 0;
} }
void gen_raw_fp_sub6(const RegExp& pz, const Reg64& px, const Reg64& py, int offset, const Pack& t, bool withCarry) void gen_raw_fp_sub6(const RegExp& pz, const RegExp& px, const RegExp& py, int offset, const Pack& t, bool withCarry)
{ {
load_rm(t, px + offset); load_rm(t, px + offset);
sub_rm(t, py + offset, withCarry); sub_rm(t, py + offset, withCarry);
@ -3654,7 +3654,6 @@ private:
{ {
if (isFullBit_) return 0; if (isFullBit_) return 0;
if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
// if (pn_ != 4) return 0;
align(16); align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
bool embedded = pn_ == 4; bool embedded = pn_ == 4;
@ -3730,17 +3729,11 @@ private:
} }
void2u gen_fp2_sqr() void2u gen_fp2_sqr()
{ {
if (isFullBit_) return 0;
if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
align(16); align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
if (pn_ == 4 && !isFullBit_) {
gen_fp2_sqr4();
return func;
}
return 0;
}
void gen_fp2_sqr4()
{
assert(!isFullBit_);
const RegExp y = rsp + 0 * 8; const RegExp y = rsp + 0 * 8;
const RegExp x = rsp + 1 * 8; const RegExp x = rsp + 1 * 8;
const Ext1 t1(FpByte_, rsp, 2 * 8); const Ext1 t1(FpByte_, rsp, 2 * 8);
@ -3753,7 +3746,7 @@ private:
// t1 = b + b // t1 = b + b
lea(gp0, ptr [t1]); lea(gp0, ptr [t1]);
if (nocarry) { if (nocarry) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < pn_; i++) {
mov(rax, ptr [gp1 + FpByte_ + i * 8]); mov(rax, ptr [gp1 + FpByte_ + i * 8]);
if (i == 0) { if (i == 0) {
add(rax, rax); add(rax, rax);
@ -3763,7 +3756,15 @@ private:
mov(ptr [gp0 + i * 8], rax); mov(ptr [gp0 + i * 8], rax);
} }
} else { } else {
gen_raw_fp_add(gp0, gp1 + FpByte_, gp1 + FpByte_, sf.t, false); if (pn_ == 4) {
gen_raw_fp_add(gp0, gp1 + FpByte_, gp1 + FpByte_, sf.t, false);
} else {
assert(pn_ == 6);
Pack t = sf.t.sub(6, 4);
t.append(rax);
t.append(rdx);
gen_raw_fp_add6(gp0, gp1 + FpByte_, gp1 + FpByte_, sf.t.sub(0, 6), t, false);
}
} }
// t1 = 2ab // t1 = 2ab
mov(gp1, gp0); mov(gp1, gp0);
@ -3771,13 +3772,16 @@ private:
call(fp_mulL); call(fp_mulL);
if (nocarry) { if (nocarry) {
Pack a = sf.t.sub(0, 4); Pack t = sf.t;
Pack b = sf.t.sub(4, 4); t.append(rdx);
t.append(gp1);
Pack a = t.sub(0, pn_);
Pack b = t.sub(pn_, pn_);
mov(gp0, ptr [x]); mov(gp0, ptr [x]);
load_rm(a, gp0); load_rm(a, gp0);
load_rm(b, gp0 + FpByte_); load_rm(b, gp0 + FpByte_);
// t2 = a + b // t2 = a + b
for (int i = 0; i < 4; i++) { for (int i = 0; i < pn_; i++) {
mov(rax, a[i]); mov(rax, a[i]);
if (i == 0) { if (i == 0) {
add(rax, b[i]); add(rax, b[i]);
@ -3787,14 +3791,24 @@ private:
mov(ptr [(RegExp)t2 + i * 8], rax); mov(ptr [(RegExp)t2 + i * 8], rax);
} }
// t3 = a + p - b // t3 = a + p - b
mov(gp1, (size_t)p_); mov(rax, pL_);
add_rm(a, gp1); add_rm(a, rax);
sub_rr(a, b); sub_rr(a, b);
store_mr(t3, a); store_mr(t3, a);
} else { } else {
mov(gp0, ptr [x]); mov(gp0, ptr [x]);
gen_raw_fp_add(t2, gp0, gp0 + FpByte_, sf.t, false); if (pn_ == 4) {
gen_raw_fp_sub(t3, gp0, gp0 + FpByte_, sf.t, false); gen_raw_fp_add(t2, gp0, gp0 + FpByte_, sf.t, false);
gen_raw_fp_sub(t3, gp0, gp0 + FpByte_, sf.t, false);
} else {
assert(pn_ == 6);
Pack p1 = sf.t.sub(0, 6);
Pack p2 = sf.t.sub(6, 4);
p2.append(rax);
p2.append(rdx);
gen_raw_fp_add6(t2, gp0, gp0 + FpByte_, p1, p2, false);
gen_raw_fp_sub6(t3, gp0, gp0 + FpByte_, 0, p1, false);
}
} }
mov(gp0, ptr [y]); mov(gp0, ptr [y]);
@ -3802,10 +3816,11 @@ private:
lea(gp2, ptr [t3]); lea(gp2, ptr [t3]);
call(fp_mulL); call(fp_mulL);
mov(gp0, ptr [y]); mov(gp0, ptr [y]);
for (int i = 0; i < 4; i++) { for (int i = 0; i < pn_; i++) {
mov(rax, ptr [(RegExp)t1 + i * 8]); mov(rax, ptr [(RegExp)t1 + i * 8]);
mov(ptr [gp0 + FpByte_ + i * 8], rax); mov(ptr [gp0 + FpByte_ + i * 8], rax);
} }
return func;
} }
}; };

@ -692,6 +692,9 @@ int main(int argc, char *argv[])
x2.b.setByCSPRNG(rg); x2.b.setByCSPRNG(rg);
y2.a.setByCSPRNG(rg); y2.a.setByCSPRNG(rg);
y2.b.setByCSPRNG(rg); y2.b.setByCSPRNG(rg);
Fp2Dbl x2d, y2d;
Fp2Dbl::mulPre(x2d, x2, x2);
Fp2Dbl::mulPre(y2d, x2, y2);
if(0){ if(0){
puts("----------"); puts("----------");
xv[0].dump(); xv[0].dump();
@ -700,8 +703,10 @@ if(0){
puts("----------"); puts("----------");
// exit(1); // exit(1);
} }
CYBOZU_BENCH_C("Fp2::neg", 10000000, Fp2::neg, x2, x2); // CYBOZU_BENCH_C("Fp2::neg", 10000000, Fp2::neg, x2, x2);
// CYBOZU_BENCH_C("mulPre", 100000000, FpDbl::mulPre, dx, xv[0], yv[0]); CYBOZU_BENCH_C("Fp2::sqr", 10000000, Fp2::sqr, x2, x2);
// CYBOZU_BENCH_C("Fp2::sqrPre", 100000000, Fp2Dbl::sqrPre, x2d, x2);
// CYBOZU_BENCH_C("Fp2::mulPre", 100000000, Fp2Dbl::mulPre, x2d, x2, y2);
// CYBOZU_BENCH_C("sqrPre", 100000000, FpDbl::sqrPre, dx, xv[0]); // CYBOZU_BENCH_C("sqrPre", 100000000, FpDbl::sqrPre, dx, xv[0]);
// CYBOZU_BENCH_C("mod ", 100000000, FpDbl::mod, xv[0], dx); // CYBOZU_BENCH_C("mod ", 100000000, FpDbl::mod, xv[0], dx);
// CYBOZU_BENCH_C("mul ", 100000000, Fp::mul, xv[0], yv[0], xv[0]); // CYBOZU_BENCH_C("mul ", 100000000, Fp::mul, xv[0], yv[0], xv[0]);

Loading…
Cancel
Save