diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index 09669d1..3857b8c 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -463,9 +463,15 @@ private: { op_.fpDbl_subP(z, x, y, op_.p); } - static inline void fp_modW(Unit *y, const Unit *x) + // z[N] <- xy[N * 2] % p[N] + static inline void fp_modW(Unit *z, const Unit *xy) { - op_.fp_modP(y, x, op_.p); + op_.fp_modP(z, xy, op_.p); + } + // z[N] <- montRed(xy[N * 2]) + static inline void fp_montRedW(Unit *z, const Unit *xy) + { + op_.montRedPU(z, xy, op_.p, op_.rp); } static inline void fp_mulW(Unit *z, const Unit *x, const Unit *y) { @@ -486,11 +492,23 @@ private: // wrapper function for mcl_fp_mont by LLVM static inline void fp_montW(Unit *z, const Unit *x, const Unit *y) { - op_.mont(z, x, y, op_.p, op_.rp); +#if 1 + op_.montPU(z, x, y, op_.p, op_.rp); +#else + Unit xy[maxSize * 2]; + op_.fp_mulPre(xy, x, y); + fp_montRedW(z, xy); +#endif } static inline void fp_montSqrW(Unit *y, const Unit *x) { - op_.mont(y, x, x, op_.p, op_.rp); +#if 1 + op_.montPU(y, x, x, op_.p, op_.rp); +#else + Unit xx[maxSize * 2]; + op_.fp_sqrPre(xx, x); + fp_montRedW(y, xx); +#endif } }; diff --git a/include/mcl/fp_proto.hpp b/include/mcl/fp_proto.hpp index d576242..5bcbe8d 100644 --- a/include/mcl/fp_proto.hpp +++ b/include/mcl/fp_proto.hpp @@ -21,7 +21,8 @@ void mcl_fp_addNC ## len(mcl::fp::Unit* z, const mcl::fp::Unit* x, const mcl::fp void mcl_fp_subNC ## len(mcl::fp::Unit* z, const mcl::fp::Unit* x, const mcl::fp::Unit* y); \ void mcl_fp_mulPre ## len(mcl::fp::Unit* z, const mcl::fp::Unit* x, const mcl::fp::Unit* y); \ void mcl_fp_sqrPre ## len(mcl::fp::Unit* y, const mcl::fp::Unit* x); \ -void mcl_fp_mont ## len(mcl::fp::Unit* z, const mcl::fp::Unit* x, const mcl::fp::Unit* y, const mcl::fp::Unit* p, mcl::fp::Unit r); +void mcl_fp_mont ## len(mcl::fp::Unit* z, const mcl::fp::Unit* x, const mcl::fp::Unit* y, const mcl::fp::Unit* p, mcl::fp::Unit r); \ +void mcl_fp_montRed ## len(mcl::fp::Unit* z, const mcl::fp::Unit* xy, const mcl::fp::Unit* p, mcl::fp::Unit r); MCL_FP_DEF_FUNC(128) MCL_FP_DEF_FUNC(192) @@ -52,6 +53,8 @@ void mcl_fpDbl_sub224(mcl::fp::Unit*, const mcl::fp::Unit*, const mcl::fp::Unit* MCL_FP_DEF_FUNC(576) #endif +#undef MCL_FP_DEF_FUNC + void mcl_fp_mul_NIST_P192(mcl::fp::Unit*, const mcl::fp::Unit*, const mcl::fp::Unit*); } diff --git a/include/mcl/op.hpp b/include/mcl/op.hpp index c093822..cd059f9 100644 --- a/include/mcl/op.hpp +++ b/include/mcl/op.hpp @@ -104,7 +104,10 @@ struct Op { int2u fp_preInv; // these two members are for mcl_fp_mont Unit rp; - void (*mont)(Unit *z, const Unit *x, const Unit *y, const Unit *p, Unit rp); + // z = montRed(xy) + void (*montRedPU)(Unit *z, const Unit *xy, const Unit *p, Unit rp); + // z = mont(x, y) = montRed(fp_mulPre(x, y)) + void (*montPU)(Unit *z, const Unit *x, const Unit *y, const Unit *p, Unit rp); // require p void3u fp_negP; @@ -150,7 +153,7 @@ struct Op { , fp_neg(0), fp_sqr(0), fp_add(0), fp_sub(0), fp_mul(0) , isFullBit(true), fp_addNC(0), fp_subNC(0) , isMont(false), fp_preInv(0) - , rp(0), mont(0) + , rp(0), montRedPU(0), montPU(0) , fp_negP(0), fp_invOp(0), fp_addP(0), fp_subP(0), fp_modP(0) , fg(createFpGenerator()) , fpDbl_add(0), fpDbl_sub() diff --git a/src/fp.cpp b/src/fp.cpp index 343ce4e..24bafff 100644 --- a/src/fp.cpp +++ b/src/fp.cpp @@ -262,7 +262,8 @@ struct OpeFunc { if (n <= 256) { \ fp_sqrPre = mcl_fp_sqrPre ## n; \ } \ - mont = mcl_fp_mont ## n; \ + montPU = mcl_fp_mont ## n; \ + montRedPU = mcl_fp_montRed ## n; \ } #define SET_OP_DBL_LLVM(n, n2) \ if (mode == FP_LLVM || mode == FP_LLVM_MONT) { \ diff --git a/src/mul.txt b/src/mul.txt index 10f094d..d62c53e 100644 --- a/src/mul.txt +++ b/src/mul.txt @@ -102,3 +102,29 @@ define void @mcl_fp_mont$(bit)(i$(bit)* %pz, i$(bit)* %px, i$(unit)* %py, i$(bit ret void } +@define b2 = bit * 2 +@define b2u = b2 + unit +define void @mcl_fp_montRed$(bit)(i$(bit)* %pz, i$(b2)* %pxy, i$(bit)* %pp, i$(unit) %r) { + %p = load i$(bit)* %pp + %xy = load i$(b2)* %pxy + %t0 = zext i$(b2) %xy to i$(b2+unit) + +@for i, 0, N + %z0$(i+1) = trunc i$(b2u - unit * i) %t$(i) to i$(unit) + %q$(i) = mul i$(unit) %z0$(i+1), %r + %pq$(i) = call i$(bu) @mul$(bit)x$(unit)(i$(bit) %p, i$(unit) %q$(i)) + %pqe$(i) = zext i$(bu) %pq$(i) to i$(b2u - unit * i) + %z$(i+1) = add i$(b2u - unit * i) %t$(i), %pqe$(i) + %zt$(i+1) = lshr i$(b2u - unit * i) %z$(i+1), $(unit) + %t$(i+1) = trunc i$(b2u - unit * i) %zt$(i+1) to i$(b2 - unit * i) +@endfor + %pe = zext i$(bit) %p to i$(bu) + %vc = sub i$(bu) %t$(N), %pe + %c = lshr i$(bu) %vc, $(bit) + %c1 = trunc i$(bu) %c to i1 + %z = select i1 %c1, i$(bu) %t$(N), i$(bu) %vc + %zt = trunc i$(bu) %z to i$(bit) + store i$(bit) %zt, i$(bit)* %pz + ret void +} + diff --git a/test/mont_fp_test.cpp b/test/mont_fp_test.cpp index e7927f6..8cbb618 100644 --- a/test/mont_fp_test.cpp +++ b/test/mont_fp_test.cpp @@ -13,13 +13,13 @@ struct Montgomery { mpz_class p_; mpz_class R_; // (1 << (pn_ * 64)) % p mpz_class RR_; // (R * R) % p - Unit pp_; // p * pp = -1 mod M = 1 << 64 + Unit rp_; // rp * p = -1 mod M = 1 << 64 size_t pn_; Montgomery() {} explicit Montgomery(const mpz_class& p) { p_ = p; - pp_ = mcl::fp::getMontgomeryCoeff(mcl::Gmp::getUnit(p, 0)); + rp_ = mcl::fp::getMontgomeryCoeff(mcl::Gmp::getUnit(p, 0)); pn_ = mcl::Gmp::getUnitSize(p); R_ = 1; R_ = (R_ << (pn_ * 64)) % p_; @@ -34,14 +34,14 @@ struct Montgomery { #if 0 const size_t ySize = mcl::Gmp::getUnitSize(y); mpz_class c = x * mcl::Gmp::getUnit(y, 0); - Unit q = mcl::Gmp::getUnit(c, 0) * pp_; + Unit q = mcl::Gmp::getUnit(c, 0) * rp_; c += p_ * q; c >>= sizeof(Unit) * 8; for (size_t i = 1; i < pn_; i++) { if (i < ySize) { c += x * mcl::Gmp::getUnit(y, i); } - Unit q = mcl::Gmp::getUnit(c, 0) * pp_; + Unit q = mcl::Gmp::getUnit(c, 0) * rp_; c += p_ * q; c >>= sizeof(Unit) * 8; } @@ -52,7 +52,7 @@ struct Montgomery { #else z = x * y; for (size_t i = 0; i < pn_; i++) { - Unit q = mcl::Gmp::getUnit(z, 0) * pp_; + Unit q = mcl::Gmp::getUnit(z, 0) * rp_; z += p_ * (mp_limb_t)q; z >>= sizeof(Unit) * 8; }