From 91a696e53db30c402667a9e46a156e46ed621713 Mon Sep 17 00:00:00 2001 From: MITSUNARI Shigeo Date: Tue, 31 May 2016 17:44:20 +0900 Subject: [PATCH] add karatuba(not enabled) --- src/gen.cpp | 112 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 92 insertions(+), 20 deletions(-) diff --git a/src/gen.cpp b/src/gen.cpp index efcee45..a17ac99 100644 --- a/src/gen.cpp +++ b/src/gen.cpp @@ -42,7 +42,7 @@ struct Code : public mcl::Generator { x = zext(x, unit2); y = zext(y, unit2); - z= mul(x, y); + z = mul(x, y); ret(z); endFunc(); } @@ -549,27 +549,99 @@ struct Code : public mcl::Generator { } void generic_fpDbl_mul(const Operand& pz, const Operand& px, const Operand& py) { - const int bu = bit + unit; - Operand y = load(py); - Operand xy = call(mulPvM[bit], px, y); - store(trunc(xy, unit), pz); - Operand t = lshr(xy, unit); - Operand z, pzi; - for (uint32_t i = 1; i < N; i++) { - Operand pyi = getelementptr(py, makeImm(32, i)); - y = load(pyi); - xy = call(mulPvM[bit], px, y); - t = add(t, xy); - z = trunc(t, unit); - pzi = getelementptr(pz, makeImm(32, i)); - if (i < N - 1) { - store(z, pzi); - t = lshr(t, unit); + if (N == 1) { + Operand x = load(px); + Operand y = load(py); + x = zext(x, unit * 2); + y = zext(y, unit * 2); + Operand z = mul(x, y); + store(z, bitcast(pz, Operand(IntPtr, unit * 2))); + ret(Void); + } else if (N >= 32 && (N % 2) == 0) { + /* + W = 1 << half + (aW + b)(cW + d) = acW^2 + (ad + bc)W + bd + ad + bc = (a + b)(c + d) - ac - bd + */ + const int half = bit / 2; + Operand pxW = getelementptr(px, makeImm(32, N / 2)); + Operand pyW = getelementptr(py, makeImm(32, N / 2)); + Operand pzWW = getelementptr(pz, makeImm(32, N)); + call(mcl_fpDbl_mulPreM[half], pz, px, py); // bd + call(mcl_fpDbl_mulPreM[half], pzWW, pxW, pyW); // ac + + Operand pa = bitcast(pxW, Operand(IntPtr, half)); + Operand pb = bitcast(px, Operand(IntPtr, half)); + Operand pc = bitcast(pyW, Operand(IntPtr, half)); + Operand pd = bitcast(py, Operand(IntPtr, half)); + Operand a = zext(load(pa), half + unit); + Operand b = zext(load(pb), half + unit); + Operand c = zext(load(pc), half + unit); + Operand d = zext(load(pd), half + unit); + Operand t1 = add(a, b); + Operand t2 = add(c, d); + Operand buf = _alloca(unit, N); + Operand t1L = trunc(t1, half); + Operand t2L = trunc(t2, half); + Operand c1 = trunc(lshr(t1, half), 1); + Operand c2 = trunc(lshr(t2, half), 1); + Operand c0 = _and(c1, c2); + c1 = select(c1, t2L, makeImm(half, 0)); + c2 = select(c2, t1L, makeImm(half, 0)); + Operand buf1 = _alloca(half, 1); + Operand buf2 = _alloca(half, 1); + store(t1L, buf1); + store(t2L, buf2); + buf1 = bitcast(buf1, Operand(IntPtr, unit)); + buf2 = bitcast(buf2, Operand(IntPtr, unit)); + call(mcl_fpDbl_mulPreM[half], buf, buf1, buf2); + buf = bitcast(buf, Operand(IntPtr, bit)); + Operand t = load(buf); + t = zext(t, bit + unit); + c0 = zext(c0, bit + unit); + c0 = shl(c0, bit); + t = _or(t, c0); + c1 = zext(c1, bit + unit); + c2 = zext(c2, bit + unit); + c1 = shl(c1, half); + c2 = shl(c2, half); + t = add(t, c1); + t = add(t, c2); + Operand pzL = bitcast(pz, Operand(IntPtr, bit)); + Operand pzH = getelementptr(pzL, makeImm(32, 1)); + t = sub(t, zext(load(pzL), bit + unit)); + t = sub(t, zext(load(pzH), bit + unit)); + pzL = getelementptr(pz, makeImm(32, N / 2)); + pzL = bitcast(pzL, Operand(IntPtr, bit + half)); + if (bit + half > t.bit) { + t = zext(t, bit + half); } + t = add(t, load(pzL)); + store(t, pzL); + ret(Void); + } else { + const int bu = bit + unit; + Operand y = load(py); + Operand xy = call(mulPvM[bit], px, y); + store(trunc(xy, unit), pz); + Operand t = lshr(xy, unit); + Operand z, pzi; + for (uint32_t i = 1; i < N; i++) { + Operand pyi = getelementptr(py, makeImm(32, i)); + y = load(pyi); + xy = call(mulPvM[bit], px, y); + t = add(t, xy); + z = trunc(t, unit); + pzi = getelementptr(pz, makeImm(32, i)); + if (i < N - 1) { + store(z, pzi); + t = lshr(t, unit); + } + } + pzi = bitcast(pzi, Operand(IntPtr, bu)); + store(t, pzi); + ret(Void); } - pzi = bitcast(pzi, Operand(IntPtr, bu)); - store(t, pzi); - ret(Void); } void gen_mcl_fpDbl_mulPre() {