Use threading during solving constraints (#1760)

* Check for 'unchecked' before turning off Integer module

* Use threading during solving constraints

* Hide segfaults
pull/1763/head
Nikhil Parasaram 2 years ago committed by GitHub
parent 9f00f5b4e2
commit fedf68f09d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      mythril/interfaces/cli.py
  2. 95
      mythril/support/model.py

@ -626,7 +626,6 @@ def validate_args(args: Namespace):
coloredlogs.install(
fmt="%(name)s [%(levelname)s]: %(message)s", level=log_levels[args.v]
)
logging.getLogger("mythril").setLevel(log_levels[args.v])
else:
exit_with_error(
args.outform, "Invalid -v value, you can find valid values in usage"

@ -1,16 +1,22 @@
from functools import lru_cache
from z3 import sat, unknown, is_true
from pathlib import Path
from mythril.support.support_utils import ModelCache
from mythril.support.support_args import args
from mythril.laser.smt import Optimize, simplify, And
from mythril.laser.ethereum.time_handler import time_handler
from mythril.exceptions import UnsatError, SolverTimeOutException
import logging
import os
import signal
import sys
from collections import OrderedDict
from copy import deepcopy
from functools import lru_cache
from multiprocessing.pool import ThreadPool
from multiprocessing import TimeoutError
from pathlib import Path
from time import time
from z3 import sat, unknown, is_true
log = logging.getLogger(__name__)
@ -18,12 +24,10 @@ log = logging.getLogger(__name__)
model_cache = ModelCache()
@lru_cache(maxsize=2**23)
def get_model(
def solver_worker(
constraints,
minimize=(),
maximize=(),
enforce_execution_time=True,
solver_timeout=None,
):
"""
@ -31,27 +35,11 @@ def get_model(
:param constraints: Tuple of constraints
:param minimize: Tuple of minimization conditions
:param maximize: Tuple of maximization conditions
:param enforce_execution_time: Bool variable which enforces --execution-timeout's time
:param solver_timeout: The timeout for solver
:return:
"""
s = Optimize()
timeout = solver_timeout or args.solver_timeout
if enforce_execution_time:
timeout = min(timeout, time_handler.time_remaining() - 500)
if timeout <= 0:
raise UnsatError
s.set_timeout(timeout)
for constraint in constraints:
if type(constraint) == bool and not constraint:
raise UnsatError
if type(constraints) != tuple:
constraints = constraints.get_all_constraints()
constraints = [constraint for constraint in constraints if type(constraint) != bool]
if len(maximize) + len(minimize) == 0:
ret_model = model_cache.check_quick_sat(simplify(And(*constraints)).raw)
if ret_model:
return ret_model
s.set_timeout(solver_timeout)
for constraint in constraints:
s.add(constraint)
@ -73,6 +61,63 @@ def get_model(
f.write(s.sexpr())
result = s.check()
return result, s
@lru_cache(maxsize=2**23)
def get_model(
constraints,
minimize=(),
maximize=(),
solver_timeout=None,
):
"""
Returns a model based on given constraints as a tuple
:param constraints: Tuple of constraints
:param minimize: Tuple of minimization conditions
:param maximize: Tuple of maximization conditions
:param solver_timeout: The solver timeout
:return:
"""
solver_timeout = solver_timeout or args.solver_timeout
solver_timeout = min(solver_timeout, time_handler.time_remaining())
if solver_timeout <= 0:
raise SolverTimeOutException
for constraint in constraints:
if type(constraint) == bool and not constraint:
raise UnsatError
if type(constraints) != tuple:
constraints = constraints.get_all_constraints()
constraints = [constraint for constraint in constraints if type(constraint) != bool]
if len(maximize) + len(minimize) == 0:
ret_model = model_cache.check_quick_sat(simplify(And(*constraints)).raw)
if ret_model:
return ret_model
pool = ThreadPool(1)
try:
thread_result = pool.apply_async(
solver_worker, args=(constraints, minimize, maximize, solver_timeout)
)
try:
result, s = thread_result.get(solver_timeout)
except TimeoutError:
log.debug("Timeout/Error encountered while solving expression using z3")
result = unknown
except Exception:
log.warning("Encountered an exception while solving expression using z3")
result = unknown
finally:
# This is to prevent any segmentation faults from being displayed from z3
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
pool.terminate()
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
if result == sat:
model_cache.model_cache.put(s.model(), 1)
return s.model()

Loading…
Cancel
Save