Merge pull request #902 from JoranHonig/features/independence_optimization

Independence optimization
pull/913/head
JoranHonig 6 years ago committed by GitHub
commit 990bff2a0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      mythril/laser/smt/__init__.py
  2. 150
      mythril/laser/smt/independence_solver.py
  3. 59
      mythril/laser/smt/model.py
  4. 19
      mythril/laser/smt/solver.py
  5. 0
      tests/laser/smt/__init__.py
  6. 145
      tests/laser/smt/independece_solver_test.py
  7. 56
      tests/laser/smt/model_test.py
  8. 4
      tests/laser/state/calldata_test.py

@ -18,6 +18,7 @@ from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.solver import Solver, Optimize
from mythril.laser.smt.model import Model
from typing import Union, Any, Optional, List, TypeVar, Generic
import z3

@ -0,0 +1,150 @@
import z3
from mythril.laser.smt.model import Model
from mythril.laser.smt.bool import Bool
from typing import Set, Tuple, Dict, List, cast
def _get_expr_variables(expression: z3.ExprRef) -> List[z3.ExprRef]:
"""
Gets the variables that make up the current expression
:param expression:
:return:
"""
result = []
if not expression.children() and not isinstance(expression, z3.BitVecNumRef):
result.append(expression)
for child in expression.children():
c_children = _get_expr_variables(child)
result.extend(c_children)
return result
class DependenceBucket:
""" Bucket object to contain a set of conditions that are dependent on each other """
def __init__(self, variables=None, conditions=None):
"""
Initializes a DependenceBucket object
:param variables: Variables contained in the conditions
:param conditions: The conditions that are dependent on each other
"""
self.variables = variables or [] # type: List[z3.ExprRef]
self.conditions = conditions or [] # type: List[z3.ExprRef]
class DependenceMap:
""" DependenceMap object that maintains a set of dependence buckets, used to separate independent smt queries """
def __init__(self):
""" Initializes a DependenceMap object """
self.buckets = [] # type: List[DependenceBucket]
self.variable_map = {} # type: Dict[str, DependenceBucket]
def add_condition(self, condition: z3.BoolRef) -> None:
"""
Add condition to the dependence map
:param condition: The condition that is to be added to the dependence map
"""
variables = set(_get_expr_variables(condition))
relevant_buckets = set()
for variable in variables:
try:
bucket = self.variable_map[str(variable)]
relevant_buckets.add(bucket)
except KeyError:
continue
new_bucket = DependenceBucket(variables, [condition])
self.buckets.append(new_bucket)
if relevant_buckets:
# Merge buckets, and rewrite variable map accordingly
relevant_buckets.add(new_bucket)
new_bucket = self._merge_buckets(relevant_buckets)
for variable in variables:
self.variable_map[str(variable)] = new_bucket
def _merge_buckets(self, bucket_list: Set[DependenceBucket]) -> DependenceBucket:
""" Merges the buckets in bucket list """
variables = [] # type: List[str]
conditions = [] # type: List[z3.BoolRef]
for bucket in bucket_list:
self.buckets.remove(bucket)
variables += bucket.variables
conditions += bucket.conditions
new_bucket = DependenceBucket(variables, conditions)
self.buckets.append(new_bucket)
return new_bucket
class IndependenceSolver:
"""An SMT solver object that uses independence optimization"""
def __init__(self):
""""""
self.raw = z3.Solver()
self.constraints = []
self.models = []
def set_timeout(self, timeout: int) -> None:
"""Sets the timeout that will be used by this solver, timeout is in
milliseconds.
:param timeout:
"""
self.raw.set(timeout=timeout)
def add(self, *constraints: Tuple[Bool]) -> None:
"""Adds the constraints to this solver.
:param constraints: constraints to add
"""
raw_constraints = [
c.raw for c in cast(Tuple[Bool], constraints)
] # type: List[z3.BoolRef]
self.constraints.extend(raw_constraints)
def append(self, *constraints: Tuple[Bool]) -> None:
"""Adds the constraints to this solver.
:param constraints: constraints to add
"""
raw_constraints = [
c.raw for c in cast(Tuple[Bool], constraints)
] # type: List[z3.BoolRef]
self.constraints.extend(raw_constraints)
def check(self) -> z3.CheckSatResult:
"""Returns z3 smt check result. """
dependence_map = DependenceMap()
for constraint in self.constraints:
dependence_map.add_condition(constraint)
self.models = []
for bucket in dependence_map.buckets:
self.raw.reset()
self.raw.append(*bucket.conditions)
check_result = self.raw.check()
if check_result == z3.sat:
self.models.append(self.raw.model())
else:
return check_result
return z3.sat
def model(self) -> Model:
"""Returns z3 model for a solution. """
return Model(self.models)
def reset(self) -> None:
"""Reset this solver."""
self.constraints = []
def pop(self, num) -> None:
"""Pop num constraints from this solver.
"""
self.constraints.pop(num)

@ -0,0 +1,59 @@
import z3
from typing import Union, List
class Model:
""" The model class wraps a z3 model
This implementation allows for multiple internal models, this is required for the use of an independence solver
which has models for multiple queries which need an uniform output.
"""
def __init__(self, models: List[z3.ModelRef] = None):
"""
Initialize a model object
:param models: the internal z3 models that this model should use
"""
self.raw = models or []
def decls(self) -> List[z3.ExprRef]:
"""Get the declarations for this model"""
result = [] # type: List[z3.ExprRef]
for internal_model in self.raw:
result.extend(internal_model.decls())
return result
def __getitem__(self, item) -> Union[None, z3.ExprRef]:
""" Get declaration from model
If item is an int, then the declaration at offset item is returned
If item is a declaration, then the interpretation is returned
"""
for internal_model in self.raw:
is_last_model = self.raw.index(internal_model) == len(self.raw) - 1
try:
result = internal_model[item]
if result is not None:
return result
except IndexError:
if is_last_model:
raise
continue
return None
def eval(
self, expression: z3.ExprRef, model_completion: bool = False
) -> Union[None, z3.ExprRef]:
""" Evaluate the expression using this model
:param expression: The expression to evaluate
:param model_completion: Use the default value if the model has no interpretation of the given expression
:return: The evaluated expression
"""
for internal_model in self.raw:
is_last_model = self.raw.index(internal_model) == len(self.raw) - 1
is_relevant_model = expression.decl() in list(internal_model.decls())
if is_relevant_model or is_last_model:
return internal_model.eval(expression, model_completion)
return None

@ -3,6 +3,7 @@ import z3
from typing import Union, cast, TypeVar, Generic, List, Sequence
from mythril.laser.smt.expression import Expression
from mythril.laser.smt.model import Model
from mythril.laser.smt.bool import Bool
@ -20,28 +21,26 @@ class BaseSolver(Generic[T]):
:param timeout:
"""
assert timeout > 0 # timeout <= 0 isn't supported by z3
self.raw.set(timeout=timeout)
def add(self, constraints: Union[Bool, List[Bool]]) -> None:
def add(self, *constraints: List[Bool]) -> None:
"""Adds the constraints to this solver.
:param constraints:
:return:
"""
if not isinstance(constraints, list):
self.raw.add(constraints.raw)
return
z3_constraints = [c.raw for c in constraints] # type: Sequence[z3.BoolRef]
z3_constraints = [
c.raw for c in cast(List[Bool], constraints)
] # type: Sequence[z3.BoolRef]
self.raw.add(z3_constraints)
def append(self, constraints: Union[Bool, List[Bool]]) -> None:
def append(self, *constraints: List[Bool]) -> None:
"""Adds the constraints to this solver.
:param constraints:
:return:
"""
self.add(constraints)
self.add(*constraints)
def check(self) -> z3.CheckSatResult:
"""Returns z3 smt check result.
@ -50,12 +49,12 @@ class BaseSolver(Generic[T]):
"""
return self.raw.check()
def model(self) -> z3.ModelRef:
def model(self) -> Model:
"""Returns z3 model for a solution.
:return:
"""
return self.raw.model()
return Model([self.raw.model()])
class Solver(BaseSolver[z3.Solver]):

@ -0,0 +1,145 @@
from mythril.laser.smt.independence_solver import (
_get_expr_variables,
DependenceBucket,
DependenceMap,
IndependenceSolver,
)
from mythril.laser.smt import symbol_factory
import z3
def test_get_expr_variables():
# Arrange
x = z3.Bool("x")
y = z3.BitVec("y", 256)
z = z3.BitVec("z", 256)
b = z3.BitVec("b", 256)
expression = z3.If(x, y, z + b)
# Act
variables = list(map(str, _get_expr_variables(expression)))
# Assert
assert str(x) in variables
assert str(y) in variables
assert str(z) in variables
assert str(b) in variables
def test_get_expr_variables_num():
# Arrange
b = z3.BitVec("b", 256)
expression = b + z3.BitVecVal(2, 256)
# Act
variables = _get_expr_variables(expression)
# Assert
assert [b] == variables
def test_create_bucket():
# Arrange
x = z3.Bool("x")
# Act
bucket = DependenceBucket([x], [x])
# Assert
assert [x] == bucket.variables
assert [x] == bucket.conditions
def test_dependence_map():
# Arrange
x = z3.BitVec("x", 256)
y = z3.BitVec("y", 256)
z = z3.BitVec("z", 256)
a = z3.BitVec("a", 256)
b = z3.BitVec("b", 256)
conditions = [x > y, y == z, a == b]
dependence_map = DependenceMap()
# Act
for condition in conditions:
dependence_map.add_condition(condition)
# Assert
assert 2 == len(dependence_map.buckets)
assert x in dependence_map.buckets[0].variables
assert y in dependence_map.buckets[0].variables
assert z in dependence_map.buckets[0].variables
assert len(set(dependence_map.buckets[0].variables)) == 3
assert conditions[0] in dependence_map.buckets[0].conditions
assert conditions[1] in dependence_map.buckets[0].conditions
assert a in dependence_map.buckets[1].variables
assert b in dependence_map.buckets[1].variables
assert len(set(dependence_map.buckets[1].variables)) == 2
assert conditions[2] in dependence_map.buckets[1].conditions
def test_Independence_solver_unsat():
# Arrange
x = symbol_factory.BitVecSym("x", 256)
y = symbol_factory.BitVecSym("y", 256)
z = symbol_factory.BitVecSym("z", 256)
a = symbol_factory.BitVecSym("a", 256)
b = symbol_factory.BitVecSym("b", 256)
conditions = [x > y, y == z, y != z, a == b]
solver = IndependenceSolver()
# Act
solver.add(*conditions)
result = solver.check()
# Assert
assert z3.unsat == result
def test_independence_solver_unsat_in_second_bucket():
# Arrange
x = symbol_factory.BitVecSym("x", 256)
y = symbol_factory.BitVecSym("y", 256)
z = symbol_factory.BitVecSym("z", 256)
a = symbol_factory.BitVecSym("a", 256)
b = symbol_factory.BitVecSym("b", 256)
conditions = [x > y, y == z, a == b, a != b]
solver = IndependenceSolver()
# Act
solver.add(*conditions)
result = solver.check()
# Assert
assert z3.unsat == result
def test_independence_solver_sat():
# Arrange
x = symbol_factory.BitVecSym("x", 256)
y = symbol_factory.BitVecSym("y", 256)
z = symbol_factory.BitVecSym("z", 256)
a = symbol_factory.BitVecSym("a", 256)
b = symbol_factory.BitVecSym("b", 256)
conditions = [x > y, y == z, a == b]
solver = IndependenceSolver()
# Act
solver.add(*conditions)
result = solver.check()
# Assert
assert z3.sat == result

@ -0,0 +1,56 @@
from mythril.laser.smt import Solver, symbol_factory
import z3
def test_decls():
# Arrange
solver = Solver()
x = symbol_factory.BitVecSym("x", 256)
expression = x == symbol_factory.BitVecVal(2, 256)
# Act
solver.add(expression)
result = solver.check()
model = solver.model()
decls = model.decls()
# Assert
assert z3.sat == result
assert x.raw.decl() in decls
def test_get_item():
# Arrange
solver = Solver()
x = symbol_factory.BitVecSym("x", 256)
expression = x == symbol_factory.BitVecVal(2, 256)
# Act
solver.add(expression)
result = solver.check()
model = solver.model()
x_concrete = model[x.raw.decl()]
# Assert
assert z3.sat == result
assert 2 == x_concrete
def test_as_long():
# Arrange
solver = Solver()
x = symbol_factory.BitVecSym("x", 256)
expression = x == symbol_factory.BitVecVal(2, 256)
# Act
solver.add(expression)
result = solver.check()
model = solver.model()
x_concrete = model.eval(x.raw).as_long()
# Assert
assert z3.sat == result
assert 2 == x_concrete

@ -48,7 +48,7 @@ def test_concrete_calldata_constrain_index():
value = calldata[2]
constraint = value == 3
solver.add([constraint])
solver.add(constraint)
result = solver.check()
# Assert
@ -65,7 +65,7 @@ def test_symbolic_calldata_constrain_index():
constraints = [value == 1, calldata.calldatasize == 50]
solver.add(constraints)
solver.add(*constraints)
result = solver.check()

Loading…
Cancel
Save