diff --git a/include/mcl/fp_tower.hpp b/include/mcl/fp_tower.hpp index 13d922c..d7ab91b 100644 --- a/include/mcl/fp_tower.hpp +++ b/include/mcl/fp_tower.hpp @@ -656,6 +656,7 @@ private: template struct Fp2DblT { + typedef Fp2DblT Fp2Dbl; typedef FpDblT FpDbl; typedef Fp2T Fp2; typedef fp::Unit Unit; @@ -711,11 +712,13 @@ struct Fp2DblT { FpDbl::add(y.b, y.b, x.a); y.a = t; } - static void (*mulPre)(Fp2DblT&, const Fp2&, const Fp2&); + static void mulPre(Fp2DblT& z, const Fp2& x, const Fp2& y) + { + Fp::getOp().fp2Dbl_mulPreA_(z.a.v_, x.getUnit(), y.getUnit()); + } static void sqrPre(Fp2DblT& y, const Fp2& x) { - const mcl::fp::Op& op = Fp::getOp(); - op.fp2Dbl_sqrPreA_(y.a.v_, x.getUnit()); + Fp::getOp().fp2Dbl_sqrPreA_(y.a.v_, x.getUnit()); } static void (*mul_xi)(Fp2DblT&, const Fp2DblT&); static void mod(Fp2& y, const Fp2DblT& x) @@ -735,13 +738,11 @@ struct Fp2DblT { { assert(!Fp::getOp().isFullBit); mcl::fp::Op& op = Fp::getOpNonConst(); - if (op.fp2Dbl_mulPreA_) { - mulPre = fp::func_ptr_cast(op.fp2Dbl_mulPreA_); - } else { - mulPre = fp2Dbl_mulPreW; + if (op.fp2Dbl_mulPreA_ == 0) { + op.fp2Dbl_mulPreA_ = mulPreA; } if (op.fp2Dbl_sqrPreA_ == 0) { - op.fp2Dbl_sqrPreA_ = fp2Dbl_sqrPreC; + op.fp2Dbl_sqrPreA_ = sqrPreA; } const uint32_t xi_a = Fp2::get_xi_a(); switch (xi_a) { @@ -756,12 +757,19 @@ struct Fp2DblT { break; } } +private: + static Fp2 cast(Unit *x) { return *reinterpret_cast(x); } + static const Fp2 cast(const Unit *x) { return *reinterpret_cast(x); } + static Fp2Dbl& castD(Unit *x) { return *reinterpret_cast(x); } /* Fp2Dbl::mulPre by FpDblT @note mod of NIST_P192 is fast */ - static void fp2Dbl_mulPreW(Fp2DblT& z, const Fp2& x, const Fp2& y) + static void mulPreA(Unit *pz, const Unit *px, const Unit *py) { + Fp2Dbl& z = castD(pz); + const Fp2& x = cast(px); + const Fp2& y = cast(py); assert(!Fp::getOp().isFullBit); const Fp& a = x.a; const Fp& b = x.b; @@ -780,11 +788,11 @@ struct Fp2DblT { FpDbl::subPre(d1, d1, d2); FpDbl::sub(d0, d0, d2); // ac - bd } - static void fp2Dbl_sqrPreC(Unit *py, const Unit *px) + static void sqrPreA(Unit *py, const Unit *px) { assert(!Fp::getOp().isFullBit); - const Fp2& x = *reinterpret_cast(px); - Fp2DblT& y = *reinterpret_cast(py); + Fp2Dbl& y = castD(py); + const Fp2& x = cast(px); Fp t1, t2; Fp::addPre(t1, x.b, x.b); // 2b Fp::addPre(t2, x.a, x.b); // a + b @@ -794,7 +802,6 @@ struct Fp2DblT { } }; -template void (*Fp2DblT::mulPre)(Fp2DblT&, const Fp2T&, const Fp2T&); template void (*Fp2DblT::mul_xi)(Fp2DblT&, const Fp2DblT&); template Fp2T Fp2T::g[Fp2T::gN];