diff --git a/mythril/analysis/potential_issues.py b/mythril/analysis/potential_issues.py index 51be9799..1965edf6 100644 --- a/mythril/analysis/potential_issues.py +++ b/mythril/analysis/potential_issues.py @@ -54,6 +54,10 @@ class PotentialIssuesAnnotation(StateAnnotation): def __init__(self): self.potential_issues = [] + @property + def search_importance(self): + return 10 * len(self.potential_issues) + def get_potential_issues_annotation(state: GlobalState) -> PotentialIssuesAnnotation: """ diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index 3e5329e0..d7b7bc25 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -13,7 +13,7 @@ from mythril.laser.ethereum.strategy.basic import ( ReturnWeightedRandomStrategy, BasicSearchStrategy, ) - +from mythril.laser.ethereum.strategy.beam import BeamSearch from mythril.laser.ethereum.natives import PRECOMPILE_COUNT from mythril.laser.ethereum.transaction.symbolic import ACTORS @@ -82,7 +82,7 @@ class SymExecWrapper: address = symbol_factory.BitVecVal(int(address, 16), 256) if isinstance(address, int): address = symbol_factory.BitVecVal(address, 256) - + beam_width = None if strategy == "dfs": s_strategy = DepthFirstSearchStrategy # type: Type[BasicSearchStrategy] elif strategy == "bfs": @@ -91,6 +91,9 @@ class SymExecWrapper: s_strategy = ReturnRandomNaivelyStrategy elif strategy == "weighted-random": s_strategy = ReturnWeightedRandomStrategy + elif "beam-search: " in strategy: + beam_width = int(strategy.split("beam-search: ")[1]) + s_strategy = BeamSearch else: raise ValueError("Invalid strategy argument supplied") @@ -121,10 +124,13 @@ class SymExecWrapper: create_timeout=create_timeout, transaction_count=transaction_count, requires_statespace=requires_statespace, + beam_width=beam_width, ) if loop_bound is not None: - self.laser.extend_strategy(BoundedLoopsStrategy, loop_bound) + self.laser.extend_strategy( + BoundedLoopsStrategy, loop_bound=loop_bound, beam_width=beam_width + ) plugin_loader = LaserPluginLoader() plugin_loader.load(CoveragePluginBuilder()) diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index 27b1b226..d46729d7 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -437,6 +437,13 @@ def create_safe_functions_parser(parser: ArgumentParser): default="bfs", help="Symbolic execution strategy", ) + options.add_argument( + "--beam-search", + type=int, + default=None, + help="Beam search with with", + ) + options.add_argument( "-b", "--loop-bound", @@ -545,6 +552,12 @@ def create_analyzer_parser(analyzer_parser: ArgumentParser): default="bfs", help="Symbolic execution strategy", ) + options.add_argument( + "--beam-search", + type=int, + default=None, + help="Beam search with with", + ) options.add_argument( "-b", "--loop-bound", @@ -670,6 +683,7 @@ def validate_args(args: Namespace): exit_with_error( args.outform, "Invalid -v value, you can find valid values in usage" ) + if args.command in DISASSEMBLE_LIST and len(args.solidity_files) > 1: exit_with_error("text", "Only a single arg is supported for using disassemble") @@ -783,6 +797,10 @@ def execute_command( :param args: :return: """ + if args.__dict__.get("beam_search"): + strategy = f"beam-search: {args.beam_search}" + else: + strategy = args.__dict__.get("strategy") if args.command == "read-storage": storage = disassembler.get_state_variable_from_storage( @@ -799,7 +817,7 @@ def execute_command( elif args.command == SAFE_FUNCTIONS_COMMAND: function_analyzer = MythrilAnalyzer( - strategy=args.strategy, + strategy=strategy, disassembler=disassembler, address=address, max_depth=args.max_depth, @@ -834,7 +852,7 @@ def execute_command( elif args.command in ANALYZE_LIST: analyzer = MythrilAnalyzer( - strategy=args.strategy, + strategy=strategy, disassembler=disassembler, address=address, max_depth=args.max_depth, diff --git a/mythril/laser/ethereum/state/annotation.py b/mythril/laser/ethereum/state/annotation.py index 19f255e3..8ffbbcb9 100644 --- a/mythril/laser/ethereum/state/annotation.py +++ b/mythril/laser/ethereum/state/annotation.py @@ -36,6 +36,14 @@ class StateAnnotation: """ return False + @property + def search_importance(self) -> int: + """ + Used in estimating the priority of a state annotated with the corresponding annotation. + Default is 1 + """ + return 1 + class MergeableStateAnnotation(StateAnnotation): """This class allows a base annotation class for annotations that diff --git a/mythril/laser/ethereum/strategy/__init__.py b/mythril/laser/ethereum/strategy/__init__.py index 76e14725..18fb4932 100644 --- a/mythril/laser/ethereum/strategy/__init__.py +++ b/mythril/laser/ethereum/strategy/__init__.py @@ -8,7 +8,7 @@ class BasicSearchStrategy(ABC): A basic search strategy which halts based on depth """ - def __init__(self, work_list, max_depth): + def __init__(self, work_list, max_depth, **kwargs): self.work_list = work_list # type: List[GlobalState] self.max_depth = max_depth @@ -35,8 +35,8 @@ class CriterionSearchStrategy(BasicSearchStrategy): If a criterion is satisfied, the search halts """ - def __init__(self, work_list, max_depth): - super().__init__(work_list, max_depth) + def __init__(self, work_list, max_depth, **kwargs): + super().__init__(work_list, max_depth, **kwargs) self._satisfied_criterion = False def get_strategic_global_state(self): diff --git a/mythril/laser/ethereum/strategy/beam.py b/mythril/laser/ethereum/strategy/beam.py new file mode 100644 index 00000000..3e92a657 --- /dev/null +++ b/mythril/laser/ethereum/strategy/beam.py @@ -0,0 +1,31 @@ +from typing import List + +from mythril.laser.ethereum.state.global_state import GlobalState +from . import BasicSearchStrategy + + +class BeamSearch(BasicSearchStrategy): + """chooses a random state from the worklist with equal likelihood.""" + + def __init__(self, work_list, max_depth, beam_width, **kwargs): + super().__init__(work_list, max_depth) + self.beam_width = beam_width + + @staticmethod + def beam_priority(state): + return sum([annotation.search_importance for annotation in state._annotations]) + + def sort_and_eliminate_states(self): + self.work_list.sort(key=lambda state: self.beam_priority(state), reverse=True) + del self.work_list[self.beam_width :] + + def get_strategic_global_state(self) -> GlobalState: + """ + + :return: + """ + self.sort_and_eliminate_states() + if len(self.work_list) > 0: + return self.work_list.pop(0) + else: + raise IndexError diff --git a/mythril/laser/ethereum/strategy/extensions/bounded_loops.py b/mythril/laser/ethereum/strategy/extensions/bounded_loops.py index 85c6eb25..4d195f35 100644 --- a/mythril/laser/ethereum/strategy/extensions/bounded_loops.py +++ b/mythril/laser/ethereum/strategy/extensions/bounded_loops.py @@ -29,11 +29,11 @@ class BoundedLoopsStrategy(BasicSearchStrategy): Ignores JUMPI instruction if the destination was targeted >JUMPDEST_LIMIT times. """ - def __init__(self, super_strategy: BasicSearchStrategy, *args) -> None: + def __init__(self, super_strategy: BasicSearchStrategy, **kwargs) -> None: """""" self.super_strategy = super_strategy - self.bound = args[0][0] + self.bound = kwargs["loop_bound"] log.info( "Loaded search strategy extension: Loop bounds (limit = {})".format( @@ -42,7 +42,7 @@ class BoundedLoopsStrategy(BasicSearchStrategy): ) BasicSearchStrategy.__init__( - self, super_strategy.work_list, super_strategy.max_depth + self, super_strategy.work_list, super_strategy.max_depth, **kwargs ) @staticmethod diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index cfe6c55b..9c994832 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -61,6 +61,7 @@ class LaserEVM: requires_statespace=True, iprof=None, use_reachability_check=True, + beam_width=None, ) -> None: """ Initializes the laser evm object @@ -81,9 +82,8 @@ class LaserEVM: self.dynamic_loader = dynamic_loader self.use_reachability_check = use_reachability_check - # TODO: What about using a deque here? self.work_list: List[GlobalState] = [] - self.strategy = strategy(self.work_list, max_depth) + self.strategy = strategy(self.work_list, max_depth, beam_width=beam_width) self.max_depth = max_depth self.transaction_count = transaction_count @@ -133,8 +133,8 @@ class LaserEVM: } log.info("LASER EVM initialized with dynamic loader: " + str(dynamic_loader)) - def extend_strategy(self, extension: ABCMeta, *args) -> None: - self.strategy = extension(self.strategy, args) + def extend_strategy(self, extension: ABCMeta, **kwargs) -> None: + self.strategy = extension(self.strategy, **kwargs) def sym_exec( self, diff --git a/tests/instructions/berlin_fork_opcodes_test.py b/tests/instructions/berlin_fork_opcodes_test.py index 7c6ba133..af36ffc8 100644 --- a/tests/instructions/berlin_fork_opcodes_test.py +++ b/tests/instructions/berlin_fork_opcodes_test.py @@ -7,7 +7,6 @@ from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.ethereum.state.global_state import GlobalState -from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.instructions import Instruction from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction from mythril.laser.smt import symbol_factory, simplify diff --git a/tests/laser/strategy/test_beam.py b/tests/laser/strategy/test_beam.py new file mode 100644 index 00000000..fe5f3517 --- /dev/null +++ b/tests/laser/strategy/test_beam.py @@ -0,0 +1,108 @@ +import pytest +from mythril.laser.ethereum.strategy.beam import ( + BeamSearch, +) +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.ethereum.state.environment import Environment +from mythril.laser.ethereum.state.machine_state import MachineState +from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.analysis.potential_issues import PotentialIssuesAnnotation + +world_state = WorldState() +account = world_state.create_account(balance=10, address=101) +account.code = Disassembly("60606040") +environment = Environment(account, None, None, None, None, None, None) +potential_issues = PotentialIssuesAnnotation() +# It is a hassle to construct multiple issues +potential_issues.potential_issues = [0, 0] + + +@pytest.mark.parametrize( + "state, priority", + [ + ( + GlobalState( + world_state, + environment, + None, + MachineState(gas_limit=8000000), + annotations=[PotentialIssuesAnnotation()], + ), + 0, + ), + ( + GlobalState( + world_state, + environment, + None, + MachineState(gas_limit=8000000), + annotations=[potential_issues], + ), + 20, + ), + ], +) +def test_priority_sum(state, priority): + assert priority == BeamSearch.beam_priority(state) + + +@pytest.mark.parametrize( + "states, width", + [ + ( + [ + GlobalState( + world_state, + environment, + None, + MachineState(gas_limit=8000000), + annotations=[PotentialIssuesAnnotation()], + ), + GlobalState( + world_state, + environment, + None, + MachineState(gas_limit=8000000), + annotations=[potential_issues], + ), + ], + 1, + ), + ( + 100 + * [ + GlobalState( + world_state, + environment, + None, + MachineState(gas_limit=8000000), + annotations=[PotentialIssuesAnnotation()], + ) + ], + 1, + ), + ( + 100 + * [ + GlobalState( + world_state, + environment, + None, + MachineState(gas_limit=8000000), + annotations=[PotentialIssuesAnnotation()], + ) + ], + 0, + ), + ], +) +def test_elimination(states, width): + strategy = BeamSearch(states, max_depth=100, beam_width=width) + strategy.sort_and_eliminate_states() + + assert len(strategy.work_list) <= width + for i in range(len(strategy.work_list) - 1): + assert strategy.beam_priority(strategy.work_list[i]) >= strategy.beam_priority( + strategy.work_list[i + 1] + )