diff --git a/misc/low_test.cpp b/misc/low_test.cpp index 856be61..a77ed36 100644 --- a/misc/low_test.cpp +++ b/misc/low_test.cpp @@ -69,6 +69,39 @@ CYBOZU_TEST_AUTO(mulT) { mulTest<8>(); mulTest<12>(); + mulTest<16>(); +} + +template +void sqrTest() +{ + printf("N=%zd (%zdbit)\n", N, N * 32); + cybozu::XorShift rg; + uint32_t x[N]; + uint32_t y[N * 2]; + for (size_t i = 0; i < 1000; i++) { + setRand(x, N, rg); + // remove MSB + x[N - 1] &= 0x7fffffff; + mcl::Vint vx; + vx.setArray(x, N); + vx *= vx; + mcl::sqrT(y, x); + CYBOZU_TEST_EQUAL_ARRAY(y, vx.getUnit(), N * 2); +#if 0 + memset(z, 0, sizeof(z)); + mcl::karatsubaT(z, x, y); + CYBOZU_TEST_EQUAL_ARRAY(z, vx.getUnit(), N * 2); +#endif + } + CYBOZU_BENCH_C("sqrT", 10000, mcl::sqrT, y, x); +} + +CYBOZU_TEST_AUTO(sqrT) +{ + sqrTest<8>(); + sqrTest<12>(); + sqrTest<16>(); } struct Montgomery { diff --git a/src/low_funct.hpp b/src/low_funct.hpp index 5db871b..f25559c 100644 --- a/src/low_funct.hpp +++ b/src/low_funct.hpp @@ -101,6 +101,54 @@ void mulT(uint32_t z[N * 2], const uint32_t x[N], const uint32_t y[N]) } } +template +uint32_t mulUnitWithTblT(uint32_t z[N], const uint64_t *tbl_j) +{ + uint32_t H = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = tbl_j[i]; + v += H; + z[i] = uint32_t(v); + H = uint32_t(v >> 32); + } + return H; +} + +template +uint32_t addMulUnitWithTblT(uint32_t z[N], const uint64_t *tbl_j) +{ + uint32_t H = 0; + for (size_t i = 0; i < N; i++) { + uint64_t v = tbl_j[i]; + v += H; + v += z[i]; + z[i] = uint32_t(v); + H = uint32_t(v >> 32); + } + return H; +} + + +// y[N * 2] = x[N] * x[N] +template +void sqrT(uint32_t y[N * 2], const uint32_t x[N]) +{ + uint64_t tbl[N * N]; // x[i]x[j] + for (size_t i = 0; i < N; i++) { + uint64_t xi = x[i]; + tbl[i * N + i] = xi * xi; + for (size_t j = i + 1; j < N; j++) { + uint64_t v = xi * x[j]; + tbl[i * N + j] = v; + tbl[j * N + i] = v; + } + } + y[N] = mulUnitWithTblT(y, tbl); + for (size_t i = 1; i < N; i++) { + y[N + i] = addMulUnitWithTblT(&y[i], tbl + N * i); + } +} + /* z[N * 2] = x[N] * y[N] H = N/2