refactor DumpCode

update-fork
MITSUNARI Shigeo 4 years ago
parent c29157cc9a
commit f11b3be1ab
  1. 182
      src/fp_generator.hpp
  2. 8
      test/static_code_test.cpp

@ -23,46 +23,34 @@
#pragma warning(disable : 4458)
#endif
//#define MCL_DUMP_JIT
namespace mcl {
#ifdef MCL_STATIC_JIT
typedef fp::Unit Unit;
extern "C" {
Unit mclx_Fr_addPre(Unit*, const Unit*, const Unit*);
void mclx_Fr_add(Unit*, const Unit*, const Unit*);
Unit mclx_Fp_addPre(Unit*, const Unit*, const Unit*);
void mclx_Fp_add(Unit*, const Unit*, const Unit*);
}
#endif
#ifdef MCL_DUMP_JIT
// not profiler, but dump jit code
struct Profiler {
struct DumpCode {
FILE *fp_;
const uint8_t *prev_;
std::string suf_;
Profiler()
DumpCode()
: fp_(stdout)
, prev_(0)
{
}
~Profiler()
void set(const std::string& name, const uint8_t *begin, const size_t size)
{
// if (fp_) fclose(fp_);
}
#if 0
void open(const std::string& fileName)
{
fp_ = fopen(fileName.c_str(), "wb");
}
#endif
void setStartAddr(const uint8_t *addr)
{
prev_ = addr;
}
void setNameSuffix(const std::string& suf)
{
suf_ = suf;
}
void set(const char *name, const uint8_t *end)
{
fprintf(fp_, "global %s%s\n", suf_.c_str(), name);
fprintf(fp_, "segment .text\n");
fprintf(fp_, "global %s\n", name.c_str());
fprintf(fp_, "align 16\n");
fprintf(fp_, "%s%s:\n", suf_.c_str(), name);
const uint8_t *p = prev_;
size_t remain = end - prev_;
fprintf(fp_, "%s:\n", name.c_str());
const uint8_t *p = begin;
size_t remain = size;
while (remain > 0) {
size_t n = remain >= 16 ? 16 : remain;
fprintf(fp_, "db ");
@ -72,7 +60,6 @@ struct Profiler {
fprintf(fp_, "\n");
remain -= n;
}
prev_ = end;
}
void dumpData(const void *begin, const void *end)
{
@ -87,8 +74,19 @@ struct Profiler {
fprintf(fp_, "\n");
}
};
template<class T>
void setFuncInfo(DumpCode& prof, const char *suf, const char *name, const T& begin, const uint8_t* end)
{
const uint8_t*p = (const uint8_t*)begin;
prof.set(std::string("mclx_") + suf + name, p, end - p);
}
#else
typedef Xbyak::util::Profiler Profiler;
template<class T>
void setFuncInfo(Xbyak::util::Profiler& prof, const char *suf, const char *name, const T& begin, const uint8_t* end)
{
const uint8_t*p = (const uint8_t*)begin;
prof.set((std::string("mclx_") + suf + name).c_str(), p, end - p);
}
#endif
namespace fp_gen_local {
@ -269,7 +267,11 @@ struct FpGenerator : Xbyak::CodeGenerator {
int pn_;
int FpByte_;
bool isFullBit_;
Profiler prof_;
#ifdef MCL_DUMP_JIT
DumpCode prof_;
#else
Xbyak::util::Profiler prof_;
#endif
/*
@param op [in] ; use op.p, op.N, op.isFullBit
@ -331,8 +333,6 @@ private:
}
#ifdef MCL_DUMP_JIT
prof_.dumpData(p_, getCurr());
prof_.setStartAddr(getCurr());
prof_.setNameSuffix(std::string("mclx_") + suf);
#endif
rp_ = fp::getMontgomeryCoeff(p_[0]);
pn_ = (int)op.N;
@ -351,97 +351,130 @@ private:
if (profMode) {
prof_.init(profMode);
prof_.setStartAddr(getCurr());
if (suf == 0) suf = "fp";
prof_.setNameSuffix(suf);
suf[1]++;
}
#else
(void)suf;
#endif
align(16);
op.fp_addPre = gen_addSubPre(true, pn_);
prof_.set("_addPre", getCurr());
setFuncInfo(prof_, suf, "_addPre", op.fp_addPre, getCurr());
align(16);
op.fp_subPre = gen_addSubPre(false, pn_);
prof_.set("_subPre", getCurr());
setFuncInfo(prof_, suf, "_subPre", op.fp_subPre, getCurr());
align(16);
op.fp_addA_ = gen_fp_add();
prof_.set("_add", getCurr());
setFuncInfo(prof_, suf, "_add", op.fp_addA_, getCurr());
op.fp_subA_ = gen_fp_sub();
prof_.set("_sub", getCurr());
setFuncInfo(prof_, suf, "_sub", op.fp_subA_, getCurr());
align(16);
op.fp_shr1 = gen_shr1();
prof_.set("_shr1", getCurr());
setFuncInfo(prof_, suf, "_shr1", op.fp_shr1, getCurr());
align(16);
op.fp_negA_ = gen_fp_neg();
prof_.set("_neg", getCurr());
setFuncInfo(prof_, suf, "_neg", op.fp_negA_, getCurr());
align(16);
op.fpDbl_addA_ = gen_fpDbl_add();
prof_.set("Dbl_add", getCurr());
setFuncInfo(prof_, suf, "Dbl_add", op.fpDbl_addA_, getCurr());
align(16);
op.fpDbl_subA_ = gen_fpDbl_sub();
prof_.set("Dbl_sub", getCurr());
setFuncInfo(prof_, suf, "Dbl_sub", op.fpDbl_subA_, getCurr());
align(16);
op.fpDbl_addPre = gen_addSubPre(true, pn_ * 2);
prof_.set("Dbl_addPre", getCurr());
setFuncInfo(prof_, suf, "Dbl_addPre", op.fpDbl_addPre, getCurr());
align(16);
op.fpDbl_subPre = gen_addSubPre(false, pn_ * 2);
prof_.set("Dbl_subPre", getCurr());
setFuncInfo(prof_, suf, "Dbl_subPre", op.fpDbl_subPre, getCurr());
align(16);
op.fpDbl_mulPreA_ = gen_fpDbl_mulPre();
prof_.set("Dbl_mulPre", getCurr());
setFuncInfo(prof_, suf, "Dbl_mulPre", op.fpDbl_mulPreA_, getCurr());
align(16);
op.fpDbl_sqrPreA_ = gen_fpDbl_sqrPre();
prof_.set("Dbl_sqrPre", getCurr());
setFuncInfo(prof_, suf, "Dbl_sqrPre", op.fpDbl_sqrPreA_, getCurr());
align(16);
op.fpDbl_modA_ = gen_fpDbl_mod(op);
prof_.set("Dbl_mod", getCurr());
setFuncInfo(prof_, suf, "Dbl_mod", op.fpDbl_modA_, getCurr());
align(16);
op.fp_mulA_ = gen_mul();
prof_.set("_mul", getCurr());
setFuncInfo(prof_, suf, "_mul", op.fp_mulA_, getCurr());
align(16);
if (op.fp_mulA_) {
op.fp_mul = fp::func_ptr_cast<void4u>(op.fp_mulA_); // used in toMont/fromMont
}
op.fp_sqrA_ = gen_sqr();
prof_.set("_sqr", getCurr());
setFuncInfo(prof_, suf, "_sqr", op.fp_sqrA_, getCurr());
align(16);
if (op.primeMode != PM_NIST_P192 && op.N <= 4) { // support general op.N but not fast for op.N > 4
align(16);
op.fp_preInv = getCurr<int2u>();
gen_preInv();
prof_.set("_preInv", getCurr());
setFuncInfo(prof_, suf, "_preInv", op.fp_preInv, getCurr());
align(16);
}
if (op.xi_a == 0) return; // Fp2 is not used
op.fp2_addA_ = gen_fp2_add();
prof_.set("2_add", getCurr());
setFuncInfo(prof_, suf, "2_add", op.fp2_addA_, getCurr());
align(16);
op.fp2_subA_ = gen_fp2_sub();
prof_.set("2_sub", getCurr());
setFuncInfo(prof_, suf, "2_sub", op.fp2_subA_, getCurr());
align(16);
op.fp2_negA_ = gen_fp2_neg();
prof_.set("2_neg", getCurr());
setFuncInfo(prof_, suf, "2_neg", op.fp2_negA_, getCurr());
align(16);
op.fp2_mulNF = 0;
op.fp2Dbl_mulPreA_ = gen_fp2Dbl_mulPre();
prof_.set("2Dbl_mulPre", getCurr());
if (op.fp2Dbl_mulPreA_) setFuncInfo(prof_, suf, "2Dbl_mulPre", op.fp2Dbl_mulPreA_, getCurr());
align(16);
op.fp2Dbl_sqrPreA_ = gen_fp2Dbl_sqrPre();
prof_.set("2Dbl_sqrPre", getCurr());
if (op.fp2Dbl_sqrPreA_) setFuncInfo(prof_, suf, "2Dbl_sqrPre", op.fp2Dbl_sqrPreA_, getCurr());
align(16);
op.fp2_mulA_ = gen_fp2_mul();
prof_.set("2_mul", getCurr());
setFuncInfo(prof_, suf, "2_mul", op.fp2_mulA_, getCurr());
align(16);
op.fp2_sqrA_ = gen_fp2_sqr();
prof_.set("2_sqr", getCurr());
setFuncInfo(prof_, suf, "2_sqr", op.fp2_sqrA_, getCurr());
align(16);
op.fp2_mul_xiA_ = gen_fp2_mul_xi();
prof_.set("2_mul_xi", getCurr());
setFuncInfo(prof_, suf, "2_mul_xi", op.fp2_mul_xiA_, getCurr());
align(16);
#ifdef MCL_STATIC_JIT
const bool isFp = strcmp(suf, "Fp") == 0;
printf("isFp=%d\n", isFp);
if (isFp) {
op.fp_addPre = mclx_Fp_addPre;
op.fp_addA_ = mclx_Fr_add;
} else {
op.fp_addPre = mclx_Fr_addPre;
op.fp_addA_ = mclx_Fr_add;
}
#endif
}
u3u gen_addSubPre(bool isAdd, int n)
{
// if (isFullBit_) return 0;
align(16);
u3u func = getCurr<u3u>();
StackFrame sf(this, 3);
if (isAdd) {
@ -721,7 +754,6 @@ private:
}
void3u gen_fp_add()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ <= 4) {
gen_fp_add_le4();
@ -769,7 +801,6 @@ private:
}
void3u gen_fpDbl_add()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ <= 4) {
int tn = pn_ * 2 + (isFullBit_ ? 1 : 0);
@ -797,7 +828,6 @@ private:
}
void3u gen_fpDbl_sub()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ <= 4) {
int tn = pn_ * 2;
@ -847,7 +877,6 @@ private:
}
void3u gen_fp_sub()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ <= 4) {
gen_fp_sub_le4();
@ -872,7 +901,6 @@ private:
}
void2u gen_fp_neg()
{
align(16);
void2u func = getCurr<void2u>();
StackFrame sf(this, 2, UseRDX | pn_);
gen_raw_neg(sf.p[0], sf.p[1], sf.t);
@ -880,7 +908,6 @@ private:
}
void2u gen_shr1()
{
align(16);
void2u func = getCurr<void2u>();
const int c = 1;
StackFrame sf(this, 2, 1);
@ -901,7 +928,6 @@ private:
}
void3u gen_mul()
{
align(16);
void3u func = getCurr<void3u>();
if (op_->primeMode == PM_NIST_P192) {
StackFrame sf(this, 3, 10 | UseRDX, 8 * 6);
@ -1214,7 +1240,6 @@ private:
}
void2u gen_fpDbl_mod(const fp::Op& op)
{
align(16);
void2u func = getCurr<void2u>();
if (op.primeMode == PM_NIST_P192) {
StackFrame sf(this, 2, 6 | UseRDX);
@ -1260,7 +1285,6 @@ private:
}
void2u gen_sqr()
{
align(16);
void2u func = getCurr<void2u>();
if (op_->primeMode == PM_NIST_P192) {
StackFrame sf(this, 3, 10 | UseRDX, 6 * 8);
@ -2364,7 +2388,6 @@ private:
}
void2u gen_fpDbl_sqrPre()
{
align(16);
void2u func = getCurr<void2u>();
if (pn_ == 2 && useMulx_) {
StackFrame sf(this, 2, 7 | UseRDX);
@ -2405,7 +2428,6 @@ private:
}
void3u gen_fpDbl_mulPre()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ == 2 && useMulx_) {
StackFrame sf(this, 3, 5 | UseRDX);
@ -3446,7 +3468,6 @@ private:
// if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
// almost same for pn_ == 6
if (pn_ != 4) return 0;
align(16);
void3u func = getCurr<void3u>();
const RegExp z = rsp + 0 * 8;
@ -3511,7 +3532,6 @@ private:
// if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
// almost same for pn_ == 6
if (pn_ != 4) return 0;
align(16);
void2u func = getCurr<void2u>();
// almost same for pn_ == 6
if (pn_ != 4) return 0;
@ -3597,7 +3617,6 @@ private:
}
void3u gen_fp2_add()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ == 4 && !isFullBit_) {
gen_fp2_add4();
@ -3611,7 +3630,6 @@ private:
}
void3u gen_fp2_sub()
{
align(16);
void3u func = getCurr<void3u>();
if (pn_ == 4 && !isFullBit_) {
gen_fp2_sub4();
@ -3697,7 +3715,6 @@ private:
{
if (isFullBit_) return 0;
if (op_->xi_a != 1) return 0;
align(16);
void2u func = getCurr<void2u>();
if (pn_ == 4) {
gen_fp2_mul_xi4();
@ -3711,7 +3728,6 @@ private:
}
void2u gen_fp2_neg()
{
align(16);
void2u func = getCurr<void2u>();
if (pn_ <= 6) {
StackFrame sf(this, 2, UseRDX | pn_);
@ -3725,7 +3741,6 @@ private:
{
if (isFullBit_) return 0;
if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
align(16);
void3u func = getCurr<void3u>();
bool embedded = pn_ == 4;
@ -3802,7 +3817,6 @@ private:
{
if (isFullBit_) return 0;
if (pn_ != 4 && !(pn_ == 6 && useMulx_ && useAdx_)) return 0;
align(16);
void2u func = getCurr<void2u>();
const RegExp y = rsp + 0 * 8;

@ -5,7 +5,11 @@ using namespace mcl::bn;
int main()
{
initPairing(mcl::BLS12_381);
Fr x;
Fp x, y, z;
x = 3;
printf("%s\n", x.getStr(16).c_str());
y = 5;
z = x + y;
printf("x=%s\n", x.getStr(16).c_str());
printf("y=%s\n", y.getStr(16).c_str());
printf("z=%s\n", z.getStr(16).c_str());
}

Loading…
Cancel
Save