diff --git a/include/mcl/vint.hpp b/include/mcl/vint.hpp index 13c4483..3bed527 100644 --- a/include/mcl/vint.hpp +++ b/include/mcl/vint.hpp @@ -90,35 +90,19 @@ inline uint32_t mulUnit(uint32_t *pH, uint32_t x, uint32_t y) inline uint64_t mulUnit(uint64_t *pH, uint64_t x, uint64_t y) { #ifdef MCL_VINT_64BIT_PORTABLE - uint32_t a = uint32_t(x >> 32); - uint32_t b = uint32_t(x); - uint32_t c = uint32_t(y >> 32); - uint32_t d = uint32_t(y); - - uint64_t ad = uint64_t(d) * a; - uint64_t bd = uint64_t(d) * b; - uint64_t L = uint32_t(bd); - ad += bd >> 32; // [ad:L] - - uint64_t ac = uint64_t(c) * a; - uint64_t bc = uint64_t(c) * b; - uint64_t H = uint32_t(bc); - ac += bc >> 32; // [ac:H] - /* - adL - acH - */ - uint64_t t = (ac << 32) | H; - ac >>= 32; - H = t + ad; - if (H < t) { - ac++; - } - /* - ac:H:L - */ + const uint64_t mask = 0xffffffff; + uint64_t v = (x & mask) * (y & mask); + uint64_t L = uint32_t(v); + uint64_t H = v >> 32; + uint64_t ad = (x & mask) * uint32_t(y >> 32); + uint64_t bc = uint32_t(x >> 32) * (y & mask); + H += uint32_t(ad); + H += uint32_t(bc); L |= H << 32; - H = (ac << 32) | uint32_t(H >> 32); + H >>= 32; + H += ad >> 32; + H += bc >> 32; + H += (x >> 32) * (y >> 32); *pH = H; return L; #elif defined(_WIN64) && !defined(__INTEL_COMPILER)