diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 86ded2ed..cc506ae1 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -18,7 +18,6 @@ from mythril.laser.smt.bitvec_helper import ( LShR, ) -from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index b308e863..22acc1c3 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -66,8 +66,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other + self if isinstance(other, int): return BitVec(self.raw + other, annotations=self.annotations) @@ -80,8 +78,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other - self if isinstance(other, int): return BitVec(self.raw - other, annotations=self.annotations) @@ -94,8 +90,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other * self union = self.annotations.union(other.annotations) return BitVec(self.raw * other.raw, annotations=union) @@ -105,8 +99,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other / self union = self.annotations.union(other.annotations) return BitVec(self.raw / other.raw, annotations=union) @@ -116,8 +108,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other & self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -129,8 +119,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other | self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -142,8 +130,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other ^ self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -155,8 +141,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other > self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -168,8 +152,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other < self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -204,8 +186,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other == self if not isinstance(other, BitVec): return Bool( cast(z3.BoolRef, self.raw == other), annotations=self.annotations @@ -224,8 +204,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other != self if not isinstance(other, BitVec): return Bool( cast(z3.BoolRef, self.raw != other), annotations=self.annotations @@ -244,8 +222,6 @@ class BitVec(Expression[z3.BitVecRef]): :param operator: The shift operator :return: the resulting output """ - if isinstance(other, BitVecFunc): - return operator(other, self) if not isinstance(other, BitVec): return BitVec( operator(self.raw, other), annotations=self.annotations @@ -275,7 +251,3 @@ class BitVec(Expression[z3.BitVecRef]): :return: """ return self.raw.__hash__() - - -# TODO: Fix circular import issues -from mythril.laser.smt.bitvecfunc import BitVecFunc diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index c1f60607..e2d2c54d 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -3,9 +3,6 @@ import z3 from mythril.laser.smt.bool import Bool, Or from mythril.laser.smt.bitvec import BitVec -from mythril.laser.smt.bitvecfunc import BitVecFunc -from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper -from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper Annotations = Set[Any] @@ -14,20 +11,12 @@ def _comparison_helper( a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool ) -> Bool: annotations = a.annotations.union(b.annotations) - if isinstance(a, BitVecFunc): - return _func_comparison_helper(a, b, operation, default_value, inputs_equal) return Bool(operation(a.raw, b.raw), annotations) def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: raw = operation(a.raw, b.raw) union = a.annotations.union(b.annotations) - - if isinstance(a, BitVecFunc): - return _func_arithmetic_helper(a, b, operation) - elif isinstance(b, BitVecFunc): - return _func_arithmetic_helper(b, a, operation) - return BitVec(raw, annotations=union) @@ -43,8 +32,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi :param c: :return: """ - # TODO: Handle BitVecFunc - if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) if not isinstance(b, BitVec): @@ -52,19 +39,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi if not isinstance(c, BitVec): c = BitVec(z3.BitVecVal(c, 256)) union = a.annotations.union(b.annotations).union(c.annotations) - - bvf = [] # type: List[BitVecFunc] - if isinstance(a, BitVecFunc): - bvf += [a] - if isinstance(b, BitVecFunc): - bvf += [b] - if isinstance(c, BitVecFunc): - bvf += [c] - if bvf: - raw = z3.If(a.raw, b.raw, c.raw) - nested_functions = [nf for func in bvf for nf in func.nested_functions] + bvf - return BitVecFunc(raw, func_name="Hybrid", nested_functions=nested_functions) - return BitVec(z3.If(a.raw, b.raw, c.raw), union) @@ -133,21 +107,8 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: nraw = z3.Concat([a.raw for a in bvs]) annotations = set() # type: Annotations - nested_functions = [] # type: List[BitVecFunc] for bv in bvs: annotations = annotations.union(bv.annotations) - if isinstance(bv, BitVecFunc): - nested_functions += bv.nested_functions - nested_functions += [bv] - - if nested_functions: - return BitVecFunc( - raw=nraw, - func_name="Hybrid", - input_=BitVec(z3.BitVec("", 256), annotations=annotations), - nested_functions=nested_functions, - ) - return BitVec(nraw, annotations) @@ -160,16 +121,6 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :return: """ raw = z3.Extract(high, low, bv.raw) - if isinstance(bv, BitVecFunc): - input_string = "" - # Is there a better value to set func_name and input to in this case? - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations), - nested_functions=bv.nested_functions + [bv], - ) - return BitVec(raw, annotations=bv.annotations) @@ -210,34 +161,9 @@ def Sum(*args: BitVec) -> BitVec: """ raw = z3.Sum([a.raw for a in args]) annotations = set() # type: Annotations - bitvecfuncs = [] for bv in args: annotations = annotations.union(bv.annotations) - if isinstance(bv, BitVecFunc): - bitvecfuncs.append(bv) - - nested_functions = [ - nf for func in bitvecfuncs for nf in func.nested_functions - ] + bitvecfuncs - - if len(bitvecfuncs) >= 2: - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=None, - annotations=annotations, - nested_functions=nested_functions, - ) - elif len(bitvecfuncs) == 1: - return BitVecFunc( - raw=raw, - func_name=bitvecfuncs[0].func_name, - input_=bitvecfuncs[0].input_, - annotations=annotations, - nested_functions=nested_functions, - ) - return BitVec(raw, annotations) diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py deleted file mode 100644 index e5bdfec4..00000000 --- a/mythril/laser/smt/bitvecfunc.py +++ /dev/null @@ -1,297 +0,0 @@ -import operator -from itertools import product -from typing import Optional, Union, cast, Callable, List -import z3 - -from mythril.laser.smt.bitvec import BitVec, Annotations, _padded_operation -from mythril.laser.smt.bool import Or, Bool, And - - -def _arithmetic_helper( - a: "BitVecFunc", b: Union[BitVec, int], operation: Callable -) -> "BitVecFunc": - """ - Helper function for arithmetic operations on BitVecFuncs. - - :param a: The BitVecFunc to perform the operation on. - :param b: A BitVec or int to perform the operation on. - :param operation: The arithmetic operation to perform. - :return: The resulting BitVecFunc - """ - if isinstance(b, int): - b = BitVec(z3.BitVecVal(b, a.size())) - - raw = operation(a.raw, b.raw) - union = a.annotations.union(b.annotations) - - if isinstance(b, BitVecFunc): - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=BitVec(z3.BitVec("", 256), annotations=union), - nested_functions=a.nested_functions + b.nested_functions + [a, b], - ) - - return BitVecFunc( - raw=raw, - func_name=a.func_name, - input_=a.input_, - annotations=union, - nested_functions=a.nested_functions + [a], - ) - - -def _comparison_helper( - a: "BitVecFunc", - b: Union[BitVec, int], - operation: Callable, - default_value: bool, - inputs_equal: bool, -) -> Bool: - """ - Helper function for comparison operations with BitVecFuncs. - - :param a: The BitVecFunc to compare. - :param b: A BitVec or int to compare to. - :param operation: The comparison operation to perform. - :return: The resulting Bool - """ - # Is there some hack for gt/lt comparisons? - if isinstance(b, int): - b = BitVec(z3.BitVecVal(b, a.size())) - union = a.annotations.union(b.annotations) - - if not a.symbolic and not b.symbolic: - if operation == z3.UGT: - operation = operator.gt - if operation == z3.ULT: - operation = operator.lt - return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) - if ( - not isinstance(b, BitVecFunc) - or not a.func_name - or not a.input_ - or not a.func_name == b.func_name - or str(operation) not in ("", "") - ): - return Bool(z3.BoolVal(default_value), annotations=union) - - condition = True - for a_nest, b_nest in product(a.nested_functions, b.nested_functions): - if a_nest.func_name != b_nest.func_name: - continue - if a_nest.func_name == "Hybrid": - continue - # a.input (eq/neq) b.input ==> a == b - if inputs_equal: - condition = z3.And( - condition, - z3.Or( - z3.Not((a_nest.input_ == b_nest.input_).raw), - (a_nest.raw == b_nest.raw), - ), - z3.Or( - z3.Not((a_nest.raw == b_nest.raw)), - (a_nest.input_ == b_nest.input_).raw, - ), - ) - else: - condition = z3.And( - condition, - z3.Or( - z3.Not((a_nest.input_ != b_nest.input_).raw), - (a_nest.raw == b_nest.raw), - ), - z3.Or( - z3.Not((a_nest.raw == b_nest.raw)), - (a_nest.input_ != b_nest.input_).raw, - ), - ) - - return And( - Bool( - cast(z3.BoolRef, _padded_operation(a.raw, b.raw, operation)), - annotations=union, - ), - Bool(condition) if b.nested_functions else Bool(True), - a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, - ) - - -class BitVecFunc(BitVec): - """A bit vector function symbol. Used in place of functions like sha3.""" - - def __init__( - self, - raw: z3.BitVecRef, - func_name: Optional[str], - input_: "BitVec" = None, - annotations: Optional[Annotations] = None, - nested_functions: Optional[List["BitVecFunc"]] = None, - ): - """ - - :param raw: The raw bit vector symbol - :param func_name: The function name. e.g. sha3 - :param input: The input to the functions - :param annotations: The annotations the BitVecFunc should start with - """ - - self.func_name = func_name - self.input_ = input_ - self.nested_functions = nested_functions or [] - self.nested_functions = list(dict.fromkeys(self.nested_functions)) - if isinstance(input_, BitVecFunc): - self.nested_functions.extend(input_.nested_functions) - super().__init__(raw, annotations) - - def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an addition expression. - - :param other: The int or BitVec to add to this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.add) - - def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create a subtraction expression. - - :param other: The int or BitVec to subtract from this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.sub) - - def __mul__(self, other: "BitVec") -> "BitVecFunc": - """Create a multiplication expression. - - :param other: The int or BitVec to multiply to this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.mul) - - def __truediv__(self, other: "BitVec") -> "BitVecFunc": - """Create a signed division expression. - - :param other: The int or BitVec to divide this BitVecFunc by - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.truediv) - - def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an and expression. - - :param other: The int or BitVec to and with this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.and_) - - def __or__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an or expression. - - :param other: The int or BitVec to or with this BitVecFunc - :return: The resulting BitVecFunc - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _arithmetic_helper(self, other, operator.or_) - - def __xor__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create a xor expression. - - :param other: The int or BitVec to xor with this BitVecFunc - :return: The resulting BitVecFunc - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _arithmetic_helper(self, other, operator.xor) - - def __lt__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed less than expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.lt, default_value=False, inputs_equal=False - ) - - def __gt__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed greater than expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.gt, default_value=False, inputs_equal=False - ) - - def __le__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed less than or equal to expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return Or(self < other, self == other) - - def __ge__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed greater than or equal to expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return Or(self > other, self == other) - - # MYPY: fix complains about overriding __eq__ - def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore - """Create an equality expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - return _comparison_helper( - self, other, operator.eq, default_value=False, inputs_equal=True - ) - - # MYPY: fix complains about overriding __ne__ - def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore - """Create an inequality expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.ne, default_value=True, inputs_equal=False - ) - - def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec": - """ - Left shift operation - :param other: The int or BitVec to shift on - :return The resulting left shifted output - """ - return _arithmetic_helper(self, other, operator.lshift) - - def __rshift__(self, other: Union[int, "BitVec"]) -> "BitVec": - """ - Right shift operation - :param other: The int or BitVec to shift on - :return The resulting right shifted output: - """ - return _arithmetic_helper(self, other, operator.rshift) - - def __hash__(self) -> int: - return self.raw.__hash__()