From 70d234e4f1bd16ad1c07178eb174fac23f4f505f Mon Sep 17 00:00:00 2001 From: Nikhil Parasaram Date: Sun, 10 Mar 2024 21:49:23 +0000 Subject: [PATCH] Add improved state merging (#1843) * Add improved state merging * Fix namespace --- mythril/analysis/issue_annotation.py | 13 ++ mythril/analysis/symbolic.py | 10 +- mythril/interfaces/cli.py | 5 + .../laser/ethereum/state/transient_storage.py | 104 +++++++++++ mythril/laser/plugin/plugins/__init__.py | 7 +- .../plugin/plugins/state_merge/__init__.py | 1 + .../plugins/state_merge/check_mergeability.py | 106 ++++++++++++ .../plugins/state_merge/merge_states.py | 162 ++++++++++++++++++ .../plugins/state_merge/state_merge_plugin.py | 97 +++++++++++ mythril/mythril/mythril_analyzer.py | 1 + mythril/support/support_args.py | 1 + tests/graph_test.py | 1 + tests/integration_tests/state_merge_tests.py | 34 ++++ tests/integration_tests/summary_test.py | 14 +- tests/mythril/mythril_analyzer_test.py | 1 + tests/statespace_test.py | 1 + 16 files changed, 539 insertions(+), 19 deletions(-) create mode 100644 mythril/laser/ethereum/state/transient_storage.py create mode 100644 mythril/laser/plugin/plugins/state_merge/__init__.py create mode 100644 mythril/laser/plugin/plugins/state_merge/check_mergeability.py create mode 100644 mythril/laser/plugin/plugins/state_merge/merge_states.py create mode 100644 mythril/laser/plugin/plugins/state_merge/state_merge_plugin.py create mode 100644 tests/integration_tests/state_merge_tests.py diff --git a/mythril/analysis/issue_annotation.py b/mythril/analysis/issue_annotation.py index 7f5f431b..6e741de6 100644 --- a/mythril/analysis/issue_annotation.py +++ b/mythril/analysis/issue_annotation.py @@ -32,3 +32,16 @@ class IssueAnnotation(StateAnnotation): issue=self.issue, detector=self.detector, ) + + def check_merge_annotation(self, annotation: "IssueAnnotation") -> bool: + if self.conditions == annotation.conditions: + return False + if self.issue.address != annotation.issue.address: + return False + if type(self.detector) != type(annotation.detector): + return False + + return True + + def merge_annotation(self, annotation: "IssueAnnotation") -> "IssueAnnotation": + return self diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index a1de5476..31847292 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -21,17 +21,19 @@ from mythril.laser.ethereum.tx_prioritiser import RfTxPrioritiser from mythril.laser.plugin.loader import LaserPluginLoader from mythril.laser.plugin.plugins import ( - MutationPrunerBuilder, - DependencyPrunerBuilder, + CallDepthLimitBuilder, CoveragePluginBuilder, CoverageMetricsPluginBuilder, - CallDepthLimitBuilder, + DependencyPrunerBuilder, InstructionProfilerBuilder, + MutationPrunerBuilder, + StateMergePluginBuilder, SymbolicSummaryPluginBuilder, ) from mythril.laser.ethereum.strategy.extensions.bounded_loops import ( BoundedLoopsStrategy, ) +from mythril.laser.plugin.plugins.state_merge.state_merge_plugin import StateMergePlugin from mythril.laser.smt import symbol_factory, BitVec from mythril.support.support_args import args from typing import Union, List, Type, Optional @@ -145,6 +147,8 @@ class SymExecWrapper: plugin_loader = LaserPluginLoader() plugin_loader.load(CoverageMetricsPluginBuilder()) + if args.enable_state_merge: + plugin_loader.load(StateMergePluginBuilder()) if not args.disable_coverage_strategy: plugin_loader.load(CoveragePluginBuilder()) if not args.disable_mutation_pruner: diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index 0e462cff..40348674 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -571,6 +571,11 @@ def add_analysis_args(options): action="store_true", help="Disable mutation pruner", ) + options.add_argument( + "--enable-state-merging", + action="store_true", + help="Enable State Merging", + ) options.add_argument( "--enable-summaries", action="store_true", diff --git a/mythril/laser/ethereum/state/transient_storage.py b/mythril/laser/ethereum/state/transient_storage.py new file mode 100644 index 00000000..0c2b82bb --- /dev/null +++ b/mythril/laser/ethereum/state/transient_storage.py @@ -0,0 +1,104 @@ +"""This module contains account-related functionality. + +This includes classes representing accounts and their storage. +""" +import logging +from copy import copy, deepcopy +from typing import Any, Dict, Union, Set + + +from mythril.laser.smt import Array, K, BitVec, simplify, BaseArray, If, Bool +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.smt import symbol_factory +from mythril.support.support_args import args + +log = logging.getLogger(__name__) + + +from pysmt.shortcuts import ( + Symbol, + Array, + Store, + Select, + ArrayType, + Solver, + BV, + Equals, + BVZero, + BVConcat, +) + + +class TransientStorage: + def __init__(self): + self.checkpoints = [ + 0 + ] # List to store lengths of the journal at each checkpoint + + # Define symbolic arrays for current state and journal + self.current_array = Array("current", ArrayType(BV(256), BV(256))) + self.journal_array = Array( + "journal", ArrayType(BV(256), ArrayType(BV(256), BV(256))) + ) + + # Function to get value from current state + self.get_func = Select(self.current_array, addr)[key] + + # Function to update current state + self.update_current_func = Store( + self.current_array, + addr, + Store(Select(self.current_array, addr), key, prevValue), + ) + + def get(self, addr, key): + # Create symbolic variables for address and key + addr_sym = BVConcat(BVZero(96), addr) + key_sym = BVConcat(BVZero(96), key) + + # Call SMT solver to get value from current state + solver = Solver() + solver.add( + Equals(self.get_func.simplify({addr: addr_sym, key: key_sym}), BVZero(256)) + ) + if solver.solve(): + model = solver.get_model() + return model.get_value(self.get_func) + else: + return BVZero(256) # Return symbolic zero if value is not found + + def put(self, addr, key, value): + # Create symbolic variables for address, key, and value + addr_sym = BVConcat(BVZero(96), addr) + key_sym = BVConcat(BVZero(96), key) + value_sym = BVConcat(BVZero(96), value) + + # Store the journal entry + self.journal.append((addr_sym, key_sym, self.get(addr, key))) + + # Update the current state + self.current_array = self.update_current_func.simplify( + {addr: addr_sym, key: key_sym, prevValue: value_sym} + ) + + def commit(self): + if len(self.checkpoints) == 0: + raise ValueError("Nothing to commit") + self.checkpoints.pop() # The last checkpoint is discarded. + + def checkpoint(self): + self.checkpoints.append(len(self.journal)) + + def revert(self): + last_checkpoint = self.checkpoints.pop() + if last_checkpoint is None: + raise ValueError("Nothing to revert") + + for i in range(len(self.journal) - 1, last_checkpoint - 1, -1): + (addr, key, prevValue) = self.journal[i] + self.current_array = Store( + self.current_array, + addr, + Store(Select(self.current_array, addr), key, prevValue), + ) + self.journal = self.journal[:last_checkpoint] diff --git a/mythril/laser/plugin/plugins/__init__.py b/mythril/laser/plugin/plugins/__init__.py index 190ac298..b5c577f0 100644 --- a/mythril/laser/plugin/plugins/__init__.py +++ b/mythril/laser/plugin/plugins/__init__.py @@ -6,11 +6,12 @@ This module contains the implementation of some features - pruning """ from mythril.laser.plugin.plugins.benchmark import BenchmarkPluginBuilder +from mythril.laser.plugin.plugins.call_depth_limiter import CallDepthLimitBuilder from mythril.laser.plugin.plugins.coverage.coverage_plugin import CoveragePluginBuilder +from mythril.laser.plugin.plugins.coverage_metrics import CoverageMetricsPluginBuilder from mythril.laser.plugin.plugins.dependency_pruner import DependencyPrunerBuilder -from mythril.laser.plugin.plugins.mutation_pruner import MutationPrunerBuilder -from mythril.laser.plugin.plugins.call_depth_limiter import CallDepthLimitBuilder from mythril.laser.plugin.plugins.instruction_profiler import InstructionProfilerBuilder +from mythril.laser.plugin.plugins.mutation_pruner import MutationPrunerBuilder +from mythril.laser.plugin.plugins.state_merge import StateMergePluginBuilder from mythril.laser.plugin.plugins.summary import SymbolicSummaryPluginBuilder from mythril.laser.plugin.plugins.trace import TraceFinderBuilder -from mythril.laser.plugin.plugins.coverage_metrics import CoverageMetricsPluginBuilder diff --git a/mythril/laser/plugin/plugins/state_merge/__init__.py b/mythril/laser/plugin/plugins/state_merge/__init__.py new file mode 100644 index 00000000..cc723428 --- /dev/null +++ b/mythril/laser/plugin/plugins/state_merge/__init__.py @@ -0,0 +1 @@ +from .state_merge_plugin import StateMergePluginBuilder diff --git a/mythril/laser/plugin/plugins/state_merge/check_mergeability.py b/mythril/laser/plugin/plugins/state_merge/check_mergeability.py new file mode 100644 index 00000000..84b1b41a --- /dev/null +++ b/mythril/laser/plugin/plugins/state_merge/check_mergeability.py @@ -0,0 +1,106 @@ +import logging +from mythril.laser.ethereum.cfg import Node +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.state.account import Account +from mythril.laser.ethereum.state.constraints import Constraints +from mythril.laser.smt import Not + +CONSTRAINT_DIFFERENCE_LIMIT = 15 + +log = logging.getLogger(__name__) + + +def check_node_merge_condition(node1: Node, node2: Node): + """ + Checks whether two nodes are merge-able + :param node1: The node to be merged + :param node2: The other node to be merged + :return: Boolean, True if we can merge + """ + return all( + [ + node1.function_name == node2.function_name, + node1.contract_name == node2.contract_name, + node1.start_addr == node2.start_addr, + ] + ) + + +def check_account_merge_condition(account1: Account, account2: Account): + """ + Checks whether we can merge accounts + """ + return all( + [ + account1.nonce == account2.nonce, + account1.deleted == account2.deleted, + account1.code.bytecode == account2.code.bytecode, + ] + ) + + +def check_ws_merge_condition(state1: WorldState, state2: WorldState): + """ + Checks whether we can merge these states + """ + if state1.node and not check_node_merge_condition(state1.node, state2.node): + return False + + for address, account in state2.accounts.items(): + if ( + address in state1._accounts + and check_account_merge_condition(state1._accounts[address], account) + is False + ): + return False + if not _check_merge_annotations(state1, state2): + return False + + return True + + +def _check_merge_annotations(state1: WorldState, state2: WorldState): + """ + Checks whether two annotations can be merged + :param state: + :return: + """ + if len(state2.annotations) != len(state1.annotations): + return False + if _check_constraint_merge(state1.constraints, state2.constraints) is False: + return False + for v1, v2 in zip(state2.annotations, state1.annotations): + if type(v1) != type(v2): + return False + try: + if v1.check_merge_annotation(v2) is False: # type: ignore + return False + except AttributeError: + log.error( + f"check_merge_annotation() method doesn't exist " + f"for the annotation {type(v1)}. Aborting merge for the state" + ) + return False + + return True + + +def _check_constraint_merge( + constraints1: Constraints, constraints2: Constraints +) -> bool: + """ + We are merging the states which have a no more than CONSTRAINT_DIFFERENCE_LIMIT + different constraints. This helps in merging states which are not too different + """ + dict1 = {c: True for c in constraints1} + dict2 = {c: True for c in constraints2} + c1, c2 = 0, 0 + for key in dict1: + if key not in dict2 and Not(key) not in dict2: + c1 += 1 + for key in dict2: + if key not in dict1 and Not(key) not in dict1: + c2 += 1 + if c1 + c2 > CONSTRAINT_DIFFERENCE_LIMIT: + return False + return True diff --git a/mythril/laser/plugin/plugins/state_merge/merge_states.py b/mythril/laser/plugin/plugins/state_merge/merge_states.py new file mode 100644 index 00000000..2b78295a --- /dev/null +++ b/mythril/laser/plugin/plugins/state_merge/merge_states.py @@ -0,0 +1,162 @@ +import logging + +from mythril.laser.ethereum.cfg import Node +from typing import Tuple, cast +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.state.account import Account, Storage +from mythril.laser.ethereum.state.constraints import Constraints +from mythril.laser.smt import symbol_factory, Array, If, Or, And, Not, Bool + +log = logging.getLogger(__name__) + + +def merge_states(state1: WorldState, state2: WorldState): + """ + Merge state2 into state1 + :param state1: The state to be merged into + :param state2: The state which is merged into state1 + :return: + """ + + # Merge constraints + state1.constraints, condition1, _ = _merge_constraints( + state1.constraints, state2.constraints + ) + + # Merge balances + state1.balances = cast(Array, If(condition1, state1.balances, state2.balances)) + state1.starting_balances = cast( + Array, If(condition1, state1.starting_balances, state2.starting_balances) + ) + + # Merge accounts + for address, account in state2.accounts.items(): + if address not in state1._accounts: + state1.put_account(account) + else: + merge_accounts( + state1._accounts[address], account, condition1, state1.balances + ) + + # Merge annotations + _merge_annotations(state1, state2) + + # Merge Node + merge_nodes(state1.node, state2.node, state1.constraints) + + +def merge_nodes(node1: Node, node2: Node, constraints: Constraints): + """ + Merges node2 into node1 + :param node1: The node to be merged + :param node2: The other node to be merged + :param constraints: The merged constraints + :return: + """ + node1.states += node2.states + node1.uid = hash(node1) + node1.flags |= node2.flags + node1.constraints = constraints + + +def merge_accounts( + account1: Account, + account2: Account, + path_condition: Bool, + merged_balance: Array, +): + """ + Merges account2 into account1 + :param account1: The account to merge with + :param account2: The second account to merge + :param path_condition: The constraint for this account + :param merged_balance: The merged balance + :return: + """ + if ( + account1.nonce != account2.nonce + or account1.deleted != account2.deleted + or account1.code.bytecode != account2.code.bytecode + ): + raise ValueError("Un-Mergeable accounts are given to be merged") + + account1._balances = merged_balance + merge_storage(account1.storage, account2.storage, path_condition) + + +def merge_storage(storage1: Storage, storage2: Storage, path_condition: Bool): + """ + Merge storage2 into storage1 + :param storage1: To storage to merge into + :param storage2: To storage to merge with + :param path_condition: The constraint for this storage to be executed + :return: + """ + storage1._standard_storage = If( + path_condition, storage1._standard_storage, storage2._standard_storage + ) + storage1.storage_keys_loaded = storage1.storage_keys_loaded.union( + storage2.storage_keys_loaded + ) + for key, value in storage2.printable_storage.items(): + if key in storage1.printable_storage: + storage1.printable_storage[key] = If( + path_condition, storage1.printable_storage[key], value + ) + else: + storage1.printable_storage[key] = If(path_condition, 0, value) + + +def _merge_annotations(state1: "WorldState", state2: "WorldState"): + """ + Merges the annotations of the two states into state1 + :param state1: + :param state2: + :return: + """ + for v1, v2 in zip(state1.annotations, state2.annotations): + try: + v1.merge_annotation(v2) # type: ignore + except AttributeError: + log.error( + f"merge_annotation() method doesn't exist for the annotation {type(v1)}. " + "Aborting merge for the state" + ) + return False + + +def _merge_constraints( + constraints1: Constraints, constraints2: Constraints +) -> Tuple[Constraints, Bool, Bool]: + """ + Merges constraints + :param constraints1: Constraint2 of state1 + :param constraints2: Constraints of state2 + :return: A Tuple of merged constraints, + conjunction of constraints in state 1 not in state 2, conjunction of constraints + in state2 not in state1 + """ + dict1 = {c: True for c in constraints1} + dict2 = {c: True for c in constraints2} + c1, c2 = symbol_factory.Bool(True), symbol_factory.Bool(True) + new_constraint1, new_constraint2 = ( + symbol_factory.Bool(True), + symbol_factory.Bool(True), + ) + same_constraints = Constraints() + for key in dict1: + if key not in dict2: + c1 = And(c1, key) + if Not(key) not in dict2: + new_constraint1 = And(new_constraint1, key) + else: + same_constraints.append(key) + for key in dict2: + if key not in dict1: + c2 = And(c2, key) + if Not(key) not in dict1: + new_constraint2 = And(new_constraint2, key) + else: + same_constraints.append(key) + merge_constraints = same_constraints + [Or(new_constraint1, new_constraint2)] + return merge_constraints, c1, c2 diff --git a/mythril/laser/plugin/plugins/state_merge/state_merge_plugin.py b/mythril/laser/plugin/plugins/state_merge/state_merge_plugin.py new file mode 100644 index 00000000..0d6a2883 --- /dev/null +++ b/mythril/laser/plugin/plugins/state_merge/state_merge_plugin.py @@ -0,0 +1,97 @@ +from copy import copy +from typing import Set, List +from mythril.laser.ethereum.svm import LaserEVM +from mythril.laser.plugin.interface import LaserPlugin +from .merge_states import merge_states +from .check_mergeability import check_ws_merge_condition +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.state.annotation import StateAnnotation +from mythril.laser.plugin.interface import LaserPlugin +import logging + +log = logging.getLogger(__name__) + + +class MergeAnnotation(StateAnnotation): + pass + + +class StateMergePluginBuilder(LaserPlugin): + plugin_default_enabled = True + enabled = True + + author = "MythX Development Team" + name = "MythX State Merge" + plugin_license = "All rights reserved." + plugin_type = "Laser Plugin" + plugin_version = "0.0.1 " + plugin_description = "This plugin merges states after the end of a transaction" + + def __call__(self, *args, **kwargs): + return StateMergePlugin() + + +class StateMergePlugin(LaserPlugin): + """ + Tries to merge states based on their similarity. + Currently it only tries to merge if everything is same + except constraints and storage. And there is some tolerance level + to the constraints. + A state can be merged only once --> avoids segfaults + better performance + """ + + def initialize(self, symbolic_vm: LaserEVM): + """Initializes the State merging plugin + + Introduces hooks for stop_sym_trans function + :param symbolic_vm: + :return: + """ + + @symbolic_vm.laser_hook("stop_sym_trans") + def execute_stop_sym_trans_hook(): + open_states = symbolic_vm.open_states + if len(open_states) <= 1: + return + num_old_states = len(open_states) + new_states = [] # type: List[WorldState] + old_size = len(open_states) + old_states = copy(open_states) + while old_size != len(new_states): + old_size = len(new_states) + new_states = [] + merged_set = set() # type: Set[int] + for i, state in enumerate(old_states): + if i in merged_set: + continue + if len(list(state.get_annotations(MergeAnnotation))) > 0: + new_states.append(state) + continue + new_states.append(self._look_for_merges(i, old_states, merged_set)) + + old_states = copy(new_states) + log.info(f"States reduced from {num_old_states} to {len(new_states)}") + symbolic_vm.open_states = new_states + + def _look_for_merges( + self, + offset: int, + states: List[WorldState], + merged_set: Set[int], + ) -> WorldState: + """ + Tries to merge states[offset] with any of the states in states[offset+1:] + :param offset: The offset of state + :param states: The List of states + :param merged_set: Set indicating which states are excluded from merging + :return: Returns a state + """ + state = states[offset] + for j in range(offset + 1, len(states)): + if j in merged_set or not check_ws_merge_condition(state, states[j]): + continue + merge_states(state, states[j]) + merged_set.add(j) + state.annotations.append(MergeAnnotation()) + return state + return state diff --git a/mythril/mythril/mythril_analyzer.py b/mythril/mythril/mythril_analyzer.py index bd5cf04a..8f724ec9 100644 --- a/mythril/mythril/mythril_analyzer.py +++ b/mythril/mythril/mythril_analyzer.py @@ -73,6 +73,7 @@ class MythrilAnalyzer: args.disable_coverage_strategy = cmd_args.disable_coverage_strategy args.disable_mutation_pruner = cmd_args.disable_mutation_pruner args.enable_summaries = cmd_args.enable_summaries + args.enable_state_merge = cmd_args.enable_state_merging if args.pruning_factor is None: if self.execution_timeout > LARGE_TIME: diff --git a/mythril/support/support_args.py b/mythril/support/support_args.py index d4ccb4f3..8c6b5559 100644 --- a/mythril/support/support_args.py +++ b/mythril/support/support_args.py @@ -24,6 +24,7 @@ class Args(object, metaclass=Singleton): self.disable_mutation_pruner = False self.incremental_txs = True self.enable_summaries = False + self.enable_state_merge = False args = Args() diff --git a/tests/graph_test.py b/tests/graph_test.py index 3b92c358..13935631 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -35,6 +35,7 @@ def test_generate_graph(): disable_coverage_strategy=False, disable_mutation_pruner=False, enable_summaries=False, + enable_state_merging=False, ) analyzer = MythrilAnalyzer( disassembler=disassembler, diff --git a/tests/integration_tests/state_merge_tests.py b/tests/integration_tests/state_merge_tests.py new file mode 100644 index 00000000..da714bd9 --- /dev/null +++ b/tests/integration_tests/state_merge_tests.py @@ -0,0 +1,34 @@ +import pytest +import os +import subprocess + +from tests import PROJECT_DIR, TESTDATA + + +MYTH = str(PROJECT_DIR / "myth") + + +def output_with_stderr(command): + return subprocess.run( + command.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + +testfile_path = os.path.split(__file__)[0] + +""" +calls.bin is the bytecode of +https://github.com/ConsenSys/mythril/blob/develop/solidity_examples/calls.sol +""" +swc_test_data = [ + ("114", f"{TESTDATA}/inputs/calls.sol.o", (9, 5)), +] + + +@pytest.mark.parametrize("swc, code, states_reduction", swc_test_data) +def test_merge(swc, code, states_reduction): + output = output_with_stderr( + f"{MYTH} -v4 a -f {code} -t 1 --solver-timeout 500 -mUncheckedRetval --enable-state-merging" + ) + output_str = f"States reduced from {states_reduction[0]} to {states_reduction[1]}" + assert output_str in output.stderr.decode("utf-8") diff --git a/tests/integration_tests/summary_test.py b/tests/integration_tests/summary_test.py index 46d7f6f8..ce8dcc5b 100644 --- a/tests/integration_tests/summary_test.py +++ b/tests/integration_tests/summary_test.py @@ -4,23 +4,11 @@ import sys import os from tests import PROJECT_DIR, TESTDATA -from subprocess import check_output, CalledProcessError +from utils import output_of MYTH = str(PROJECT_DIR / "myth") -def output_of(command): - """ - - :param command: - :return: - """ - try: - return check_output(command, shell=True).decode("UTF-8") - except CalledProcessError as exc: - return exc.output.decode("UTF-8") - - test_data = ( # TODO: The commented tests should be sped up! # ( diff --git a/tests/mythril/mythril_analyzer_test.py b/tests/mythril/mythril_analyzer_test.py index 90ea5756..9f3f56d8 100644 --- a/tests/mythril/mythril_analyzer_test.py +++ b/tests/mythril/mythril_analyzer_test.py @@ -42,6 +42,7 @@ def test_fire_lasers(mock_sym, mock_fire_lasers, mock_code_info): disable_coverage_strategy=False, disable_mutation_pruner=False, enable_summaries=False, + enable_state_merging=False, ) analyzer = MythrilAnalyzer(disassembler, cmd_args=args) diff --git a/tests/statespace_test.py b/tests/statespace_test.py index af8e3cb0..449b0580 100644 --- a/tests/statespace_test.py +++ b/tests/statespace_test.py @@ -32,6 +32,7 @@ def test_statespace_dump(): disable_coverage_strategy=False, disable_mutation_pruner=False, enable_summaries=False, + enable_state_merging=False, ) analyzer = MythrilAnalyzer( disassembler=disassembler,