use setProtectModeRE

dev
MITSUNARI Shigeo 6 years ago
parent c47fdddc17
commit 6b1f2f7e81
  1. 122
      src/fp_generator.hpp
  2. 4
      test/fp_generator_test.cpp

@ -126,7 +126,7 @@ if (rm.isReg()) { \
namespace fp { namespace fp {
struct Code : Xbyak::CodeGenerator { struct FpGenerator : Xbyak::CodeGenerator {
typedef Xbyak::RegExp RegExp; typedef Xbyak::RegExp RegExp;
typedef Xbyak::Reg64 Reg64; typedef Xbyak::Reg64 Reg64;
typedef Xbyak::Xmm Xmm; typedef Xbyak::Xmm Xmm;
@ -213,16 +213,12 @@ struct Code : Xbyak::CodeGenerator {
typedef int (*int2op)(uint64_t*, const uint64_t*); typedef int (*int2op)(uint64_t*, const uint64_t*);
void4u mul_; void4u mul_;
uint3opI mulUnit_; uint3opI mulUnit_;
// the following labels assume sf(this, 3, 10 | UseRDX)
Label mulPreL_;
Label fpDbl_modL_;
Label fp_mulL_;
/* /*
@param op [in] ; use op.p, op.N, op.isFullBit @param op [in] ; use op.p, op.N, op.isFullBit
*/ */
Code(size_t codeSize, uint8_t *mem, Op& op) FpGenerator()
: CodeGenerator(codeSize, mem) : CodeGenerator(4096 * 8, Xbyak::DontSetProtectRWE)
#ifdef XBYAK64_WIN #ifdef XBYAK64_WIN
, gp0(rcx) , gp0(rcx)
, gp1(r11) , gp1(r11)
@ -256,7 +252,21 @@ struct Code : Xbyak::CodeGenerator {
{ {
useMulx_ = cpu_.has(Xbyak::util::Cpu::tBMI2); useMulx_ = cpu_.has(Xbyak::util::Cpu::tBMI2);
useAdx_ = cpu_.has(Xbyak::util::Cpu::tADX); useAdx_ = cpu_.has(Xbyak::util::Cpu::tADX);
}
void init(Op& op)
{
reset(); // reset jit code for reuse
setProtectModeRW(); // read/write memory
init_inner(op);
setProtectModeRE(); // set read/exec memory
}
private:
void init_inner(Op& op)
{
// the following labels assume sf(this, 3, 10 | UseRDX)
Label mulPreL;
Label fpDbl_modL;
Label fp_mulL;
op_ = &op; op_ = &op;
p_ = op.p; p_ = op.p;
rp_ = fp::getMontgomeryCoeff(p_[0]); rp_ = fp::getMontgomeryCoeff(p_[0]);
@ -293,7 +303,7 @@ struct Code : Xbyak::CodeGenerator {
align(16); align(16);
op.fp_mul = getCurr<void4u>(); // used in toMont/fromMont op.fp_mul = getCurr<void4u>(); // used in toMont/fromMont
op.fp_mulA_ = getCurr<void3u>(); op.fp_mulA_ = getCurr<void3u>();
gen_mul(); gen_mul(fp_mulL);
align(16); align(16);
op.fp_sqrA_ = getCurr<void2u>(); op.fp_sqrA_ = getCurr<void2u>();
gen_sqr(); gen_sqr();
@ -327,9 +337,9 @@ struct Code : Xbyak::CodeGenerator {
op.fpDbl_modA_ = getCurr<void2u>(); op.fpDbl_modA_ = getCurr<void2u>();
if (op.N == 4) { if (op.N == 4) {
StackFrame sf(this, 3, 10 | UseRDX, 0, false); StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(fpDbl_modL_); call(fpDbl_modL);
sf.close(); sf.close();
L(fpDbl_modL_); L(fpDbl_modL);
gen_fpDbl_mod4(gp0, gp1, sf.t, gp2); gen_fpDbl_mod4(gp0, gp1, sf.t, gp2);
ret(); ret();
} else { } else {
@ -342,16 +352,16 @@ struct Code : Xbyak::CodeGenerator {
if (op.N == 4) { if (op.N == 4) {
/* /*
fpDbl_mulPre is available as C function fpDbl_mulPre is available as C function
this function calls mulPreL_ directly. this function calls mulPreL directly.
*/ */
StackFrame sf(this, 3, 10 | UseRDX, 0, false); StackFrame sf(this, 3, 10 | UseRDX, 0, false);
#if 0 #if 0
call(mulPreL_); call(mulPreL);
#else #else
mulPre4(gp0, gp1, gp2, sf.t); mulPre4(gp0, gp1, gp2, sf.t);
#endif #endif
sf.close(); // make epilog sf.close(); // make epilog
L(mulPreL_); // called only from asm code L(mulPreL); // called only from asm code
mulPre4(gp0, gp1, gp2, sf.t); mulPre4(gp0, gp1, gp2, sf.t);
ret(); ret();
} else { } else {
@ -375,13 +385,13 @@ struct Code : Xbyak::CodeGenerator {
gen_fp2_neg4(); gen_fp2_neg4();
align(16); align(16);
op.fp2Dbl_mulPreA_ = getCurr<void3u>(); op.fp2Dbl_mulPreA_ = getCurr<void3u>();
gen_fp2Dbl_mulPre(); gen_fp2Dbl_mulPre(mulPreL);
align(16); align(16);
op.fp2_mulA_ = getCurr<void3u>(); op.fp2_mulA_ = getCurr<void3u>();
gen_fp2_mul4(); gen_fp2_mul4(fpDbl_modL);
align(16); align(16);
op.fp2_sqrA_ = getCurr<void2u>(); op.fp2_sqrA_ = getCurr<void2u>();
gen_fp2_sqr4(); gen_fp2_sqr4(fp_mulL);
align(16); align(16);
op.fp2_mul_xiA_ = getCurr<void2u>(); op.fp2_mul_xiA_ = getCurr<void2u>();
gen_fp2_mul_xi4(); gen_fp2_mul_xi4();
@ -730,7 +740,7 @@ struct Code : Xbyak::CodeGenerator {
shr(*t0, c); shr(*t0, c);
mov(ptr [pz + (pn_ - 1) * 8], *t0); mov(ptr [pz + (pn_ - 1) * 8], *t0);
} }
void gen_mul() void gen_mul(Label& fp_mulL)
{ {
if (op_->primeMode == PM_NIST_P192) { if (op_->primeMode == PM_NIST_P192) {
StackFrame sf(this, 3, 10 | UseRDX, 8 * 6); StackFrame sf(this, 3, 10 | UseRDX, 8 * 6);
@ -740,11 +750,11 @@ struct Code : Xbyak::CodeGenerator {
if (pn_ == 3) { if (pn_ == 3) {
gen_montMul3(p_, rp_); gen_montMul3(p_, rp_);
} else if (pn_ == 4) { } else if (pn_ == 4) {
gen_montMul4(p_, rp_); gen_montMul4(fp_mulL, p_, rp_);
} else if (pn_ <= 9) { } else if (pn_ <= 9) {
gen_montMulN(p_, rp_, pn_); gen_montMulN(p_, rp_, pn_);
} else { } else {
throw cybozu::Exception("mcl:Code:gen_mul:not implemented for") << pn_; throw cybozu::Exception("mcl:FpGenerator:gen_mul:not implemented for") << pn_;
} }
} }
/* /*
@ -1108,10 +1118,10 @@ struct Code : Xbyak::CodeGenerator {
z[0..3] <- montgomery(x[0..3], y[0..3]) z[0..3] <- montgomery(x[0..3], y[0..3])
destroy gt0, ..., gt9, xm0, xm1, p2 destroy gt0, ..., gt9, xm0, xm1, p2
*/ */
void gen_montMul4(const uint64_t *p, uint64_t pp) void gen_montMul4(Label& fp_mulL, const uint64_t *p, uint64_t pp)
{ {
StackFrame sf(this, 3, 10 | UseRDX, 0, false); StackFrame sf(this, 3, 10 | UseRDX, 0, false);
call(fp_mulL_); call(fp_mulL);
sf.close(); sf.close();
const Reg64& p0 = sf.p[0]; const Reg64& p0 = sf.p[0];
const Reg64& p1 = sf.p[1]; const Reg64& p1 = sf.p[1];
@ -1128,7 +1138,7 @@ struct Code : Xbyak::CodeGenerator {
const Reg64& t8 = sf.t[8]; const Reg64& t8 = sf.t[8];
const Reg64& t9 = sf.t[9]; const Reg64& t9 = sf.t[9];
L(fp_mulL_); L(fp_mulL);
movq(xm0, p0); // save p0 movq(xm0, p0); // save p0
mov(p0, (uint64_t)p); mov(p0, (uint64_t)p);
movq(xm1, p2); movq(xm1, p2);
@ -2001,7 +2011,7 @@ struct Code : Xbyak::CodeGenerator {
assert(pn_ >= 1); assert(pn_ >= 1);
const int freeRegNum = 13; const int freeRegNum = 13;
if (pn_ > 9) { if (pn_ > 9) {
throw cybozu::Exception("mcl:Code:gen_preInv:large pn_") << pn_; throw cybozu::Exception("mcl:FpGenerator:gen_preInv:large pn_") << pn_;
} }
StackFrame sf(this, 2, 10 | UseRDX | UseRCX, (std::max<int>(0, pn_ * 5 - freeRegNum) + 1 + (isFullBit_ ? 1 : 0)) * 8); StackFrame sf(this, 2, 10 | UseRDX | UseRCX, (std::max<int>(0, pn_ * 5 - freeRegNum) + 1 + (isFullBit_ ? 1 : 0)) * 8);
const Reg64& pr = sf.p[0]; const Reg64& pr = sf.p[0];
@ -2275,8 +2285,8 @@ struct Code : Xbyak::CodeGenerator {
L("@@"); L("@@");
} }
private: private:
Code(const Code&); FpGenerator(const FpGenerator&);
void operator=(const Code&); void operator=(const FpGenerator&);
void make_op_rm(void (Xbyak::CodeGenerator::*op)(const Xbyak::Operand&, const Xbyak::Operand&), const Reg64& op1, const MemReg& op2) void make_op_rm(void (Xbyak::CodeGenerator::*op)(const Xbyak::Operand&, const Xbyak::Operand&), const Reg64& op1, const MemReg& op2)
{ {
if (op2.isReg()) { if (op2.isReg()) {
@ -2804,7 +2814,7 @@ private:
} }
} }
} }
void gen_fp2Dbl_mulPre() void gen_fp2Dbl_mulPre(Label& mulPreL)
{ {
assert(!isFullBit_); assert(!isFullBit_);
const RegExp z = rsp + 0 * 8; const RegExp z = rsp + 0 * 8;
@ -2827,12 +2837,12 @@ private:
add(gp0, FpByte_ * 2); // d1 add(gp0, FpByte_ * 2); // d1
lea(gp1, ptr [s]); lea(gp1, ptr [s]);
lea(gp2, ptr [t]); lea(gp2, ptr [t]);
call(mulPreL_); call(mulPreL);
// d0 = a c // d0 = a c
mov(gp0, ptr [z]); mov(gp0, ptr [z]);
mov(gp1, ptr [x]); mov(gp1, ptr [x]);
mov(gp2, ptr [y]); mov(gp2, ptr [y]);
call(mulPreL_); call(mulPreL);
// d2 = b d // d2 = b d
lea(gp0, ptr [d2]); lea(gp0, ptr [d2]);
@ -2840,7 +2850,7 @@ private:
add(gp1, FpByte_); add(gp1, FpByte_);
mov(gp2, ptr [y]); mov(gp2, ptr [y]);
add(gp2, FpByte_); add(gp2, FpByte_);
call(mulPreL_); call(mulPreL);
mov(gp0, ptr [z]); mov(gp0, ptr [z]);
add(gp0, FpByte_ * 2); // d1 add(gp0, FpByte_ * 2); // d1
@ -2928,7 +2938,7 @@ private:
gen_raw_neg(sf.p[0], sf.p[1], sf.t); gen_raw_neg(sf.p[0], sf.p[1], sf.t);
gen_raw_neg(sf.p[0] + FpByte_, sf.p[1] + FpByte_, sf.t); gen_raw_neg(sf.p[0] + FpByte_, sf.p[1] + FpByte_, sf.t);
} }
void gen_fp2_mul4() void gen_fp2_mul4(Label& fpDbl_modL)
{ {
assert(!isFullBit_); assert(!isFullBit_);
const RegExp z = rsp + 0 * 8; const RegExp z = rsp + 0 * 8;
@ -2969,14 +2979,14 @@ private:
mov(gp0, ptr [z]); mov(gp0, ptr [z]);
lea(gp1, ptr[d0]); lea(gp1, ptr[d0]);
call(fpDbl_modL_); call(fpDbl_modL);
mov(gp0, ptr [z]); mov(gp0, ptr [z]);
add(gp0, FpByte_); add(gp0, FpByte_);
lea(gp1, ptr[d1]); lea(gp1, ptr[d1]);
call(fpDbl_modL_); call(fpDbl_modL);
} }
void gen_fp2_sqr4() void gen_fp2_sqr4(Label& fp_mulL)
{ {
assert(!isFullBit_); assert(!isFullBit_);
const RegExp y = rsp + 0 * 8; const RegExp y = rsp + 0 * 8;
@ -3006,7 +3016,7 @@ private:
// t1 = 2ab // t1 = 2ab
mov(gp1, gp0); mov(gp1, gp0);
mov(gp2, ptr [x]); mov(gp2, ptr [x]);
call(fp_mulL_); call(fp_mulL);
if (nocarry) { if (nocarry) {
Pack a = sf.t.sub(0, 4); Pack a = sf.t.sub(0, 4);
@ -3038,7 +3048,7 @@ private:
mov(gp0, ptr [y]); mov(gp0, ptr [y]);
lea(gp1, ptr [t2]); lea(gp1, ptr [t2]);
lea(gp2, ptr [t3]); lea(gp2, ptr [t3]);
call(fp_mulL_); call(fp_mulL);
mov(gp0, ptr [y]); mov(gp0, ptr [y]);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
mov(rax, ptr [(RegExp)t1 + i * 8]); mov(rax, ptr [(RegExp)t1 + i * 8]);
@ -3047,46 +3057,6 @@ private:
} }
}; };
struct FpGenerator {
static const size_t codeSize = 4096 * 8;
static const size_t pageSize = 4096;
uint8_t *mem;
Code *code;
FpGenerator()
: mem((uint8_t*)cybozu::AlignedMalloc(codeSize, pageSize))
, code(0)
{
}
void init(Op& op)
{
if (code) {
setRW();
delete code;
}
code = new Code(codeSize, mem, op);
setRE();
}
~FpGenerator()
{
setRW();
delete code;
cybozu::AlignedFree(mem);
}
private:
FpGenerator(const FpGenerator&);
void operator==(const FpGenerator&);
void setRW()
{
Xbyak::CodeArray::protect(mem, codeSize, Xbyak::CodeArray::PROTECT_RW);
}
void setRE()
{
if (!Xbyak::CodeArray::protect(mem, codeSize, Xbyak::CodeArray::PROTECT_RE)) {
throw cybozu::Exception("err protect read/exec");
}
}
};
} } // mcl::fp } } // mcl::fp
#ifdef _MSC_VER #ifdef _MSC_VER

@ -152,7 +152,7 @@ void testMulI(const mcl::fp::FpGenerator& fg, int pn)
mpz_class my; mpz_class my;
mcl::gmp::set(my, y); mcl::gmp::set(my, y);
mx *= my; mx *= my;
uint64_t d = fg.code->mulUnit_(z, x, y); uint64_t d = fg.mulUnit_(z, x, y);
z[pn] = d; z[pn] = d;
mcl::gmp::setArray(my, z, pn + 1); mcl::gmp::setArray(my, z, pn + 1);
CYBOZU_TEST_EQUAL(mx, my); CYBOZU_TEST_EQUAL(mx, my);
@ -162,7 +162,7 @@ void testMulI(const mcl::fp::FpGenerator& fg, int pn)
uint64_t z[MAX_N + 1]; uint64_t z[MAX_N + 1];
rg.read(x, pn); rg.read(x, pn);
uint64_t y = rg.get64(); uint64_t y = rg.get64();
CYBOZU_BENCH_C("mulUnit", 10000000, fg.code->mulUnit_, z, x, y); CYBOZU_BENCH_C("mulUnit", 10000000, fg.mulUnit_, z, x, y);
} }
} }

Loading…
Cancel
Save