From 0fac8e3d6dc3055a00375fa4cbfd420347f5a86c Mon Sep 17 00:00:00 2001 From: Nikhil Parasaram Date: Fri, 2 Dec 2022 16:17:04 +0000 Subject: [PATCH] Using cache search strategy (#1714) * Fix JSON serialisation issue * Add cache * Add tests * Remove a debug statement --- mythril/analysis/symbolic.py | 3 + mythril/interfaces/cli.py | 4 +- mythril/laser/ethereum/state/constraints.py | 12 ++++ mythril/laser/ethereum/strategy/__init__.py | 5 +- mythril/laser/ethereum/strategy/basic.py | 64 +++++++++++++++++-- mythril/laser/ethereum/strategy/beam.py | 11 ++++ .../ethereum/strategy/constraint_strategy.py | 47 ++++++++++++++ mythril/laser/ethereum/svm.py | 6 +- mythril/support/model.py | 18 +++++- mythril/support/support_utils.py | 46 ++++++++++++- tests/integration_tests/analysis_tests.py | 14 ++++ 11 files changed, 218 insertions(+), 12 deletions(-) create mode 100644 mythril/laser/ethereum/strategy/constraint_strategy.py diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index d7b7bc25..1f3a7019 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -13,6 +13,7 @@ from mythril.laser.ethereum.strategy.basic import ( ReturnWeightedRandomStrategy, BasicSearchStrategy, ) +from mythril.laser.ethereum.strategy.constraint_strategy import DelayConstraintStrategy from mythril.laser.ethereum.strategy.beam import BeamSearch from mythril.laser.ethereum.natives import PRECOMPILE_COUNT from mythril.laser.ethereum.transaction.symbolic import ACTORS @@ -94,6 +95,8 @@ class SymExecWrapper: elif "beam-search: " in strategy: beam_width = int(strategy.split("beam-search: ")[1]) s_strategy = BeamSearch + elif "delayed" in strategy: + s_strategy = DelayConstraintStrategy else: raise ValueError("Invalid strategy argument supplied") diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index 61369176..8fd192d1 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -449,7 +449,7 @@ def add_analysis_args(options): options.add_argument( "--strategy", - choices=["dfs", "bfs", "naive-random", "weighted-random"], + choices=["dfs", "bfs", "naive-random", "weighted-random", "delayed"], default="bfs", help="Symbolic execution strategy", ) @@ -622,7 +622,7 @@ def validate_args(args: Namespace): args.outform, "The transaction sequence is in incorrect format, It should be " "[list of possible function hashes in 1st transaction, " - "list of possible func hashes in 2nd tx, ..]" + "list of possible func hashes in 2nd tx, ...] " "If any list is empty then all possible functions are considered for that transaction", ) if len(args.transaction_sequences) != args.transaction_count: diff --git a/mythril/laser/ethereum/state/constraints.py b/mythril/laser/ethereum/state/constraints.py index 1b671388..30efa93e 100644 --- a/mythril/laser/ethereum/state/constraints.py +++ b/mythril/laser/ethereum/state/constraints.py @@ -42,6 +42,18 @@ class Constraints(list): return False return True + def get_model(self, solver_timeout=None) -> bool: + """ + :param solver_timeout: The default timeout uses analysis timeout from args.solver_timeout + :return: True/False based on the existence of solution of constraints + """ + try: + return get_model(self, solver_timeout=solver_timeout) + except SolverTimeOutException: + return None + except UnsatError: + return None + def append(self, constraint: Union[bool, Bool]) -> None: """ diff --git a/mythril/laser/ethereum/strategy/__init__.py b/mythril/laser/ethereum/strategy/__init__.py index 18fb4932..9671de54 100644 --- a/mythril/laser/ethereum/strategy/__init__.py +++ b/mythril/laser/ethereum/strategy/__init__.py @@ -20,6 +20,9 @@ class BasicSearchStrategy(ABC): """""" raise NotImplementedError("Must be implemented by a subclass") + def run_check(self): + return True + def __next__(self): try: global_state = self.get_strategic_global_state() @@ -43,7 +46,7 @@ class CriterionSearchStrategy(BasicSearchStrategy): if self._satisfied_criterion: raise StopIteration try: - global_state = self.get_strategic_global_state() + return self.get_strategic_global_state() except StopIteration: raise StopIteration diff --git a/mythril/laser/ethereum/strategy/basic.py b/mythril/laser/ethereum/strategy/basic.py index 3bdd36ec..ecd81ff7 100644 --- a/mythril/laser/ethereum/strategy/basic.py +++ b/mythril/laser/ethereum/strategy/basic.py @@ -20,6 +20,13 @@ class DepthFirstSearchStrategy(BasicSearchStrategy): """ return self.work_list.pop() + def view_strategic_global_state(self) -> GlobalState: + """ + + :return: + """ + return self.work_list[-1] + class BreadthFirstSearchStrategy(BasicSearchStrategy): """Implements a breadth first search strategy I.E. @@ -34,17 +41,44 @@ class BreadthFirstSearchStrategy(BasicSearchStrategy): """ return self.work_list.pop(0) + def view_strategic_global_state(self) -> GlobalState: + """ + + :return: + """ + return self.work_list[0] + class ReturnRandomNaivelyStrategy(BasicSearchStrategy): """chooses a random state from the worklist with equal likelihood.""" + def __init__(self, work_list, max_depth, **kwargs): + super().__init__(work_list, max_depth, **kwargs) + self.previous_random_value = -1 + def get_strategic_global_state(self) -> GlobalState: """ :return: """ if len(self.work_list) > 0: - return self.work_list.pop(randrange(len(self.work_list))) + if self.previous_random_value == -1: + return self.work_list.pop(randrange(len(self.work_list))) + else: + new_state = self.work_list.pop(self.previous_random_value) + self.previous_random_value = -1 + return new_state + else: + raise IndexError + + def view_strategic_global_state(self) -> GlobalState: + """ + + :return: + """ + if len(self.work_list) > 0: + self.previous_random_value = randrange(len(self.work_list)) + return self.work_list[self.previous_random_value] else: raise IndexError @@ -53,6 +87,10 @@ class ReturnWeightedRandomStrategy(BasicSearchStrategy): """chooses a random state from the worklist with likelihood based on inverse proportion to depth.""" + def __init__(self, work_list, max_depth, **kwargs): + super().__init__(work_list, max_depth, **kwargs) + self.previous_random_value = -1 + def get_strategic_global_state(self) -> GlobalState: """ @@ -61,6 +99,24 @@ class ReturnWeightedRandomStrategy(BasicSearchStrategy): probability_distribution = [ 1 / (global_state.mstate.depth + 1) for global_state in self.work_list ] - return self.work_list.pop( - choices(range(len(self.work_list)), probability_distribution)[0] - ) + if self.previous_random_value != -1: + ns = self.work_list.pop(self.previous_random_value) + self.previous_random_value = -1 + return ns + else: + return self.work_list.pop( + choices(range(len(self.work_list)), probability_distribution)[0] + ) + + def view_strategic_global_state(self) -> GlobalState: + """ + + :return: + """ + probability_distribution = [ + 1 / (global_state.mstate.depth + 1) for global_state in self.work_list + ] + self.previous_random_value = choices( + range(len(self.work_list)), probability_distribution + )[0] + return self.work_list[self.previous_random_value] diff --git a/mythril/laser/ethereum/strategy/beam.py b/mythril/laser/ethereum/strategy/beam.py index 3e92a657..24408105 100644 --- a/mythril/laser/ethereum/strategy/beam.py +++ b/mythril/laser/ethereum/strategy/beam.py @@ -19,6 +19,17 @@ class BeamSearch(BasicSearchStrategy): self.work_list.sort(key=lambda state: self.beam_priority(state), reverse=True) del self.work_list[self.beam_width :] + def view_strategic_global_state(self) -> GlobalState: + """ + + :return: + """ + self.sort_and_eliminate_states() + if len(self.work_list) > 0: + return self.work_list[0] + else: + raise IndexError + def get_strategic_global_state(self) -> GlobalState: """ diff --git a/mythril/laser/ethereum/strategy/constraint_strategy.py b/mythril/laser/ethereum/strategy/constraint_strategy.py new file mode 100644 index 00000000..30a9d8c5 --- /dev/null +++ b/mythril/laser/ethereum/strategy/constraint_strategy.py @@ -0,0 +1,47 @@ +from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.strategy.basic import BasicSearchStrategy +from mythril.laser.ethereum.state.annotation import StateAnnotation +from mythril.laser.ethereum.transaction import ContractCreationTransaction +from mythril.laser.smt import And, simplify +from mythril.support.support_utils import ModelCache + +from typing import Dict, cast, List +from collections import OrderedDict +from copy import copy, deepcopy +from functools import lru_cache +import logging +import z3 +from time import time + +log = logging.getLogger(__name__) + + +class DelayConstraintStrategy(BasicSearchStrategy): + def __init__(self, work_list, max_depth, **kwargs): + super().__init__(work_list, max_depth) + self.model_cache = ModelCache() + self.pending_worklist = [] + log.info("Loaded search strategy extension: DelayConstraintStrategy") + + def get_strategic_global_state(self) -> GlobalState: + """Returns the next state + + :return: Global state + """ + while True: + if len(self.work_list) == 0: + state = self.pending_worklist.pop(0) + model = state.world_state.constraints.get_model() + if model is not None: + self.model_cache.put(model, 1) + return state + else: + state = self.work_list[0] + c_val = self.model_cache.check_quick_sat( + simplify(And(*state.world_state.constraints)).raw + ) + if c_val == False: + self.pending_worklist.append(state) + self.work_list.pop(0) + else: + return self.work_list.pop(0) diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index c4d9f390..b5f40c8f 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -17,6 +17,7 @@ from mythril.laser.plugin.signals import PluginSkipWorldState, PluginSkipState from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy +from mythril.laser.ethereum.strategy.constraint_strategy import DelayConstraintStrategy from abc import ABCMeta from mythril.laser.ethereum.time_handler import time_handler @@ -309,7 +310,10 @@ class LaserEVM: except NotImplementedError: log.debug("Encountered unimplemented instruction") continue - if len(new_states) > 1 and random.uniform(0, 1) < args.pruning_factor: + + if self.strategy.run_check() and ( + len(new_states) > 1 and random.uniform(0, 1) < args.pruning_factor + ): new_states = [ state for state in new_states diff --git a/mythril/support/model.py b/mythril/support/model.py index 6997d90d..f0d2d5b4 100644 --- a/mythril/support/model.py +++ b/mythril/support/model.py @@ -1,15 +1,21 @@ from functools import lru_cache -from z3 import sat, unknown +from z3 import sat, unknown, is_true from pathlib import Path +from mythril.support.support_utils import ModelCache from mythril.support.support_args import args -from mythril.laser.smt import Optimize +from mythril.laser.smt import Optimize, simplify, And from mythril.laser.ethereum.time_handler import time_handler from mythril.exceptions import UnsatError, SolverTimeOutException import logging +from collections import OrderedDict +from copy import deepcopy +from time import time log = logging.getLogger(__name__) -# LRU cache works great when used in powers of 2 + + +model_cache = ModelCache() @lru_cache(maxsize=2**23) @@ -42,6 +48,11 @@ def get_model( constraints = constraints.get_all_constraints() constraints = [constraint for constraint in constraints if type(constraint) != bool] + if len(maximize) + len(minimize) == 0: + ret_model = model_cache.check_quick_sat(simplify(And(*constraints)).raw) + if ret_model: + return ret_model + for constraint in constraints: s.add(constraint) for e in minimize: @@ -63,6 +74,7 @@ def get_model( result = s.check() if result == sat: + model_cache.model_cache.put(s.model(), 1) return s.model() elif result == unknown: log.debug("Timeout/Error encountered while solving expression using z3") diff --git a/mythril/support/support_utils.py b/mythril/support/support_utils.py index bbabc125..484a3c9e 100644 --- a/mythril/support/support_utils.py +++ b/mythril/support/support_utils.py @@ -1,8 +1,12 @@ """This module contains utility functions for the Mythril support package.""" + +from collections import OrderedDict +from copy import deepcopy +from eth_hash.auto import keccak from functools import lru_cache from typing import Dict +from z3 import is_true, simplify, And import logging -from eth_hash.auto import keccak log = logging.getLogger(__name__) @@ -27,6 +31,46 @@ class Singleton(type): return cls._instances[cls] +class LRUCache: + def __init__(self, size): + self.size = size + self.lru_cache = OrderedDict() + + def get(self, key): + try: + value = self.lru_cache.pop(key) + self.lru_cache[key] = value + return value + except KeyError: + return -1 + + def put(self, key, value): + try: + self.lru_cache.pop(key) + except KeyError: + if len(self.lru_cache) >= self.size: + self.lru_cache.popitem(last=False) + self.lru_cache[key] = value + + +class ModelCache: + def __init__(self): + self.model_cache = LRUCache(size=100) + + @lru_cache(maxsize=2**10) + def check_quick_sat(self, constraints) -> bool: + model_list = list(reversed(self.model_cache.lru_cache.keys())) + for model in reversed(self.model_cache.lru_cache.keys()): + model_copy = deepcopy(model) + if is_true(model_copy.eval(constraints, model_completion=True)): + self.model_cache.put(model, self.model_cache.get(model) + 1) + return model + return False + + def put(self, key, value): + self.model_cache.put(key, value) + + @lru_cache(maxsize=2**10) def get_code_hash(code) -> str: """ diff --git a/tests/integration_tests/analysis_tests.py b/tests/integration_tests/analysis_tests.py index 08854ff3..49a33ab1 100644 --- a/tests/integration_tests/analysis_tests.py +++ b/tests/integration_tests/analysis_tests.py @@ -55,3 +55,17 @@ def test_analysis(file_name, tx_data, calldata): 0 ] assert test_case["steps"][tx_data["TX_OUTPUT"]]["input"] == calldata + + +@pytest.mark.parametrize("file_name, tx_data, calldata", test_data) +def test_analysis_delayed(file_name, tx_data, calldata): + bytecode_file = str(TESTDATA / "inputs" / file_name) + command = f"""python3 {MYTH} analyze -f {bytecode_file} -t {tx_data["TX_COUNT"]} -o jsonv2 -m {tx_data["MODULE"]} --solver-timeout 60000 --strategy delayed""" + output = json.loads(output_of(command)) + + assert len(output[0]["issues"]) == tx_data["ISSUE_COUNT"] + if calldata: + test_case = output[0]["issues"][tx_data["ISSUE_NUMBER"]]["extra"]["testCases"][ + 0 + ] + assert test_case["steps"][tx_data["TX_OUTPUT"]]["input"] == calldata