diff --git a/mythril/analysis/solver.py b/mythril/analysis/solver.py index 89b3db6e..93e60106 100644 --- a/mythril/analysis/solver.py +++ b/mythril/analysis/solver.py @@ -1,8 +1,10 @@ """This module contains analysis module helpers to solve path constraints.""" -import logging -from typing import Dict, List, Tuple, Union + +from typing import Dict, List, Tuple, Union, Any import z3 +import logging + from z3 import FuncInterp from mythril.exceptions import UnsatError @@ -48,8 +50,10 @@ def pretty_print_model(model): def get_transaction_sequence( global_state: GlobalState, constraints: Constraints -) -> Dict: +) -> Dict[str, Any]: """Generate concrete transaction sequence. + Note: This function only considers the constraints in constraint argument, + which in some cases is expected to differ from global_state's constraints :param global_state: GlobalState to generate transaction sequence for :param constraints: list of constraints used to generate transaction sequence @@ -66,16 +70,19 @@ def get_transaction_sequence( model = get_model(tx_constraints, minimize=minimize) except UnsatError: raise UnsatError - # Include creation account in initial state - # Note: This contains the code, which should not exist until after the first tx - initial_world_state = transaction_sequence[0].world_state + + if isinstance(transaction_sequence[0], ContractCreationTransaction): + initial_world_state = transaction_sequence[0].prev_world_state + else: + initial_world_state = transaction_sequence[0].world_state + initial_accounts = initial_world_state.accounts for transaction in transaction_sequence: concrete_transaction = _get_concrete_transaction(model, transaction) concrete_transactions.append(concrete_transaction) - min_price_dict = {} # type: Dict[str, int] + min_price_dict: Dict[str, int] = {} for address in initial_accounts.keys(): min_price_dict[address] = model.eval( initial_world_state.starting_balances[ @@ -159,13 +166,15 @@ def _replace_with_actual_sha( ) -def _get_concrete_state(initial_accounts: Dict, min_price_dict: Dict[str, int]): +def _get_concrete_state( + initial_accounts: Dict, min_price_dict: Dict[str, int] +) -> Dict[str, Dict]: """Gets a concrete state""" accounts = {} for address, account in initial_accounts.items(): # Skip empty default account - data = dict() # type: Dict[str, Union[int, str]] + data: Dict[str, Union[int, str]] = {} data["nonce"] = account.nonce data["code"] = account.serialised_code() data["storage"] = str(account.storage) diff --git a/mythril/concolic/__init__.py b/mythril/concolic/__init__.py new file mode 100644 index 00000000..5f9fa523 --- /dev/null +++ b/mythril/concolic/__init__.py @@ -0,0 +1,2 @@ +from mythril.concolic.concolic_execution import concolic_execution +from mythril.concolic.find_trace import concrete_execution diff --git a/mythril/concolic/concolic_execution.py b/mythril/concolic/concolic_execution.py new file mode 100644 index 00000000..4dfed873 --- /dev/null +++ b/mythril/concolic/concolic_execution.py @@ -0,0 +1,85 @@ +import json +import binascii + +from datetime import datetime, timedelta +from typing import Dict, List, Any +from copy import deepcopy + +from mythril.concolic.concrete_data import ConcreteData +from mythril.concolic.find_trace import concrete_execution +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.ethereum.strategy.concolic import ConcolicStrategy +from mythril.laser.ethereum.svm import LaserEVM +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.state.account import Account +from mythril.laser.ethereum.transaction.symbolic import execute_transaction +from mythril.laser.ethereum.transaction.transaction_models import tx_id_manager +from mythril.laser.smt import Expression, BitVec, symbol_factory +from mythril.laser.ethereum.time_handler import time_handler +from mythril.support.support_args import args + + +def flip_branches( + init_state: WorldState, + concrete_data: ConcreteData, + jump_addresses: List[str], + trace: List, +) -> List[Dict[str, Dict[str, Any]]]: + """ + Flips branches and prints the input required for branch flip + + :param concrete_data: Concrete data + :param jump_addresses: Jump addresses to flip + :param trace: trace to follow + """ + tx_id_manager.restart_counter() + output_list = [] + laser_evm = LaserEVM( + execution_timeout=600, use_reachability_check=False, transaction_count=10 + ) + laser_evm.open_states = [deepcopy(init_state)] + laser_evm.strategy = ConcolicStrategy( + work_list=laser_evm.work_list, + max_depth=100, + trace=trace, + flip_branch_addresses=jump_addresses, + ) + + time_handler.start_execution(laser_evm.execution_timeout) + laser_evm.time = datetime.now() + + for transaction in concrete_data["steps"]: + execute_transaction( + laser_evm, + callee_address=transaction["address"], + caller_address=symbol_factory.BitVecVal( + int(transaction["origin"], 16), 256 + ), + data=transaction["input"][2:], + ) + + if laser_evm.strategy.results: + for addr in jump_addresses: + output_list.append(laser_evm.strategy.results[addr]) + return output_list + + +def concolic_execution( + concrete_data: ConcreteData, jump_addresses: List, solver_timeout=100000 +) -> List[Dict[str, Dict[str, Any]]]: + """ + Executes codes and prints input required to cover the branch flips + :param input_file: Input file + :param jump_addresses: Jump addresses to flip + :param solver_timeout: Solver timeout + + """ + init_state, trace = concrete_execution(concrete_data) + args.solver_timeout = solver_timeout + output_list = flip_branches( + init_state=init_state, + concrete_data=concrete_data, + jump_addresses=jump_addresses, + trace=trace, + ) + return output_list diff --git a/mythril/concolic/concrete_data.py b/mythril/concolic/concrete_data.py new file mode 100644 index 00000000..c2c4f02f --- /dev/null +++ b/mythril/concolic/concrete_data.py @@ -0,0 +1,34 @@ +from typing import Dict, List +from typing_extensions import TypedDict + + +class AccountData(TypedDict): + balance: str + code: str + nonce: int + storage: dict + + +class InitialState(TypedDict): + accounts: Dict[str, AccountData] + + +class TransactionData(TypedDict): + address: str + blockCoinbase: str + blockDifficulty: str + blockGasLimit: str + blockNumber: str + blockTime: str + calldata: str + gasLimit: str + gasPrice: str + input: str + name: str + origin: str + value: str + + +class ConcreteData(TypedDict): + initialState: InitialState + steps: List[TransactionData] diff --git a/mythril/concolic/find_trace.py b/mythril/concolic/find_trace.py new file mode 100644 index 00000000..9738de99 --- /dev/null +++ b/mythril/concolic/find_trace.py @@ -0,0 +1,78 @@ +import json +import binascii + +from copy import deepcopy +from datetime import datetime +from typing import Dict, List, Tuple + +from mythril.concolic.concrete_data import ConcreteData + +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.ethereum.svm import LaserEVM +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.state.account import Account +from mythril.laser.ethereum.transaction.concolic import execute_transaction +from mythril.laser.plugin.loader import LaserPluginLoader +from mythril.laser.smt import Expression, BitVec, symbol_factory +from mythril.laser.ethereum.transaction.transaction_models import tx_id_manager +from mythril.plugin.discovery import PluginDiscovery + + +def setup_concrete_initial_state(concrete_data: ConcreteData) -> WorldState: + """ + Sets up concrete initial state + :param concrete_data: Concrete data + :return: initialised world state + """ + world_state = WorldState() + for address, details in concrete_data["initialState"]["accounts"].items(): + account = Account(address, concrete_storage=True) + account.code = Disassembly(details["code"][2:]) + account.nonce = details["nonce"] + if type(details["storage"]) == str: + details["storage"] = eval(details["storage"]) # type: ignore + for key, value in details["storage"].items(): + key_bitvec = symbol_factory.BitVecVal(int(key, 16), 256) + account.storage[key_bitvec] = symbol_factory.BitVecVal(int(value, 16), 256) + + world_state.put_account(account) + account.set_balance(int(details["balance"], 16)) + return world_state + + +def concrete_execution(concrete_data: ConcreteData) -> Tuple[WorldState, List]: + """ + Executes code concretely to find the path to be followed by concolic executor + :param concrete_data: Concrete data + :return: path trace + """ + tx_id_manager.restart_counter() + init_state = setup_concrete_initial_state(concrete_data) + laser_evm = LaserEVM(execution_timeout=1000) + laser_evm.open_states = [deepcopy(init_state)] + plugin_loader = LaserPluginLoader() + assert PluginDiscovery().is_installed("myth_concolic_execution") + trace_plugin = PluginDiscovery().installed_plugins["myth_concolic_execution"]() + + plugin_loader.load(trace_plugin) + laser_evm.time = datetime.now() + plugin_loader.instrument_virtual_machine(laser_evm, None) + for transaction in concrete_data["steps"]: + execute_transaction( + laser_evm, + callee_address=transaction["address"], + caller_address=symbol_factory.BitVecVal( + int(transaction["origin"], 16), 256 + ), + origin_address=symbol_factory.BitVecVal( + int(transaction["origin"], 16), 256 + ), + gas_limit=int(transaction.get("gasLimit", "0x9999999999999999999999"), 16), + data=binascii.a2b_hex(transaction["input"][2:]), + gas_price=int(transaction.get("gasPrice", "0x773594000"), 16), + value=int(transaction["value"], 16), + track_gas=False, + ) + + tx_id_manager.restart_counter() + return init_state, plugin_loader.plugin_list["MythX Trace Finder"].tx_trace # type: ignore diff --git a/mythril/disassembler/disassembly.py b/mythril/disassembler/disassembly.py index 6c96803c..9f31f2a3 100644 --- a/mythril/disassembler/disassembly.py +++ b/mythril/disassembler/disassembly.py @@ -23,8 +23,10 @@ class Disassembly(object): :param enable_online_lookup: """ self.bytecode = code - self.instruction_list = asm.disassemble(util.safe_decode(code)) - + if type(code) == str: + self.instruction_list = asm.disassemble(util.safe_decode(code)) + else: + self.instruction_list = asm.disassemble(code) self.func_hashes = [] # type: List[str] self.function_name_to_address = {} # type: Dict[str, int] self.address_to_function_name = {} # type: Dict[int, str] diff --git a/mythril/exceptions.py b/mythril/exceptions.py index 3b3c968c..cb708213 100644 --- a/mythril/exceptions.py +++ b/mythril/exceptions.py @@ -39,3 +39,9 @@ class DetectorNotFoundError(MythrilBaseException): detection module.""" pass + + +class IllegalArgumentError(ValueError): + """The argument used does not exist""" + + pass diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index 2edb1603..8e260269 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -16,12 +16,15 @@ import traceback import mythril.support.signatures as sigs from argparse import ArgumentParser, Namespace, RawTextHelpFormatter +from mythril.concolic import concolic_execution from mythril.exceptions import ( DetectorNotFoundError, CriticalError, ) from mythril.laser.ethereum.transaction.symbolic import ACTORS +from mythril.plugin.discovery import PluginDiscovery from mythril.plugin.loader import MythrilPluginLoader + from mythril.mythril import MythrilAnalyzer, MythrilDisassembler, MythrilConfig from mythril.analysis.module import ModuleLoader @@ -34,6 +37,8 @@ _ = MythrilPluginLoader() ANALYZE_LIST = ("analyze", "a") DISASSEMBLE_LIST = ("disassemble", "d") + +CONCOLIC_LIST = ("concolic", "c") SAFE_FUNCTIONS_COMMAND = "safe-functions" READ_STORAGE_COMNAND = "read-storage" FUNCTION_TO_HASH_COMMAND = "function-to-hash" @@ -47,6 +52,7 @@ log = logging.getLogger(__name__) COMMAND_LIST = ( ANALYZE_LIST + DISASSEMBLE_LIST + + CONCOLIC_LIST + ( READ_STORAGE_COMNAND, SAFE_FUNCTIONS_COMMAND, @@ -203,6 +209,29 @@ def get_utilities_parser() -> ArgumentParser: return parser +def create_concolic_parser(parser: ArgumentParser) -> ArgumentParser: + """ + Get parser which handles arguments for concolic branch flipping + """ + parser.add_argument( + "input", + help="The input jsonv2 file with concrete data", + ) + parser.add_argument( + "--branches", + help="branch addresses to be flipped. usage: --branches 34,6f8,16a", + required=True, + metavar="BRANCH", + ) + parser.add_argument( + "--solver-timeout", + type=int, + default=100000, + help="The maximum amount of time(in milli seconds) the solver spends for queries from analysis modules", + ) + return parser + + def main() -> None: """The main CLI interface entry point.""" @@ -211,6 +240,7 @@ def main() -> None: runtime_input_parser = get_runtime_input_parser() creation_input_parser = get_creation_input_parser() output_parser = get_output_parser() + parser = argparse.ArgumentParser( description="Security analysis of Ethereum smart contracts" ) @@ -262,6 +292,16 @@ def main() -> None: ) create_disassemble_parser(disassemble_parser) + if PluginDiscovery().is_installed("myth_concolic_execution"): + concolic_parser = subparsers.add_parser( + CONCOLIC_LIST[0], + help="Runs concolic execution to flip the desired branches", + aliases=CONCOLIC_LIST[1:], + parents=[], + formatter_class=RawTextHelpFormatter, + ) + create_concolic_parser(concolic_parser) + subparsers.add_parser( LIST_DETECTORS_COMMAND, parents=[output_parser], @@ -936,6 +976,15 @@ def parse_args_and_execute(parser: ArgumentParser, args: Namespace) -> None: parser.print_help() sys.exit() + if args.command in CONCOLIC_LIST: + with open(args.input) as f: + concrete_data = json.load(f) + output_list = concolic_execution( + concrete_data, args.branches.split(","), args.solver_timeout + ) + json.dump(output_list, sys.stdout, indent=4) + sys.exit() + # Parse cmdline args validate_args(args) try: diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 40cf81cb..8037ef77 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -56,7 +56,7 @@ from mythril.laser.ethereum.transaction import ( MessageCallTransaction, TransactionStartSignal, ContractCreationTransaction, - get_next_transaction_id, + tx_id_manager, ) from mythril.support.support_utils import get_code_hash @@ -194,6 +194,7 @@ class StateTransition(object): new_global_states = [ self.accumulate_gas(state) for state in new_global_states ] + return self.increment_states_pc(new_global_states) return wrapper @@ -726,7 +727,6 @@ class Instruction: op1 = state.stack.pop() op2 = state.stack.pop() - if isinstance(op1, Bool): op1 = If( op1, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256) @@ -1599,6 +1599,7 @@ class Instruction: new_state.mstate.max_gas_used += max_gas # manually increment PC + new_state.mstate.depth += 1 new_state.mstate.pc += 1 new_state.world_state.constraints.append(negated) @@ -1606,8 +1607,6 @@ class Instruction: else: log.debug("Pruned unreachable states.") - # True case - # Get jump destination index = util.get_instruction_index(disassembly.instruction_list, jump_addr) @@ -1750,7 +1749,7 @@ class Instruction: code_str = bytes.hex(bytes(code_raw)) - next_transaction_id = get_next_transaction_id() + next_transaction_id = tx_id_manager.get_next_tx_id() constructor_arguments = ConcreteCalldata( next_transaction_id, call_data[code_end:] ) diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 53aff3f3..bc79c7bf 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -109,17 +109,18 @@ class Account: balances: Array = None, concrete_storage=False, dynamic_loader=None, + nonce=0, ) -> None: """Constructor for account. :param address: Address of the account :param code: The contract code of the account :param contract_name: The name associated with the account - :param balance: The balance for the account + :param balances: The balance for the account :param concrete_storage: Interpret storage as concrete """ self.concrete_storage = concrete_storage - self.nonce = 0 + self.nonce = nonce self.code = code or Disassembly("") self.address = ( address @@ -162,6 +163,16 @@ class Account: assert self._balances is not None self._balances[self.address] = balance + def set_storage(self, storage: Dict): + """ + Sets concrete storage + """ + for key, value in storage.items(): + concrete_key, concrete_value = int(key, 16), int(value, 16) + self.storage[ + symbol_factory.BitVecVal(concrete_key, 256) + ] = symbol_factory.BitVecVal(concrete_value, 256) + def add_balance(self, balance: Union[int, BitVec]) -> None: """ @@ -205,6 +216,7 @@ class Account: contract_name=self.contract_name, balances=self._balances, concrete_storage=self.concrete_storage, + nonce=self.nonce, ) new_account.storage = deepcopy(self.storage) new_account.code = self.code diff --git a/mythril/laser/ethereum/state/constraints.py b/mythril/laser/ethereum/state/constraints.py index 4ca59bf3..1c6a61ec 100644 --- a/mythril/laser/ethereum/state/constraints.py +++ b/mythril/laser/ethereum/state/constraints.py @@ -47,13 +47,6 @@ class Constraints(list): ) super(Constraints, self).append(constraint) - def pop(self, index: int = -1) -> None: - """ - - :param index: Index to be popped from the list - """ - raise NotImplementedError - @property def as_list(self) -> List[Bool]: """ diff --git a/mythril/laser/ethereum/state/machine_state.py b/mythril/laser/ethereum/state/machine_state.py index ac87eaf2..bf3c043e 100644 --- a/mythril/laser/ethereum/state/machine_state.py +++ b/mythril/laser/ethereum/state/machine_state.py @@ -107,7 +107,6 @@ class MachineState: depth=0, max_gas_used=0, min_gas_used=0, - prev_pc=-1, ) -> None: """Constructor for machineState. @@ -119,9 +118,8 @@ class MachineState: :param depth: :param max_gas_used: :param min_gas_used: - :param prev_pc: """ - self._pc = pc + self.pc = pc self.stack = MachineStack(stack) self.subroutine_stack = MachineStack(subroutine_stack) self.memory = memory or Memory() @@ -129,7 +127,6 @@ class MachineState: self.min_gas_used = min_gas_used # lower gas usage bound self.max_gas_used = max_gas_used # upper gas usage bound self.depth = depth - self.prev_pc = prev_pc # holds context of current pc def calculate_extension_size(self, start: int, size: int) -> int: """ @@ -225,11 +222,10 @@ class MachineState: gas_limit=self.gas_limit, max_gas_used=self.max_gas_used, min_gas_used=self.min_gas_used, - pc=self._pc, + pc=self.pc, stack=copy(self.stack), memory=copy(self.memory), depth=self.depth, - prev_pc=self.prev_pc, subroutine_stack=copy(self.subroutine_stack), ) @@ -240,19 +236,6 @@ class MachineState: """ return str(self.as_dict) - @property - def pc(self) -> int: - """ - - :return: - """ - return self._pc - - @pc.setter - def pc(self, value): - self.prev_pc = self._pc - self._pc = value - @property def memory_size(self) -> int: """ @@ -268,7 +251,7 @@ class MachineState: :return: """ return dict( - pc=self._pc, + pc=self.pc, stack=self.stack, subroutine_stack=self.subroutine_stack, memory=self.memory, @@ -276,5 +259,4 @@ class MachineState: gas=self.gas_limit, max_gas_used=self.max_gas_used, min_gas_used=self.min_gas_used, - prev_pc=self.prev_pc, ) diff --git a/mythril/laser/ethereum/state/world_state.py b/mythril/laser/ethereum/state/world_state.py index 017633cd..44237770 100644 --- a/mythril/laser/ethereum/state/world_state.py +++ b/mythril/laser/ethereum/state/world_state.py @@ -125,6 +125,7 @@ class WorldState: dynamic_loader=None, creator=None, code=None, + nonce=0, ) -> Account: """Create non-contract account. @@ -134,14 +135,21 @@ class WorldState: :param dynamic_loader: used for dynamically loading storage from the block chain :param creator: The address of the creator of the contract if it's a contract :param code: The code of the contract, if it's a contract + :param nonce: Nonce of the account :return: The new account """ + if creator in self.accounts: + nonce = self.accounts[creator].nonce + elif creator: + self.create_account(address=creator) + address = ( symbol_factory.BitVecVal(address, 256) - if address - else self._generate_new_address(creator) + if address is not None + else self._generate_new_address(creator, nonce=self.accounts[creator].nonce) ) - + if creator: + self.accounts[creator].nonce += 1 new_account = Account( address=address, balances=self.balances, @@ -150,7 +158,7 @@ class WorldState: ) if code: new_account.code = code - + new_account.nonce = nonce new_account.set_balance(symbol_factory.BitVecVal(balance, 256)) self.put_account(new_account) @@ -197,14 +205,14 @@ class WorldState: """ return filter(lambda x: isinstance(x, annotation_type), self.annotations) - def _generate_new_address(self, creator=None) -> BitVec: + def _generate_new_address(self, creator=None, nonce=0) -> BitVec: """Generates a new address for the global state. :return: """ if creator: # TODO: Use nounce - address = "0x" + str(generate_contract_address(creator, 0).hex()) + address = "0x" + str(generate_contract_address(creator, nonce).hex()) return symbol_factory.BitVecVal(int(address, 16), 256) while True: address = "0x" + "".join([str(hex(randint(0, 16)))[-1] for _ in range(40)]) diff --git a/mythril/laser/ethereum/strategy/__init__.py b/mythril/laser/ethereum/strategy/__init__.py index 140880b0..76e14725 100644 --- a/mythril/laser/ethereum/strategy/__init__.py +++ b/mythril/laser/ethereum/strategy/__init__.py @@ -4,9 +4,9 @@ from mythril.laser.ethereum.state.global_state import GlobalState class BasicSearchStrategy(ABC): - """""" - - __slots__ = "work_list", "max_depth" + """ + A basic search strategy which halts based on depth + """ def __init__(self, work_list, max_depth): self.work_list = work_list # type: List[GlobalState] @@ -26,5 +26,26 @@ class BasicSearchStrategy(ABC): if global_state.mstate.depth >= self.max_depth: return self.__next__() return global_state - except IndexError: + except (IndexError, StopIteration): + raise StopIteration + + +class CriterionSearchStrategy(BasicSearchStrategy): + """ + If a criterion is satisfied, the search halts + """ + + def __init__(self, work_list, max_depth): + super().__init__(work_list, max_depth) + self._satisfied_criterion = False + + def get_strategic_global_state(self): + if self._satisfied_criterion: + raise StopIteration + try: + global_state = self.get_strategic_global_state() + except StopIteration: raise StopIteration + + def set_criterion_satisfied(self): + self._satisfied_criterion = True diff --git a/mythril/laser/ethereum/strategy/basic.py b/mythril/laser/ethereum/strategy/basic.py index 5930f4e6..3bdd36ec 100644 --- a/mythril/laser/ethereum/strategy/basic.py +++ b/mythril/laser/ethereum/strategy/basic.py @@ -4,33 +4,7 @@ from typing import List from mythril.laser.ethereum.state.global_state import GlobalState from . import BasicSearchStrategy - -try: - from random import choices -except ImportError: - - # This is for supporting python versions < 3.6 - from itertools import accumulate - from random import random - from bisect import bisect - - # TODO: Remove ignore after this has been fixed: https://github.com/python/mypy/issues/1297 - def choices( # type: ignore - population: List, weights: List[int] = None - ) -> List[int]: - """Returns a random element out of the population based on weight. - - If the relative weights or cumulative weights are not specified, - the selections are made with equal probability. - """ - if weights is None: - return [population[int(random() * len(population))]] - cum_weights = list(accumulate(weights)) - return [ - population[ - bisect(cum_weights, random() * cum_weights[-1], 0, len(population) - 1) - ] - ] +from random import choices class DepthFirstSearchStrategy(BasicSearchStrategy): diff --git a/mythril/laser/ethereum/strategy/concolic.py b/mythril/laser/ethereum/strategy/concolic.py new file mode 100644 index 00000000..d3b5e9cc --- /dev/null +++ b/mythril/laser/ethereum/strategy/concolic.py @@ -0,0 +1,133 @@ +from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.state.constraints import Constraints +from mythril.laser.ethereum.strategy.basic import BasicSearchStrategy +from mythril.laser.ethereum.state.annotation import StateAnnotation +from mythril.laser.ethereum.transaction import ContractCreationTransaction +from mythril.laser.ethereum.util import get_instruction_index +from mythril.analysis.solver import get_transaction_sequence +from mythril.laser.smt import Not +from mythril.exceptions import UnsatError + +from functools import reduce +from typing import Dict, cast, List, Any, Tuple +from copy import copy +from . import CriterionSearchStrategy +import logging +import operator + +log = logging.getLogger(__name__) + + +class TraceAnnotation(StateAnnotation): + """ + This is the annotation used by the ConcolicStrategy to store concolic traces. + """ + + def __init__(self, trace=None): + self.trace = trace or [] + + @property + def persist_over_calls(self) -> bool: + return True + + def __copy__(self): + return TraceAnnotation(copy(self.trace)) + + +class ConcolicStrategy(CriterionSearchStrategy): + """ + Executes program concolically using the input trace till a specific branch + """ + + def __init__( + self, + work_list: List[GlobalState], + max_depth: int, + trace: List[List[Tuple[int, str]]], + flip_branch_addresses: List[str], + ): + """ + + work_list: The work-list of states + max_depth: The maximum depth for the strategy to go through + trace: This is the trace to be followed, each element is of the type Tuple(program counter, tx_id) + flip_branch_addresses: Branch addresses to be flipped. + """ + super().__init__(work_list, max_depth) + self.trace: List[Tuple[int, str]] = reduce(operator.iconcat, trace, []) + self.last_tx_count: int = len(trace) + self.flip_branch_addresses: List[str] = flip_branch_addresses + self.results: Dict[str, Dict[str, Any]] = {} + + def check_completion_criterion(self): + if len(self.flip_branch_addresses) == len(self.results): + self.set_criterion_satisfied() + + def get_strategic_global_state(self) -> GlobalState: + """ + This function does the following:- + 1) Choose the states by following the concolic trace. + 2) In case we have an executed JUMPI that is in flip_branch_addresses, flip that branch. + :return: + """ + while len(self.work_list) > 0: + state = self.work_list.pop() + seq_id = len(state.world_state.transaction_sequence) + + trace_annotations = cast( + List[TraceAnnotation], + list(state.world_state.get_annotations(TraceAnnotation)), + ) + + if len(trace_annotations) == 0: + annotation = TraceAnnotation() + state.world_state.annotate(annotation) + else: + annotation = trace_annotations[0] + + # Appends trace + annotation.trace.append((state.mstate.pc, state.current_transaction.id)) + + # If length of trace is 1 then it is not a JUMPI + if len(annotation.trace) < 2: + # If trace does not follow the specified path, ignore the state + if annotation.trace != self.trace[: len(annotation.trace)]: + continue + return state + + # Get the address of the previous pc + addr: str = str( + state.environment.code.instruction_list[annotation.trace[-2][0]][ + "address" + ] + ) + if ( + annotation.trace == self.trace[: len(annotation.trace)] + and seq_id == self.last_tx_count + and addr in self.flip_branch_addresses + and addr not in self.results + ): + if ( + state.environment.code.instruction_list[annotation.trace[-2][0]][ + "opcode" + ] + != "JUMPI" + ): + log.error( + f"The branch {addr} does not lead " + "to a jump address, skipping this branch" + ) + continue + + constraints = Constraints(state.world_state.constraints[:-1]) + constraints.append(Not(state.world_state.constraints[-1])) + + try: + self.results[addr] = get_transaction_sequence(state, constraints) + except UnsatError: + self.results[addr] = None + elif annotation.trace != self.trace[: len(annotation.trace)]: + continue + self.check_completion_criterion() + return state + raise StopIteration diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 85aa7434..cfe6c55b 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -60,6 +60,7 @@ class LaserEVM: transaction_count=2, requires_statespace=True, iprof=None, + use_reachability_check=True, ) -> None: """ Initializes the laser evm object @@ -73,14 +74,15 @@ class LaserEVM: :param requires_statespace: Variable indicating whether the statespace should be recorded :param iprof: Instruction Profiler """ - self.execution_info = [] # type: List[ExecutionInfo] + self.execution_info: List[ExecutionInfo] = [] - self.open_states = [] # type: List[WorldState] + self.open_states: List[WorldState] = [] self.total_states = 0 self.dynamic_loader = dynamic_loader + self.use_reachability_check = use_reachability_check # TODO: What about using a deque here? - self.work_list = [] # type: List[GlobalState] + self.work_list: List[GlobalState] = [] self.strategy = strategy(self.work_list, max_depth) self.max_depth = max_depth self.transaction_count = transaction_count @@ -90,30 +92,45 @@ class LaserEVM: self.requires_statespace = requires_statespace if self.requires_statespace: - self.nodes = {} # type: Dict[int, Node] - self.edges = [] # type: List[Edge] + self.nodes: Dict[int, Node] = {} + self.edges: List[Edge] = [] - self.time = None # type: datetime + self.time: datetime = None - self.pre_hooks = defaultdict(list) # type: DefaultDict[str, List[Callable]] - self.post_hooks = defaultdict(list) # type: DefaultDict[str, List[Callable]] + self.pre_hooks: DefaultDict[str, List[Callable]] = defaultdict(list) + self.post_hooks: DefaultDict[str, List[Callable]] = defaultdict(list) - self._add_world_state_hooks = [] # type: List[Callable] - self._execute_state_hooks = [] # type: List[Callable] + self._add_world_state_hooks: List[Callable] = [] + self._execute_state_hooks: List[Callable] = [] - self._start_sym_trans_hooks = [] # type: List[Callable] - self._stop_sym_trans_hooks = [] # type: List[Callable] + self._start_sym_trans_hooks: List[Callable] = [] + self._stop_sym_trans_hooks: List[Callable] = [] - self._start_sym_exec_hooks = [] # type: List[Callable] - self._stop_sym_exec_hooks = [] # type: List[Callable] + self._start_sym_exec_hooks: List[Callable] = [] + self._stop_sym_exec_hooks: List[Callable] = [] + + self._start_exec_hooks: List[Callable] = [] + self._stop_exec_hooks: List[Callable] = [] self._transaction_end_hooks: List[Callable] = [] + self.iprof = iprof - self.instr_pre_hook = {} # type: Dict[str, List[Callable]] - self.instr_post_hook = {} # type: Dict[str, List[Callable]] + self.instr_pre_hook: Dict[str, List[Callable]] = {} + self.instr_post_hook: Dict[str, List[Callable]] = {} for op in OPCODES: self.instr_pre_hook[op] = [] self.instr_post_hook[op] = [] + self.hook_type_map = { + "add_world_state": self._add_world_state_hooks, + "execute_state": self._execute_state_hooks, + "start_sym_exec": self._start_sym_exec_hooks, + "stop_sym_exec": self._stop_sym_exec_hooks, + "start_sym_trans": self._start_sym_trans_hooks, + "stop_sym_trans": self._stop_sym_trans_hooks, + "start_exec": self._start_exec_hooks, + "stop_exec": self._stop_exec_hooks, + "transaction_end": self._transaction_end_hooks, + } log.info("LASER EVM initialized with dynamic loader: " + str(dynamic_loader)) def extend_strategy(self, extension: ABCMeta, *args) -> None: @@ -200,12 +217,13 @@ class LaserEVM: if len(self.open_states) == 0: break old_states_count = len(self.open_states) - self.open_states = [ - state for state in self.open_states if state.constraints.is_possible - ] - prune_count = old_states_count - len(self.open_states) - if prune_count: - log.info("Pruned {} unreachable states".format(prune_count)) + if self.use_reachability_check: + self.open_states = [ + state for state in self.open_states if state.constraints.is_possible + ] + prune_count = old_states_count - len(self.open_states) + if prune_count: + log.info("Pruned {} unreachable states".format(prune_count)) log.info( "Starting message call transaction, iteration: {}, {} initial states".format( i, len(self.open_states) @@ -242,6 +260,8 @@ class LaserEVM: :return: """ final_states = [] # type: List[GlobalState] + for hook in self._start_exec_hooks: + hook() for global_state in self.strategy: if create and self._check_create_termination(): @@ -272,6 +292,9 @@ class LaserEVM: final_states.append(global_state) self.total_states += len(new_states) + for hook in self._stop_exec_hooks: + hook() + return final_states if track_gas else None def _add_world_state(self, global_state: GlobalState): @@ -581,20 +604,9 @@ class LaserEVM: def register_laser_hooks(self, hook_type: str, hook: Callable): """registers the hook with this Laser VM""" - if hook_type == "add_world_state": - self._add_world_state_hooks.append(hook) - elif hook_type == "execute_state": - self._execute_state_hooks.append(hook) - elif hook_type == "start_sym_exec": - self._start_sym_exec_hooks.append(hook) - elif hook_type == "stop_sym_exec": - self._stop_sym_exec_hooks.append(hook) - elif hook_type == "start_sym_trans": - self._start_sym_trans_hooks.append(hook) - elif hook_type == "stop_sym_trans": - self._stop_sym_trans_hooks.append(hook) - elif hook_type == "transaction_end": - self._transaction_end_hooks.append(hook) + + if hook_type in self.hook_type_map: + self.hook_type_map[hook_type].append(hook) else: raise ValueError(f"Invalid hook type {hook_type}") diff --git a/mythril/laser/ethereum/transaction/concolic.py b/mythril/laser/ethereum/transaction/concolic.py index 07d97816..dfc22d1d 100644 --- a/mythril/laser/ethereum/transaction/concolic.py +++ b/mythril/laser/ethereum/transaction/concolic.py @@ -1,27 +1,87 @@ """This module contains functions to set up and execute concolic message calls.""" +import binascii + from typing import List, Union +from copy import deepcopy +from mythril.exceptions import IllegalArgumentError from mythril.disassembler.disassembly import Disassembly from mythril.laser.ethereum.cfg import Node, Edge, JumpType +from mythril.laser.smt import symbol_factory +from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.calldata import ConcreteCalldata from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.transaction.transaction_models import ( MessageCallTransaction, - get_next_transaction_id, + ContractCreationTransaction, + tx_id_manager, ) +def execute_contract_creation( + laser_evm, + callee_address, + caller_address, + origin_address, + data, + gas_limit, + gas_price, + value, + code=None, + track_gas=False, + contract_name=None, +): + """Executes a contract creation transaction concretely. + + :param laser_evm: + :param callee_address: + :param caller_address: + :param origin_address: + :param code: + :param data: + :param gas_limit: + :param gas_price: + :param value: + :param track_gas: + :return: + """ + + open_states: List[WorldState] = laser_evm.open_states[:] + del laser_evm.open_states[:] + + data = binascii.b2a_hex(data).decode("utf-8") + + for open_world_state in open_states: + next_transaction_id = tx_id_manager.get_next_tx_id() + transaction = ContractCreationTransaction( + world_state=open_world_state, + identifier=next_transaction_id, + gas_price=gas_price, + gas_limit=gas_limit, # block gas limit + origin=origin_address, + code=Disassembly(data), + caller=caller_address, + contract_name=contract_name, + call_data=None, + call_value=value, + ) + _setup_global_state_for_execution(laser_evm, transaction) + + return laser_evm.exec(True, track_gas=track_gas) + + def execute_message_call( laser_evm, callee_address, caller_address, origin_address, - code, data, gas_limit, gas_price, value, + code=None, track_gas=False, ) -> Union[None, List[GlobalState]]: """Execute a message call transaction from all open states. @@ -38,12 +98,12 @@ def execute_message_call( :param track_gas: :return: """ - # TODO: Resolve circular import between .transaction and ..svm to import LaserEVM here - open_states = laser_evm.open_states[:] - del laser_evm.open_states[:] + open_states: List[WorldState] = laser_evm.open_states[:] + del laser_evm.open_states[:] for open_world_state in open_states: - next_transaction_id = get_next_transaction_id() + next_transaction_id = tx_id_manager.get_next_tx_id() + code = code or open_world_state[callee_address].code.bytecode transaction = MessageCallTransaction( world_state=open_world_state, identifier=next_transaction_id, @@ -94,3 +154,21 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None: global_state.node = new_node new_node.states.append(global_state) laser_evm.work_list.append(global_state) + + +def execute_transaction(*args, **kwargs) -> Union[None, List[GlobalState]]: + """ + Chooses the transaction type based on callee address and + executes the transaction + """ + try: + if kwargs["callee_address"] == "": + if kwargs["caller_address"] == "": + kwargs["caller_address"] = kwargs["origin"] + return execute_contract_creation(*args, **kwargs) + kwargs["callee_address"] = symbol_factory.BitVecVal( + int(kwargs["callee_address"], 16), 256 + ) + except KeyError as k: + raise IllegalArgumentError(f"Argument not found: {k}") + return execute_message_call(*args, **kwargs) diff --git a/mythril/laser/ethereum/transaction/symbolic.py b/mythril/laser/ethereum/transaction/symbolic.py index c4e67fca..c4307266 100644 --- a/mythril/laser/ethereum/transaction/symbolic.py +++ b/mythril/laser/ethereum/transaction/symbolic.py @@ -2,6 +2,7 @@ symbolic values.""" import logging from typing import Optional +from copy import deepcopy from mythril.disassembler.disassembly import Disassembly from mythril.laser.ethereum.cfg import Node, Edge, JumpType @@ -11,9 +12,10 @@ from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.transaction.transaction_models import ( MessageCallTransaction, ContractCreationTransaction, - get_next_transaction_id, + tx_id_manager, BaseTransaction, ) +from typing import List, Union from mythril.laser.smt import symbol_factory, Or, BitVec log = logging.getLogger(__name__) @@ -82,7 +84,7 @@ def execute_message_call(laser_evm, callee_address: BitVec) -> None: log.debug("Can not execute dead contract, skipping.") continue - next_transaction_id = get_next_transaction_id() + next_transaction_id = tx_id_manager.get_next_tx_id() external_sender = symbol_factory.BitVecSym( "sender_{}".format(next_transaction_id), 256 @@ -109,7 +111,12 @@ def execute_message_call(laser_evm, callee_address: BitVec) -> None: def execute_contract_creation( - laser_evm, contract_initialization_code, contract_name=None, world_state=None + laser_evm, + contract_initialization_code, + contract_name=None, + world_state=None, + origin=ACTORS["CREATOR"], + caller=ACTORS["CREATOR"], ) -> Account: """Executes a contract creation transaction from all open states. @@ -118,14 +125,13 @@ def execute_contract_creation( :param contract_name: :return: """ - # TODO: Resolve circular import between .transaction and ..svm to import LaserEVM here - del laser_evm.open_states[:] world_state = world_state or WorldState() open_states = [world_state] + del laser_evm.open_states[:] new_account = None for open_world_state in open_states: - next_transaction_id = get_next_transaction_id() + next_transaction_id = tx_id_manager.get_next_tx_id() # call_data "should" be '[]', but it is easier to model the calldata symbolically # and add logic in codecopy/codesize/calldatacopy/calldatasize than to model code "correctly" transaction = ContractCreationTransaction( @@ -135,9 +141,9 @@ def execute_contract_creation( "gas_price{}".format(next_transaction_id), 256 ), gas_limit=8000000, # block gas limit - origin=ACTORS["CREATOR"], + origin=origin, code=Disassembly(contract_initialization_code), - caller=ACTORS["CREATOR"], + caller=caller, contract_name=contract_name, call_data=None, call_value=symbol_factory.BitVecSym( @@ -189,3 +195,24 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) - global_state.node = new_node new_node.states.append(global_state) laser_evm.work_list.append(global_state) + + +def execute_transaction(*args, **kwargs): + """ + Chooses the transaction type based on callee address and + executes the transaction + """ + laser_evm = args[0] + if kwargs["callee_address"] == "": + for ws in laser_evm.open_states[:]: + execute_contract_creation( + laser_evm=laser_evm, + contract_initialization_code=kwargs["data"], + world_state=ws, + ) + return + + execute_message_call( + laser_evm=laser_evm, + callee_address=symbol_factory.BitVecVal(int(kwargs["callee_address"], 16), 256), + ) diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index ae8183dd..f354402e 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -4,7 +4,7 @@ execution.""" from copy import deepcopy from z3 import ExprRef from typing import Union, Optional - +from mythril.support.support_utils import Singleton from mythril.laser.ethereum.state.calldata import ConcreteCalldata from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.calldata import BaseCalldata, SymbolicCalldata @@ -16,17 +16,20 @@ import logging log = logging.getLogger(__name__) -_next_transaction_id = 0 +class TxIdManager(object, metaclass=Singleton): + def __init__(self): + self._next_transaction_id = 0 + + def get_next_tx_id(self): + self._next_transaction_id += 1 + return str(self._next_transaction_id) + + def restart_counter(self): + self._next_transaction_id = 0 -def get_next_transaction_id() -> str: - """ - :return: - """ - global _next_transaction_id - _next_transaction_id += 1 - return str(_next_transaction_id) +tx_id_manager = TxIdManager() class TransactionEndSignal(Exception): @@ -72,7 +75,7 @@ class BaseTransaction: ) -> None: assert isinstance(world_state, WorldState) self.world_state = world_state - self.id = identifier or get_next_transaction_id() + self.id = identifier or tx_id_manager.get_next_tx_id() self.gas_price = ( gas_price diff --git a/mythril/laser/plugin/loader.py b/mythril/laser/plugin/loader.py index 01b03fbc..f2cc1172 100644 --- a/mythril/laser/plugin/loader.py +++ b/mythril/laser/plugin/loader.py @@ -2,6 +2,7 @@ import logging from typing import Dict, List, Optional from mythril.laser.ethereum.svm import LaserEVM +from mythril.laser.plugin.interface import LaserPlugin from mythril.laser.plugin.builder import PluginBuilder from mythril.support.support_utils import Singleton @@ -18,6 +19,7 @@ class LaserPluginLoader(object, metaclass=Singleton): """Initializes the plugin loader""" self.laser_plugin_builders = {} # type: Dict[str, PluginBuilder] self.plugin_args = {} # type: Dict[str, Dict] + self.plugin_list = {} # type: Dict[str, LaserPlugin] def add_args(self, plugin_name, **kwargs): self.plugin_args[plugin_name] = kwargs @@ -70,3 +72,4 @@ class LaserPluginLoader(object, metaclass=Singleton): log.info(f"Instrumenting symbolic vm with plugin: {plugin_name}") plugin = plugin_builder(**self.plugin_args.get(plugin_name, {})) plugin.initialize(symbolic_vm) + self.plugin_list[plugin_name] = plugin diff --git a/mythril/mythril/mythril_disassembler.py b/mythril/mythril/mythril_disassembler.py index 36feaa29..b33a1044 100644 --- a/mythril/mythril/mythril_disassembler.py +++ b/mythril/mythril/mythril_disassembler.py @@ -89,6 +89,7 @@ class MythrilDisassembler: """ if address is None: address = util.get_indexed_address(0) + if bin_runtime: self.contracts.append( EVMContract( diff --git a/mythril/plugin/discovery.py b/mythril/plugin/discovery.py index 96eff676..f6d66905 100644 --- a/mythril/plugin/discovery.py +++ b/mythril/plugin/discovery.py @@ -36,7 +36,6 @@ class PluginDiscovery(object, metaclass=Singleton): raise ValueError(f"Plugin with name: `{plugin_name}` is not installed") plugin = self.installed_plugins.get(plugin_name) - if plugin is None or not issubclass(plugin, MythrilPlugin): raise ValueError(f"No valid plugin was found for {plugin_name}") diff --git a/mythril/plugin/loader.py b/mythril/plugin/loader.py index 3b611e60..5879f8ea 100644 --- a/mythril/plugin/loader.py +++ b/mythril/plugin/loader.py @@ -46,7 +46,6 @@ class MythrilPluginLoader(object, metaclass=Singleton): logging.info(f"Loading plugin: {plugin.name}") log.info(f"Loading plugin: {str(plugin)}") - if isinstance(plugin, DetectionModule): self._load_detection_module(plugin) elif isinstance(plugin, MythrilLaserPlugin):