diff --git a/src/fp_generator.hpp b/src/fp_generator.hpp index a7536e6..2feaf7b 100644 --- a/src/fp_generator.hpp +++ b/src/fp_generator.hpp @@ -28,11 +28,53 @@ namespace mcl { #ifdef MCL_STATIC_JIT typedef fp::Unit Unit; extern "C" { -Unit mclx_Fr_addPre(Unit*, const Unit*, const Unit*); -void mclx_Fr_add(Unit*, const Unit*, const Unit*); - Unit mclx_Fp_addPre(Unit*, const Unit*, const Unit*); +Unit mclx_Fp_subPre(Unit*, const Unit*, const Unit*); void mclx_Fp_add(Unit*, const Unit*, const Unit*); +void mclx_Fp_sub(Unit*, const Unit*, const Unit*); +void mclx_Fp_shr1(Unit*, const Unit*); +void mclx_Fp_neg(Unit*, const Unit*); +void mclx_FpDbl_add(Unit*, const Unit*, const Unit*); +void mclx_FpDbl_sub(Unit*, const Unit*, const Unit*); +void mclx_FpDbl_add(Unit*, const Unit*, const Unit*); +void mclx_FpDbl_sub(Unit*, const Unit*, const Unit*); +Unit mclx_FpDbl_addPre(Unit*, const Unit*, const Unit*); +Unit mclx_FpDbl_subPre(Unit*, const Unit*, const Unit*); +void mclx_FpDbl_mulPre(Unit*, const Unit*, const Unit*); +void mclx_FpDbl_sqrPre(Unit*, const Unit*); +void mclx_FpDbl_mod(Unit*, const Unit*); +void mclx_Fp_mul(Unit*, const Unit*, const Unit*); +void mclx_Fp_sqr(Unit*, const Unit*); +void mclx_Fp2_add(Unit*, const Unit*, const Unit*); +void mclx_Fp2_sub(Unit*, const Unit*, const Unit*); +void mclx_Fp2_neg(Unit*, const Unit*); +void mclx_Fp2_mul(Unit*, const Unit*, const Unit*); +void mclx_Fp2_sqr(Unit*, const Unit*); +void mclx_Fp2_mul_xi(Unit*, const Unit*); + +Unit mclx_Fr_addPre(Unit*, const Unit*, const Unit*); +Unit mclx_Fr_subPre(Unit*, const Unit*, const Unit*); +void mclx_Fr_add(Unit*, const Unit*, const Unit*); +void mclx_Fr_sub(Unit*, const Unit*, const Unit*); +void mclx_Fr_shr1(Unit*, const Unit*); +void mclx_Fr_neg(Unit*, const Unit*); +void mclx_FrDbl_add(Unit*, const Unit*, const Unit*); +void mclx_FrDbl_sub(Unit*, const Unit*, const Unit*); +void mclx_FrDbl_add(Unit*, const Unit*, const Unit*); +void mclx_FrDbl_sub(Unit*, const Unit*, const Unit*); +Unit mclx_FrDbl_addPre(Unit*, const Unit*, const Unit*); +Unit mclx_FrDbl_subPre(Unit*, const Unit*, const Unit*); +void mclx_FrDbl_mulPre(Unit*, const Unit*, const Unit*); +void mclx_FrDbl_sqrPre(Unit*, const Unit*); +void mclx_FrDbl_mod(Unit*, const Unit*); +void mclx_Fr_mul(Unit*, const Unit*, const Unit*); +void mclx_Fr_sqr(Unit*, const Unit*); +void mclx_Fr2_add(Unit*, const Unit*, const Unit*); +void mclx_Fr2_sub(Unit*, const Unit*, const Unit*); +void mclx_Fr2_neg(Unit*, const Unit*); +void mclx_Fr2_mul(Unit*, const Unit*, const Unit*); +void mclx_Fr2_sqr(Unit*, const Unit*); +void mclx_Fr2_mul_xi(Unit*, const Unit*); } #endif @@ -327,6 +369,7 @@ struct FpGenerator : Xbyak::CodeGenerator { private: void init_inner(Op& op, const char *suf) { + const bool isFp = suf && suf[0] == 'F' && suf[1] == 'p'; op_ = &op; L(pL_); p_ = reinterpret_cast(getCurr()); @@ -382,6 +425,7 @@ private: setFuncInfo(prof_, suf, "_neg", op.fp_negA_, getCurr()); align(16); +if (op.xi_a) { op.fpDbl_addA_ = gen_fpDbl_add(); setFuncInfo(prof_, suf, "Dbl_add", op.fpDbl_addA_, getCurr()); align(16); @@ -409,6 +453,7 @@ private: op.fpDbl_modA_ = gen_fpDbl_mod(op); setFuncInfo(prof_, suf, "Dbl_mod", op.fpDbl_modA_, getCurr()); align(16); +} op.fp_mulA_ = gen_mul(); setFuncInfo(prof_, suf, "_mul", op.fp_mulA_, getCurr()); @@ -463,14 +508,50 @@ private: align(16); #ifdef MCL_STATIC_JIT - const bool isFp = strcmp(suf, "Fp") == 0; -printf("isFp=%d\n", isFp); if (isFp) { + // Fp, sizeof(Fp) = 48 op.fp_addPre = mclx_Fp_addPre; - op.fp_addA_ = mclx_Fr_add; + op.fp_subPre = mclx_Fp_subPre; + op.fp_addA_ = mclx_Fp_add; + op.fp_subA_ = mclx_Fp_sub; + op.fp_shr1 = mclx_Fp_shr1; + op.fp_negA_ = mclx_Fp_neg; + op.fpDbl_addA_ = mclx_FpDbl_add; + op.fpDbl_subA_ = mclx_FpDbl_sub; + op.fpDbl_addPre = mclx_FpDbl_addPre; + op.fpDbl_subPre = mclx_FpDbl_subPre; + op.fpDbl_mulPreA_ = mclx_FpDbl_mulPre; + op.fpDbl_sqrPreA_ = mclx_FpDbl_sqrPre; + op.fpDbl_modA_ = mclx_FpDbl_mod; + op.fp_mulA_ = mclx_Fp_mul; + op.fp_sqrA_ = mclx_Fp_sqr; +#if 0 +// op.fp_preInv = mclx_Fp_preInv; + op.fp2_addA_ = mclx_Fp2_add; + op.fp2_subA_ = mclx_Fp2_sub; + op.fp2_negA_ = mclx_Fp2_neg; + op.fp2_mulA_ = mclx_Fp2_mul; + op.fp2_sqrA_ = mclx_Fp2_sqr; + op.fp2_mul_xiA_ = mclx_Fp2_mul_xi; +#endif } else { + // Fr, sizeof(Fr) = 32 op.fp_addPre = mclx_Fr_addPre; + op.fp_subPre = mclx_Fr_subPre; op.fp_addA_ = mclx_Fr_add; + op.fp_subA_ = mclx_Fr_sub; + op.fp_shr1 = mclx_Fr_shr1; + op.fp_negA_ = mclx_Fr_neg; + op.fpDbl_addA_ = mclx_FpDbl_add; + op.fpDbl_subA_ = mclx_FpDbl_sub; + op.fpDbl_addPre = mclx_FpDbl_addPre; + op.fpDbl_subPre = mclx_FpDbl_subPre; + op.fpDbl_mulPreA_ = mclx_FpDbl_mulPre; + op.fpDbl_sqrPreA_ = mclx_FpDbl_sqrPre; + op.fpDbl_modA_ = mclx_FpDbl_mod; + op.fp_mulA_ = mclx_Fr_mul; + op.fp_sqrA_ = mclx_Fr_sqr; + op.fp_preInv = mclx_Fr_preInv; } #endif } diff --git a/test/bench.hpp b/test/bench.hpp index c8c3911..b4a8bd2 100644 --- a/test/bench.hpp +++ b/test/bench.hpp @@ -100,6 +100,7 @@ void testBench(const G1& P, const G2& Q) CYBOZU_BENCH_C("Fp::mul ", C3, Fp::mul, x, x, y); CYBOZU_BENCH_C("Fp::sqr ", C3, Fp::sqr, x, x); CYBOZU_BENCH_C("Fp::inv ", C3, Fp::inv, x, x); + CYBOZU_BENCH_C("Fp::pow ", C3, Fp::pow, x, x, y); Fp2 xx, yy; xx.a = x; xx.b = 3; diff --git a/test/static_code_test.cpp b/test/static_code_test.cpp index 93dc223..e69fda7 100644 --- a/test/static_code_test.cpp +++ b/test/static_code_test.cpp @@ -2,14 +2,39 @@ using namespace mcl::bn; -int main() +void testFr() +{ + Fr x, y, z; + x = 3; + y = 5; + z = x + y; + printf("x=%s\n", x.getStr().c_str()); + printf("y=%s\n", y.getStr().c_str()); + printf("z=%s\n", z.getStr().c_str()); + z = x * y; + printf("z=%s\n", z.getStr().c_str()); + Fr::sqr(z, x); + printf("z=%s\n", z.getStr().c_str()); +} + +void testFp() { - initPairing(mcl::BLS12_381); Fp x, y, z; x = 3; y = 5; z = x + y; - printf("x=%s\n", x.getStr(16).c_str()); - printf("y=%s\n", y.getStr(16).c_str()); - printf("z=%s\n", z.getStr(16).c_str()); + printf("x=%s\n", x.getStr().c_str()); + printf("y=%s\n", y.getStr().c_str()); + printf("z=%s\n", z.getStr().c_str()); + z = x * y; + printf("z=%s\n", z.getStr().c_str()); + Fp::sqr(z, x); + printf("z=%s\n", z.getStr().c_str()); +} + +int main() +{ + initPairing(mcl::BLS12_381); + testFr(); + testFp(); }