avoid cast of Fp2Dbl::mulPre

2merge^2
MITSUNARI Shigeo 4 years ago
parent 0f141988bd
commit de21f2ea4c
  1. 33
      include/mcl/fp_tower.hpp

@ -656,6 +656,7 @@ private:
template<class Fp>
struct Fp2DblT {
typedef Fp2DblT<Fp> Fp2Dbl;
typedef FpDblT<Fp> FpDbl;
typedef Fp2T<Fp> Fp2;
typedef fp::Unit Unit;
@ -711,11 +712,13 @@ struct Fp2DblT {
FpDbl::add(y.b, y.b, x.a);
y.a = t;
}
static void (*mulPre)(Fp2DblT&, const Fp2&, const Fp2&);
static void mulPre(Fp2DblT& z, const Fp2& x, const Fp2& y)
{
Fp::getOp().fp2Dbl_mulPreA_(z.a.v_, x.getUnit(), y.getUnit());
}
static void sqrPre(Fp2DblT& y, const Fp2& x)
{
const mcl::fp::Op& op = Fp::getOp();
op.fp2Dbl_sqrPreA_(y.a.v_, x.getUnit());
Fp::getOp().fp2Dbl_sqrPreA_(y.a.v_, x.getUnit());
}
static void (*mul_xi)(Fp2DblT&, const Fp2DblT&);
static void mod(Fp2& y, const Fp2DblT& x)
@ -735,13 +738,11 @@ struct Fp2DblT {
{
assert(!Fp::getOp().isFullBit);
mcl::fp::Op& op = Fp::getOpNonConst();
if (op.fp2Dbl_mulPreA_) {
mulPre = fp::func_ptr_cast<void (*)(Fp2DblT&, const Fp2&, const Fp2&)>(op.fp2Dbl_mulPreA_);
} else {
mulPre = fp2Dbl_mulPreW;
if (op.fp2Dbl_mulPreA_ == 0) {
op.fp2Dbl_mulPreA_ = mulPreA;
}
if (op.fp2Dbl_sqrPreA_ == 0) {
op.fp2Dbl_sqrPreA_ = fp2Dbl_sqrPreC;
op.fp2Dbl_sqrPreA_ = sqrPreA;
}
const uint32_t xi_a = Fp2::get_xi_a();
switch (xi_a) {
@ -756,12 +757,19 @@ struct Fp2DblT {
break;
}
}
private:
static Fp2 cast(Unit *x) { return *reinterpret_cast<Fp2*>(x); }
static const Fp2 cast(const Unit *x) { return *reinterpret_cast<const Fp2*>(x); }
static Fp2Dbl& castD(Unit *x) { return *reinterpret_cast<Fp2Dbl*>(x); }
/*
Fp2Dbl::mulPre by FpDblT
@note mod of NIST_P192 is fast
*/
static void fp2Dbl_mulPreW(Fp2DblT& z, const Fp2& x, const Fp2& y)
static void mulPreA(Unit *pz, const Unit *px, const Unit *py)
{
Fp2Dbl& z = castD(pz);
const Fp2& x = cast(px);
const Fp2& y = cast(py);
assert(!Fp::getOp().isFullBit);
const Fp& a = x.a;
const Fp& b = x.b;
@ -780,11 +788,11 @@ struct Fp2DblT {
FpDbl::subPre(d1, d1, d2);
FpDbl::sub(d0, d0, d2); // ac - bd
}
static void fp2Dbl_sqrPreC(Unit *py, const Unit *px)
static void sqrPreA(Unit *py, const Unit *px)
{
assert(!Fp::getOp().isFullBit);
const Fp2& x = *reinterpret_cast<const Fp2*>(px);
Fp2DblT& y = *reinterpret_cast<Fp2DblT*>(py);
Fp2Dbl& y = castD(py);
const Fp2& x = cast(px);
Fp t1, t2;
Fp::addPre(t1, x.b, x.b); // 2b
Fp::addPre(t2, x.a, x.b); // a + b
@ -794,7 +802,6 @@ struct Fp2DblT {
}
};
template<class Fp> void (*Fp2DblT<Fp>::mulPre)(Fp2DblT&, const Fp2T<Fp>&, const Fp2T<Fp>&);
template<class Fp> void (*Fp2DblT<Fp>::mul_xi)(Fp2DblT<Fp>&, const Fp2DblT<Fp>&);
template<class Fp> Fp2T<Fp> Fp2T<Fp>::g[Fp2T<Fp>::gN];

Loading…
Cancel
Save