diff --git a/misc/Makefile b/misc/Makefile new file mode 100644 index 0000000..25a7c27 --- /dev/null +++ b/misc/Makefile @@ -0,0 +1,6 @@ +all: low_test + +CFLAGS=-I ../include/ -m32 -Ofast -Wall -Wextra -DNDEBUG + +low_test: low_test.cpp ../src/low_funct.hpp + $(CXX) -o low_test low_test.cpp $(CFLAGS) diff --git a/misc/low_test.cpp b/misc/low_test.cpp new file mode 100644 index 0000000..06418f9 --- /dev/null +++ b/misc/low_test.cpp @@ -0,0 +1,234 @@ +#include +#include + +void dump(const char *msg, const uint32_t *x, size_t n) +{ + printf("%s", msg); + for (size_t i = 0; i < n; i++) { + printf("%08x", x[n - 1 - i]); + } + printf("\n"); +} +#include "../src/low_funct.hpp" + +#define MCL_USE_VINT +#define MCL_VINT_FIXED_BUFFER +#define MCL_SIZEOF_UNIT 4 +#define MCL_MAX_BIT_SIZE 768 +#include +#include +#include +#include +#include + +const int C = 10000; + +template +void setRand(uint32_t *x, size_t n, RG& rg) +{ + for (size_t i = 0; i < n; i++) { + x[i] = rg.get32(); + } +} + +/* +g++ -Ofast -DNDEBUG -Wall -Wextra -m32 -I ./include/ misc/low_test.cpp +Core i7-8700 + mulT karatsuba +N = 6, 182clk 225clk +N = 8, 300clk 350clk +N = 12, 594clk 730clk +*/ +template +void mulTest() +{ + printf("N=%zd (%zdbit)\n", N, N * 32); + cybozu::XorShift rg; + uint32_t x[N]; + uint32_t y[N]; + uint32_t z[N * 2]; + for (size_t i = 0; i < 1000; i++) { + setRand(x, N, rg); + setRand(y, N, rg); + // remove MSB + x[N - 1] &= 0x7fffffff; + y[N - 1] &= 0x7fffffff; + mcl::Vint vx, vy; + vx.setArray(x, N); + vy.setArray(y, N); + vx *= vy; + mcl::mulT(z, x, y); + CYBOZU_TEST_EQUAL_ARRAY(z, vx.getUnit(), N * 2); + memset(z, 0, sizeof(z)); + mcl::karatsubaT(z, x, y); + CYBOZU_TEST_EQUAL_ARRAY(z, vx.getUnit(), N * 2); + } + CYBOZU_BENCH_C("mulT", C, mcl::mulT, z, x, y); + CYBOZU_BENCH_C("kara", C, mcl::karatsubaT, z, x, y); +} + +CYBOZU_TEST_AUTO(mulT) +{ + mulTest<8>(); + mulTest<12>(); +} + +template +void sqrTest() +{ + printf("N=%zd (%zdbit)\n", N, N * 32); + cybozu::XorShift rg; + uint32_t x[N]; + uint32_t y[N * 2]; + for (size_t i = 0; i < 1000; i++) { + setRand(x, N, rg); + // remove MSB + x[N - 1] &= 0x7fffffff; + mcl::Vint vx; + vx.setArray(x, N); + vx *= vx; + mcl::sqrT(y, x); + CYBOZU_TEST_EQUAL_ARRAY(y, vx.getUnit(), N * 2); + } + CYBOZU_BENCH_C("sqrT", C, mcl::sqrT, y, x); +} + +CYBOZU_TEST_AUTO(sqrT) +{ + sqrTest<8>(); + sqrTest<12>(); +} + +struct Montgomery { + mcl::Vint p_; + mcl::Vint R_; // (1 << (pn_ * 64)) % p + mcl::Vint RR_; // (R * R) % p + uint32_t rp_; // rp * p = -1 mod M = 1 << 64 + size_t pn_; + Montgomery() {} + explicit Montgomery(const mcl::Vint& p) + { + p_ = p; + rp_ = mcl::fp::getMontgomeryCoeff(p.getUnit()[0]); + pn_ = p.getUnitSize(); + R_ = 1; + R_ = (R_ << (pn_ * 64)) % p_; + RR_ = (R_ * R_) % p_; + } + + void toMont(mcl::Vint& x) const { mul(x, x, RR_); } + void fromMont(mcl::Vint& x) const { mul(x, x, 1); } + + void mul(mcl::Vint& z, const mcl::Vint& x, const mcl::Vint& y) const + { + const size_t ySize = y.getUnitSize(); + mcl::Vint c = x * y.getUnit()[0]; + uint32_t q = c.getUnit()[0] * rp_; + c += p_ * q; + c >>= sizeof(uint32_t) * 8; + for (size_t i = 1; i < pn_; i++) { + if (i < ySize) { + c += x * y.getUnit()[i]; + } + uint32_t q = c.getUnit()[0] * rp_; + c += p_ * q; + c >>= sizeof(uint32_t) * 8; + } + if (c >= p_) { + c -= p_; + } + z = c; + } + void mod(mcl::Vint& z, const mcl::Vint& xy) const + { + z = xy; + for (size_t i = 0; i < pn_; i++) { + uint32_t q = z.getUnit()[0] * rp_; + mcl::Vint t = q; + z += p_ * t; + z >>= 32; + } + if (z >= p_) { + z -= p_; + } + } +}; + +template +void mulMontTest(const char *pStr) +{ + mcl::Vint vp; + vp.setStr(pStr); + Montgomery mont(vp); + + cybozu::XorShift rg; + uint32_t x[N]; + uint32_t y[N]; + uint32_t z[N]; + uint32_t _p[N + 1]; + uint32_t *const p = _p + 1; + vp.getArray(p, N); + p[-1] = mont.rp_; + + for (size_t i = 0; i < 1000; i++) { + setRand(x, N, rg); + setRand(y, N, rg); + // remove MSB + x[N - 1] &= 0x7fffffff; + y[N - 1] &= 0x7fffffff; + mcl::Vint vx, vy, vz; + vx.setArray(x, N); + vy.setArray(y, N); + mont.mul(vz, vx, vy); + mcl::mulMontT(z, x, y, p); + CYBOZU_TEST_EQUAL_ARRAY(z, vz.getUnit(), N); + + mont.mul(vz, vx, vx); + mcl::sqrMontT(z, x, p); + CYBOZU_TEST_EQUAL_ARRAY(z, vz.getUnit(), N); + } + CYBOZU_BENCH_C("mulMontT", C, mcl::mulMontT, x, x, y, p); + CYBOZU_BENCH_C("sqrMontT", C, mcl::sqrMontT, x, x, p); +} + +template +void modTest(const char *pStr) +{ + mcl::Vint vp; + vp.setStr(pStr); + Montgomery mont(vp); + + cybozu::XorShift rg; + uint32_t xy[N * 2]; + uint32_t z[N]; + uint32_t _p[N + 1]; + uint32_t *const p = _p + 1; + vp.getArray(p, N); + p[-1] = mont.rp_; + + for (size_t i = 0; i < 1000; i++) { + setRand(xy, N * 2, rg); + // remove MSB + xy[N * 2 - 1] &= 0x7fffffff; + mcl::Vint vxy, vz; + vxy.setArray(xy, N * 2); + mont.mod(vz, vxy); + mcl::modT(z, xy, p); + CYBOZU_TEST_EQUAL_ARRAY(z, vz.getUnit(), N); + } + CYBOZU_BENCH_C("modT", C, mcl::modT, z, xy, p); +} + +CYBOZU_TEST_AUTO(mont) +{ + const char *pBN254 = "0x2523648240000001ba344d80000000086121000000000013a700000000000013"; + puts("BN254"); + mulMontTest<8>(pBN254); + modTest<8>(pBN254); + + const char *pBLS12_381 = "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab"; + puts("BLS12"); + mulMontTest<12>(pBLS12_381); + modTest<12>(pBLS12_381); +} + diff --git a/src/fp.cpp b/src/fp.cpp index eb8a7de..484ad43 100644 --- a/src/fp.cpp +++ b/src/fp.cpp @@ -3,6 +3,10 @@ #include #include #include +#if defined(__EMSCRIPTEN__) && MCL_SIZEOF_UNIT == 4 +#define FOR_WASM +#include "low_funct.hpp" +#endif #if defined(MCL_STATIC_CODE) || defined(MCL_USE_XBYAK) || (defined(MCL_USE_LLVM) && (CYBOZU_HOST == CYBOZU_HOST_INTEL)) @@ -407,6 +411,25 @@ static bool initForMont(Op& op, const Unit *p, Mode mode) return true; } +#ifdef FOR_WASM +template +void setWasmOp(Op& op) +{ + if (!(op.isMont && !op.isFullBit)) return; +EM_ASM({console.log($0)}, N); +// op.fp_addPre = mcl::addT; +// op.fp_subPre = mcl::subT; +// op.fpDbl_addPre = mcl::addT; +// op.fpDbl_subPre = mcl::subT; + op.fp_add = mcl::addModT; + op.fp_sub = mcl::subModT; + op.fp_mul = mcl::mulMontT; + op.fp_sqr = mcl::sqrMontT; + op.fpDbl_mulPre = mulT; + op.fpDbl_mod = modT; +} +#endif + bool Op::init(const mpz_class& _p, size_t maxBitSize, int _xi_a, Mode mode, size_t mclMaxBitSize) { if (mclMaxBitSize != MCL_MAX_BIT_SIZE) return false; @@ -547,6 +570,13 @@ bool Op::init(const mpz_class& _p, size_t maxBitSize, int _xi_a, Mode mode, size default: return false; } +#ifdef FOR_WASM + if (N == 8) { + setWasmOp<8>(*this); + } else if (N == 12) { + setWasmOp<12>(*this); + } +#endif #ifdef MCL_USE_LLVM if (primeMode == PM_NIST_P192) { fp_mul = &mcl_fp_mulNIST_P192L; diff --git a/src/low_funct.hpp b/src/low_funct.hpp new file mode 100644 index 0000000..885b16a --- /dev/null +++ b/src/low_funct.hpp @@ -0,0 +1,342 @@ +#pragma once +/** + @file + @author MITSUNARI Shigeo(@herumi) + @license modified new BSD license + http://opensource.org/licenses/BSD-3-Clause + @note for only 32bit not full bit prime version + assert((p[N - 1] & 0x80000000) == 0); +*/ +#include +#include +#include + + +namespace mcl { + +template +void copyT(uint32_t y[N], const uint32_t x[N]) +{ + for (size_t i = 0; i < N; i++) { + y[i] = x[i]; + } +} + +template +uint32_t shlT(uint32_t y[N], const uint32_t x[N], size_t bit) +{ + assert(0 < bit && bit < 32); + assert((N % 2) == 0); + size_t rBit = sizeof(uint32_t) * 8 - bit; + uint32_t keep = x[N - 1]; + uint32_t prev = keep; + for (size_t i = N - 1; i > 0; i--) { + uint32_t t = x[i - 1]; + y[i] = (prev << bit) | (t >> rBit); + prev = t; + } + y[0] = prev << bit; + return keep >> rBit; +} + +// [return:y[N]] += x +template +inline uint32_t addUnitT(uint32_t y[N], uint32_t x) +{ + uint64_t v = uint64_t(y[0]) + x; + y[0] = uint32_t(v); + uint32_t c = v >> 32; + if (c == 0) return 0; + for (size_t i = 1; i < N; i++) { + v = uint64_t(y[i]) + 1; + y[i] = uint32_t(v); + if ((v >> 32) == 0) return 0; + } + return 1; +} + +template +uint32_t addT(uint32_t z[N], const uint32_t x[N], const uint32_t y[N]) +{ + uint32_t c = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = uint64_t(x[i]) + y[i] + c; + z[i] = uint32_t(v); + c = uint32_t(v >> 32); + } + return c; +} + +template +uint32_t subT(uint32_t z[N], const uint32_t x[N], const uint32_t y[N]) +{ + uint32_t c = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = uint64_t(x[i]) - y[i] - c; + z[i] = uint32_t(v); + c = uint32_t(v >> 63); + } + return c; +} + +// [return:z[N]] = x[N] * y +template +uint32_t mulUnitT(uint32_t z[N], const uint32_t x[N], uint32_t y) +{ + uint32_t H = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = uint64_t(x[i]) * y; + v += H; + z[i] = uint32_t(v); + H = uint32_t(v >> 32); + } + return H; +} + +// [return:z[N]] = z[N] + x[N] * y +template +uint32_t addMulUnitT(uint32_t z[N], const uint32_t x[N], uint32_t y) +{ + uint32_t H = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = uint64_t(x[i]) * y; + v += H; + v += z[i]; + z[i] = uint32_t(v); + H = uint32_t(v >> 32); + } + return H; +} + +// z[N * 2] = x[N] * y[N] +template +void mulT(uint32_t z[N * 2], const uint32_t x[N], const uint32_t y[N]) +{ + z[N] = mulUnitT(z, x, y[0]); + for (size_t i = 1; i < N; i++) { + z[N + i] = addMulUnitT(&z[i], x, y[i]); + } +} + +#if 0 +// slower than mulT +template +uint32_t mulUnitWithTblT(uint32_t z[N], const uint64_t *tbl_j) +{ + uint32_t H = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = tbl_j[i]; + v += H; + z[i] = uint32_t(v); + H = uint32_t(v >> 32); + } + return H; +} + +template +uint32_t addMulUnitWithTblT(uint32_t z[N], const uint64_t *tbl_j) +{ + uint32_t H = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = tbl_j[i]; + v += H; + v += z[i]; + z[i] = uint32_t(v); + H = uint32_t(v >> 32); + } + return H; +} + +// y[N * 2] = x[N] * x[N] +template +void sqrT(uint32_t y[N * 2], const uint32_t x[N]) +{ + uint64_t tbl[N * N]; // x[i]x[j] + for (size_t i = 0; i < N; i++) { + uint64_t xi = x[i]; + tbl[i * N + i] = xi * xi; + for (size_t j = i + 1; j < N; j++) { + uint64_t v = xi * x[j]; + tbl[i * N + j] = v; + tbl[j * N + i] = v; + } + } + y[N] = mulUnitWithTblT(y, tbl); + for (size_t i = 1; i < N; i++) { + y[N + i] = addMulUnitWithTblT(&y[i], tbl + N * i); + } +} +#endif + +/* + z[N * 2] = x[N] * y[N] + H = N/2 + W = 1 << (H * 32) + x = aW + b, y = cW + d + assume a < W/2, c < W/2 + (aW + b)(cW + d) = acW^2 + (ad + bc)W + bd + ad + bc = (a + b)(c + d) - ac - bd < (1 << (N * 32)) + slower than mulT on Core i7 with -m32 for N <= 12 +*/ +template +void karatsubaT(uint32_t z[N * 2], const uint32_t x[N], const uint32_t y[N]) +{ + assert((N % 2) == 0); + assert((x[N - 1] & 0x80000000) == 0); + assert((y[N - 1] & 0x80000000) == 0); + const size_t H = N / 2; + uint32_t a_b[H]; + uint32_t c_d[H]; + uint32_t c1 = addT(a_b, x, x + H); // a + b + uint32_t c2 = addT(c_d, y, y + H); // c + d + uint32_t tmp[N]; + mulT(tmp, a_b, c_d); + if (c1) { + addT(tmp + H, tmp + H, c_d); + } + if (c2) { + addT(tmp + H, tmp + H, a_b); + } + mulT(z, x, y); // bd + mulT(z + N, x + H, y + H); // ac + // c:tmp[N] = (a + b)(c + d) + subT(tmp, tmp, z); + subT(tmp, tmp, z + N); + // c:tmp[N] = ad + bc + if (addT(z + H, z + H, tmp)) { + addUnitT(z + N + H, 1); + } +} + +/* + y[N * 2] = x[N] * x[N] + (aW + b)^2 = a^2 W + b^2 + 2abW + (a+b)^2 - a^2 - b^2 +*/ +template +void sqrT(uint32_t y[N * 2], const uint32_t x[N]) +{ + assert((N % 2) == 0); + assert((x[N - 1] & 0x80000000) == 0); + const size_t H = N / 2; + uint32_t a_b[H]; + uint32_t c = addT(a_b, x, x + H); // a + b + uint32_t tmp[N]; + mulT(tmp, a_b, a_b); + if (c) { + shlT(a_b, a_b, 1); + addT(tmp + H, tmp + H, a_b); + } + mulT(y, x, x); // b^2 + mulT(y + N, x + H, x + H); // a^2 + // tmp[N] = (a + b)^2 + subT(tmp, tmp, y); + subT(tmp, tmp, y + N); + // tmp[N] = 2ab + if (addT(y + H, y + H, tmp)) { + addUnitT(y + N + H, 1); + } +} + +template +void addModT(uint32_t z[N], const uint32_t x[N], const uint32_t y[N], const uint32_t p[N]) +{ + uint32_t t[N]; + addT(z, x, y); + uint32_t c = subT(t, z, p); + if (!c) { + copyT(z, t); + } +} + +template +void subModT(uint32_t z[N], const uint32_t x[N], const uint32_t y[N], const uint32_t p[N]) +{ + uint32_t c = subT(z, x, y); + if (c) { + addT(z, z, p); + } +} + +/* + z[N] = Montgomery(x[N], y[N], p[N]) + @remark : assume p[-1] = rp +*/ +template +void mulMontT(uint32_t z[N], const uint32_t x[N], const uint32_t y[N], const uint32_t p[N]) +{ + const uint32_t rp = p[-1]; + assert((p[N - 1] & 0x80000000) == 0); + uint32_t buf[N * 2]; + buf[N] = mulUnitT(buf, x, y[0]); + uint32_t q = buf[0] * rp; + buf[N] += addMulUnitT(buf, p, q); + for (size_t i = 1; i < N; i++) { + buf[N + i] = addMulUnitT(buf + i, x, y[i]); + uint32_t q = buf[i] * rp; + buf[N + i] += addMulUnitT(buf + i, p, q); + } + if (subT(z, buf + N, p)) { + copyT(z, buf + N); + } +} + +// [return:z[N+1]] = z[N+1] + x[N] * y + (cc << (N * 32)) +template +uint32_t addMulUnit2T(uint32_t z[N + 1], const uint32_t x[N], uint32_t y, const uint32_t *cc = 0) +{ + uint32_t H = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = uint64_t(x[i]) * y; + v += H; + v += z[i]; + z[i] = uint32_t(v); + H = uint32_t(v >> 32); + } + if (cc) H += *cc; + uint64_t v = uint64_t(z[N]); + v += H; + z[N] = uint32_t(v); + return uint32_t(v >> 32); +} + +/* + z[N] = Montgomery reduction(y[N], xy[N], p[N]) + @remark : assume p[-1] = rp +*/ +template +void modT(uint32_t y[N], const uint32_t xy[N * 2], const uint32_t p[N]) +{ + const uint32_t rp = p[-1]; + assert((p[N - 1] & 0x80000000) == 0); + uint32_t buf[N * 2]; + copyT(buf, xy); + uint32_t c = 0; + for (size_t i = 0; i < N; i++) { + uint32_t q = buf[i] * rp; + c = addMulUnit2T(buf + i, p, q, &c); + } + if (subT(y, buf + N, p)) { + copyT(y, buf + N); + } +} + +/* + z[N] = Montgomery(x[N], y[N], p[N]) + @remark : assume p[-1] = rp +*/ +template +void sqrMontT(uint32_t y[N], const uint32_t x[N], const uint32_t p[N]) +{ +#if 1 + mulMontT(y, x, x, p); +#else + // slower + uint32_t xx[N * 2]; + sqrT(xx, x); + modT(y, xx, p); +#endif +} + +} // mcl +