refactor Vint

dev
MITSUNARI Shigeo 7 years ago
parent e01ffa8db7
commit c19ef69cb4
  1. 500
      include/mcl/vint.hpp

@ -152,7 +152,7 @@ template<class T>
struct Empty {};
template<class T, class E = Empty<T> >
struct comparable : E {
struct Operator : E {
inline friend bool operator<(const T& x, const T& y) { return T::compare(x, y) < 0; }
inline friend bool operator>=(const T& x, const T& y) { return !operator<(x, y); }
@ -160,10 +160,6 @@ struct comparable : E {
inline friend bool operator<=(const T& x, const T& y) { return !operator>(x, y); }
inline friend bool operator==(const T& x, const T& y) { return T::compare(x, y) == 0; }
inline friend bool operator!=(const T& x, const T& y) { return !operator==(x, y); }
};
template<class T, class E = Empty<T> >
struct addsubmul : E {
template<class N>
inline T& operator+=(const N& rhs) { T::add(static_cast<T&>(*this), static_cast<T&>(*this), rhs); return static_cast<T&>(*this); }
inline T& operator-=(const T& rhs) { T::sub(static_cast<T&>(*this), static_cast<T&>(*this), rhs); return static_cast<T&>(*this); }
@ -171,24 +167,12 @@ struct addsubmul : E {
inline friend T operator+(const T& a, const T& b) { T c; T::add(c, a, b); return c; }
inline friend T operator-(const T& a, const T& b) { T c; T::sub(c, a, b); return c; }
inline friend T operator*(const T& a, const T& b) { T c; T::mul(c, a, b); return c; }
};
template<class T, class E = Empty<T> >
struct dividable : E {
inline T& operator/=(const T& rhs) { T rdummy; T::div(static_cast<T*>(this), rdummy, static_cast<const T&>(*this), rhs); return static_cast<T&>(*this); }
inline T& operator%=(const T& rhs) { T::div(0, static_cast<T&>(*this), static_cast<const T&>(*this), rhs); return static_cast<T&>(*this); }
inline friend T operator/(const T& a, const T& b) { T q, r; T::div(&q, r, a, b); return q; }
inline friend T operator%(const T& a, const T& b) { T r; T::div(0, r, a, b); return r; }
};
template<class T, class E = Empty<T> >
struct hasNegative : E {
inline T operator-() const { T c; T::neg(c, static_cast<const T&>(*this)); return c; }
};
template<class T, class E = Empty<T> >
struct shiftable : E {
inline T operator<<(size_t n) const { T out; T::shl(out, static_cast<const T&>(*this), n); return out; }
inline T operator>>(size_t n) const { T out; T::shr(out, static_cast<const T&>(*this), n); return out; }
@ -198,13 +182,6 @@ struct shiftable : E {
inline T& operator>>=(size_t n) { T::shr(static_cast<T&>(*this), static_cast<const T&>(*this), n); return static_cast<T&>(*this); }
};
template<class T, class E = Empty<T> >
struct inversible : E {
inline void inverse() { T& self = static_cast<T&>(*this);T out; T::inv(out, self); self = out; }
inline friend T operator/(const T& x, const T& y) { T out; T::inv(out, y); out *= x; return out; }
inline T& operator/=(const T& x) { T rx; T::inv(rx, x); T& self = static_cast<T&>(*this); self *= rx; return self; }
};
/*
compare x[] and y[]
@retval positive if x > y
@ -671,7 +648,7 @@ public:
T& operator[](size_t n) { return v_[n]; }
};
template<class T, size_t = 0>
template<class T>
class Buffer {
size_t allocSize_;
T *ptr_;
@ -815,283 +792,9 @@ public:
typedef _Buffer Buffer;
Buffer buf;
size_t size_;
void alloc(size_t n) { return buf.alloc(n); }
typedef typename Buffer::value_type value_type;
typedef value_type T;
static const size_t unitBitSize = sizeof(T) * 8;
VuintT(T x = 0) : size_(0)
{
operator=(x);
}
explicit VuintT(const std::string& str) : size_(0)
{
setStr(str);
}
VuintT& operator=(T x)
{
buf.alloc(1);
buf[0] = x;
size_ = 1;
return *this;
}
// @note assume little endian system
template<class S>
void setArray(const S *x, size_t size)
{
if (size == 0) {
*this = 0;
return;
}
size_t unitSize = (sizeof(S) * size + sizeof(T) - 1) / sizeof(T);
buf.alloc(unitSize);
buf[unitSize - 1] = 0;
memcpy(&buf[0], x, sizeof(S) * size);
trim(unitSize);
}
/*
buf[0, size) = x
buf[size, maxSize) with zero
@note assume little endian system
*/
void getArray(T *x, size_t maxSize) const
{
if (size_ > maxSize) throw cybozu::Exception("Vint:getArray:small maxSize") << maxSize << size_;
local::copyN(x, &buf[0], size_);
local::clearN(x + size_, maxSize - size_);
}
void clear() { *this = 0; }
#if 0
std::string getStr(int base = 10) const
{
std::ostringstream os;
switch (base) {
case 10:
{
const uint32_t i1e9 = 1000000000U;
VuintT x = *this;
std::vector<uint32_t> t;
while (!x.isZero()) {
uint32_t r = (uint32_t)div1(&x, x, i1e9);
t.push_back(r);
}
if (t.empty()) {
return "0";
}
os << t[t.size() - 1];
for (size_t i = 1, n = t.size(); i < n; i++) {
os << std::setfill('0') << std::setw(9) << t[n - 1 - i];
}
}
break;
case 16:
{
os << "0x" << std::hex;
const size_t n = size();
os << (*this)[n - 1];
for (size_t i = 1; i < n; i++) {
os << std::setfill('0') << std::setw(sizeof(Unit) * 2) << (*this)[n - 1 - i];
}
}
break;
default:
throw cybozu::Exception("getStr:not supported base") << base;
}
return os.str();
}
#endif
/*
@param str [in] number string
@note "0x..." => base = 16
"0b..." => base = 2
otherwise => base = 10
*/
void setStr(std::string str, int base = 0)
{
if (str.size() >= 2 && str[0] == '0') {
switch (str[1]) {
case 'x':
if (base != 0 && base != 16) throw cybozu::Exception("bad base in setStr(str)") << base;
base = 16;
str = str.substr(2);
break;
default:
throw cybozu::Exception("not support base in setStr(str) 0") << str[1];
}
}
if (base == 0) {
base = 10;
}
if (str.empty()) throw cybozu::Exception("empty string");
switch (base) {
case 16:
{
std::vector<uint32_t> x;
while (!str.empty()) {
size_t remain = std::min((int)str.size(), 8);
char *endp;
uint32_t v = strtoul(&str[str.size() - remain], &endp, 16);
if (*endp) goto ERR;
x.push_back(v);
str = str.substr(0, str.size() - remain);
}
setArray(&x[0], x.size());
}
break;
default:
case 10:
decStr2Int(*this, str);
break;
}
return;
ERR:
throw std::invalid_argument(std::string("bad digit `") + str + "`");
}
static int compare(const VuintT& x, const VuintT& y)
{
return local::compareNM(&x[0], x.size(), &y[0], y.size());
}
size_t size() const { return size_; }
bool isZero() const
{
return size() == 1 && buf[0] == 0;
}
T& operator[](size_t n) { return buf[n]; }
const T& operator[](size_t n) const { return buf[n]; }
void swap(VuintT& rhs) { buf.swap(rhs.buf); }
size_t bitLen() const
{
if (isZero()) return 0;
size_t size = size_;
T v = buf[size - 1];
assert(v);
return (size - 1) * sizeof(T) * 8 + 1 + cybozu::bsr<T>(v);
}
bool testBit(size_t i) const
{
size_t unit_pos = i / (sizeof(T) * 8);
size_t bit_pos = i % (sizeof(T) * 8);
T mask = T(1) << bit_pos;
return (buf[unit_pos] & mask) != 0;
}
static void add(VuintT& z, const VuintT& x, T y)
{
size_t xn = x.size();
size_t zn = xn + 1;
z.alloc(zn);
z[zn - 1] = local::add1(&z[0], &x[0], xn, y);
z.trim(zn);
}
static void sub(VuintT& z, const VuintT& x, const VuintT& y)
{
const size_t xn = x.size();
const size_t yn = y.size();
assert(xn >= yn);
z.alloc(xn);
T c = local::subN(&z[0], &x[0], &y[0], yn);
if (xn > yn) {
c = local::sub1(&z[yn], &x[yn], xn - yn, c);
}
if (c) throw cybozu::Exception("can't sub");
z .trim(xn);
}
static void mul1(VuintT& z, const VuintT& x, T y)
{
const size_t xn = x.size();
z.alloc(xn + 1);
z[xn] = local::mul1(&z[0], &x[0], xn, y);
z.trim(xn + 1);
}
static void mul(VuintT& z, const VuintT& x, const VuintT& y)
{
const size_t xn = x.size();
const size_t yn = y.size();
z.alloc(xn + yn);
local::mulNM(&z[0], &x[0], xn, &y[0], yn);
z.trim(xn + yn);
}
/**
@param q [out] q = x / y
@param x [in]
@param y [in] must be not zero
@return x % y
*/
#if 0
static T div1(VuintT *q, const VuintT& x, T y)
{
const size_t xn = x.size();
T r;
if (q) {
q->alloc(xn); // assume q is not destroyed if q == x
r = local::div1(&(*q)[0], &x[0], xn, y);
q->trim(xn);
} else {
r = local::mod1(&x[0], xn, y);
}
return r;
}
#endif
/**
@param q [out] x / y if q != 0
@param r [out] x % y
@retval true if y != 0
@retavl false if y == 0
*/
static bool div(VuintT* q, VuintT& r, const VuintT& x, const VuintT& y)
{
assert(q != &r);
const size_t xn = x.size();
const size_t yn = y.size();
if (q) {
q->alloc(xn - yn + 1);
}
r.alloc(xn);
local::divNM(q ? &(*q)[0] : 0, &r[0], &x[0], xn, &y[0], yn);
if (q) {
q->trim(xn - yn + 1);
}
r.trim(xn);
return true;
}
#if 0
static inline void shl(VuintT& y, const VuintT& x, size_t n)
{
size_t xn = x.size();
const size_t unitSize = sizeof(T) * 8;
size_t yn = xn + (n + unitSize - 1) / unitSize;
y.alloc(yn);
local::shlN(&y[0], &x[0], xn, n);
y.trim(yn);
}
#endif
static inline void shr(VuintT& y, const VuintT& x, size_t n)
{
size_t xn = x.size();
const size_t unitSize = sizeof(T) * 8;
if (xn < n / unitSize) {
y.clear();
return;
}
size_t yn = xn - n / unitSize;
y.alloc(yn);
local::shrN(&y[0], &x[0], xn, n);
y.trim(yn);
}
void trim(size_t n)
{
if (n == 0) throw cybozu::Exception("trim zero");
int i = (int)n - 1;
for (; i > 0; i--) {
if (buf[i]) break;
}
size_ = i ? i + 1 : 1;
}
void setSize(size_t n)
{
size_ = n;
@ -1099,14 +802,10 @@ public:
};
template<class V>
struct VintT : public local::addsubmul<VintT<V>,
local::comparable<VintT<V>,
local::dividable<VintT<V>,
local::hasNegative<VintT<V>,
local::shiftable<VintT<V> > > > > > {
struct VintT : public local::Operator<VintT<V> > {
typedef typename V::Buffer Buffer;
typedef typename V::value_type value_type;
typedef value_type T;
typedef typename Buffer::value_type value_type;
typedef typename Buffer::value_type T;
static const size_t unitBitSize = sizeof(T) * 8;
V v_;
bool isNeg_;
@ -1121,6 +820,10 @@ struct VintT : public local::addsubmul<VintT<V>,
}
return 1;
}
static int ucompare(const Buffer& x, size_t xn, const Buffer& y, size_t yn)
{
return local::compareNM(&x[0], xn, &y[0], yn);
}
static size_t uadd(Buffer& z, const Buffer& x, size_t xn, const Buffer& y, size_t yn)
{
size_t zn = std::max(xn, yn) + 1;
@ -1135,6 +838,24 @@ struct VintT : public local::addsubmul<VintT<V>,
z[zn - 1] = local::add1(&z[0], &x[0], xn, y);
return realSize(z, zn);
}
static size_t umul1(Buffer& z, const Buffer& x, size_t xn, T y)
{
size_t zn = xn + 1;
z.alloc(zn);
z[zn - 1] = local::mul1(&z[0], &x[0], xn, y);
return realSize(z, zn);
}
static size_t usub(Buffer& z, const Buffer& x, size_t xn, const Buffer& y, size_t yn)
{
assert(xn >= yn);
z.alloc(xn);
T c = local::subN(&z[0], &x[0], &y[0], yn);
if (xn > yn) {
c = local::sub1(&z[yn], &x[yn], xn - yn, c);
}
assert(!c);
return realSize(z, xn);
}
static size_t ushl(Buffer& y, const Buffer& x, size_t xn, size_t shiftBit)
{
const size_t unitSize = sizeof(T) * 8;
@ -1143,6 +864,19 @@ struct VintT : public local::addsubmul<VintT<V>,
local::shlN(&y[0], &x[0], xn, shiftBit);
return realSize(y, yn);
}
static size_t ushr(Buffer& y, const Buffer& x, size_t xn, size_t shiftBit)
{
const size_t unitSize = sizeof(T) * 8;
if (xn < shiftBit / unitSize) {
y.alloc(1);
y[0] = 0;
return 1;
}
size_t yn = xn - shiftBit / unitSize;
y.alloc(yn);
local::shrN(&y[0], &x[0], xn, shiftBit);
return realSize(y, yn);
}
public:
/**
@param q [out] q = x / y
@ -1162,6 +896,26 @@ public:
}
return r;
}
/**
@param q [out] x / y if q != 0
@param r [out] x % y
@retval true if y != 0
@retavl false if y == 0
*/
static void udiv(Buffer* q, size_t *qn, Buffer& r, size_t& rn, const Buffer& x, size_t xn, const Buffer& y, size_t yn)
{
assert(q != &r);
if (q) {
*qn = xn - yn + 1;
q->alloc(*qn);
}
r.alloc(xn);
local::divNM(q ? &(*q)[0] : 0, &r[0], &x[0], xn, &y[0], yn);
if (q) {
*qn = realSize(*q, *qn);
}
rn = realSize(r, xn);
}
std::string getStr(int base = 10) const
{
std::ostringstream os;
@ -1213,20 +967,21 @@ public:
z.isNeg_ = xNeg;
return;
}
int r = V::compare(x, y);
int r = ucompare(x.buf, x.size(), y.buf, y.size());
if (r >= 0) {
V::sub(z.v_, x, y);
size_t zn = usub(z.v_.buf, x.buf, x.size(), y.buf, y.size());
z.v_.setSize(zn);
z.isNeg_ = xNeg;
} else {
V::sub(z.v_, y, x);
size_t zn = usub(z.v_.buf, y.buf, y.size(), x.buf, x.size());
z.v_.setSize(zn);
z.isNeg_ = yNeg;
}
}
public:
VintT(int x = 0)
: v_(::abs(x))
, isNeg_(x < 0)
{
*this = x;
}
explicit VintT(const std::string& str)
{
@ -1235,7 +990,9 @@ public:
VintT& operator=(int x)
{
isNeg_ = x < 0;
v_ = (isNeg_ ? -x : x);
v_.buf.alloc(1);
v_.buf[0] = std::abs(x);
v_.setSize(1);
return *this;
}
/*
@ -1269,7 +1026,7 @@ public:
local::copyN(x, &v_.buf[0], size);
local::clearN(x + size, maxSize - size);
}
void clear() { v_ = 0; isNeg_ = false; }
void clear() { *this = 0; }
/*
return bitLen(abs(*this))
*/
@ -1281,6 +1038,7 @@ public:
assert(v);
return (size - 1) * sizeof(T) * 8 + 1 + cybozu::bsr<T>(v);
}
// ignore sign
bool testBit(size_t i) const
{
size_t unit_pos = i / (sizeof(T) * 8);
@ -1288,16 +1046,61 @@ public:
T mask = T(1) << bit_pos;
return (v_.buf[unit_pos] & mask) != 0;
}
void setStr(const std::string& str, int base = 0)
/*
@param str [in] number string
@note "0x..." => base = 16
"0b..." => base = 2
otherwise => base = 10
*/
void setStr(std::string str, int base = 0)
{
isNeg_ = false;
if (str.size() > 0 && str[0] == '-') {
bool neg = false;
if (!str.empty() && str[0] == '-') {
neg = true;
str = str.substr(1);
}
if (str.size() >= 2 && str[0] == '0') {
switch (str[1]) {
case 'x':
if (base != 0 && base != 16) throw cybozu::Exception("bad base in setStr(str)") << base;
base = 16;
str = str.substr(2);
break;
default:
throw cybozu::Exception("not support base in setStr(str) 0") << str[1];
}
}
if (base == 0) {
base = 10;
}
if (str.empty()) throw cybozu::Exception("empty string");
switch (base) {
case 16:
{
std::vector<uint32_t> x;
while (!str.empty()) {
size_t remain = std::min((int)str.size(), 8);
char *endp;
uint32_t v = strtoul(&str[str.size() - remain], &endp, 16);
if (*endp) goto ERR;
x.push_back(v);
str = str.substr(0, str.size() - remain);
}
setArray(&x[0], x.size());
}
break;
default:
case 10:
decStr2Int(*this, str);
break;
}
if (!isZero() && neg) {
isNeg_ = true;
v_.setStr(&str[1], base);
} else {
v_.setStr(str, base);
}
return;
ERR:
throw std::invalid_argument(std::string("bad digit `") + str + "`");
}
static inline int compare(const VintT& x, const VintT& y)
{
@ -1306,11 +1109,11 @@ public:
return x.isNeg_ ? -1 : 1;
} else {
// same sign
return V::compare(x.v_, y.v_) * (x.isNeg_ ? -1 : 1);
return ucompare(x.v_.buf, x.size(), y.v_.buf, y.size()) * (x.isNeg_ ? -1 : 1);
}
}
size_t size() const { return v_.size(); }
bool isZero() const { return v_.isZero(); }
bool isZero() const { return size() == 1 && v_.buf[0] == 0; }
bool isNegative() const { return isNeg_; }
static inline void add(VintT& z, const VintT& x, const VintT& y)
{
@ -1322,12 +1125,20 @@ public:
}
static inline void mul(VintT& z, const VintT& x, const VintT& y)
{
V::mul(z.v_, x.v_, y.v_);
const size_t xn = x.size();
const size_t yn = y.size();
size_t zn = xn + yn;
z.v_.buf.alloc(zn);
local::mulNM(&z.v_.buf[0], &x.v_.buf[0], xn, &y.v_.buf[0], yn);
zn = realSize(z.v_.buf, zn);
z.v_.setSize(zn);
z.isNeg_ = x.isNeg_ ^ y.isNeg_;
}
static void mul(VintT& z, const VintT& x, T y)
{
V::mul1(z.v_, x.v_, y);
// V::mul1(z.v_, x.v_, y);
size_t zn = umul1(z.v_.buf, x.v_.buf, x.v_.size(), y);
z.v_.setSize(zn);
z.isNeg_ = x.isNeg_;
}
/*
@ -1345,40 +1156,39 @@ public:
return udiv1(0, 0, x.v_.buf, x.size(), y);
}
}
static inline bool div(VintT *q, VintT& r, const VintT& x, const VintT& y)
static void div(VintT *q, VintT& r, const VintT& x, const VintT& y)
{
#if 1
// like Python
// 13 / -5 = -3 ... -2
// -13 / 5 = -3 ... 2
// -13 / -5 = 2 ... -3
V yy = y.v_;
bool ret = V::div(q ? &(q->v_) : 0, r.v_, x.v_, y.v_);
if (!ret) return false;
size_t rn = 0;
if (q) {
size_t qn = 0;
udiv(&q->v_.buf, &qn, r.v_.buf, rn, x.v_.buf, x.v_.size(), y.v_.buf, y.v_.size());
q->v_.setSize(qn);
} else {
udiv(0, 0, r.v_.buf, rn, x.v_.buf, x.v_.size(), y.v_.buf, y.v_.size());
}
r.v_.setSize(rn);
bool qsign = x.isNeg_ ^ y.isNeg_;
if (r.v_.isZero()) {
if (r.isZero()) {
r.isNeg_ = false;
} else {
if (qsign) {
if (q) {
V::add(q->v_, q->v_, 1);
// V::add(q->v_, q->v_, 1);
size_t n = uadd1(q->v_.buf, q->v_.buf, q->v_.size(), 1);
q->v_.setSize(n);
}
V::sub(r.v_, yy, r.v_);
// V::sub(r.v_, yy, r.v_);
size_t n = usub(r.v_.buf, yy.buf, yy.size(), r.v_.buf, r.v_.size());
r.v_.setSize(n);
}
r.isNeg_ = y.isNeg_;
}
if (q) q->isNeg_ = qsign;
return true;
#else
// 13 / -5 = -2 ... 3
// -13 / 5 = -2 ... -3
// -13 / -5 = 2 ... -3
bool ret = V::div(q ? &(q->v_) : 0, r.v_, x.v_, y.v_);
bool qsign = x.isNeg_ ^ y.isNeg_;
r.isNeg_ = x.isNeg_;
if (q) q->isNeg_ = qsign;
return ret;
#endif
}
static inline void neg(VintT& z, const VintT& x)
{
@ -1403,9 +1213,11 @@ public:
z.v_.setSize(zn);
z.isNeg_ = x.isNeg_;
}
static inline void shr(VintT& z, const VintT& x, size_t n)
static inline void shr(VintT& z, const VintT& x, size_t shiftBit)
{
V::shr(z.v_, x.v_, n);
// V::shr(z.v_, x.v_, n);
size_t zn = ushr(z.v_.buf, x.v_.buf, x.size(), shiftBit);
z.v_.setSize(zn);
z.isNeg_ = x.isNeg_;
}
static inline void abs(VintT& z, const VintT& in)

Loading…
Cancel
Save