From a3293c2c85582e4874e5e8599948711babfafdea Mon Sep 17 00:00:00 2001 From: MITSUNARI Shigeo Date: Wed, 3 Feb 2021 17:06:29 +0900 Subject: [PATCH] add sqrT --- misc/Makefile | 2 +- misc/low_test.cpp | 7 ------ src/low_funct.hpp | 57 ++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/misc/Makefile b/misc/Makefile index 1a727f3..25a7c27 100644 --- a/misc/Makefile +++ b/misc/Makefile @@ -2,5 +2,5 @@ all: low_test CFLAGS=-I ../include/ -m32 -Ofast -Wall -Wextra -DNDEBUG -low_test: low_test.cpp +low_test: low_test.cpp ../src/low_funct.hpp $(CXX) -o low_test low_test.cpp $(CFLAGS) diff --git a/misc/low_test.cpp b/misc/low_test.cpp index a77ed36..540e3d9 100644 --- a/misc/low_test.cpp +++ b/misc/low_test.cpp @@ -69,7 +69,6 @@ CYBOZU_TEST_AUTO(mulT) { mulTest<8>(); mulTest<12>(); - mulTest<16>(); } template @@ -88,11 +87,6 @@ void sqrTest() vx *= vx; mcl::sqrT(y, x); CYBOZU_TEST_EQUAL_ARRAY(y, vx.getUnit(), N * 2); -#if 0 - memset(z, 0, sizeof(z)); - mcl::karatsubaT(z, x, y); - CYBOZU_TEST_EQUAL_ARRAY(z, vx.getUnit(), N * 2); -#endif } CYBOZU_BENCH_C("sqrT", 10000, mcl::sqrT, y, x); } @@ -101,7 +95,6 @@ CYBOZU_TEST_AUTO(sqrT) { sqrTest<8>(); sqrTest<12>(); - sqrTest<16>(); } struct Montgomery { diff --git a/src/low_funct.hpp b/src/low_funct.hpp index f25559c..2f3a2a1 100644 --- a/src/low_funct.hpp +++ b/src/low_funct.hpp @@ -22,6 +22,23 @@ void copyT(uint32_t y[N], const uint32_t x[N]) } } +template +uint32_t shlT(uint32_t y[N], const uint32_t x[N], size_t bit) +{ + assert(0 < bit && bit < 32); + assert((N % 2) == 0); + size_t rBit = sizeof(uint32_t) * 8 - bit; + uint32_t keep = x[N - 1]; + uint32_t prev = keep; + for (size_t i = N - 1; i > 0; i--) { + uint32_t t = x[i - 1]; + y[i] = (prev << bit) | (t >> rBit); + prev = t; + } + y[0] = prev << bit; + return keep >> rBit; +} + // [return:y[N]] += x template inline bool addUnitT(uint32_t y[N], uint32_t x) @@ -101,6 +118,8 @@ void mulT(uint32_t z[N * 2], const uint32_t x[N], const uint32_t y[N]) } } +#if 0 +// slower than mulT template uint32_t mulUnitWithTblT(uint32_t z[N], const uint64_t *tbl_j) { @@ -128,7 +147,6 @@ uint32_t addMulUnitWithTblT(uint32_t z[N], const uint64_t *tbl_j) return H; } - // y[N * 2] = x[N] * x[N] template void sqrT(uint32_t y[N * 2], const uint32_t x[N]) @@ -148,6 +166,7 @@ void sqrT(uint32_t y[N * 2], const uint32_t x[N]) y[N + i] = addMulUnitWithTblT(&y[i], tbl + N * i); } } +#endif /* z[N * 2] = x[N] * y[N] @@ -157,6 +176,7 @@ void sqrT(uint32_t y[N * 2], const uint32_t x[N]) assume a < W/2, c < W/2 (aW + b)(cW + d) = acW^2 + (ad + bc)W + bd ad + bc = (a + b)(c + d) - ac - bd < (1 << (N * 32)) + slower than mulT on Core i7 with -m32 for N <= 12 */ template void karatsubaT(uint32_t z[N * 2], const uint32_t x[N], const uint32_t y[N]) @@ -165,8 +185,6 @@ void karatsubaT(uint32_t z[N * 2], const uint32_t x[N], const uint32_t y[N]) assert((x[N - 1] & 0x80000000) == 0); assert((y[N - 1] & 0x80000000) == 0); const size_t H = N / 2; - mulT(z, x, y); // bd - mulT(z + N, x + H, y + H); // ac uint32_t a_b[H]; uint32_t c_d[H]; bool c1 = addT(a_b, x, x + H); // a + b @@ -179,6 +197,8 @@ void karatsubaT(uint32_t z[N * 2], const uint32_t x[N], const uint32_t y[N]) if (c2) { addT(tmp + H, tmp + H, a_b); } + mulT(z, x, y); // bd + mulT(z + N, x + H, y + H); // ac // c:tmp[N] = (a + b)(c + d) subT(tmp, tmp, z); subT(tmp, tmp, z + N); @@ -188,6 +208,37 @@ void karatsubaT(uint32_t z[N * 2], const uint32_t x[N], const uint32_t y[N]) } } +/* + y[N * 2] = x[N] * x[N] + (aW + b)^2 = a^2 W + b^2 + 2abW + (a+b)^2 - a^2 - b^2 +*/ +template +void sqrT(uint32_t y[N * 2], const uint32_t x[N]) +{ + assert((N % 2) == 0); + assert((x[N - 1] & 0x80000000) == 0); + const size_t H = N / 2; + uint32_t a_b[H]; + bool c = addT(a_b, x, x + H); // a + b + uint32_t tmp[N]; + mulT(tmp, a_b, a_b); + if (c) { +// addT(a_b, a_b, a_b); + shlT(a_b, a_b, 1); + addT(tmp + H, tmp + H, a_b); + } + mulT(y, x, x); // b^2 + mulT(y + N, x + H, x + H); // a^2 + // tmp[N] = (a + b)^2 + subT(tmp, tmp, y); + subT(tmp, tmp, y + N); + // tmp[N] = 2ab + if (addT(y + H, y + H, tmp)) { + addUnitT(y + N + H, 1); + } +} + template void addModT(uint32_t z[N], const uint32_t x[N], const uint32_t y[N], const uint32_t p[N]) {