diff --git a/mythril/laser/ethereum/iprof.py b/mythril/laser/ethereum/iprof.py deleted file mode 100644 index 83a49416..00000000 --- a/mythril/laser/ethereum/iprof.py +++ /dev/null @@ -1,78 +0,0 @@ -from collections import namedtuple -from datetime import datetime -from typing import Dict, List, Tuple - -# Type annotations: -# start_time: datetime -# end_time: datetime -_InstrExecRecord = namedtuple("_InstrExecRecord", ["start_time", "end_time"]) - -# Type annotations: -# total_time: float -# total_nr: float -# min_time: float -# max_time: float -_InstrExecStatistic = namedtuple( - "_InstrExecStatistic", ["total_time", "total_nr", "min_time", "max_time"] -) - -# Map the instruction opcode to its records if all execution times -_InstrExecRecords = Dict[str, List[_InstrExecRecord]] - -# Map the instruction opcode to the statistic of its execution times -_InstrExecStatistics = Dict[str, _InstrExecStatistic] - - -class InstructionProfiler: - """Performance profile for the execution of each instruction.""" - - def __init__(self): - self.records = dict() - - def record(self, op: int, start_time: datetime, end_time: datetime): - try: - self.records[op].append(_InstrExecRecord(start_time, end_time)) - except KeyError: - self.records[op] = [_InstrExecRecord(start_time, end_time)] - - def _make_stats(self) -> Tuple[float, _InstrExecStatistics]: - periods = { - op: list( - map(lambda r: r.end_time.timestamp() - r.start_time.timestamp(), rs) - ) - for op, rs in self.records.items() - } - - stats = dict() - total_time = 0 - - for _, (op, times) in enumerate(periods.items()): - stat = _InstrExecStatistic( - total_time=sum(times), - total_nr=len(times), - min_time=min(times), - max_time=max(times), - ) - total_time += stat.total_time - stats[op] = stat - - return total_time, stats - - def __str__(self): - total, stats = self._make_stats() - - s = "Total: {} s\n".format(total) - - for op in sorted(stats): - stat = stats[op] - s += "[{:12s}] {:>8.4f} %, nr {:>6}, total {:>8.4f} s, avg {:>8.4f} s, min {:>8.4f} s, max {:>8.4f} s\n".format( - op, - stat.total_time * 100 / total, - stat.total_nr, - stat.total_time, - stat.total_time / stat.total_nr, - stat.min_time, - stat.max_time, - ) - - return s diff --git a/mythril/laser/ethereum/state/global_state.py b/mythril/laser/ethereum/state/global_state.py index 22d5d414..6dd92aab 100644 --- a/mythril/laser/ethereum/state/global_state.py +++ b/mythril/laser/ethereum/state/global_state.py @@ -70,6 +70,7 @@ class GlobalState: mstate = deepcopy(self.mstate) transaction_stack = copy(self.transaction_stack) environment.active_account = world_state[environment.active_account.address] + return GlobalState( world_state, environment, diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 00213041..85aa7434 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -107,6 +107,7 @@ class LaserEVM: self._start_sym_exec_hooks = [] # type: List[Callable] self._stop_sym_exec_hooks = [] # type: List[Callable] + self._transaction_end_hooks: List[Callable] = [] self.iprof = iprof self.instr_pre_hook = {} # type: Dict[str, List[Callable]] self.instr_post_hook = {} # type: Dict[str, List[Callable]] @@ -312,8 +313,12 @@ class LaserEVM: :return: A list of successor states. """ # Execute hooks - for hook in self._execute_state_hooks: - hook(global_state) + try: + for hook in self._execute_state_hooks: + hook(global_state) + except PluginSkipState: + self._add_world_state(global_state) + return [], None instructions = global_state.environment.code.instruction_list try: @@ -374,6 +379,10 @@ class LaserEVM: ) = end_signal.global_state.transaction_stack[-1] log.debug("Ending transaction %s.", transaction) + + for hook in self._transaction_end_hooks: + hook(global_state, transaction, return_global_state, end_signal.revert) + if return_global_state is None: if ( not isinstance(transaction, ContractCreationTransaction) @@ -584,10 +593,10 @@ class LaserEVM: self._start_sym_trans_hooks.append(hook) elif hook_type == "stop_sym_trans": self._stop_sym_trans_hooks.append(hook) + elif hook_type == "transaction_end": + self._transaction_end_hooks.append(hook) else: - raise ValueError( - "Invalid hook type %s. Must be one of {add_world_state}", hook_type - ) + raise ValueError(f"Invalid hook type {hook_type}") def register_instr_hooks(self, hook_type: str, opcode: str, hook: Callable): """Registers instructions hooks from plugins""" diff --git a/mythril/laser/smt/array.py b/mythril/laser/smt/array.py index c2d658d2..732b292e 100644 --- a/mythril/laser/smt/array.py +++ b/mythril/laser/smt/array.py @@ -14,17 +14,32 @@ from mythril.laser.smt.bitvec import BitVec class BaseArray: """Base array type, which implements basic store and set operations.""" + def __init__(self, raw): + self.raw = raw + def __getitem__(self, item: BitVec) -> BitVec: """Gets item from the array, item can be symbolic.""" if isinstance(item, slice): raise ValueError( "Instance of BaseArray, does not support getitem with slices" ) - return BitVec(cast(z3.BitVecRef, z3.Select(self.raw, item.raw))) # type: ignore + return BitVec(cast(z3.BitVecRef, z3.Select(self.raw, item.raw))) def __setitem__(self, key: BitVec, value: BitVec) -> None: """Sets an item in the array, key can be symbolic.""" - self.raw = z3.Store(self.raw, key.raw, value.raw) # type: ignore + self.raw = z3.Store(self.raw, key.raw, value.raw) + + def substitute(self, original_expression: "BaseArray", new_expression: "BaseArray"): + """ + + :param original_expression: + :param new_expression: + """ + if self.raw is None: + return + original_z3 = original_expression.raw + new_z3 = new_expression.raw + self.raw = z3.substitute(self.raw, (original_z3, new_z3)) class Array(BaseArray): @@ -39,7 +54,7 @@ class Array(BaseArray): """ self.domain = z3.BitVecSort(domain) self.range = z3.BitVecSort(value_range) - self.raw = z3.Array(name, self.domain, self.range) + super(Array, self).__init__(z3.Array(name, self.domain, self.range)) class K(BaseArray): diff --git a/mythril/laser/smt/bool.py b/mythril/laser/smt/bool.py index bc5f369a..b98e10c4 100644 --- a/mythril/laser/smt/bool.py +++ b/mythril/laser/smt/bool.py @@ -7,7 +7,6 @@ import z3 from mythril.laser.smt.expression import Expression - # fmt: off @@ -80,6 +79,18 @@ class Bool(Expression[z3.BoolRef]): else: return False + def substitute(self, original_expression, new_expression): + """ + + :param original_expression: + :param new_expression: + """ + if self.raw is None: + return + original_z3 = original_expression.raw + new_z3 = new_expression.raw + self.raw = z3.substitute(self.raw, (original_z3, new_z3)) + def __hash__(self) -> int: return self.raw.__hash__()