diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 6fee240e..3d9e52ef 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -4,7 +4,7 @@ This includes classes representing accounts and their storage. """ import logging from copy import copy, deepcopy -from typing import Any, Dict, Union, Tuple, Set, cast +from typing import Any, Dict, Union, Set from mythril.laser.smt import ( @@ -12,10 +12,7 @@ from mythril.laser.smt import ( K, BitVec, simplify, - BitVecFunc, - Extract, BaseArray, - Concat, ) from mythril.disassembler.disassembly import Disassembly from mythril.laser.smt import symbol_factory @@ -23,26 +20,6 @@ from mythril.laser.smt import symbol_factory log = logging.getLogger(__name__) -class StorageRegion: - def __getitem__(self, item): - raise NotImplementedError - - def __setitem__(self, key, value): - raise NotImplementedError - - -class ArrayStorageRegion(StorageRegion): - """ An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions""" - - pass - - -class IteStorageRegion(StorageRegion): - """ An IteStorageRegion is a storage region that uses Ite statements to implement a storage""" - - pass - - class Storage: """Storage class represents the storage of an Account.""" @@ -62,18 +39,8 @@ class Storage: self.storage_keys_loaded = set() # type: Set[int] self.address = address - @staticmethod - def _sanitize(input_: BitVec) -> BitVec: - if input_.size() == 512: - return input_ - if input_.size() > 512: - return Extract(511, 0, input_) - else: - return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_) - def __getitem__(self, item: BitVec) -> BitVec: storage = self._standard_storage - sanitized_item = item if ( self.address and self.address.value != 0 @@ -82,7 +49,7 @@ class Storage: and (self.dynld and self.dynld.storage_loading) ): try: - storage[sanitized_item] = symbol_factory.BitVecVal( + storage[item] = symbol_factory.BitVecVal( int( self.dynld.read_storage( contract_address="0x{:040X}".format(self.address.value), @@ -93,29 +60,14 @@ class Storage: 256, ) self.storage_keys_loaded.add(int(item.value)) - self.printable_storage[item] = storage[sanitized_item] + self.printable_storage[item] = storage[item] except ValueError as e: log.debug("Couldn't read storage at %s: %s", item, e) - return simplify(storage[sanitized_item]) - - @staticmethod - def get_map_index(key: BitVec) -> BitVec: - if ( - not isinstance(key, BitVecFunc) - or key.func_name != "keccak256" - or key.input_ is None - ): - return None - index = Extract(255, 0, key.input_) - return simplify(index) - - def _get_corresponding_storage(self, key: BitVec) -> BaseArray: - return self._standard_storage + return simplify(storage[item]) def __setitem__(self, key, value: Any) -> None: - storage = self._get_corresponding_storage(key) self.printable_storage[key] = value - storage[key] = value + self._standard_storage[key] = value if key.symbolic is False: self.storage_keys_loaded.add(int(key.value)) diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index cc506ae1..4dba9e6f 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -79,44 +79,6 @@ class SymbolFactory(Generic[T, U]): """ raise NotImplementedError() - @staticmethod - def BitVecFuncVal( - value: int, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value. - - :param value: The concrete value to set the bit vector to - :param func_name: The name of the bit vector function - :param size: The size of the bit vector - :param annotations: The annotations to initialize the bit vector with - :param input_: The input to the bit vector function - :return: The freshly created bit vector function - """ - raise NotImplementedError() - - @staticmethod - def BitVecFuncSym( - name: str, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value. - - :param name: The name of the symbolic bit vector - :param func_name: The name of the bit vector function - :param size: The size of the bit vector - :param annotations: The annotations to initialize the bit vector with - :param input_: The input to the bit vector function - :return: The freshly created bit vector function - """ - raise NotImplementedError() - class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): """ @@ -158,30 +120,6 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): raw = z3.BitVec(name, size) return BitVec(raw, annotations) - @staticmethod - def BitVecFuncVal( - value: int, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a concrete value.""" - raw = z3.BitVecVal(value, size) - return BitVecFunc(raw, func_name, input_, annotations) - - @staticmethod - def BitVecFuncSym( - name: str, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value.""" - raw = z3.BitVec(name, size) - return BitVecFunc(raw, func_name, input_, annotations) - class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): """ diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index e2d2c54d..a8fd8e8d 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -7,9 +7,7 @@ from mythril.laser.smt.bitvec import BitVec Annotations = Set[Any] -def _comparison_helper( - a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool -) -> Bool: +def _comparison_helper(a: BitVec, b: BitVec, operation: Callable) -> Bool: annotations = a.annotations.union(b.annotations) return Bool(operation(a.raw, b.raw), annotations) @@ -49,7 +47,7 @@ def UGT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) + return _comparison_helper(a, b, z3.UGT) def UGE(a: BitVec, b: BitVec) -> Bool: @@ -69,7 +67,7 @@ def ULT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) + return _comparison_helper(a, b, z3.ULT) def ULE(a: BitVec, b: BitVec) -> Bool: