refactor DumpCode

update-fork
MITSUNARI Shigeo 4 years ago
parent c29157cc9a
commit f11b3be1ab
  1. 182
      src/fp_generator.hpp
  2. 8
      test/static_code_test.cpp

@ -23,46 +23,34 @@
#pragma warning(disable : 4458) #pragma warning(disable : 4458)
#endif #endif
//#define MCL_DUMP_JIT
namespace mcl { namespace mcl {
#ifdef MCL_STATIC_JIT
typedef fp::Unit Unit;
extern "C" {
Unit mclx_Fr_addPre(Unit*, const Unit*, const Unit*);
void mclx_Fr_add(Unit*, const Unit*, const Unit*);
Unit mclx_Fp_addPre(Unit*, const Unit*, const Unit*);
void mclx_Fp_add(Unit*, const Unit*, const Unit*);
}
#endif
#ifdef MCL_DUMP_JIT #ifdef MCL_DUMP_JIT
// not profiler, but dump jit code struct DumpCode {
struct Profiler {
FILE *fp_; FILE *fp_;
const uint8_t *prev_; DumpCode()
std::string suf_;
Profiler()
: fp_(stdout) : fp_(stdout)
, prev_(0)
{
}
~Profiler()
{
// if (fp_) fclose(fp_);
}
#if 0
void open(const std::string& fileName)
{ {
fp_ = fopen(fileName.c_str(), "wb");
} }
#endif void set(const std::string& name, const uint8_t *begin, const size_t size)
void setStartAddr(const uint8_t *addr)
{ {
prev_ = addr; fprintf(fp_, "segment .text\n");
} fprintf(fp_, "global %s\n", name.c_str());
void setNameSuffix(const std::string& suf)
{
suf_ = suf;
}
void set(const char *name, const uint8_t *end)
{
fprintf(fp_, "global %s%s\n", suf_.c_str(), name);
fprintf(fp_, "align 16\n"); fprintf(fp_, "align 16\n");
fprintf(fp_, "%s%s:\n", suf_.c_str(), name); fprintf(fp_, "%s:\n", name.c_str());
const uint8_t *p = prev_; const uint8_t *p = begin;
size_t remain = end - prev_; size_t remain = size;
while (remain > 0) { while (remain > 0) {
size_t n = remain >= 16 ? 16 : remain; size_t n = remain >= 16 ? 16 : remain;
fprintf(fp_, "db "); fprintf(fp_, "db ");
@ -72,7 +60,6 @@ struct Profiler {
fprintf(fp_, "\n"); fprintf(fp_, "\n");
remain -= n; remain -= n;
} }
prev_ = end;
} }
void dumpData(const void *begin, const void *end) void dumpData(const void *begin, const void *end)
{ {
@ -87,8 +74,19 @@ struct Profiler {
fprintf(fp_, "\n"); fprintf(fp_, "\n");
} }
}; };
template<class T>
void setFuncInfo(DumpCode& prof, const char *suf, const char *name, const T& begin, const uint8_t* end)
{
const uint8_t*p = (const uint8_t*)begin;
prof.set(std::string("mclx_") + suf + name, p, end - p);
}
#else #else
typedef Xbyak::util::Profiler Profiler; template<class T>
void setFuncInfo(Xbyak::util::Profiler& prof, const char *suf, const char *name, const T& begin, const uint8_t* end)
{
const uint8_t*p = (const uint8_t*)begin;
prof.set((std::string("mclx_") + suf + name).c_str(), p, end - p);
}
#endif #endif
namespace fp_gen_local { namespace fp_gen_local {
@ -269,7 +267,11 @@ struct FpGenerator : Xbyak::CodeGenerator {
int pn_; int pn_;
int FpByte_; int FpByte_;
bool isFullBit_; bool isFullBit_;
Profiler prof_; #ifdef MCL_DUMP_JIT
DumpCode prof_;
#else
Xbyak::util::Profiler prof_;
#endif
/* /*
@param op [in] ; use op.p, op.N, op.isFullBit @param op [in] ; use op.p, op.N, op.isFullBit
@ -331,8 +333,6 @@ private:
} }
#ifdef MCL_DUMP_JIT #ifdef MCL_DUMP_JIT
prof_.dumpData(p_, getCurr()); prof_.dumpData(p_, getCurr());
prof_.setStartAddr(getCurr());
prof_.setNameSuffix(std::string("mclx_") + suf);
#endif #endif
rp_ = fp::getMontgomeryCoeff(p_[0]); rp_ = fp::getMontgomeryCoeff(p_[0]);
pn_ = (int)op.N; pn_ = (int)op.N;
@ -351,97 +351,130 @@ private:
if (profMode) { if (profMode) {
prof_.init(profMode); prof_.init(profMode);
prof_.setStartAddr(getCurr()); prof_.setStartAddr(getCurr());
if (suf == 0) suf = "fp";
prof_.setNameSuffix(suf);
suf[1]++;
} }
#else #else
(void)suf; (void)suf;
#endif #endif
align(16);
op.fp_addPre = gen_addSubPre(true, pn_); op.fp_addPre = gen_addSubPre(true, pn_);
prof_.set("_addPre", getCurr()); setFuncInfo(prof_, suf, "_addPre", op.fp_addPre, getCurr());
align(16);
op.fp_subPre = gen_addSubPre(false, pn_); op.fp_subPre = gen_addSubPre(false, pn_);
prof_.set("_subPre", getCurr()); setFuncInfo(prof_, suf, "_subPre", op.fp_subPre, getCurr());
align(16);
op.fp_addA_ = gen_fp_add(); op.fp_addA_ = gen_fp_add();
prof_.set("_add", getCurr()); setFuncInfo(prof_, suf, "_add", op.fp_addA_, getCurr());
op.fp_subA_ = gen_fp_sub(); op.fp_subA_ = gen_fp_sub();
prof_.set("_sub", getCurr()); setFuncInfo(prof_, suf, "_sub", op.fp_subA_, getCurr());
align(16);
op.fp_shr1 = gen_shr1(); op.fp_shr1 = gen_shr1();
prof_.set("_shr1", getCurr()); setFuncInfo(prof_, suf, "_shr1", op.fp_shr1, getCurr());
align(16);
op.fp_negA_ = gen_fp_neg(); op.fp_negA_ = gen_fp_neg();
prof_.set("_neg", getCurr()); setFuncInfo(prof_, suf, "_neg", op.fp_negA_, getCurr());
align(16);
op.fpDbl_addA_ = gen_fpDbl_add(); op.fpDbl_addA_ = gen_fpDbl_add();
prof_.set("Dbl_add", getCurr()); setFuncInfo(prof_, suf, "Dbl_add", op.fpDbl_addA_, getCurr());
align(16);
op.fpDbl_subA_ = gen_fpDbl_sub(); op.fpDbl_subA_ = gen_fpDbl_sub();
prof_.set("Dbl_sub", getCurr()); setFuncInfo(prof_, suf, "Dbl_sub", op.fpDbl_subA_, getCurr());
align(16);
op.fpDbl_addPre = gen_addSubPre(true, pn_ * 2); op.fpDbl_addPre = gen_addSubPre(true, pn_ * 2);
prof_.set("Dbl_addPre", getCurr()); setFuncInfo(prof_, suf, "Dbl_addPre", op.fpDbl_addPre, getCurr());
align(16);
op.fpDbl_subPre = gen_addSubPre(false, pn_ * 2); op.fpDbl_subPre = gen_addSubPre(false, pn_ * 2);
prof_.set("Dbl_subPre", getCurr()); setFuncInfo(prof_, suf, "Dbl_subPre", op.fpDbl_subPre, getCurr());
align(16);
op.fpDbl_mulPreA_ = gen_fpDbl_mulPre(); op.fpDbl_mulPreA_ = gen_fpDbl_mulPre();
prof_.set("Dbl_mulPre", getCurr()); setFuncInfo(prof_, suf, "Dbl_mulPre", op.fpDbl_mulPreA_, getCurr());
align(16);
op.fpDbl_sqrPreA_ = gen_fpDbl_sqrPre(); op.fpDbl_sqrPreA_ = gen_fpDbl_sqrPre();
prof_.set("Dbl_sqrPre", getCurr()); setFuncInfo(prof_, suf, "Dbl_sqrPre", op.fpDbl_sqrPreA_, getCurr());
align(16);
op.fpDbl_modA_ = gen_fpDbl_mod(op); op.fpDbl_modA_ = gen_fpDbl_mod(op);
prof_.set("Dbl_mod", getCurr()); setFuncInfo(prof_, suf, "Dbl_mod", op.fpDbl_modA_, getCurr());
align(16);
op.fp_mulA_ = gen_mul(); op.fp_mulA_ = gen_mul();
prof_.set("_mul", getCurr()); setFuncInfo(prof_, suf, "_mul", op.fp_mulA_, getCurr());
align(16);
if (op.fp_mulA_) { if (op.fp_mulA_) {
op.fp_mul = fp::func_ptr_cast<void4u>(op.fp_mulA_); // used in toMont/fromMont op.fp_mul = fp::func_ptr_cast<void4u>(op.fp_mulA_); // used in toMont/fromMont
} }
op.fp_sqrA_ = gen_sqr(); op.fp_sqrA_ = gen_sqr();
prof_.set("_sqr", getCurr()); setFuncInfo(prof_, suf, "_sqr", op.fp_sqrA_, getCurr());
align(16);
if (op.primeMode != PM_NIST_P192 && op.N <= 4) { // support general op.N but not fast for op.N > 4 if (op.primeMode != PM_NIST_P192 && op.N <= 4) { // support general op.N but not fast for op.N > 4
align(16);
op.fp_preInv = getCurr<int2u>(); op.fp_preInv = getCurr<int2u>();
gen_preInv(); gen_preInv();
prof_.set("_preInv", getCurr()); setFuncInfo(prof_, suf, "_preInv", op.fp_preInv, getCurr());
align(16);
} }
if (op.xi_a == 0) return; // Fp2 is not used if (op.xi_a == 0) return; // Fp2 is not used
op.fp2_addA_ = gen_fp2_add(); op.fp2_addA_ = gen_fp2_add();
prof_.set("2_add", getCurr()); setFuncInfo(prof_, suf, "2_add", op.fp2_addA_, getCurr());
align(16);
op.fp2_subA_ = gen_fp2_sub(); op.fp2_subA_ = gen_fp2_sub();
prof_.set("2_sub", getCurr()); setFuncInfo(prof_, suf, "2_sub", op.fp2_subA_, getCurr());
align(16);
op.fp2_negA_ = gen_fp2_neg(); op.fp2_negA_ = gen_fp2_neg();
prof_.set("2_neg", getCurr()); setFuncInfo(prof_, suf, "2_neg", op.fp2_negA_, getCurr());
align(16);
op.fp2_mulNF = 0; op.fp2_mulNF = 0;
op.fp2Dbl_mulPreA_ = gen_fp2Dbl_mulPre(); op.fp2Dbl_mulPreA_ = gen_fp2Dbl_mulPre();
prof_.set("2Dbl_mulPre", getCurr()); if (op.fp2Dbl_mulPreA_) setFuncInfo(prof_, suf, "2Dbl_mulPre", op.fp2Dbl_mulPreA_, getCurr());
align(16);
op.fp2Dbl_sqrPreA_ = gen_fp2Dbl_sqrPre(); op.fp2Dbl_sqrPreA_ = gen_fp2Dbl_sqrPre();
prof_.set("2Dbl_sqrPre", getCurr()); if (op.fp2Dbl_sqrPreA_) setFuncInfo(prof_, suf, "2Dbl_sqrPre", op.fp2Dbl_sqrPreA_, getCurr());
align(16);
op.fp2_mulA_ = gen_fp2_mul(); op.fp2_mulA_ = gen_fp2_mul();
prof_.set("2_mul", getCurr()); setFuncInfo(prof_, suf, "2_mul", op.fp2_mulA_, getCurr());
align(16);
op.fp2_sqrA_ = gen_fp2_sqr(); op.fp2_sqrA_ = gen_fp2_sqr();
prof_.set("2_sqr", getCurr()); setFuncInfo(prof_, suf, "2_sqr", op.fp2_sqrA_, getCurr());
align(16);
op.fp2_mul_xiA_ = gen_fp2_mul_xi(); op.fp2_mul_xiA_ = gen_fp2_mul_xi();
prof_.set("2_mul_xi", getCurr()); setFuncInfo(prof_, suf, "2_mul_xi", op.fp2_mul_xiA_, getCurr());
align(16);
#ifdef MCL_STATIC_JIT
const bool isFp = strcmp(suf, "Fp") == 0;
printf("isFp=%d\n", isFp);
if (isFp) {
op.fp_addPre = mclx_Fp_addPre;
op.fp_addA_ = mclx_Fr_add;
} else {
op.fp_addPre = mclx_Fr_addPre;
op.fp_addA_ = mclx_Fr_add;
}
#endif
} }
u3u gen_addSubPre(bool isAdd, int n) u3u gen_addSubPre(bool isAdd, int n)
{ {
// if (isFullBit_) return 0; // if (isFullBit_) return 0;
align(16);
u3u func = getCurr<u3u>(); u3u func = getCurr<u3u>();
StackFrame sf(this, 3); StackFrame sf(this, 3);
if (isAdd) { if (isAdd) {
@ -721,7 +754,6 @@ private:
} }
void3u gen_fp_add() void3u gen_fp_add()
{ {
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
if (pn_ <= 4) { if (pn_ <= 4) {
gen_fp_add_le4(); gen_fp_add_le4();
@ -769,7 +801,6 @@ private:
} }
void3u gen_fpDbl_add() void3u gen_fpDbl_add()
{ {
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
if (pn_ <= 4) { if (pn_ <= 4) {
int tn = pn_ * 2 + (isFullBit_ ? 1 : 0); int tn = pn_ * 2 + (isFullBit_ ? 1 : 0);
@ -797,7 +828,6 @@ private:
} }
void3u gen_fpDbl_sub() void3u gen_fpDbl_sub()
{ {
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
if (pn_ <= 4) { if (pn_ <= 4) {
int tn = pn_ * 2; int tn = pn_ * 2;
@ -847,7 +877,6 @@ private:
} }
void3u gen_fp_sub() void3u gen_fp_sub()
{ {
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
if (pn_ <= 4) { if (pn_ <= 4) {
gen_fp_sub_le4(); gen_fp_sub_le4();
@ -872,7 +901,6 @@ private:
} }
void2u gen_fp_neg() void2u gen_fp_neg()
{ {
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
StackFrame sf(this, 2, UseRDX | pn_); StackFrame sf(this, 2, UseRDX | pn_);
gen_raw_neg(sf.p[0], sf.p[1], sf.t); gen_raw_neg(sf.p[0], sf.p[1], sf.t);
@ -880,7 +908,6 @@ private:
} }
void2u gen_shr1() void2u gen_shr1()
{ {
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
const int c = 1; const int c = 1;
StackFrame sf(this, 2, 1); StackFrame sf(this, 2, 1);
@ -901,7 +928,6 @@ private:
} }
void3u gen_mul() void3u gen_mul()
{ {
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
if (op_->primeMode == PM_NIST_P192) { if (op_->primeMode == PM_NIST_P192) {
StackFrame sf(this, 3, 10 | UseRDX, 8 * 6); StackFrame sf(this, 3, 10 | UseRDX, 8 * 6);
@ -1214,7 +1240,6 @@ private:
} }
void2u gen_fpDbl_mod(const fp::Op& op) void2u gen_fpDbl_mod(const fp::Op& op)
{ {
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
if (op.primeMode == PM_NIST_P192) { if (op.primeMode == PM_NIST_P192) {
StackFrame sf(this, 2, 6 | UseRDX); StackFrame sf(this, 2, 6 | UseRDX);
@ -1260,7 +1285,6 @@ private:
} }
void2u gen_sqr() void2u gen_sqr()
{ {
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
if (op_->primeMode == PM_NIST_P192) { if (op_->primeMode == PM_NIST_P192) {
StackFrame sf(this, 3, 10 | UseRDX, 6 * 8); StackFrame sf(this, 3, 10 | UseRDX, 6 * 8);
@ -2364,7 +2388,6 @@ private:
} }
void2u gen_fpDbl_sqrPre() void2u gen_fpDbl_sqrPre()
{ {
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
if (pn_ == 2 && useMulx_) { if (pn_ == 2 && useMulx_) {
StackFrame sf(this, 2, 7 | UseRDX); StackFrame sf(this, 2, 7 | UseRDX);
@ -2405,7 +2428,6 @@ private:
} }
void3u gen_fpDbl_mulPre() void3u gen_fpDbl_mulPre()
{ {
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
if (pn_ == 2 && useMulx_) { if (pn_ == 2 && useMulx_) {
StackFrame sf(this, 3, 5 | UseRDX); StackFrame sf(this, 3, 5 | UseRDX);
@ -3446,7 +3468,6 @@ private:
// if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; // if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
// almost same for pn_ == 6 // almost same for pn_ == 6
if (pn_ != 4) return 0; if (pn_ != 4) return 0;
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
const RegExp z = rsp + 0 * 8; const RegExp z = rsp + 0 * 8;
@ -3511,7 +3532,6 @@ private:
// if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; // if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
// almost same for pn_ == 6 // almost same for pn_ == 6
if (pn_ != 4) return 0; if (pn_ != 4) return 0;
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
// almost same for pn_ == 6 // almost same for pn_ == 6
if (pn_ != 4) return 0; if (pn_ != 4) return 0;
@ -3597,7 +3617,6 @@ private:
} }
void3u gen_fp2_add() void3u gen_fp2_add()
{ {
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
if (pn_ == 4 && !isFullBit_) { if (pn_ == 4 && !isFullBit_) {
gen_fp2_add4(); gen_fp2_add4();
@ -3611,7 +3630,6 @@ private:
} }
void3u gen_fp2_sub() void3u gen_fp2_sub()
{ {
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
if (pn_ == 4 && !isFullBit_) { if (pn_ == 4 && !isFullBit_) {
gen_fp2_sub4(); gen_fp2_sub4();
@ -3697,7 +3715,6 @@ private:
{ {
if (isFullBit_) return 0; if (isFullBit_) return 0;
if (op_->xi_a != 1) return 0; if (op_->xi_a != 1) return 0;
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
if (pn_ == 4) { if (pn_ == 4) {
gen_fp2_mul_xi4(); gen_fp2_mul_xi4();
@ -3711,7 +3728,6 @@ private:
} }
void2u gen_fp2_neg() void2u gen_fp2_neg()
{ {
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
if (pn_ <= 6) { if (pn_ <= 6) {
StackFrame sf(this, 2, UseRDX | pn_); StackFrame sf(this, 2, UseRDX | pn_);
@ -3725,7 +3741,6 @@ private:
{ {
if (isFullBit_) return 0; if (isFullBit_) return 0;
if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
align(16);
void3u func = getCurr<void3u>(); void3u func = getCurr<void3u>();
bool embedded = pn_ == 4; bool embedded = pn_ == 4;
@ -3802,7 +3817,6 @@ private:
{ {
if (isFullBit_) return 0; if (isFullBit_) return 0;
if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0; if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
align(16);
void2u func = getCurr<void2u>(); void2u func = getCurr<void2u>();
const RegExp y = rsp + 0 * 8; const RegExp y = rsp + 0 * 8;

@ -5,7 +5,11 @@ using namespace mcl::bn;
int main() int main()
{ {
initPairing(mcl::BLS12_381); initPairing(mcl::BLS12_381);
Fr x; Fp x, y, z;
x = 3; x = 3;
printf("%s\n", x.getStr(16).c_str()); y = 5;
z = x + y;
printf("x=%s\n", x.getStr(16).c_str());
printf("y=%s\n", y.getStr(16).c_str());
printf("z=%s\n", z.getStr(16).c_str());
} }

Loading…
Cancel
Save