diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 9cb81244..98f52d7e 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -18,6 +18,7 @@ from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.solver import Solver, Optimize +from mythril.laser.smt.model import Model from typing import Union, Any, Optional, List, TypeVar, Generic import z3 diff --git a/mythril/laser/smt/independence_solver.py b/mythril/laser/smt/independence_solver.py new file mode 100644 index 00000000..781a26fd --- /dev/null +++ b/mythril/laser/smt/independence_solver.py @@ -0,0 +1,150 @@ +import z3 + +from mythril.laser.smt.model import Model +from mythril.laser.smt.bool import Bool +from typing import Set, Tuple, Dict, List, cast + + +def _get_expr_variables(expression: z3.ExprRef) -> List[z3.ExprRef]: + """ + Gets the variables that make up the current expression + :param expression: + :return: + """ + result = [] + if not expression.children() and not isinstance(expression, z3.BitVecNumRef): + result.append(expression) + for child in expression.children(): + c_children = _get_expr_variables(child) + result.extend(c_children) + return result + + +class DependenceBucket: + """ Bucket object to contain a set of conditions that are dependent on each other """ + + def __init__(self, variables=None, conditions=None): + """ + Initializes a DependenceBucket object + :param variables: Variables contained in the conditions + :param conditions: The conditions that are dependent on each other + """ + self.variables = variables or [] # type: List[z3.ExprRef] + self.conditions = conditions or [] # type: List[z3.ExprRef] + + +class DependenceMap: + """ DependenceMap object that maintains a set of dependence buckets, used to separate independent smt queries """ + + def __init__(self): + """ Initializes a DependenceMap object """ + self.buckets = [] # type: List[DependenceBucket] + self.variable_map = {} # type: Dict[str, DependenceBucket] + + def add_condition(self, condition: z3.BoolRef) -> None: + """ + Add condition to the dependence map + :param condition: The condition that is to be added to the dependence map + """ + variables = set(_get_expr_variables(condition)) + relevant_buckets = set() + for variable in variables: + try: + bucket = self.variable_map[str(variable)] + relevant_buckets.add(bucket) + except KeyError: + continue + + new_bucket = DependenceBucket(variables, [condition]) + self.buckets.append(new_bucket) + + if relevant_buckets: + # Merge buckets, and rewrite variable map accordingly + relevant_buckets.add(new_bucket) + new_bucket = self._merge_buckets(relevant_buckets) + + for variable in variables: + self.variable_map[str(variable)] = new_bucket + + def _merge_buckets(self, bucket_list: Set[DependenceBucket]) -> DependenceBucket: + """ Merges the buckets in bucket list """ + variables = [] # type: List[str] + conditions = [] # type: List[z3.BoolRef] + for bucket in bucket_list: + self.buckets.remove(bucket) + variables += bucket.variables + conditions += bucket.conditions + + new_bucket = DependenceBucket(variables, conditions) + self.buckets.append(new_bucket) + + return new_bucket + + +class IndependenceSolver: + """An SMT solver object that uses independence optimization""" + + def __init__(self): + """""" + self.raw = z3.Solver() + self.constraints = [] + self.models = [] + + def set_timeout(self, timeout: int) -> None: + """Sets the timeout that will be used by this solver, timeout is in + milliseconds. + + :param timeout: + """ + self.raw.set(timeout=timeout) + + def add(self, *constraints: Tuple[Bool]) -> None: + """Adds the constraints to this solver. + + :param constraints: constraints to add + """ + raw_constraints = [ + c.raw for c in cast(Tuple[Bool], constraints) + ] # type: List[z3.BoolRef] + self.constraints.extend(raw_constraints) + + def append(self, *constraints: Tuple[Bool]) -> None: + """Adds the constraints to this solver. + + :param constraints: constraints to add + """ + raw_constraints = [ + c.raw for c in cast(Tuple[Bool], constraints) + ] # type: List[z3.BoolRef] + self.constraints.extend(raw_constraints) + + def check(self) -> z3.CheckSatResult: + """Returns z3 smt check result. """ + dependence_map = DependenceMap() + for constraint in self.constraints: + dependence_map.add_condition(constraint) + + self.models = [] + for bucket in dependence_map.buckets: + self.raw.reset() + self.raw.append(*bucket.conditions) + check_result = self.raw.check() + if check_result == z3.sat: + self.models.append(self.raw.model()) + else: + return check_result + + return z3.sat + + def model(self) -> Model: + """Returns z3 model for a solution. """ + return Model(self.models) + + def reset(self) -> None: + """Reset this solver.""" + self.constraints = [] + + def pop(self, num) -> None: + """Pop num constraints from this solver. + """ + self.constraints.pop(num) diff --git a/mythril/laser/smt/model.py b/mythril/laser/smt/model.py new file mode 100644 index 00000000..524683e9 --- /dev/null +++ b/mythril/laser/smt/model.py @@ -0,0 +1,59 @@ +import z3 + +from typing import Union, List + + +class Model: + """ The model class wraps a z3 model + + This implementation allows for multiple internal models, this is required for the use of an independence solver + which has models for multiple queries which need an uniform output. + """ + + def __init__(self, models: List[z3.ModelRef] = None): + """ + Initialize a model object + :param models: the internal z3 models that this model should use + """ + self.raw = models or [] + + def decls(self) -> List[z3.ExprRef]: + """Get the declarations for this model""" + result = [] # type: List[z3.ExprRef] + for internal_model in self.raw: + result.extend(internal_model.decls()) + return result + + def __getitem__(self, item) -> Union[None, z3.ExprRef]: + """ Get declaration from model + If item is an int, then the declaration at offset item is returned + If item is a declaration, then the interpretation is returned + """ + for internal_model in self.raw: + is_last_model = self.raw.index(internal_model) == len(self.raw) - 1 + + try: + result = internal_model[item] + if result is not None: + return result + except IndexError: + if is_last_model: + raise + continue + return None + + def eval( + self, expression: z3.ExprRef, model_completion: bool = False + ) -> Union[None, z3.ExprRef]: + """ Evaluate the expression using this model + + :param expression: The expression to evaluate + :param model_completion: Use the default value if the model has no interpretation of the given expression + :return: The evaluated expression + """ + for internal_model in self.raw: + is_last_model = self.raw.index(internal_model) == len(self.raw) - 1 + is_relevant_model = expression.decl() in list(internal_model.decls()) + if is_relevant_model or is_last_model: + return internal_model.eval(expression, model_completion) + return None diff --git a/mythril/laser/smt/solver.py b/mythril/laser/smt/solver.py index 21c18195..b3754e94 100644 --- a/mythril/laser/smt/solver.py +++ b/mythril/laser/smt/solver.py @@ -3,6 +3,7 @@ import z3 from typing import Union, cast, TypeVar, Generic, List, Sequence from mythril.laser.smt.expression import Expression +from mythril.laser.smt.model import Model from mythril.laser.smt.bool import Bool @@ -20,28 +21,26 @@ class BaseSolver(Generic[T]): :param timeout: """ - assert timeout > 0 # timeout <= 0 isn't supported by z3 self.raw.set(timeout=timeout) - def add(self, constraints: Union[Bool, List[Bool]]) -> None: + def add(self, *constraints: List[Bool]) -> None: """Adds the constraints to this solver. :param constraints: :return: """ - if not isinstance(constraints, list): - self.raw.add(constraints.raw) - return - z3_constraints = [c.raw for c in constraints] # type: Sequence[z3.BoolRef] + z3_constraints = [ + c.raw for c in cast(List[Bool], constraints) + ] # type: Sequence[z3.BoolRef] self.raw.add(z3_constraints) - def append(self, constraints: Union[Bool, List[Bool]]) -> None: + def append(self, *constraints: List[Bool]) -> None: """Adds the constraints to this solver. :param constraints: :return: """ - self.add(constraints) + self.add(*constraints) def check(self) -> z3.CheckSatResult: """Returns z3 smt check result. @@ -50,12 +49,12 @@ class BaseSolver(Generic[T]): """ return self.raw.check() - def model(self) -> z3.ModelRef: + def model(self) -> Model: """Returns z3 model for a solution. :return: """ - return self.raw.model() + return Model([self.raw.model()]) class Solver(BaseSolver[z3.Solver]): diff --git a/tests/laser/smt/__init__.py b/tests/laser/smt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/laser/smt/independece_solver_test.py b/tests/laser/smt/independece_solver_test.py new file mode 100644 index 00000000..829b4707 --- /dev/null +++ b/tests/laser/smt/independece_solver_test.py @@ -0,0 +1,145 @@ +from mythril.laser.smt.independence_solver import ( + _get_expr_variables, + DependenceBucket, + DependenceMap, + IndependenceSolver, +) +from mythril.laser.smt import symbol_factory + +import z3 + + +def test_get_expr_variables(): + # Arrange + x = z3.Bool("x") + y = z3.BitVec("y", 256) + z = z3.BitVec("z", 256) + b = z3.BitVec("b", 256) + expression = z3.If(x, y, z + b) + + # Act + variables = list(map(str, _get_expr_variables(expression))) + + # Assert + assert str(x) in variables + assert str(y) in variables + assert str(z) in variables + assert str(b) in variables + + +def test_get_expr_variables_num(): + # Arrange + b = z3.BitVec("b", 256) + expression = b + z3.BitVecVal(2, 256) + + # Act + variables = _get_expr_variables(expression) + + # Assert + assert [b] == variables + + +def test_create_bucket(): + # Arrange + x = z3.Bool("x") + + # Act + bucket = DependenceBucket([x], [x]) + + # Assert + assert [x] == bucket.variables + assert [x] == bucket.conditions + + +def test_dependence_map(): + # Arrange + x = z3.BitVec("x", 256) + y = z3.BitVec("y", 256) + z = z3.BitVec("z", 256) + a = z3.BitVec("a", 256) + b = z3.BitVec("b", 256) + + conditions = [x > y, y == z, a == b] + + dependence_map = DependenceMap() + + # Act + for condition in conditions: + dependence_map.add_condition(condition) + + # Assert + assert 2 == len(dependence_map.buckets) + + assert x in dependence_map.buckets[0].variables + assert y in dependence_map.buckets[0].variables + assert z in dependence_map.buckets[0].variables + assert len(set(dependence_map.buckets[0].variables)) == 3 + + assert conditions[0] in dependence_map.buckets[0].conditions + assert conditions[1] in dependence_map.buckets[0].conditions + + assert a in dependence_map.buckets[1].variables + assert b in dependence_map.buckets[1].variables + assert len(set(dependence_map.buckets[1].variables)) == 2 + + assert conditions[2] in dependence_map.buckets[1].conditions + + +def test_Independence_solver_unsat(): + # Arrange + x = symbol_factory.BitVecSym("x", 256) + y = symbol_factory.BitVecSym("y", 256) + z = symbol_factory.BitVecSym("z", 256) + a = symbol_factory.BitVecSym("a", 256) + b = symbol_factory.BitVecSym("b", 256) + + conditions = [x > y, y == z, y != z, a == b] + + solver = IndependenceSolver() + + # Act + solver.add(*conditions) + result = solver.check() + + # Assert + assert z3.unsat == result + + +def test_independence_solver_unsat_in_second_bucket(): + # Arrange + x = symbol_factory.BitVecSym("x", 256) + y = symbol_factory.BitVecSym("y", 256) + z = symbol_factory.BitVecSym("z", 256) + a = symbol_factory.BitVecSym("a", 256) + b = symbol_factory.BitVecSym("b", 256) + + conditions = [x > y, y == z, a == b, a != b] + + solver = IndependenceSolver() + + # Act + solver.add(*conditions) + result = solver.check() + + # Assert + assert z3.unsat == result + + +def test_independence_solver_sat(): + # Arrange + x = symbol_factory.BitVecSym("x", 256) + y = symbol_factory.BitVecSym("y", 256) + z = symbol_factory.BitVecSym("z", 256) + a = symbol_factory.BitVecSym("a", 256) + b = symbol_factory.BitVecSym("b", 256) + + conditions = [x > y, y == z, a == b] + + solver = IndependenceSolver() + + # Act + solver.add(*conditions) + result = solver.check() + + # Assert + assert z3.sat == result diff --git a/tests/laser/smt/model_test.py b/tests/laser/smt/model_test.py new file mode 100644 index 00000000..7f9ae1d3 --- /dev/null +++ b/tests/laser/smt/model_test.py @@ -0,0 +1,56 @@ +from mythril.laser.smt import Solver, symbol_factory +import z3 + + +def test_decls(): + # Arrange + solver = Solver() + x = symbol_factory.BitVecSym("x", 256) + expression = x == symbol_factory.BitVecVal(2, 256) + + # Act + solver.add(expression) + result = solver.check() + model = solver.model() + + decls = model.decls() + + # Assert + assert z3.sat == result + assert x.raw.decl() in decls + + +def test_get_item(): + # Arrange + solver = Solver() + x = symbol_factory.BitVecSym("x", 256) + expression = x == symbol_factory.BitVecVal(2, 256) + + # Act + solver.add(expression) + result = solver.check() + model = solver.model() + + x_concrete = model[x.raw.decl()] + + # Assert + assert z3.sat == result + assert 2 == x_concrete + + +def test_as_long(): + # Arrange + solver = Solver() + x = symbol_factory.BitVecSym("x", 256) + expression = x == symbol_factory.BitVecVal(2, 256) + + # Act + solver.add(expression) + result = solver.check() + model = solver.model() + + x_concrete = model.eval(x.raw).as_long() + + # Assert + assert z3.sat == result + assert 2 == x_concrete diff --git a/tests/laser/state/calldata_test.py b/tests/laser/state/calldata_test.py index f32f297a..46c28585 100644 --- a/tests/laser/state/calldata_test.py +++ b/tests/laser/state/calldata_test.py @@ -48,7 +48,7 @@ def test_concrete_calldata_constrain_index(): value = calldata[2] constraint = value == 3 - solver.add([constraint]) + solver.add(constraint) result = solver.check() # Assert @@ -65,7 +65,7 @@ def test_symbolic_calldata_constrain_index(): constraints = [value == 1, calldata.calldatasize == 50] - solver.add(constraints) + solver.add(*constraints) result = solver.check()