Merge pull request #1121 from ConsenSys/annotations_to_union

Refactor Bitvec annotations to sets
fix_integers
Bernhard Mueller 6 years ago committed by GitHub
commit beea142fb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      mythril/laser/ethereum/instructions.py
  2. 4
      mythril/laser/smt/__init__.py
  3. 46
      mythril/laser/smt/bitvec.py
  4. 4
      mythril/laser/smt/bitvecfunc.py
  5. 20
      mythril/laser/smt/bool.py
  6. 17
      mythril/laser/smt/expression.py

@ -4,7 +4,7 @@ import binascii
import logging import logging
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import cast, Callable, List, Union, Tuple from typing import cast, Callable, List, Set, Union, Tuple, Any
from datetime import datetime from datetime import datetime
from math import ceil from math import ceil
from ethereum import utils from ethereum import utils
@ -553,7 +553,7 @@ class Instruction:
+ str(hash(simplify(exponent))) + str(hash(simplify(exponent)))
+ ")", + ")",
256, 256,
base.annotations + exponent.annotations, base.annotations.union(exponent.annotations),
) )
) # Hash is used because str(symbol) takes a long time to be converted to a string ) # Hash is used because str(symbol) takes a long time to be converted to a string
else: else:
@ -562,7 +562,7 @@ class Instruction:
symbol_factory.BitVecVal( symbol_factory.BitVecVal(
pow(base.value, exponent.value, 2 ** 256), pow(base.value, exponent.value, 2 ** 256),
256, 256,
annotations=base.annotations + exponent.annotations, annotations=base.annotations.union(exponent.annotations),
) )
) )
@ -925,11 +925,11 @@ class Instruction:
if data.symbolic: if data.symbolic:
annotations = [] annotations = set() # type: Set[Any]
for b in state.memory[index : index + length]: for b in state.memory[index : index + length]:
if isinstance(b, BitVec): if isinstance(b, BitVec):
annotations += b.annotations annotations = annotations.union(b.annotations)
argument_str = str(state.memory[index]).replace(" ", "_") argument_str = str(state.memory[index]).replace(" ", "_")
result = symbol_factory.BitVecFuncSym( result = symbol_factory.BitVecFuncSym(

@ -22,11 +22,11 @@ from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics
from mythril.laser.smt.model import Model from mythril.laser.smt.model import Model
from typing import Union, Any, Optional, List, TypeVar, Generic from typing import Union, Any, Optional, Set, TypeVar, Generic
import z3 import z3
Annotations = Optional[List[Any]] Annotations = Optional[Set[Any]]
T = TypeVar("T", bound=Union[bool.Bool, z3.BoolRef]) T = TypeVar("T", bound=Union[bool.Bool, z3.BoolRef])
U = TypeVar("U", bound=Union[BitVec, z3.BitVecRef]) U = TypeVar("U", bound=Union[BitVec, z3.BitVecRef])

@ -1,13 +1,13 @@
"""This module provides classes for an SMT abstraction of bit vectors.""" """This module provides classes for an SMT abstraction of bit vectors."""
from typing import Union, overload, List, cast, Any, Optional, Callable from typing import Union, overload, List, Set, cast, Any, Optional, Callable
from operator import lshift, rshift from operator import lshift, rshift
import z3 import z3
from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.bool import Bool, And, Or
from mythril.laser.smt.expression import Expression from mythril.laser.smt.expression import Expression
Annotations = List[Any] Annotations = Set[Any]
# fmt: off # fmt: off
@ -61,7 +61,7 @@ class BitVec(Expression[z3.BitVecRef]):
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw + other, annotations=self.annotations) return BitVec(self.raw + other, annotations=self.annotations)
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return BitVec(self.raw + other.raw, annotations=union) return BitVec(self.raw + other.raw, annotations=union)
def __sub__(self, other: Union[int, "BitVec"]) -> "BitVec": def __sub__(self, other: Union[int, "BitVec"]) -> "BitVec":
@ -75,7 +75,7 @@ class BitVec(Expression[z3.BitVecRef]):
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw - other, annotations=self.annotations) return BitVec(self.raw - other, annotations=self.annotations)
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return BitVec(self.raw - other.raw, annotations=union) return BitVec(self.raw - other.raw, annotations=union)
def __mul__(self, other: "BitVec") -> "BitVec": def __mul__(self, other: "BitVec") -> "BitVec":
@ -86,7 +86,7 @@ class BitVec(Expression[z3.BitVecRef]):
""" """
if isinstance(other, BitVecFunc): if isinstance(other, BitVecFunc):
return other * self return other * self
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return BitVec(self.raw * other.raw, annotations=union) return BitVec(self.raw * other.raw, annotations=union)
def __truediv__(self, other: "BitVec") -> "BitVec": def __truediv__(self, other: "BitVec") -> "BitVec":
@ -97,7 +97,7 @@ class BitVec(Expression[z3.BitVecRef]):
""" """
if isinstance(other, BitVecFunc): if isinstance(other, BitVecFunc):
return other / self return other / self
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return BitVec(self.raw / other.raw, annotations=union) return BitVec(self.raw / other.raw, annotations=union)
def __and__(self, other: Union[int, "BitVec"]) -> "BitVec": def __and__(self, other: Union[int, "BitVec"]) -> "BitVec":
@ -110,7 +110,7 @@ class BitVec(Expression[z3.BitVecRef]):
return other & self return other & self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return BitVec(self.raw & other.raw, annotations=union) return BitVec(self.raw & other.raw, annotations=union)
def __or__(self, other: "BitVec") -> "BitVec": def __or__(self, other: "BitVec") -> "BitVec":
@ -121,7 +121,7 @@ class BitVec(Expression[z3.BitVecRef]):
""" """
if isinstance(other, BitVecFunc): if isinstance(other, BitVecFunc):
return other | self return other | self
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return BitVec(self.raw | other.raw, annotations=union) return BitVec(self.raw | other.raw, annotations=union)
def __xor__(self, other: "BitVec") -> "BitVec": def __xor__(self, other: "BitVec") -> "BitVec":
@ -132,7 +132,7 @@ class BitVec(Expression[z3.BitVecRef]):
""" """
if isinstance(other, BitVecFunc): if isinstance(other, BitVecFunc):
return other ^ self return other ^ self
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return BitVec(self.raw ^ other.raw, annotations=union) return BitVec(self.raw ^ other.raw, annotations=union)
def __lt__(self, other: "BitVec") -> Bool: def __lt__(self, other: "BitVec") -> Bool:
@ -143,7 +143,7 @@ class BitVec(Expression[z3.BitVecRef]):
""" """
if isinstance(other, BitVecFunc): if isinstance(other, BitVecFunc):
return other > self return other > self
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return Bool(self.raw < other.raw, annotations=union) return Bool(self.raw < other.raw, annotations=union)
def __gt__(self, other: "BitVec") -> Bool: def __gt__(self, other: "BitVec") -> Bool:
@ -154,7 +154,7 @@ class BitVec(Expression[z3.BitVecRef]):
""" """
if isinstance(other, BitVecFunc): if isinstance(other, BitVecFunc):
return other < self return other < self
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return Bool(self.raw > other.raw, annotations=union) return Bool(self.raw > other.raw, annotations=union)
def __le__(self, other: "BitVec") -> Bool: def __le__(self, other: "BitVec") -> Bool:
@ -163,7 +163,7 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return Bool(self.raw <= other.raw, annotations=union) return Bool(self.raw <= other.raw, annotations=union)
def __ge__(self, other: "BitVec") -> Bool: def __ge__(self, other: "BitVec") -> Bool:
@ -172,7 +172,7 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return Bool(self.raw >= other.raw, annotations=union) return Bool(self.raw >= other.raw, annotations=union)
# MYPY: fix complains about overriding __eq__ # MYPY: fix complains about overriding __eq__
@ -189,7 +189,7 @@ class BitVec(Expression[z3.BitVecRef]):
cast(z3.BoolRef, self.raw == other), annotations=self.annotations cast(z3.BoolRef, self.raw == other), annotations=self.annotations
) )
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
# MYPY: fix complaints due to z3 overriding __eq__ # MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union) return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union)
@ -207,7 +207,7 @@ class BitVec(Expression[z3.BitVecRef]):
cast(z3.BoolRef, self.raw != other), annotations=self.annotations cast(z3.BoolRef, self.raw != other), annotations=self.annotations
) )
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
# MYPY: fix complaints due to z3 overriding __eq__ # MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union)
@ -224,7 +224,7 @@ class BitVec(Expression[z3.BitVecRef]):
return BitVec( return BitVec(
operator(self.raw, other), annotations=self.annotations operator(self.raw, other), annotations=self.annotations
) )
union = self.annotations + other.annotations union = self.annotations.union(other.annotations)
return BitVec(operator(self.raw, other.raw), annotations=union) return BitVec(operator(self.raw, other.raw), annotations=union)
def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec": def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec":
@ -254,7 +254,7 @@ class BitVec(Expression[z3.BitVecRef]):
def _comparison_helper( def _comparison_helper(
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool: ) -> Bool:
annotations = a.annotations + b.annotations annotations = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc): if isinstance(a, BitVecFunc):
if not a.symbolic and not b.symbolic: if not a.symbolic and not b.symbolic:
return Bool(operation(a.raw, b.raw), annotations=annotations) return Bool(operation(a.raw, b.raw), annotations=annotations)
@ -277,7 +277,7 @@ def _comparison_helper(
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw) raw = operation(a.raw, b.raw)
union = a.annotations + b.annotations union = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
@ -313,7 +313,7 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
b = BitVec(z3.BitVecVal(b, 256)) b = BitVec(z3.BitVecVal(b, 256))
if not isinstance(c, BitVec): if not isinstance(c, BitVec):
c = BitVec(z3.BitVecVal(c, 256)) c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations + b.annotations + c.annotations union = a.annotations.union(b.annotations).union(c.annotations)
return BitVec(z3.If(a.raw, b.raw, c.raw), union) return BitVec(z3.If(a.raw, b.raw, c.raw), union)
@ -378,10 +378,10 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
bvs = cast(List[BitVec], args) bvs = cast(List[BitVec], args)
nraw = z3.Concat([a.raw for a in bvs]) nraw = z3.Concat([a.raw for a in bvs])
annotations = [] # type: Annotations annotations = set() # type: Annotations
bitvecfunc = False bitvecfunc = False
for bv in bvs: for bv in bvs:
annotations += bv.annotations annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc): if isinstance(bv, BitVecFunc):
bitvecfunc = True bitvecfunc = True
@ -448,11 +448,11 @@ def Sum(*args: BitVec) -> BitVec:
:return: :return:
""" """
raw = z3.Sum([a.raw for a in args]) raw = z3.Sum([a.raw for a in args])
annotations = [] # type: Annotations annotations = set() # type: Annotations
bitvecfuncs = [] bitvecfuncs = []
for bv in args: for bv in args:
annotations += bv.annotations annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc): if isinstance(bv, BitVecFunc):
bitvecfuncs.append(bv) bitvecfuncs.append(bv)

@ -23,7 +23,7 @@ def _arithmetic_helper(
b = BitVec(z3.BitVecVal(b, a.size())) b = BitVec(z3.BitVecVal(b, a.size()))
raw = operation(a.raw, b.raw) raw = operation(a.raw, b.raw)
union = a.annotations + b.annotations union = a.annotations.union(b.annotations)
if isinstance(b, BitVecFunc): if isinstance(b, BitVecFunc):
# TODO: Find better value to set input and name to in this case? # TODO: Find better value to set input and name to in this case?
@ -53,7 +53,7 @@ def _comparison_helper(
if isinstance(b, int): if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size())) b = BitVec(z3.BitVecVal(b, a.size()))
union = a.annotations + b.annotations union = a.annotations.union(b.annotations)
if not a.symbolic and not b.symbolic: if not a.symbolic and not b.symbolic:
return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union)

@ -1,7 +1,7 @@
"""This module provides classes for an SMT abstraction of boolean """This module provides classes for an SMT abstraction of boolean
expressions.""" expressions."""
from typing import Union, cast, List from typing import Union, cast, List, Set
import z3 import z3
@ -55,7 +55,7 @@ class Bool(Expression[z3.BoolRef]):
""" """
if isinstance(other, Expression): if isinstance(other, Expression):
return Bool(cast(z3.BoolRef, self.raw == other.raw), return Bool(cast(z3.BoolRef, self.raw == other.raw),
self.annotations + other.annotations) self.annotations.union(other.annotations))
return Bool(cast(z3.BoolRef, self.raw == other), self.annotations) return Bool(cast(z3.BoolRef, self.raw == other), self.annotations)
# MYPY: complains about overloading __ne__ # noqa # MYPY: complains about overloading __ne__ # noqa
@ -67,7 +67,7 @@ class Bool(Expression[z3.BoolRef]):
""" """
if isinstance(other, Expression): if isinstance(other, Expression):
return Bool(cast(z3.BoolRef, self.raw != other.raw), return Bool(cast(z3.BoolRef, self.raw != other.raw),
self.annotations + other.annotations) self.annotations.union(other.annotations))
return Bool(cast(z3.BoolRef, self.raw != other), self.annotations) return Bool(cast(z3.BoolRef, self.raw != other), self.annotations)
def __bool__(self) -> bool: def __bool__(self) -> bool:
@ -86,17 +86,17 @@ class Bool(Expression[z3.BoolRef]):
def And(*args: Union[Bool, bool]) -> Bool: def And(*args: Union[Bool, bool]) -> Bool:
"""Create an And expression.""" """Create an And expression."""
union = [] # type: List annotations = set() # type: Set
args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
for arg in args_list: for arg in args_list:
union += arg.annotations annotations = annotations.union(arg.annotations)
return Bool(z3.And([a.raw for a in args_list]), union) return Bool(z3.And([a.raw for a in args_list]), annotations)
def Xor(a: Bool, b: Bool) -> Bool: def Xor(a: Bool, b: Bool) -> Bool:
"""Create an And expression.""" """Create an And expression."""
union = a.annotations + b.annotations union = a.annotations.union(b.annotations)
return Bool(z3.Xor(a.raw, b.raw), union) return Bool(z3.Xor(a.raw, b.raw), union)
@ -108,10 +108,10 @@ def Or(*args: Union[Bool, bool]) -> Bool:
:return: :return:
""" """
args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
union = [] # type: List annotations = set() # type: Set
for arg in args_list: for arg in args_list:
union += arg.annotations annotations = annotations.union(arg.annotations)
return Bool(z3.Or([a.raw for a in args_list]), annotations=union) return Bool(z3.Or([a.raw for a in args_list]), annotations=annotations)
def Not(a: Bool) -> Bool: def Not(a: Bool) -> Bool:

@ -1,9 +1,9 @@
"""This module contains the SMT abstraction for a basic symbol expression.""" """This module contains the SMT abstraction for a basic symbol expression."""
from typing import Optional, List, Any, TypeVar, Generic, cast from typing import Optional, Set, Any, TypeVar, Generic, cast
import z3 import z3
Annotations = List[Any] Annotations = Set[Any]
T = TypeVar("T", bound=z3.ExprRef) T = TypeVar("T", bound=z3.ExprRef)
@ -18,7 +18,11 @@ class Expression(Generic[T]):
:param annotations: :param annotations:
""" """
self.raw = raw self.raw = raw
self._annotations = annotations or []
if annotations:
assert isinstance(annotations, set)
self._annotations = annotations or set()
@property @property
def annotations(self) -> Annotations: def annotations(self) -> Annotations:
@ -26,6 +30,7 @@ class Expression(Generic[T]):
:return: :return:
""" """
return self._annotations return self._annotations
def annotate(self, annotation: Any) -> None: def annotate(self, annotation: Any) -> None:
@ -33,10 +38,8 @@ class Expression(Generic[T]):
:param annotation: :param annotation:
""" """
if isinstance(annotation, list):
self._annotations += annotation self._annotations.add(annotation)
else:
self._annotations.append(annotation)
def simplify(self) -> None: def simplify(self) -> None:
"""Simplify this expression.""" """Simplify this expression."""

Loading…
Cancel
Save