diff --git a/include/mcl/fp.hpp b/include/mcl/fp.hpp index 236820b..aadf8c4 100644 --- a/include/mcl/fp.hpp +++ b/include/mcl/fp.hpp @@ -119,8 +119,13 @@ public: #else if (mode == fp::FP_LLVM || mode == fp::FP_LLVM_MONT) mode = fp::FP_AUTO; #endif - if (mode == fp::FP_AUTO) mode = fp::FP_GMP_MONT; - if (maxBitSize > 576 && mode != fp::FP_GMP_MONT) mode = fp::FP_GMP; + if (mode == fp::FP_AUTO) { + if (maxBitSize > 576) { + mode = fp::FP_GMP; // QQQ : slower than FP_GMP_MONT if maxBitSize == 768 + } else { + mode = fp::FP_GMP_MONT; + } + } op_.isMont = mode == fp::FP_GMP_MONT || mode == fp::FP_LLVM_MONT || mode == fp::FP_XBYAK; if (mode == fp::FP_GMP_MONT || mode == fp::FP_LLVM_MONT) { diff --git a/sample/large.cpp b/sample/large.cpp index f7fbd86..fcfd93b 100644 --- a/sample/large.cpp +++ b/sample/large.cpp @@ -23,6 +23,10 @@ int main() { test(mcl::fp::FP_GMP); test(mcl::fp::FP_GMP_MONT); +#ifdef MCL_USE_LLVM + test(mcl::fp::FP_LLVM); + test(mcl::fp::FP_LLVM_MONT); +#endif } catch (std::exception& e) { printf("err %s\n", e.what()); puts("make clean"); diff --git a/src/fp_proto.hpp b/src/fp_proto.hpp index 87c728e..e68d19b 100644 --- a/src/fp_proto.hpp +++ b/src/fp_proto.hpp @@ -50,6 +50,9 @@ void mcl_fpDbl_add224(mcl::fp::Unit*, const mcl::fp::Unit*, const mcl::fp::Unit* void mcl_fpDbl_sub224(mcl::fp::Unit*, const mcl::fp::Unit*, const mcl::fp::Unit*, const mcl::fp::Unit*); #else MCL_FP_DEF_FUNC(576) +MCL_FP_DEF_FUNC(640) +MCL_FP_DEF_FUNC(704) +MCL_FP_DEF_FUNC(768) #endif #undef MCL_FP_DEF_FUNC diff --git a/src/gen.cpp b/src/gen.cpp index abb128c..f453004 100644 --- a/src/gen.cpp +++ b/src/gen.cpp @@ -1,5 +1,6 @@ #include "llvm_gen.hpp" #include +#include #include #include #include @@ -810,11 +811,11 @@ struct Code : public mcl::Generator { unit2 = unit * 2; unitStr = cybozu::itoa(unit); } - void gen(const StrSet& privateFuncList) + void gen(const StrSet& privateFuncList, uint32_t maxBitSize) { this->privateFuncList = &privateFuncList; gen_once(); - uint32_t end = ((576 + unit - 1) / unit) * unit; + uint32_t end = ((maxBitSize + unit - 1) / unit) * unit; for (uint32_t i = 64; i <= end; i += unit) { setBit(i); gen_all(); @@ -854,7 +855,8 @@ int main(int argc, char *argv[]) c.setOldLLVM(); } c.setUnit(unit); - c.gen(privateFuncList); + uint32_t maxBitSize = MCL_MAX_OP_BIT_SIZE; + c.gen(privateFuncList, maxBitSize); } catch (std::exception& e) { printf("ERR %s\n", e.what()); return 1;