diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index c52ce78..a223317 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -118,6 +118,7 @@ public: static inline size_t getBitSize() { return op_.bitSize; } static inline size_t getByteSize() { return (op_.bitSize + 7) / 8; } static inline const fp::Op& getOp() { return op_; } + static inline fp::Op& getOpNonConst() { return op_; } void dump() const { const size_t N = op_.N; diff --git a/include/mcl/fp_tower.hpp b/include/mcl/fp_tower.hpp index b5ed9a6..34dd062 100644 --- a/include/mcl/fp_tower.hpp +++ b/include/mcl/fp_tower.hpp @@ -10,10 +10,15 @@ namespace mcl { +template struct Fp12T; +template class BNT; +template struct Fp2DblT; + template class FpDblT : public fp::Serializable > { typedef fp::Unit Unit; Unit v_[Fp::maxSize * 2]; + friend struct Fp2DblT; public: static size_t getUnitSize() { return Fp::op_.N * 2; } const fp::Unit *getUnit() const { return v_; } @@ -172,9 +177,6 @@ template void (*FpDblT::addPre)(FpDblT&, const FpDblT&, const FpDb template void (*FpDblT::subPre)(FpDblT&, const FpDblT&, const FpDblT&); #endif -template struct Fp12T; -template class BNT; -template struct Fp2DblT; /* beta = -1 Fp2 = F[i] / (i^2 + 1) @@ -662,7 +664,11 @@ struct Fp2DblT { y.a = t; } static void (*mulPre)(Fp2DblT&, const Fp2&, const Fp2&); - static void (*sqrPre)(Fp2DblT&, const Fp2&); + static void sqrPre(Fp2DblT& y, const Fp2& x) + { + const mcl::fp::Op& op = Fp::getOp(); + op.fp2Dbl_sqrPreA_(y.a.v_, x.getUnit()); + } static void (*mul_xi)(Fp2DblT&, const Fp2DblT&); static void mod(Fp2& y, const Fp2DblT& x) { @@ -680,16 +686,14 @@ struct Fp2DblT { static void init() { assert(!Fp::getOp().isFullBit); - const mcl::fp::Op& op = Fp::getOp(); + mcl::fp::Op& op = Fp::getOpNonConst(); if (op.fp2Dbl_mulPreA_) { mulPre = fp::func_ptr_cast(op.fp2Dbl_mulPreA_); } else { mulPre = fp2Dbl_mulPreW; } - if (op.fp2Dbl_sqrPreA_) { - sqrPre = fp::func_ptr_cast(op.fp2Dbl_sqrPreA_); - } else { - sqrPre = fp2Dbl_sqrPreW; + if (op.fp2Dbl_sqrPreA_ == 0) { + op.fp2Dbl_sqrPreA_ = fp2Dbl_sqrPreC; } const uint32_t xi_a = Fp2::get_xi_a(); switch (xi_a) { @@ -728,9 +732,11 @@ struct Fp2DblT { FpDbl::subPre(d1, d1, d2); FpDbl::sub(d0, d0, d2); // ac - bd } - static void fp2Dbl_sqrPreW(Fp2DblT& y, const Fp2& x) + static void fp2Dbl_sqrPreC(Unit *py, const Unit *px) { assert(!Fp::getOp().isFullBit); + const Fp2& x = *reinterpret_cast(px); + Fp2DblT& y = *reinterpret_cast(py); Fp t1, t2; Fp::addPre(t1, x.b, x.b); // 2b Fp::addPre(t2, x.a, x.b); // a + b @@ -741,7 +747,6 @@ struct Fp2DblT { }; template void (*Fp2DblT::mulPre)(Fp2DblT&, const Fp2T&, const Fp2T&); -template void (*Fp2DblT::sqrPre)(Fp2DblT&, const Fp2T&); template void (*Fp2DblT::mul_xi)(Fp2DblT&, const Fp2DblT&); template Fp2T Fp2T::g[Fp2T::gN];