From 634d59caa537d2433aceb68360b890fe2023adab Mon Sep 17 00:00:00 2001 From: Nikhil Parasaram Date: Wed, 13 Nov 2019 13:56:22 +0000 Subject: [PATCH] Remove unused code (#1263) * Fix a regression * Handle new black format * Bye Bye BitVecFuncs * Bye Bye bitvecfunc tests * Bye Bye complex code --- mythril/laser/ethereum/state/account.py | 58 +---- mythril/laser/smt/__init__.py | 63 ----- mythril/laser/smt/bitvec.py | 28 --- mythril/laser/smt/bitvec_helper.py | 82 +------ mythril/laser/smt/bitvecfunc.py | 297 ------------------------ tests/laser/smt/bitvecfunc_test.py | 237 ------------------- 6 files changed, 8 insertions(+), 757 deletions(-) delete mode 100644 mythril/laser/smt/bitvecfunc.py delete mode 100644 tests/laser/smt/bitvecfunc_test.py diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 6fee240e..3d9e52ef 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -4,7 +4,7 @@ This includes classes representing accounts and their storage. """ import logging from copy import copy, deepcopy -from typing import Any, Dict, Union, Tuple, Set, cast +from typing import Any, Dict, Union, Set from mythril.laser.smt import ( @@ -12,10 +12,7 @@ from mythril.laser.smt import ( K, BitVec, simplify, - BitVecFunc, - Extract, BaseArray, - Concat, ) from mythril.disassembler.disassembly import Disassembly from mythril.laser.smt import symbol_factory @@ -23,26 +20,6 @@ from mythril.laser.smt import symbol_factory log = logging.getLogger(__name__) -class StorageRegion: - def __getitem__(self, item): - raise NotImplementedError - - def __setitem__(self, key, value): - raise NotImplementedError - - -class ArrayStorageRegion(StorageRegion): - """ An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions""" - - pass - - -class IteStorageRegion(StorageRegion): - """ An IteStorageRegion is a storage region that uses Ite statements to implement a storage""" - - pass - - class Storage: """Storage class represents the storage of an Account.""" @@ -62,18 +39,8 @@ class Storage: self.storage_keys_loaded = set() # type: Set[int] self.address = address - @staticmethod - def _sanitize(input_: BitVec) -> BitVec: - if input_.size() == 512: - return input_ - if input_.size() > 512: - return Extract(511, 0, input_) - else: - return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_) - def __getitem__(self, item: BitVec) -> BitVec: storage = self._standard_storage - sanitized_item = item if ( self.address and self.address.value != 0 @@ -82,7 +49,7 @@ class Storage: and (self.dynld and self.dynld.storage_loading) ): try: - storage[sanitized_item] = symbol_factory.BitVecVal( + storage[item] = symbol_factory.BitVecVal( int( self.dynld.read_storage( contract_address="0x{:040X}".format(self.address.value), @@ -93,29 +60,14 @@ class Storage: 256, ) self.storage_keys_loaded.add(int(item.value)) - self.printable_storage[item] = storage[sanitized_item] + self.printable_storage[item] = storage[item] except ValueError as e: log.debug("Couldn't read storage at %s: %s", item, e) - return simplify(storage[sanitized_item]) - - @staticmethod - def get_map_index(key: BitVec) -> BitVec: - if ( - not isinstance(key, BitVecFunc) - or key.func_name != "keccak256" - or key.input_ is None - ): - return None - index = Extract(255, 0, key.input_) - return simplify(index) - - def _get_corresponding_storage(self, key: BitVec) -> BaseArray: - return self._standard_storage + return simplify(storage[item]) def __setitem__(self, key, value: Any) -> None: - storage = self._get_corresponding_storage(key) self.printable_storage[key] = value - storage[key] = value + self._standard_storage[key] = value if key.symbolic is False: self.storage_keys_loaded.add(int(key.value)) diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 86ded2ed..4dba9e6f 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -18,7 +18,6 @@ from mythril.laser.smt.bitvec_helper import ( LShR, ) -from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray @@ -80,44 +79,6 @@ class SymbolFactory(Generic[T, U]): """ raise NotImplementedError() - @staticmethod - def BitVecFuncVal( - value: int, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value. - - :param value: The concrete value to set the bit vector to - :param func_name: The name of the bit vector function - :param size: The size of the bit vector - :param annotations: The annotations to initialize the bit vector with - :param input_: The input to the bit vector function - :return: The freshly created bit vector function - """ - raise NotImplementedError() - - @staticmethod - def BitVecFuncSym( - name: str, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value. - - :param name: The name of the symbolic bit vector - :param func_name: The name of the bit vector function - :param size: The size of the bit vector - :param annotations: The annotations to initialize the bit vector with - :param input_: The input to the bit vector function - :return: The freshly created bit vector function - """ - raise NotImplementedError() - class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): """ @@ -159,30 +120,6 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): raw = z3.BitVec(name, size) return BitVec(raw, annotations) - @staticmethod - def BitVecFuncVal( - value: int, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a concrete value.""" - raw = z3.BitVecVal(value, size) - return BitVecFunc(raw, func_name, input_, annotations) - - @staticmethod - def BitVecFuncSym( - name: str, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value.""" - raw = z3.BitVec(name, size) - return BitVecFunc(raw, func_name, input_, annotations) - class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): """ diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index b308e863..22acc1c3 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -66,8 +66,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other + self if isinstance(other, int): return BitVec(self.raw + other, annotations=self.annotations) @@ -80,8 +78,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other - self if isinstance(other, int): return BitVec(self.raw - other, annotations=self.annotations) @@ -94,8 +90,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other * self union = self.annotations.union(other.annotations) return BitVec(self.raw * other.raw, annotations=union) @@ -105,8 +99,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other / self union = self.annotations.union(other.annotations) return BitVec(self.raw / other.raw, annotations=union) @@ -116,8 +108,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other & self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -129,8 +119,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other | self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -142,8 +130,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other ^ self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -155,8 +141,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other > self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -168,8 +152,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other < self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -204,8 +186,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other == self if not isinstance(other, BitVec): return Bool( cast(z3.BoolRef, self.raw == other), annotations=self.annotations @@ -224,8 +204,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other != self if not isinstance(other, BitVec): return Bool( cast(z3.BoolRef, self.raw != other), annotations=self.annotations @@ -244,8 +222,6 @@ class BitVec(Expression[z3.BitVecRef]): :param operator: The shift operator :return: the resulting output """ - if isinstance(other, BitVecFunc): - return operator(other, self) if not isinstance(other, BitVec): return BitVec( operator(self.raw, other), annotations=self.annotations @@ -275,7 +251,3 @@ class BitVec(Expression[z3.BitVecRef]): :return: """ return self.raw.__hash__() - - -# TODO: Fix circular import issues -from mythril.laser.smt.bitvecfunc import BitVecFunc diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index c1f60607..a8fd8e8d 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -3,31 +3,18 @@ import z3 from mythril.laser.smt.bool import Bool, Or from mythril.laser.smt.bitvec import BitVec -from mythril.laser.smt.bitvecfunc import BitVecFunc -from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper -from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper Annotations = Set[Any] -def _comparison_helper( - a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool -) -> Bool: +def _comparison_helper(a: BitVec, b: BitVec, operation: Callable) -> Bool: annotations = a.annotations.union(b.annotations) - if isinstance(a, BitVecFunc): - return _func_comparison_helper(a, b, operation, default_value, inputs_equal) return Bool(operation(a.raw, b.raw), annotations) def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: raw = operation(a.raw, b.raw) union = a.annotations.union(b.annotations) - - if isinstance(a, BitVecFunc): - return _func_arithmetic_helper(a, b, operation) - elif isinstance(b, BitVecFunc): - return _func_arithmetic_helper(b, a, operation) - return BitVec(raw, annotations=union) @@ -43,8 +30,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi :param c: :return: """ - # TODO: Handle BitVecFunc - if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) if not isinstance(b, BitVec): @@ -52,19 +37,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi if not isinstance(c, BitVec): c = BitVec(z3.BitVecVal(c, 256)) union = a.annotations.union(b.annotations).union(c.annotations) - - bvf = [] # type: List[BitVecFunc] - if isinstance(a, BitVecFunc): - bvf += [a] - if isinstance(b, BitVecFunc): - bvf += [b] - if isinstance(c, BitVecFunc): - bvf += [c] - if bvf: - raw = z3.If(a.raw, b.raw, c.raw) - nested_functions = [nf for func in bvf for nf in func.nested_functions] + bvf - return BitVecFunc(raw, func_name="Hybrid", nested_functions=nested_functions) - return BitVec(z3.If(a.raw, b.raw, c.raw), union) @@ -75,7 +47,7 @@ def UGT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) + return _comparison_helper(a, b, z3.UGT) def UGE(a: BitVec, b: BitVec) -> Bool: @@ -95,7 +67,7 @@ def ULT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) + return _comparison_helper(a, b, z3.ULT) def ULE(a: BitVec, b: BitVec) -> Bool: @@ -133,21 +105,8 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: nraw = z3.Concat([a.raw for a in bvs]) annotations = set() # type: Annotations - nested_functions = [] # type: List[BitVecFunc] for bv in bvs: annotations = annotations.union(bv.annotations) - if isinstance(bv, BitVecFunc): - nested_functions += bv.nested_functions - nested_functions += [bv] - - if nested_functions: - return BitVecFunc( - raw=nraw, - func_name="Hybrid", - input_=BitVec(z3.BitVec("", 256), annotations=annotations), - nested_functions=nested_functions, - ) - return BitVec(nraw, annotations) @@ -160,16 +119,6 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :return: """ raw = z3.Extract(high, low, bv.raw) - if isinstance(bv, BitVecFunc): - input_string = "" - # Is there a better value to set func_name and input to in this case? - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations), - nested_functions=bv.nested_functions + [bv], - ) - return BitVec(raw, annotations=bv.annotations) @@ -210,34 +159,9 @@ def Sum(*args: BitVec) -> BitVec: """ raw = z3.Sum([a.raw for a in args]) annotations = set() # type: Annotations - bitvecfuncs = [] for bv in args: annotations = annotations.union(bv.annotations) - if isinstance(bv, BitVecFunc): - bitvecfuncs.append(bv) - - nested_functions = [ - nf for func in bitvecfuncs for nf in func.nested_functions - ] + bitvecfuncs - - if len(bitvecfuncs) >= 2: - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=None, - annotations=annotations, - nested_functions=nested_functions, - ) - elif len(bitvecfuncs) == 1: - return BitVecFunc( - raw=raw, - func_name=bitvecfuncs[0].func_name, - input_=bitvecfuncs[0].input_, - annotations=annotations, - nested_functions=nested_functions, - ) - return BitVec(raw, annotations) diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py deleted file mode 100644 index e5bdfec4..00000000 --- a/mythril/laser/smt/bitvecfunc.py +++ /dev/null @@ -1,297 +0,0 @@ -import operator -from itertools import product -from typing import Optional, Union, cast, Callable, List -import z3 - -from mythril.laser.smt.bitvec import BitVec, Annotations, _padded_operation -from mythril.laser.smt.bool import Or, Bool, And - - -def _arithmetic_helper( - a: "BitVecFunc", b: Union[BitVec, int], operation: Callable -) -> "BitVecFunc": - """ - Helper function for arithmetic operations on BitVecFuncs. - - :param a: The BitVecFunc to perform the operation on. - :param b: A BitVec or int to perform the operation on. - :param operation: The arithmetic operation to perform. - :return: The resulting BitVecFunc - """ - if isinstance(b, int): - b = BitVec(z3.BitVecVal(b, a.size())) - - raw = operation(a.raw, b.raw) - union = a.annotations.union(b.annotations) - - if isinstance(b, BitVecFunc): - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=BitVec(z3.BitVec("", 256), annotations=union), - nested_functions=a.nested_functions + b.nested_functions + [a, b], - ) - - return BitVecFunc( - raw=raw, - func_name=a.func_name, - input_=a.input_, - annotations=union, - nested_functions=a.nested_functions + [a], - ) - - -def _comparison_helper( - a: "BitVecFunc", - b: Union[BitVec, int], - operation: Callable, - default_value: bool, - inputs_equal: bool, -) -> Bool: - """ - Helper function for comparison operations with BitVecFuncs. - - :param a: The BitVecFunc to compare. - :param b: A BitVec or int to compare to. - :param operation: The comparison operation to perform. - :return: The resulting Bool - """ - # Is there some hack for gt/lt comparisons? - if isinstance(b, int): - b = BitVec(z3.BitVecVal(b, a.size())) - union = a.annotations.union(b.annotations) - - if not a.symbolic and not b.symbolic: - if operation == z3.UGT: - operation = operator.gt - if operation == z3.ULT: - operation = operator.lt - return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) - if ( - not isinstance(b, BitVecFunc) - or not a.func_name - or not a.input_ - or not a.func_name == b.func_name - or str(operation) not in ("", "") - ): - return Bool(z3.BoolVal(default_value), annotations=union) - - condition = True - for a_nest, b_nest in product(a.nested_functions, b.nested_functions): - if a_nest.func_name != b_nest.func_name: - continue - if a_nest.func_name == "Hybrid": - continue - # a.input (eq/neq) b.input ==> a == b - if inputs_equal: - condition = z3.And( - condition, - z3.Or( - z3.Not((a_nest.input_ == b_nest.input_).raw), - (a_nest.raw == b_nest.raw), - ), - z3.Or( - z3.Not((a_nest.raw == b_nest.raw)), - (a_nest.input_ == b_nest.input_).raw, - ), - ) - else: - condition = z3.And( - condition, - z3.Or( - z3.Not((a_nest.input_ != b_nest.input_).raw), - (a_nest.raw == b_nest.raw), - ), - z3.Or( - z3.Not((a_nest.raw == b_nest.raw)), - (a_nest.input_ != b_nest.input_).raw, - ), - ) - - return And( - Bool( - cast(z3.BoolRef, _padded_operation(a.raw, b.raw, operation)), - annotations=union, - ), - Bool(condition) if b.nested_functions else Bool(True), - a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, - ) - - -class BitVecFunc(BitVec): - """A bit vector function symbol. Used in place of functions like sha3.""" - - def __init__( - self, - raw: z3.BitVecRef, - func_name: Optional[str], - input_: "BitVec" = None, - annotations: Optional[Annotations] = None, - nested_functions: Optional[List["BitVecFunc"]] = None, - ): - """ - - :param raw: The raw bit vector symbol - :param func_name: The function name. e.g. sha3 - :param input: The input to the functions - :param annotations: The annotations the BitVecFunc should start with - """ - - self.func_name = func_name - self.input_ = input_ - self.nested_functions = nested_functions or [] - self.nested_functions = list(dict.fromkeys(self.nested_functions)) - if isinstance(input_, BitVecFunc): - self.nested_functions.extend(input_.nested_functions) - super().__init__(raw, annotations) - - def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an addition expression. - - :param other: The int or BitVec to add to this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.add) - - def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create a subtraction expression. - - :param other: The int or BitVec to subtract from this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.sub) - - def __mul__(self, other: "BitVec") -> "BitVecFunc": - """Create a multiplication expression. - - :param other: The int or BitVec to multiply to this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.mul) - - def __truediv__(self, other: "BitVec") -> "BitVecFunc": - """Create a signed division expression. - - :param other: The int or BitVec to divide this BitVecFunc by - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.truediv) - - def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an and expression. - - :param other: The int or BitVec to and with this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.and_) - - def __or__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an or expression. - - :param other: The int or BitVec to or with this BitVecFunc - :return: The resulting BitVecFunc - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _arithmetic_helper(self, other, operator.or_) - - def __xor__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create a xor expression. - - :param other: The int or BitVec to xor with this BitVecFunc - :return: The resulting BitVecFunc - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _arithmetic_helper(self, other, operator.xor) - - def __lt__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed less than expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.lt, default_value=False, inputs_equal=False - ) - - def __gt__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed greater than expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.gt, default_value=False, inputs_equal=False - ) - - def __le__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed less than or equal to expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return Or(self < other, self == other) - - def __ge__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed greater than or equal to expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return Or(self > other, self == other) - - # MYPY: fix complains about overriding __eq__ - def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore - """Create an equality expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - return _comparison_helper( - self, other, operator.eq, default_value=False, inputs_equal=True - ) - - # MYPY: fix complains about overriding __ne__ - def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore - """Create an inequality expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.ne, default_value=True, inputs_equal=False - ) - - def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec": - """ - Left shift operation - :param other: The int or BitVec to shift on - :return The resulting left shifted output - """ - return _arithmetic_helper(self, other, operator.lshift) - - def __rshift__(self, other: Union[int, "BitVec"]) -> "BitVec": - """ - Right shift operation - :param other: The int or BitVec to shift on - :return The resulting right shifted output: - """ - return _arithmetic_helper(self, other, operator.rshift) - - def __hash__(self) -> int: - return self.raw.__hash__() diff --git a/tests/laser/smt/bitvecfunc_test.py b/tests/laser/smt/bitvecfunc_test.py deleted file mode 100644 index 37217c73..00000000 --- a/tests/laser/smt/bitvecfunc_test.py +++ /dev/null @@ -1,237 +0,0 @@ -from mythril.laser.smt import Solver, symbol_factory, UGT, UGE, ULT, ULE -import z3 -import pytest - -import operator - - -@pytest.mark.parametrize( - "operation,expected", - [ - (operator.add, z3.unsat), - (operator.sub, z3.unsat), - (operator.and_, z3.sat), - (operator.or_, z3.sat), - (operator.xor, z3.unsat), - ], -) -def test_bitvecfunc_arithmetic(operation, expected): - # Arrange - s = Solver() - - input_ = symbol_factory.BitVecVal(1, 8) - bvf = symbol_factory.BitVecFuncSym("bvf", "sha3", 256, input_=input_) - - x = symbol_factory.BitVecSym("x", 256) - y = symbol_factory.BitVecSym("y", 256) - - # Act - s.add(x != y) - s.add(operation(bvf, x) == operation(y, bvf)) - - # Assert - assert s.check() == expected - - -@pytest.mark.parametrize( - "operation,expected", - [ - (operator.eq, z3.sat), - (operator.ne, z3.unsat), - (operator.lt, z3.unsat), - (operator.le, z3.sat), - (operator.gt, z3.unsat), - (operator.ge, z3.sat), - (UGT, z3.unsat), - (UGE, z3.sat), - (ULT, z3.unsat), - (ULE, z3.sat), - ], -) -def test_bitvecfunc_bitvecfunc_comparison(operation, expected): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=input2) - - # Act - s.add(operation(bvf1, bvf2)) - s.add(input1 == input2) - - # Assert - assert s.check() == expected - - -def test_bitvecfunc_bitvecfuncval_comparison(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecVal(1337, 256) - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncVal(12345678910, "sha3", 256, input_=input2) - - # Act - s.add(bvf1 == bvf2) - - # Assert - assert s.check() == z3.sat - assert s.model().eval(input2.raw) == 1337 - - -def test_bitvecfunc_nested_comparison(): - # arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) - - # Act - s.add(input1 == input2) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.sat - - -def test_bitvecfunc_unequal_nested_comparison(): - # arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) - - # Act - s.add(input1 != input2) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.unsat - - -def test_bitvecfunc_ext_nested_comparison(): - # arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - input3 = symbol_factory.BitVecSym("input3", 256) - input4 = symbol_factory.BitVecSym("input4", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) - - # Act - s.add(input1 == input2) - s.add(input3 == input4) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.sat - - -def test_bitvecfunc_ext_unequal_nested_comparison(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - input3 = symbol_factory.BitVecSym("input3", 256) - input4 = symbol_factory.BitVecSym("input4", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) - - # Act - s.add(input1 == input2) - s.add(input3 != input4) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.unsat - - -def test_bitvecfunc_ext_unequal_nested_comparison_f(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - input3 = symbol_factory.BitVecSym("input3", 256) - input4 = symbol_factory.BitVecSym("input4", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) - - # Act - s.add(input1 != input2) - s.add(input3 == input4) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.unsat - - -def test_bitvecfunc_find_input(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - - # Act - s.add(input1 == symbol_factory.BitVecVal(1, 256)) - s.add(bvf1 == bvf2) - - # Assert - assert s.check() == z3.sat - assert s.model()[input2.raw] == 1 - - -def test_bitvecfunc_nested_find_input(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) - - # Act - s.add(input1 == symbol_factory.BitVecVal(123, 256)) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.sat - assert s.model()[input2.raw] == 123