diff --git a/mythril/analysis/modules/exceptions.py b/mythril/analysis/modules/exceptions.py index 804df576..903360e1 100644 --- a/mythril/analysis/modules/exceptions.py +++ b/mythril/analysis/modules/exceptions.py @@ -55,11 +55,9 @@ class ReachableExceptionsModule(DetectionModule): "Note that explicit `assert()` should only be used to check invariants. " "Use `require()` for regular input checking." ) - transaction_sequence = solver.get_transaction_sequence( state, state.mstate.constraints ) - issue = Issue( contract=state.environment.active_account.contract_name, function_name=state.environment.active_function_name, diff --git a/mythril/analysis/modules/suicide.py b/mythril/analysis/modules/suicide.py index cc2d27dc..e3e77360 100644 --- a/mythril/analysis/modules/suicide.py +++ b/mythril/analysis/modules/suicide.py @@ -9,7 +9,7 @@ from mythril.laser.ethereum.transaction.transaction_models import ( ContractCreationTransaction, ) import logging -import json + log = logging.getLogger(__name__) @@ -69,7 +69,6 @@ class SuicideModule(DetectionModule): for tx in state.world_state.transaction_sequence: if not isinstance(tx, ContractCreationTransaction): constraints.append(tx.caller == ACTORS.attacker) - try: try: transaction_sequence = solver.get_transaction_sequence( @@ -85,7 +84,6 @@ class SuicideModule(DetectionModule): state, state.mstate.constraints + constraints ) description_tail = "Arbitrary senders can kill this contract." - issue = Issue( contract=state.environment.active_account.contract_name, function_name=state.environment.active_function_name, diff --git a/mythril/analysis/solver.py b/mythril/analysis/solver.py index 00de6bd4..f0675e4d 100644 --- a/mythril/analysis/solver.py +++ b/mythril/analysis/solver.py @@ -1,12 +1,16 @@ """This module contains analysis module helpers to solve path constraints.""" from functools import lru_cache -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple, Union from z3 import sat, unknown, FuncInterp import z3 from mythril.analysis.analysis_args import analysis_args from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.constraints import Constraints +from mythril.laser.ethereum.keccak_function_manager import ( + keccak_function_manager, + hash_matcher, +) from mythril.laser.ethereum.transaction import BaseTransaction from mythril.laser.smt import UGE, Optimize, symbol_factory from mythril.laser.ethereum.time_handler import time_handler @@ -18,6 +22,7 @@ import logging log = logging.getLogger(__name__) + # LRU cache works great when used in powers of 2 @lru_cache(maxsize=2 ** 23) def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True): @@ -48,7 +53,6 @@ def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True s.minimize(e) for e in maximize: s.maximize(e) - result = s.check() if result == sat: return s.model() @@ -97,7 +101,6 @@ def get_transaction_sequence( tx_constraints, minimize = _set_minimisation_constraints( transaction_sequence, constraints.copy(), [], 5000, global_state.world_state ) - try: model = get_model(tx_constraints, minimize=minimize) except UnsatError: @@ -122,12 +125,59 @@ def get_transaction_sequence( ).as_long() concrete_initial_state = _get_concrete_state(initial_accounts, min_price_dict) - + if isinstance(transaction_sequence[0], ContractCreationTransaction): + code = transaction_sequence[0].code + _replace_with_actual_sha(concrete_transactions, model, code) + else: + _replace_with_actual_sha(concrete_transactions, model) steps = {"initialState": concrete_initial_state, "steps": concrete_transactions} return steps +def _replace_with_actual_sha( + concrete_transactions: List[Dict[str, str]], model: z3.Model, code=None +): + for tx in concrete_transactions: + if hash_matcher not in tx["input"]: + continue + if code is not None and code.bytecode in tx["input"]: + s_index = len(code.bytecode) + 2 + else: + s_index = 10 + for i in range(s_index, len(tx["input"])): + data_slice = tx["input"][i : i + 64] + if hash_matcher not in data_slice or len(data_slice) != 64: + continue + find_input = symbol_factory.BitVecVal(int(data_slice, 16), 256) + input_ = None + for size in keccak_function_manager.store_function: + _, inverse = keccak_function_manager.get_function(size) + try: + input_ = symbol_factory.BitVecVal( + model.eval(inverse(find_input).raw).as_long(), size + ) + except AttributeError: + continue + hex_input = hex(input_.value)[2:] + found = False + for new_tx in concrete_transactions: + if hex_input in new_tx["input"]: + found = True + break + if found: + break + if input_ is None: + continue + keccak = keccak_function_manager.find_concrete_keccak(input_) + hex_keccak = hex(keccak.value)[2:] + if len(hex_keccak) != 64: + hex_keccak = "0" * (64 - len(hex_keccak)) + hex_keccak + tx["input"] = tx["input"][:s_index] + tx["input"][s_index:].replace( + tx["input"][i : 64 + i], hex_keccak + ) + + def _get_concrete_state(initial_accounts: Dict, min_price_dict: Dict[str, int]): """ Gets a concrete state """ accounts = {} diff --git a/mythril/laser/ethereum/cfg.py b/mythril/laser/ethereum/cfg.py index 0578c602..c62f8857 100644 --- a/mythril/laser/ethereum/cfg.py +++ b/mythril/laser/ethereum/cfg.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Dict, List, TYPE_CHECKING +from mythril.laser.ethereum.state.constraints import Constraints from flags import Flags if TYPE_CHECKING: @@ -46,7 +47,7 @@ class Node: :param start_addr: :param constraints: """ - constraints = constraints if constraints else [] + constraints = constraints if constraints else Constraints() self.contract_name = contract_name self.start_addr = start_addr self.states = [] # type: List[GlobalState] diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 3dc72886..52a0dce0 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -34,6 +34,7 @@ from mythril.laser.ethereum.state.calldata import ConcreteCalldata, SymbolicCall import mythril.laser.ethereum.util as helper from mythril.laser.ethereum import util +from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager from mythril.laser.ethereum.call import get_call_parameters, native_call, get_call_data from mythril.laser.ethereum.evm_exceptions import ( VmException, @@ -982,7 +983,7 @@ class Instruction: if isinstance(op0, Expression): op0 = simplify(op0) state.stack.append( - symbol_factory.BitVecSym("KECCAC_mem[" + str(op0) + "]", 256) + symbol_factory.BitVecSym("KECCAC_mem[{}]".format(hash(op0)), 256) ) gas_tuple = get_opcode_gas("SHA3") state.min_gas_used += gas_tuple[0] @@ -996,40 +997,21 @@ class Instruction: b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8) for b in state.memory[index : index + length] ] + if len(data_list) > 1: data = simplify(Concat(data_list)) elif len(data_list) == 1: data = data_list[0] else: - # length is 0; this only matters for input of the BitVecFuncVal - data = symbol_factory.BitVecVal(0, 1) - - if data.symbolic: - - annotations = set() # type: Set[Any] - - for b in state.memory[index : index + length]: - if isinstance(b, BitVec): - annotations = annotations.union(b.annotations) - - argument_hash = hash(state.memory[index]) - result = symbol_factory.BitVecFuncSym( - "KECCAC[invhash({})]".format(hash(argument_hash)), - "keccak256", - 256, - input_=data, - annotations=annotations, - ) - log.debug("Created BitVecFunc hash.") - - else: - keccak = utils.sha3(data.value.to_bytes(length, byteorder="big")) - result = symbol_factory.BitVecFuncVal( - util.concrete_int_from_bytes(keccak, 0), "keccak256", 256, input_=data - ) - log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) + # TODO: handle finding x where func(x)==func("") + result = keccak_function_manager.get_empty_keccak_hash() + state.stack.append(result) + return [global_state] + result, condition = keccak_function_manager.create_keccak(data) state.stack.append(result) + state.constraints.append(condition) + return [global_state] @StateTransition() diff --git a/mythril/laser/ethereum/keccak_function_manager.py b/mythril/laser/ethereum/keccak_function_manager.py new file mode 100644 index 00000000..6d5262ee --- /dev/null +++ b/mythril/laser/ethereum/keccak_function_manager.py @@ -0,0 +1,117 @@ +from ethereum import utils +from mythril.laser.smt import ( + BitVec, + Function, + URem, + symbol_factory, + ULE, + And, + ULT, + Bool, + Or, +) +from typing import Dict, Tuple, List + +TOTAL_PARTS = 10 ** 40 +PART = (2 ** 256 - 1) // TOTAL_PARTS +INTERVAL_DIFFERENCE = 10 ** 30 +hash_matcher = "fffffff" # This is usually the prefix for the hash in the output + + +class KeccakFunctionManager: + """ + A bunch of uninterpreted functions are considered like keccak256_160 ,... + where keccak256_160 means the input of keccak256() is 160 bit number. + the range of these functions are constrained to some mutually disjoint intervals + All the hashes modulo 64 are 0 as we need a spread among hashes for array type data structures + All the functions are kind of one to one due to constraint of the existence of inverse + for each encountered input. + For more info https://files.sri.inf.ethz.ch/website/papers/sp20-verx.pdf + """ + + def __init__(self): + self.store_function = {} # type: Dict[int, Tuple[Function, Function]] + self.interval_hook_for_size = {} # type: Dict[int, int] + self._index_counter = TOTAL_PARTS - 34534 + self.quick_inverse = {} # type: Dict[BitVec, BitVec] # This is for VMTests + + @staticmethod + def find_concrete_keccak(data: BitVec) -> BitVec: + """ + Calculates concrete keccak + :param data: input bitvecval + :return: concrete keccak output + """ + keccak = symbol_factory.BitVecVal( + int.from_bytes( + utils.sha3(data.value.to_bytes(data.size() // 8, byteorder="big")), + "big", + ), + 256, + ) + return keccak + + def get_function(self, length: int) -> Tuple[Function, Function]: + """ + Returns the keccak functions for the corresponding length + :param length: input size + :return: tuple of keccak and it's inverse + """ + try: + func, inverse = self.store_function[length] + except KeyError: + func = Function("keccak256_{}".format(length), length, 256) + inverse = Function("keccak256_{}-1".format(length), 256, length) + self.store_function[length] = (func, inverse) + return func, inverse + + @staticmethod + def get_empty_keccak_hash() -> BitVec: + """ + returns sha3("") + :return: + """ + val = 89477152217924674838424037953991966239322087453347756267410168184682657981552 + return symbol_factory.BitVecVal(val, 256) + + def create_keccak(self, data: BitVec) -> Tuple[BitVec, Bool]: + """ + Creates Keccak of the data + :param data: input + :return: Tuple of keccak and the condition it should satisfy + """ + length = data.size() + func, inverse = self.get_function(length) + + condition = self._create_condition(func_input=data) + self.quick_inverse[func(data)] = data + return func(data), condition + + def _create_condition(self, func_input: BitVec) -> Bool: + """ + Creates the constraints for hash + :param func_input: input of the hash + :return: condition + """ + length = func_input.size() + func, inv = self.get_function(length) + try: + index = self.interval_hook_for_size[length] + except KeyError: + self.interval_hook_for_size[length] = self._index_counter + index = self._index_counter + self._index_counter -= INTERVAL_DIFFERENCE + + lower_bound = index * PART + upper_bound = lower_bound + PART + + cond = And( + inv(func(func_input)) == func_input, + ULE(symbol_factory.BitVecVal(lower_bound, 256), func(func_input)), + ULT(func(func_input), symbol_factory.BitVecVal(upper_bound, 256)), + URem(func(func_input), symbol_factory.BitVecVal(64, 256)) == 0, + ) + return cond + + +keccak_function_manager = KeccakFunctionManager() diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index c6afa0c3..6fee240e 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -55,7 +55,6 @@ class Storage: self._standard_storage = K(256, 256, 0) # type: BaseArray else: self._standard_storage = Array("Storage", 256, 256) - self._map_storage = {} # type: Dict[BitVec, BaseArray] self.printable_storage = {} # type: Dict[BitVec, BitVec] @@ -73,11 +72,8 @@ class Storage: return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_) def __getitem__(self, item: BitVec) -> BitVec: - storage, is_keccak_storage = self._get_corresponding_storage(item) - if is_keccak_storage: - sanitized_item = self._sanitize(cast(BitVecFunc, item).input_) - else: - sanitized_item = item + storage = self._standard_storage + sanitized_item = item if ( self.address and self.address.value != 0 @@ -100,7 +96,6 @@ class Storage: self.printable_storage[item] = storage[sanitized_item] except ValueError as e: log.debug("Couldn't read storage at %s: %s", item, e) - return simplify(storage[sanitized_item]) @staticmethod @@ -114,29 +109,12 @@ class Storage: index = Extract(255, 0, key.input_) return simplify(index) - def _get_corresponding_storage(self, key: BitVec) -> Tuple[BaseArray, bool]: - index = self.get_map_index(key) - if index is None: - storage = self._standard_storage - is_keccak_storage = False - else: - storage_map = self._map_storage - try: - storage = storage_map[index] - except KeyError: - if isinstance(self._standard_storage, Array): - storage_map[index] = Array("Storage", 512, 256) - else: - storage_map[index] = K(512, 256, 0) - storage = storage_map[index] - is_keccak_storage = True - return storage, is_keccak_storage + def _get_corresponding_storage(self, key: BitVec) -> BaseArray: + return self._standard_storage def __setitem__(self, key, value: Any) -> None: - storage, is_keccak_storage = self._get_corresponding_storage(key) + storage = self._get_corresponding_storage(key) self.printable_storage[key] = value - if is_keccak_storage: - key = self._sanitize(key.input_) storage[key] = value if key.symbolic is False: self.storage_keys_loaded.add(int(key.value)) @@ -147,7 +125,6 @@ class Storage: concrete=concrete, address=self.address, dynamic_loader=self.dynld ) storage._standard_storage = deepcopy(self._standard_storage) - storage._map_storage = deepcopy(self._map_storage) storage.printable_storage = copy(self.printable_storage) storage.storage_keys_loaded = copy(self.storage_keys_loaded) return storage diff --git a/mythril/laser/ethereum/state/constraints.py b/mythril/laser/ethereum/state/constraints.py index 1f675ca5..419bc873 100644 --- a/mythril/laser/ethereum/state/constraints.py +++ b/mythril/laser/ethereum/state/constraints.py @@ -35,6 +35,7 @@ class Constraints(list): """ :return: True/False based on the existence of solution of constraints """ + if self._is_possible is not None: return self._is_possible solver = Solver() @@ -109,8 +110,8 @@ class Constraints(list): :param constraints: :return: """ - constraints = self._get_smt_bool_list(constraints) - super(Constraints, self).__iadd__(constraints) + list_constraints = self._get_smt_bool_list(constraints) + super(Constraints, self).__iadd__(list_constraints) self._is_possible = None return self diff --git a/mythril/laser/ethereum/strategy/extensions/bounded_loops.py b/mythril/laser/ethereum/strategy/extensions/bounded_loops.py index 56c1f55b..5a7583ab 100644 --- a/mythril/laser/ethereum/strategy/extensions/bounded_loops.py +++ b/mythril/laser/ethereum/strategy/extensions/bounded_loops.py @@ -3,6 +3,7 @@ from mythril.laser.ethereum.strategy.basic import BasicSearchStrategy from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.transaction import ContractCreationTransaction from typing import Dict, cast, List +from copy import copy import logging @@ -16,7 +17,9 @@ class JumpdestCountAnnotation(StateAnnotation): self._reached_count = {} # type: Dict[str, int] def __copy__(self): - return self + result = JumpdestCountAnnotation() + result._reached_count = copy(self._reached_count) + return result class BoundedLoopsStrategy(BasicSearchStrategy): @@ -45,6 +48,7 @@ class BoundedLoopsStrategy(BasicSearchStrategy): :return: Global state """ + while True: state = self.super_strategy.get_strategic_global_state() @@ -56,7 +60,6 @@ class BoundedLoopsStrategy(BasicSearchStrategy): if len(annotations) == 0: annotation = JumpdestCountAnnotation() - log.debug("Adding JumpdestCountAnnotation to GlobalState") state.annotate(annotation) else: annotation = annotations[0] diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index fe679946..7967b0db 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -20,6 +20,7 @@ from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy from abc import ABCMeta from mythril.laser.ethereum.time_handler import time_handler + from mythril.laser.ethereum.transaction import ( ContractCreationTransaction, TransactionEndSignal, @@ -29,6 +30,7 @@ from mythril.laser.ethereum.transaction import ( ) from mythril.laser.smt import symbol_factory + log = logging.getLogger(__name__) @@ -206,6 +208,7 @@ class LaserEVM: i, len(self.open_states) ) ) + for hook in self._start_sym_trans_hooks: hook() @@ -245,7 +248,6 @@ class LaserEVM: except NotImplementedError: log.debug("Encountered unimplemented instruction") continue - new_states = [ state for state in new_states if state.mstate.constraints.is_possible ] @@ -357,16 +359,15 @@ class LaserEVM: ] log.debug("Ending transaction %s.", transaction) - if return_global_state is None: if ( not isinstance(transaction, ContractCreationTransaction) or transaction.return_data ) and not end_signal.revert: check_potential_issues(global_state) - end_signal.global_state.world_state.node = global_state.node self._add_world_state(end_signal.global_state) + new_global_states = [] else: # First execute the post hook for the transaction ending instruction diff --git a/mythril/laser/ethereum/transaction/concolic.py b/mythril/laser/ethereum/transaction/concolic.py index 164df8db..f995a584 100644 --- a/mythril/laser/ethereum/transaction/concolic.py +++ b/mythril/laser/ethereum/transaction/concolic.py @@ -88,6 +88,9 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None: condition=None, ) ) + global_state.mstate.constraints += transaction.world_state.node.constraints + new_node.constraints = global_state.mstate.constraints + global_state.world_state.transaction_sequence.append(transaction) global_state.node = new_node new_node.states.append(global_state) diff --git a/mythril/laser/ethereum/transaction/symbolic.py b/mythril/laser/ethereum/transaction/symbolic.py index 9e8fd558..09ee0173 100644 --- a/mythril/laser/ethereum/transaction/symbolic.py +++ b/mythril/laser/ethereum/transaction/symbolic.py @@ -185,7 +185,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) - ) global_state.mstate.constraints += transaction.world_state.node.constraints - new_node.constraints = global_state.mstate.constraints.as_list + new_node.constraints = global_state.mstate.constraints global_state.world_state.transaction_sequence.append(transaction) global_state.node = new_node diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 6ab752ce..86ded2ed 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -22,6 +22,7 @@ from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray +from mythril.laser.smt.function import Function from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics from mythril.laser.smt.model import Model from mythril.laser.smt.bool import Bool as SMTBool @@ -47,6 +48,16 @@ class SymbolFactory(Generic[T, U]): """ raise NotImplementedError + @staticmethod + def BoolSym(name: str, annotations: Annotations = None) -> T: + """ + Creates a boolean symbol + :param name: The name of the Bool variable + :param annotations: The annotations to initialize the bool with + :return: The freshly created Bool() + """ + raise NotImplementedError + @staticmethod def BitVecVal(value: int, size: int, annotations: Annotations = None) -> U: """Creates a new bit vector with a concrete value. @@ -125,6 +136,17 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): raw = z3.BoolVal(value) return SMTBool(raw, annotations) + @staticmethod + def BoolSym(name: str, annotations: Annotations = None) -> SMTBool: + """ + Creates a boolean symbol + :param name: The name of the Bool variable + :param annotations: The annotations to initialize the bool with + :return: The freshly created Bool() + """ + raw = z3.Bool(name) + return SMTBool(raw, annotations) + @staticmethod def BitVecVal(value: int, size: int, annotations: Annotations = None) -> BitVec: """Creates a new bit vector with a concrete value.""" diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index 8e68e0c9..c1f60607 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -1,8 +1,7 @@ -from typing import Union, overload, List, Set, cast, Any, Optional, Callable -from operator import lshift, rshift, ne, eq +from typing import Union, overload, List, Set, cast, Any, Callable import z3 -from mythril.laser.smt.bool import Bool, And, Or +from mythril.laser.smt.bool import Bool, Or from mythril.laser.smt.bitvec import BitVec from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper diff --git a/mythril/laser/smt/function.py b/mythril/laser/smt/function.py new file mode 100644 index 00000000..77451871 --- /dev/null +++ b/mythril/laser/smt/function.py @@ -0,0 +1,25 @@ +from typing import cast +import z3 + +from mythril.laser.smt.bitvec import BitVec + + +class Function: + """An uninterpreted function.""" + + def __init__(self, name: str, domain: int, value_range: int): + """Initializes an uninterpreted function. + + :param name: Name of the Function + :param domain: The domain for the Function (10 -> all the values that a bv of size 10 could take) + :param value_range: The range for the values of the function (10 -> all the values that a bv of size 10 could take) + """ + self.domain = z3.BitVecSort(domain) + self.range = z3.BitVecSort(value_range) + self.raw = z3.Function(name, self.domain, self.range) + + def __call__(self, item: BitVec) -> BitVec: + """Function accessor, item can be symbolic.""" + return BitVec( + cast(z3.BitVecRef, self.raw(item.raw)), annotations=item.annotations + ) diff --git a/mythril/laser/smt/solver/solver.py b/mythril/laser/smt/solver/solver.py index 2f3d4ac1..aa79dc85 100644 --- a/mythril/laser/smt/solver/solver.py +++ b/mythril/laser/smt/solver/solver.py @@ -45,14 +45,14 @@ class BaseSolver(Generic[T]): self.add(*constraints) @stat_smt_query - def check(self) -> z3.CheckSatResult: + def check(self, *args) -> z3.CheckSatResult: """Returns z3 smt check result. Also suppresses the stdout when running z3 library's check() to avoid unnecessary output :return: The evaluated result which is either of sat, unsat or unknown """ old_stdout = sys.stdout sys.stdout = open(os.devnull, "w") - evaluate = self.raw.check() + evaluate = self.raw.check(args) sys.stdout = old_stdout return evaluate diff --git a/tests/instructions/create_test.py b/tests/instructions/create_test.py index 0a97565e..04f0decd 100644 --- a/tests/instructions/create_test.py +++ b/tests/instructions/create_test.py @@ -1,6 +1,6 @@ from mythril.disassembler.disassembly import Disassembly +from mythril.laser.ethereum.cfg import Node from mythril.laser.ethereum.state.environment import Environment -from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.world_state import WorldState @@ -30,11 +30,12 @@ def execute_create(): calldata = ConcreteCalldata(0, code_raw) world_state = WorldState() + world_state.node = Node("Contract") account = world_state.create_account(balance=1000000, address=101) account.code = Disassembly("60a760006000f000") environment = Environment(account, None, calldata, None, None, None) og_state = GlobalState( - world_state, environment, None, MachineState(gas_limit=8000000) + world_state, environment, world_state.node, MachineState(gas_limit=8000000) ) og_state.transaction_stack.append( (MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None) diff --git a/tests/laser/evm_testsuite/evm_test.py b/tests/laser/evm_testsuite/evm_test.py index 244b5844..8bbca5ca 100644 --- a/tests/laser/evm_testsuite/evm_test.py +++ b/tests/laser/evm_testsuite/evm_test.py @@ -1,10 +1,10 @@ from mythril.laser.ethereum.svm import LaserEVM from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager from mythril.disassembler.disassembly import Disassembly from mythril.laser.ethereum.transaction.concolic import execute_message_call from mythril.laser.smt import Expression, BitVec, symbol_factory -from mythril.analysis.solver import get_model from datetime import datetime import binascii @@ -117,7 +117,6 @@ def test_vmtest( # Arrange if test_name in ignored_test_names: return - world_state = WorldState() for address, details in pre_condition.items(): @@ -178,7 +177,15 @@ def test_vmtest( expected = int(value, 16) actual = account.storage[symbol_factory.BitVecVal(int(index, 16), 256)] if isinstance(actual, Expression): - actual = actual.value + if ( + actual.symbolic + and actual in keccak_function_manager.quick_inverse + ): + actual = keccak_function_manager.find_concrete_keccak( + keccak_function_manager.quick_inverse[actual] + ) + else: + actual = actual.value actual = 1 if actual is True else 0 if actual is False else actual else: if type(actual) == bytes: diff --git a/tests/laser/keccak_tests.py b/tests/laser/keccak_tests.py new file mode 100644 index 00000000..8a3ae00e --- /dev/null +++ b/tests/laser/keccak_tests.py @@ -0,0 +1,138 @@ +from mythril.laser.smt import Solver, symbol_factory, And +from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager +import z3 +import pytest + + +@pytest.mark.parametrize( + "input1, input2, expected", + [ + (symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(101, 8), z3.unsat), + (symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(100, 16), z3.unsat), + (symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(100, 8), z3.sat), + ( + symbol_factory.BitVecSym("N1", 256), + symbol_factory.BitVecSym("N2", 256), + z3.sat, + ), + ( + symbol_factory.BitVecVal(100, 256), + symbol_factory.BitVecSym("N1", 256), + z3.sat, + ), + ( + symbol_factory.BitVecVal(100, 8), + symbol_factory.BitVecSym("N1", 256), + z3.unsat, + ), + ], +) +def test_keccak_basic(input1, input2, expected): + s = Solver() + + o1, c1 = keccak_function_manager.create_keccak(input1) + o2, c2 = keccak_function_manager.create_keccak(input2) + s.add(And(c1, c2)) + + s.add(o1 == o2) + assert s.check() == expected + + +def test_keccak_symbol_and_val(): + """ + check keccak(100) == keccak(n) && n == 10 + :return: + """ + s = Solver() + hundred = symbol_factory.BitVecVal(100, 256) + n = symbol_factory.BitVecSym("n", 256) + o1, c1 = keccak_function_manager.create_keccak(hundred) + o2, c2 = keccak_function_manager.create_keccak(n) + s.add(And(c1, c2)) + s.add(o1 == o2) + s.add(n == symbol_factory.BitVecVal(10, 256)) + assert s.check() == z3.unsat + + +def test_keccak_complex_eq(): + """ + check for keccak(keccak(b)*2) == keccak(keccak(a)*2) && a != b + :return: + """ + s = Solver() + a = symbol_factory.BitVecSym("a", 160) + b = symbol_factory.BitVecSym("b", 160) + o1, c1 = keccak_function_manager.create_keccak(a) + o2, c2 = keccak_function_manager.create_keccak(b) + s.add(And(c1, c2)) + two = symbol_factory.BitVecVal(2, 256) + o1 = two * o1 + o2 = two * o2 + o1, c1 = keccak_function_manager.create_keccak(o1) + o2, c2 = keccak_function_manager.create_keccak(o2) + + s.add(And(c1, c2)) + s.add(o1 == o2) + s.add(a != b) + + assert s.check() == z3.unsat + + +def test_keccak_complex_eq2(): + """ + check for keccak(keccak(b)*2) == keccak(keccak(a)*2) + This isn't combined with prev test because incremental solving here requires extra-extra work + (solution is literally the opposite of prev one) so it will take forever to solve. + :return: + """ + s = Solver() + a = symbol_factory.BitVecSym("a", 160) + b = symbol_factory.BitVecSym("b", 160) + o1, c1 = keccak_function_manager.create_keccak(a) + o2, c2 = keccak_function_manager.create_keccak(b) + s.add(And(c1, c2)) + two = symbol_factory.BitVecVal(2, 256) + o1 = two * o1 + o2 = two * o2 + o1, c1 = keccak_function_manager.create_keccak(o1) + o2, c2 = keccak_function_manager.create_keccak(o2) + + s.add(And(c1, c2)) + s.add(o1 == o2) + + assert s.check() == z3.sat + + +def test_keccak_simple_number(): + """ + check for keccak(b) == 10 + :return: + """ + s = Solver() + a = symbol_factory.BitVecSym("a", 160) + ten = symbol_factory.BitVecVal(10, 256) + o, c = keccak_function_manager.create_keccak(a) + + s.add(c) + s.add(ten == o) + + assert s.check() == z3.unsat + + +def test_keccak_other_num(): + """ + check keccak(keccak(a)*2) == b + :return: + """ + s = Solver() + a = symbol_factory.BitVecSym("a", 160) + b = symbol_factory.BitVecSym("b", 256) + o, c = keccak_function_manager.create_keccak(a) + two = symbol_factory.BitVecVal(2, 256) + o = two * o + s.add(c) + o, c = keccak_function_manager.create_keccak(o) + s.add(c) + s.add(b == o) + + assert s.check() == z3.sat