Beam search (#1606)

* Init config dir

* Fix solc optimizer

* Add beam search

* Use dict over reference
pull/1608/head
Nikhil Parasaram 3 years ago committed by GitHub
parent 8fbe2e2748
commit 7d3f9b5842
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      mythril/analysis/potential_issues.py
  2. 12
      mythril/analysis/symbolic.py
  3. 22
      mythril/interfaces/cli.py
  4. 8
      mythril/laser/ethereum/state/annotation.py
  5. 6
      mythril/laser/ethereum/strategy/__init__.py
  6. 31
      mythril/laser/ethereum/strategy/beam.py
  7. 6
      mythril/laser/ethereum/strategy/extensions/bounded_loops.py
  8. 8
      mythril/laser/ethereum/svm.py
  9. 1
      tests/instructions/berlin_fork_opcodes_test.py
  10. 108
      tests/laser/strategy/test_beam.py

@ -54,6 +54,10 @@ class PotentialIssuesAnnotation(StateAnnotation):
def __init__(self): def __init__(self):
self.potential_issues = [] self.potential_issues = []
@property
def search_importance(self):
return 10 * len(self.potential_issues)
def get_potential_issues_annotation(state: GlobalState) -> PotentialIssuesAnnotation: def get_potential_issues_annotation(state: GlobalState) -> PotentialIssuesAnnotation:
""" """

@ -13,7 +13,7 @@ from mythril.laser.ethereum.strategy.basic import (
ReturnWeightedRandomStrategy, ReturnWeightedRandomStrategy,
BasicSearchStrategy, BasicSearchStrategy,
) )
from mythril.laser.ethereum.strategy.beam import BeamSearch
from mythril.laser.ethereum.natives import PRECOMPILE_COUNT from mythril.laser.ethereum.natives import PRECOMPILE_COUNT
from mythril.laser.ethereum.transaction.symbolic import ACTORS from mythril.laser.ethereum.transaction.symbolic import ACTORS
@ -82,7 +82,7 @@ class SymExecWrapper:
address = symbol_factory.BitVecVal(int(address, 16), 256) address = symbol_factory.BitVecVal(int(address, 16), 256)
if isinstance(address, int): if isinstance(address, int):
address = symbol_factory.BitVecVal(address, 256) address = symbol_factory.BitVecVal(address, 256)
beam_width = None
if strategy == "dfs": if strategy == "dfs":
s_strategy = DepthFirstSearchStrategy # type: Type[BasicSearchStrategy] s_strategy = DepthFirstSearchStrategy # type: Type[BasicSearchStrategy]
elif strategy == "bfs": elif strategy == "bfs":
@ -91,6 +91,9 @@ class SymExecWrapper:
s_strategy = ReturnRandomNaivelyStrategy s_strategy = ReturnRandomNaivelyStrategy
elif strategy == "weighted-random": elif strategy == "weighted-random":
s_strategy = ReturnWeightedRandomStrategy s_strategy = ReturnWeightedRandomStrategy
elif "beam-search: " in strategy:
beam_width = int(strategy.split("beam-search: ")[1])
s_strategy = BeamSearch
else: else:
raise ValueError("Invalid strategy argument supplied") raise ValueError("Invalid strategy argument supplied")
@ -121,10 +124,13 @@ class SymExecWrapper:
create_timeout=create_timeout, create_timeout=create_timeout,
transaction_count=transaction_count, transaction_count=transaction_count,
requires_statespace=requires_statespace, requires_statespace=requires_statespace,
beam_width=beam_width,
) )
if loop_bound is not None: if loop_bound is not None:
self.laser.extend_strategy(BoundedLoopsStrategy, loop_bound) self.laser.extend_strategy(
BoundedLoopsStrategy, loop_bound=loop_bound, beam_width=beam_width
)
plugin_loader = LaserPluginLoader() plugin_loader = LaserPluginLoader()
plugin_loader.load(CoveragePluginBuilder()) plugin_loader.load(CoveragePluginBuilder())

@ -437,6 +437,13 @@ def create_safe_functions_parser(parser: ArgumentParser):
default="bfs", default="bfs",
help="Symbolic execution strategy", help="Symbolic execution strategy",
) )
options.add_argument(
"--beam-search",
type=int,
default=None,
help="Beam search with with",
)
options.add_argument( options.add_argument(
"-b", "-b",
"--loop-bound", "--loop-bound",
@ -545,6 +552,12 @@ def create_analyzer_parser(analyzer_parser: ArgumentParser):
default="bfs", default="bfs",
help="Symbolic execution strategy", help="Symbolic execution strategy",
) )
options.add_argument(
"--beam-search",
type=int,
default=None,
help="Beam search with with",
)
options.add_argument( options.add_argument(
"-b", "-b",
"--loop-bound", "--loop-bound",
@ -670,6 +683,7 @@ def validate_args(args: Namespace):
exit_with_error( exit_with_error(
args.outform, "Invalid -v value, you can find valid values in usage" args.outform, "Invalid -v value, you can find valid values in usage"
) )
if args.command in DISASSEMBLE_LIST and len(args.solidity_files) > 1: if args.command in DISASSEMBLE_LIST and len(args.solidity_files) > 1:
exit_with_error("text", "Only a single arg is supported for using disassemble") exit_with_error("text", "Only a single arg is supported for using disassemble")
@ -783,6 +797,10 @@ def execute_command(
:param args: :param args:
:return: :return:
""" """
if args.__dict__.get("beam_search"):
strategy = f"beam-search: {args.beam_search}"
else:
strategy = args.__dict__.get("strategy")
if args.command == "read-storage": if args.command == "read-storage":
storage = disassembler.get_state_variable_from_storage( storage = disassembler.get_state_variable_from_storage(
@ -799,7 +817,7 @@ def execute_command(
elif args.command == SAFE_FUNCTIONS_COMMAND: elif args.command == SAFE_FUNCTIONS_COMMAND:
function_analyzer = MythrilAnalyzer( function_analyzer = MythrilAnalyzer(
strategy=args.strategy, strategy=strategy,
disassembler=disassembler, disassembler=disassembler,
address=address, address=address,
max_depth=args.max_depth, max_depth=args.max_depth,
@ -834,7 +852,7 @@ def execute_command(
elif args.command in ANALYZE_LIST: elif args.command in ANALYZE_LIST:
analyzer = MythrilAnalyzer( analyzer = MythrilAnalyzer(
strategy=args.strategy, strategy=strategy,
disassembler=disassembler, disassembler=disassembler,
address=address, address=address,
max_depth=args.max_depth, max_depth=args.max_depth,

@ -36,6 +36,14 @@ class StateAnnotation:
""" """
return False return False
@property
def search_importance(self) -> int:
"""
Used in estimating the priority of a state annotated with the corresponding annotation.
Default is 1
"""
return 1
class MergeableStateAnnotation(StateAnnotation): class MergeableStateAnnotation(StateAnnotation):
"""This class allows a base annotation class for annotations that """This class allows a base annotation class for annotations that

@ -8,7 +8,7 @@ class BasicSearchStrategy(ABC):
A basic search strategy which halts based on depth A basic search strategy which halts based on depth
""" """
def __init__(self, work_list, max_depth): def __init__(self, work_list, max_depth, **kwargs):
self.work_list = work_list # type: List[GlobalState] self.work_list = work_list # type: List[GlobalState]
self.max_depth = max_depth self.max_depth = max_depth
@ -35,8 +35,8 @@ class CriterionSearchStrategy(BasicSearchStrategy):
If a criterion is satisfied, the search halts If a criterion is satisfied, the search halts
""" """
def __init__(self, work_list, max_depth): def __init__(self, work_list, max_depth, **kwargs):
super().__init__(work_list, max_depth) super().__init__(work_list, max_depth, **kwargs)
self._satisfied_criterion = False self._satisfied_criterion = False
def get_strategic_global_state(self): def get_strategic_global_state(self):

@ -0,0 +1,31 @@
from typing import List
from mythril.laser.ethereum.state.global_state import GlobalState
from . import BasicSearchStrategy
class BeamSearch(BasicSearchStrategy):
"""chooses a random state from the worklist with equal likelihood."""
def __init__(self, work_list, max_depth, beam_width, **kwargs):
super().__init__(work_list, max_depth)
self.beam_width = beam_width
@staticmethod
def beam_priority(state):
return sum([annotation.search_importance for annotation in state._annotations])
def sort_and_eliminate_states(self):
self.work_list.sort(key=lambda state: self.beam_priority(state), reverse=True)
del self.work_list[self.beam_width :]
def get_strategic_global_state(self) -> GlobalState:
"""
:return:
"""
self.sort_and_eliminate_states()
if len(self.work_list) > 0:
return self.work_list.pop(0)
else:
raise IndexError

@ -29,11 +29,11 @@ class BoundedLoopsStrategy(BasicSearchStrategy):
Ignores JUMPI instruction if the destination was targeted >JUMPDEST_LIMIT times. Ignores JUMPI instruction if the destination was targeted >JUMPDEST_LIMIT times.
""" """
def __init__(self, super_strategy: BasicSearchStrategy, *args) -> None: def __init__(self, super_strategy: BasicSearchStrategy, **kwargs) -> None:
"""""" """"""
self.super_strategy = super_strategy self.super_strategy = super_strategy
self.bound = args[0][0] self.bound = kwargs["loop_bound"]
log.info( log.info(
"Loaded search strategy extension: Loop bounds (limit = {})".format( "Loaded search strategy extension: Loop bounds (limit = {})".format(
@ -42,7 +42,7 @@ class BoundedLoopsStrategy(BasicSearchStrategy):
) )
BasicSearchStrategy.__init__( BasicSearchStrategy.__init__(
self, super_strategy.work_list, super_strategy.max_depth self, super_strategy.work_list, super_strategy.max_depth, **kwargs
) )
@staticmethod @staticmethod

@ -61,6 +61,7 @@ class LaserEVM:
requires_statespace=True, requires_statespace=True,
iprof=None, iprof=None,
use_reachability_check=True, use_reachability_check=True,
beam_width=None,
) -> None: ) -> None:
""" """
Initializes the laser evm object Initializes the laser evm object
@ -81,9 +82,8 @@ class LaserEVM:
self.dynamic_loader = dynamic_loader self.dynamic_loader = dynamic_loader
self.use_reachability_check = use_reachability_check self.use_reachability_check = use_reachability_check
# TODO: What about using a deque here?
self.work_list: List[GlobalState] = [] self.work_list: List[GlobalState] = []
self.strategy = strategy(self.work_list, max_depth) self.strategy = strategy(self.work_list, max_depth, beam_width=beam_width)
self.max_depth = max_depth self.max_depth = max_depth
self.transaction_count = transaction_count self.transaction_count = transaction_count
@ -133,8 +133,8 @@ class LaserEVM:
} }
log.info("LASER EVM initialized with dynamic loader: " + str(dynamic_loader)) log.info("LASER EVM initialized with dynamic loader: " + str(dynamic_loader))
def extend_strategy(self, extension: ABCMeta, *args) -> None: def extend_strategy(self, extension: ABCMeta, **kwargs) -> None:
self.strategy = extension(self.strategy, args) self.strategy = extension(self.strategy, **kwargs)
def sym_exec( def sym_exec(
self, self,

@ -7,7 +7,6 @@ from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.ethereum.state.machine_state import MachineState
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.instructions import Instruction from mythril.laser.ethereum.instructions import Instruction
from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction
from mythril.laser.smt import symbol_factory, simplify from mythril.laser.smt import symbol_factory, simplify

@ -0,0 +1,108 @@
import pytest
from mythril.laser.ethereum.strategy.beam import (
BeamSearch,
)
from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.state.environment import Environment
from mythril.laser.ethereum.state.machine_state import MachineState
from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.analysis.potential_issues import PotentialIssuesAnnotation
world_state = WorldState()
account = world_state.create_account(balance=10, address=101)
account.code = Disassembly("60606040")
environment = Environment(account, None, None, None, None, None, None)
potential_issues = PotentialIssuesAnnotation()
# It is a hassle to construct multiple issues
potential_issues.potential_issues = [0, 0]
@pytest.mark.parametrize(
"state, priority",
[
(
GlobalState(
world_state,
environment,
None,
MachineState(gas_limit=8000000),
annotations=[PotentialIssuesAnnotation()],
),
0,
),
(
GlobalState(
world_state,
environment,
None,
MachineState(gas_limit=8000000),
annotations=[potential_issues],
),
20,
),
],
)
def test_priority_sum(state, priority):
assert priority == BeamSearch.beam_priority(state)
@pytest.mark.parametrize(
"states, width",
[
(
[
GlobalState(
world_state,
environment,
None,
MachineState(gas_limit=8000000),
annotations=[PotentialIssuesAnnotation()],
),
GlobalState(
world_state,
environment,
None,
MachineState(gas_limit=8000000),
annotations=[potential_issues],
),
],
1,
),
(
100
* [
GlobalState(
world_state,
environment,
None,
MachineState(gas_limit=8000000),
annotations=[PotentialIssuesAnnotation()],
)
],
1,
),
(
100
* [
GlobalState(
world_state,
environment,
None,
MachineState(gas_limit=8000000),
annotations=[PotentialIssuesAnnotation()],
)
],
0,
),
],
)
def test_elimination(states, width):
strategy = BeamSearch(states, max_depth=100, beam_width=width)
strategy.sort_and_eliminate_states()
assert len(strategy.work_list) <= width
for i in range(len(strategy.work_list) - 1):
assert strategy.beam_priority(strategy.work_list[i]) >= strategy.beam_priority(
strategy.work_list[i + 1]
)
Loading…
Cancel
Save