diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 5d2b9e93..b0c63a3c 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -31,6 +31,7 @@ from mythril.laser.smt import symbol_factory import mythril.laser.ethereum.util as helper from mythril.laser.ethereum import util +from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager from mythril.laser.ethereum.call import get_call_parameters, native_call from mythril.laser.ethereum.evm_exceptions import ( VmException, @@ -947,33 +948,10 @@ class Instruction: else: # length is 0; this only matters for input of the BitVecFuncVal data = symbol_factory.BitVecVal(0, 1) - - if data.symbolic: - - annotations = set() # type: Set[Any] - - for b in state.memory[index : index + length]: - if isinstance(b, BitVec): - annotations = annotations.union(b.annotations) - - argument_hash = hash(state.memory[index]) - result = symbol_factory.BitVecFuncSym( - "KECCAC[invhash({})]".format(hash(argument_hash)), - "keccak256", - 256, - input_=data, - annotations=annotations, - ) - log.debug("Created BitVecFunc hash.") - - else: - keccak = utils.sha3(data.value.to_bytes(length, byteorder="big")) - result = symbol_factory.BitVecFuncVal( - util.concrete_int_from_bytes(keccak, 0), "keccak256", 256, input_=data - ) - log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) - + result, constraints = keccak_function_manager.create_keccak(data, length) state.stack.append(result) + state.constraints += constraints + return [global_state] @StateTransition() diff --git a/mythril/laser/ethereum/keccak_function_manager.py b/mythril/laser/ethereum/keccak_function_manager.py new file mode 100644 index 00000000..ca7e406b --- /dev/null +++ b/mythril/laser/ethereum/keccak_function_manager.py @@ -0,0 +1,58 @@ +from ethereum import utils +from mythril.laser.smt import BitVec, Function, URem, symbol_factory, ULE, And, ULT, Or + +TOTAL_PARTS = 10 ** 40 +PART = (2 ** 256 - 1) // TOTAL_PARTS +INTERVAL_DIFFERENCE = 10 ** 30 + + +class KeccakFunctionManager: + def __init__(self): + self.sizes = {} + self.size_index = {} + self.index_counter = TOTAL_PARTS - 34534 + self.size_values = {} + + def create_keccak(self, data: BitVec, length: int): + length = length * 8 + assert length == data.size() + try: + func, inverse = self.sizes[length] + except KeyError: + func = Function("keccak256_{}".format(length), length, 256) + inverse = Function("keccak256_{}-1".format(length), 256, length) + self.sizes[length] = (func, inverse) + self.size_values[length] = [] + constraints = [] + if data.symbolic is False: + keccak = symbol_factory.BitVecVal( + utils.sha3(data.value.to_bytes(length // 8, byteorder="big")), 256 + ) + constraints.append(func(data) == keccak) + + constraints.append(inverse(func(data)) == data) + if data.symbolic is False: + return func(data), constraints + + constraints.append(URem(func(data), symbol_factory.BitVecVal(63, 256)) == 0) + try: + index = self.size_index[length] + except KeyError: + self.size_index[length] = self.index_counter + index = self.index_counter + self.index_counter -= INTERVAL_DIFFERENCE + + lower_bound = index * PART + upper_bound = (index + 1) * PART + + condition = And( + ULE(symbol_factory.BitVecVal(lower_bound, 256), func(data)), + ULT(func(data), symbol_factory.BitVecVal(upper_bound, 256)), + ) + for val in self.size_values[length]: + condition = Or(condition, func(data) == val) + constraints.append(condition) + return func(data), constraints + + +keccak_function_manager = KeccakFunctionManager() diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index c6afa0c3..6fee240e 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -55,7 +55,6 @@ class Storage: self._standard_storage = K(256, 256, 0) # type: BaseArray else: self._standard_storage = Array("Storage", 256, 256) - self._map_storage = {} # type: Dict[BitVec, BaseArray] self.printable_storage = {} # type: Dict[BitVec, BitVec] @@ -73,11 +72,8 @@ class Storage: return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_) def __getitem__(self, item: BitVec) -> BitVec: - storage, is_keccak_storage = self._get_corresponding_storage(item) - if is_keccak_storage: - sanitized_item = self._sanitize(cast(BitVecFunc, item).input_) - else: - sanitized_item = item + storage = self._standard_storage + sanitized_item = item if ( self.address and self.address.value != 0 @@ -100,7 +96,6 @@ class Storage: self.printable_storage[item] = storage[sanitized_item] except ValueError as e: log.debug("Couldn't read storage at %s: %s", item, e) - return simplify(storage[sanitized_item]) @staticmethod @@ -114,29 +109,12 @@ class Storage: index = Extract(255, 0, key.input_) return simplify(index) - def _get_corresponding_storage(self, key: BitVec) -> Tuple[BaseArray, bool]: - index = self.get_map_index(key) - if index is None: - storage = self._standard_storage - is_keccak_storage = False - else: - storage_map = self._map_storage - try: - storage = storage_map[index] - except KeyError: - if isinstance(self._standard_storage, Array): - storage_map[index] = Array("Storage", 512, 256) - else: - storage_map[index] = K(512, 256, 0) - storage = storage_map[index] - is_keccak_storage = True - return storage, is_keccak_storage + def _get_corresponding_storage(self, key: BitVec) -> BaseArray: + return self._standard_storage def __setitem__(self, key, value: Any) -> None: - storage, is_keccak_storage = self._get_corresponding_storage(key) + storage = self._get_corresponding_storage(key) self.printable_storage[key] = value - if is_keccak_storage: - key = self._sanitize(key.input_) storage[key] = value if key.symbolic is False: self.storage_keys_loaded.add(int(key.value)) @@ -147,7 +125,6 @@ class Storage: concrete=concrete, address=self.address, dynamic_loader=self.dynld ) storage._standard_storage = deepcopy(self._standard_storage) - storage._map_storage = deepcopy(self._map_storage) storage.printable_storage = copy(self.printable_storage) storage.storage_keys_loaded = copy(self.storage_keys_loaded) return storage diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 88b362c8..21e4c682 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -241,7 +241,6 @@ class LaserEVM: except NotImplementedError: log.debug("Encountered unimplemented instruction") continue - new_states = [ state for state in new_states if state.mstate.constraints.is_possible ] diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 6ab752ce..b0793562 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -22,6 +22,7 @@ from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray +from mythril.laser.smt.function import Function from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics from mythril.laser.smt.model import Model from mythril.laser.smt.bool import Bool as SMTBool diff --git a/mythril/laser/smt/function.py b/mythril/laser/smt/function.py new file mode 100644 index 00000000..93115c88 --- /dev/null +++ b/mythril/laser/smt/function.py @@ -0,0 +1,25 @@ +from typing import cast +import z3 + +from mythril.laser.smt.bitvec import BitVec + + +class Function: + """An uninterpreted function.""" + + def __init__(self, name: str, domain: int, value_range: int): + """Initializes an uninterpreted function. + + :param name: Name of the Function + :param domain: The domain for the Function (10 -> all the values that a bv of size 10 could take) + :param value_range: The range for the values of the function (10 -> all the values that a bv of size 10 could take) + """ + self.domain = z3.BitVecSort(domain) + self.range = z3.BitVecSort(value_range) + self.raw = z3.Function(name, self.domain, self.range) + + def __call__(self, item: BitVec) -> BitVec: + """Function accessor, item can be symbolic.""" + return BitVec( + cast(z3.BitVecRef, self.raw(item.raw)), annotations=item.annotations + ) # type: ignore