Add improved state merging (#1843)

* Add improved state merging

* Fix namespace
pull/1845/head
Nikhil Parasaram 8 months ago committed by GitHub
parent a4fe7b287e
commit 70d234e4f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 13
      mythril/analysis/issue_annotation.py
  2. 10
      mythril/analysis/symbolic.py
  3. 5
      mythril/interfaces/cli.py
  4. 104
      mythril/laser/ethereum/state/transient_storage.py
  5. 7
      mythril/laser/plugin/plugins/__init__.py
  6. 1
      mythril/laser/plugin/plugins/state_merge/__init__.py
  7. 106
      mythril/laser/plugin/plugins/state_merge/check_mergeability.py
  8. 162
      mythril/laser/plugin/plugins/state_merge/merge_states.py
  9. 97
      mythril/laser/plugin/plugins/state_merge/state_merge_plugin.py
  10. 1
      mythril/mythril/mythril_analyzer.py
  11. 1
      mythril/support/support_args.py
  12. 1
      tests/graph_test.py
  13. 34
      tests/integration_tests/state_merge_tests.py
  14. 14
      tests/integration_tests/summary_test.py
  15. 1
      tests/mythril/mythril_analyzer_test.py
  16. 1
      tests/statespace_test.py

@ -32,3 +32,16 @@ class IssueAnnotation(StateAnnotation):
issue=self.issue, issue=self.issue,
detector=self.detector, detector=self.detector,
) )
def check_merge_annotation(self, annotation: "IssueAnnotation") -> bool:
if self.conditions == annotation.conditions:
return False
if self.issue.address != annotation.issue.address:
return False
if type(self.detector) != type(annotation.detector):
return False
return True
def merge_annotation(self, annotation: "IssueAnnotation") -> "IssueAnnotation":
return self

@ -21,17 +21,19 @@ from mythril.laser.ethereum.tx_prioritiser import RfTxPrioritiser
from mythril.laser.plugin.loader import LaserPluginLoader from mythril.laser.plugin.loader import LaserPluginLoader
from mythril.laser.plugin.plugins import ( from mythril.laser.plugin.plugins import (
MutationPrunerBuilder, CallDepthLimitBuilder,
DependencyPrunerBuilder,
CoveragePluginBuilder, CoveragePluginBuilder,
CoverageMetricsPluginBuilder, CoverageMetricsPluginBuilder,
CallDepthLimitBuilder, DependencyPrunerBuilder,
InstructionProfilerBuilder, InstructionProfilerBuilder,
MutationPrunerBuilder,
StateMergePluginBuilder,
SymbolicSummaryPluginBuilder, SymbolicSummaryPluginBuilder,
) )
from mythril.laser.ethereum.strategy.extensions.bounded_loops import ( from mythril.laser.ethereum.strategy.extensions.bounded_loops import (
BoundedLoopsStrategy, BoundedLoopsStrategy,
) )
from mythril.laser.plugin.plugins.state_merge.state_merge_plugin import StateMergePlugin
from mythril.laser.smt import symbol_factory, BitVec from mythril.laser.smt import symbol_factory, BitVec
from mythril.support.support_args import args from mythril.support.support_args import args
from typing import Union, List, Type, Optional from typing import Union, List, Type, Optional
@ -145,6 +147,8 @@ class SymExecWrapper:
plugin_loader = LaserPluginLoader() plugin_loader = LaserPluginLoader()
plugin_loader.load(CoverageMetricsPluginBuilder()) plugin_loader.load(CoverageMetricsPluginBuilder())
if args.enable_state_merge:
plugin_loader.load(StateMergePluginBuilder())
if not args.disable_coverage_strategy: if not args.disable_coverage_strategy:
plugin_loader.load(CoveragePluginBuilder()) plugin_loader.load(CoveragePluginBuilder())
if not args.disable_mutation_pruner: if not args.disable_mutation_pruner:

@ -571,6 +571,11 @@ def add_analysis_args(options):
action="store_true", action="store_true",
help="Disable mutation pruner", help="Disable mutation pruner",
) )
options.add_argument(
"--enable-state-merging",
action="store_true",
help="Enable State Merging",
)
options.add_argument( options.add_argument(
"--enable-summaries", "--enable-summaries",
action="store_true", action="store_true",

@ -0,0 +1,104 @@
"""This module contains account-related functionality.
This includes classes representing accounts and their storage.
"""
import logging
from copy import copy, deepcopy
from typing import Any, Dict, Union, Set
from mythril.laser.smt import Array, K, BitVec, simplify, BaseArray, If, Bool
from mythril.disassembler.disassembly import Disassembly
from mythril.laser.smt import symbol_factory
from mythril.support.support_args import args
log = logging.getLogger(__name__)
from pysmt.shortcuts import (
Symbol,
Array,
Store,
Select,
ArrayType,
Solver,
BV,
Equals,
BVZero,
BVConcat,
)
class TransientStorage:
def __init__(self):
self.checkpoints = [
0
] # List to store lengths of the journal at each checkpoint
# Define symbolic arrays for current state and journal
self.current_array = Array("current", ArrayType(BV(256), BV(256)))
self.journal_array = Array(
"journal", ArrayType(BV(256), ArrayType(BV(256), BV(256)))
)
# Function to get value from current state
self.get_func = Select(self.current_array, addr)[key]
# Function to update current state
self.update_current_func = Store(
self.current_array,
addr,
Store(Select(self.current_array, addr), key, prevValue),
)
def get(self, addr, key):
# Create symbolic variables for address and key
addr_sym = BVConcat(BVZero(96), addr)
key_sym = BVConcat(BVZero(96), key)
# Call SMT solver to get value from current state
solver = Solver()
solver.add(
Equals(self.get_func.simplify({addr: addr_sym, key: key_sym}), BVZero(256))
)
if solver.solve():
model = solver.get_model()
return model.get_value(self.get_func)
else:
return BVZero(256) # Return symbolic zero if value is not found
def put(self, addr, key, value):
# Create symbolic variables for address, key, and value
addr_sym = BVConcat(BVZero(96), addr)
key_sym = BVConcat(BVZero(96), key)
value_sym = BVConcat(BVZero(96), value)
# Store the journal entry
self.journal.append((addr_sym, key_sym, self.get(addr, key)))
# Update the current state
self.current_array = self.update_current_func.simplify(
{addr: addr_sym, key: key_sym, prevValue: value_sym}
)
def commit(self):
if len(self.checkpoints) == 0:
raise ValueError("Nothing to commit")
self.checkpoints.pop() # The last checkpoint is discarded.
def checkpoint(self):
self.checkpoints.append(len(self.journal))
def revert(self):
last_checkpoint = self.checkpoints.pop()
if last_checkpoint is None:
raise ValueError("Nothing to revert")
for i in range(len(self.journal) - 1, last_checkpoint - 1, -1):
(addr, key, prevValue) = self.journal[i]
self.current_array = Store(
self.current_array,
addr,
Store(Select(self.current_array, addr), key, prevValue),
)
self.journal = self.journal[:last_checkpoint]

@ -6,11 +6,12 @@ This module contains the implementation of some features
- pruning - pruning
""" """
from mythril.laser.plugin.plugins.benchmark import BenchmarkPluginBuilder from mythril.laser.plugin.plugins.benchmark import BenchmarkPluginBuilder
from mythril.laser.plugin.plugins.call_depth_limiter import CallDepthLimitBuilder
from mythril.laser.plugin.plugins.coverage.coverage_plugin import CoveragePluginBuilder from mythril.laser.plugin.plugins.coverage.coverage_plugin import CoveragePluginBuilder
from mythril.laser.plugin.plugins.coverage_metrics import CoverageMetricsPluginBuilder
from mythril.laser.plugin.plugins.dependency_pruner import DependencyPrunerBuilder from mythril.laser.plugin.plugins.dependency_pruner import DependencyPrunerBuilder
from mythril.laser.plugin.plugins.mutation_pruner import MutationPrunerBuilder
from mythril.laser.plugin.plugins.call_depth_limiter import CallDepthLimitBuilder
from mythril.laser.plugin.plugins.instruction_profiler import InstructionProfilerBuilder from mythril.laser.plugin.plugins.instruction_profiler import InstructionProfilerBuilder
from mythril.laser.plugin.plugins.mutation_pruner import MutationPrunerBuilder
from mythril.laser.plugin.plugins.state_merge import StateMergePluginBuilder
from mythril.laser.plugin.plugins.summary import SymbolicSummaryPluginBuilder from mythril.laser.plugin.plugins.summary import SymbolicSummaryPluginBuilder
from mythril.laser.plugin.plugins.trace import TraceFinderBuilder from mythril.laser.plugin.plugins.trace import TraceFinderBuilder
from mythril.laser.plugin.plugins.coverage_metrics import CoverageMetricsPluginBuilder

@ -0,0 +1 @@
from .state_merge_plugin import StateMergePluginBuilder

@ -0,0 +1,106 @@
import logging
from mythril.laser.ethereum.cfg import Node
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.constraints import Constraints
from mythril.laser.smt import Not
CONSTRAINT_DIFFERENCE_LIMIT = 15
log = logging.getLogger(__name__)
def check_node_merge_condition(node1: Node, node2: Node):
"""
Checks whether two nodes are merge-able
:param node1: The node to be merged
:param node2: The other node to be merged
:return: Boolean, True if we can merge
"""
return all(
[
node1.function_name == node2.function_name,
node1.contract_name == node2.contract_name,
node1.start_addr == node2.start_addr,
]
)
def check_account_merge_condition(account1: Account, account2: Account):
"""
Checks whether we can merge accounts
"""
return all(
[
account1.nonce == account2.nonce,
account1.deleted == account2.deleted,
account1.code.bytecode == account2.code.bytecode,
]
)
def check_ws_merge_condition(state1: WorldState, state2: WorldState):
"""
Checks whether we can merge these states
"""
if state1.node and not check_node_merge_condition(state1.node, state2.node):
return False
for address, account in state2.accounts.items():
if (
address in state1._accounts
and check_account_merge_condition(state1._accounts[address], account)
is False
):
return False
if not _check_merge_annotations(state1, state2):
return False
return True
def _check_merge_annotations(state1: WorldState, state2: WorldState):
"""
Checks whether two annotations can be merged
:param state:
:return:
"""
if len(state2.annotations) != len(state1.annotations):
return False
if _check_constraint_merge(state1.constraints, state2.constraints) is False:
return False
for v1, v2 in zip(state2.annotations, state1.annotations):
if type(v1) != type(v2):
return False
try:
if v1.check_merge_annotation(v2) is False: # type: ignore
return False
except AttributeError:
log.error(
f"check_merge_annotation() method doesn't exist "
f"for the annotation {type(v1)}. Aborting merge for the state"
)
return False
return True
def _check_constraint_merge(
constraints1: Constraints, constraints2: Constraints
) -> bool:
"""
We are merging the states which have a no more than CONSTRAINT_DIFFERENCE_LIMIT
different constraints. This helps in merging states which are not too different
"""
dict1 = {c: True for c in constraints1}
dict2 = {c: True for c in constraints2}
c1, c2 = 0, 0
for key in dict1:
if key not in dict2 and Not(key) not in dict2:
c1 += 1
for key in dict2:
if key not in dict1 and Not(key) not in dict1:
c2 += 1
if c1 + c2 > CONSTRAINT_DIFFERENCE_LIMIT:
return False
return True

@ -0,0 +1,162 @@
import logging
from mythril.laser.ethereum.cfg import Node
from typing import Tuple, cast
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.state.account import Account, Storage
from mythril.laser.ethereum.state.constraints import Constraints
from mythril.laser.smt import symbol_factory, Array, If, Or, And, Not, Bool
log = logging.getLogger(__name__)
def merge_states(state1: WorldState, state2: WorldState):
"""
Merge state2 into state1
:param state1: The state to be merged into
:param state2: The state which is merged into state1
:return:
"""
# Merge constraints
state1.constraints, condition1, _ = _merge_constraints(
state1.constraints, state2.constraints
)
# Merge balances
state1.balances = cast(Array, If(condition1, state1.balances, state2.balances))
state1.starting_balances = cast(
Array, If(condition1, state1.starting_balances, state2.starting_balances)
)
# Merge accounts
for address, account in state2.accounts.items():
if address not in state1._accounts:
state1.put_account(account)
else:
merge_accounts(
state1._accounts[address], account, condition1, state1.balances
)
# Merge annotations
_merge_annotations(state1, state2)
# Merge Node
merge_nodes(state1.node, state2.node, state1.constraints)
def merge_nodes(node1: Node, node2: Node, constraints: Constraints):
"""
Merges node2 into node1
:param node1: The node to be merged
:param node2: The other node to be merged
:param constraints: The merged constraints
:return:
"""
node1.states += node2.states
node1.uid = hash(node1)
node1.flags |= node2.flags
node1.constraints = constraints
def merge_accounts(
account1: Account,
account2: Account,
path_condition: Bool,
merged_balance: Array,
):
"""
Merges account2 into account1
:param account1: The account to merge with
:param account2: The second account to merge
:param path_condition: The constraint for this account
:param merged_balance: The merged balance
:return:
"""
if (
account1.nonce != account2.nonce
or account1.deleted != account2.deleted
or account1.code.bytecode != account2.code.bytecode
):
raise ValueError("Un-Mergeable accounts are given to be merged")
account1._balances = merged_balance
merge_storage(account1.storage, account2.storage, path_condition)
def merge_storage(storage1: Storage, storage2: Storage, path_condition: Bool):
"""
Merge storage2 into storage1
:param storage1: To storage to merge into
:param storage2: To storage to merge with
:param path_condition: The constraint for this storage to be executed
:return:
"""
storage1._standard_storage = If(
path_condition, storage1._standard_storage, storage2._standard_storage
)
storage1.storage_keys_loaded = storage1.storage_keys_loaded.union(
storage2.storage_keys_loaded
)
for key, value in storage2.printable_storage.items():
if key in storage1.printable_storage:
storage1.printable_storage[key] = If(
path_condition, storage1.printable_storage[key], value
)
else:
storage1.printable_storage[key] = If(path_condition, 0, value)
def _merge_annotations(state1: "WorldState", state2: "WorldState"):
"""
Merges the annotations of the two states into state1
:param state1:
:param state2:
:return:
"""
for v1, v2 in zip(state1.annotations, state2.annotations):
try:
v1.merge_annotation(v2) # type: ignore
except AttributeError:
log.error(
f"merge_annotation() method doesn't exist for the annotation {type(v1)}. "
"Aborting merge for the state"
)
return False
def _merge_constraints(
constraints1: Constraints, constraints2: Constraints
) -> Tuple[Constraints, Bool, Bool]:
"""
Merges constraints
:param constraints1: Constraint2 of state1
:param constraints2: Constraints of state2
:return: A Tuple of merged constraints,
conjunction of constraints in state 1 not in state 2, conjunction of constraints
in state2 not in state1
"""
dict1 = {c: True for c in constraints1}
dict2 = {c: True for c in constraints2}
c1, c2 = symbol_factory.Bool(True), symbol_factory.Bool(True)
new_constraint1, new_constraint2 = (
symbol_factory.Bool(True),
symbol_factory.Bool(True),
)
same_constraints = Constraints()
for key in dict1:
if key not in dict2:
c1 = And(c1, key)
if Not(key) not in dict2:
new_constraint1 = And(new_constraint1, key)
else:
same_constraints.append(key)
for key in dict2:
if key not in dict1:
c2 = And(c2, key)
if Not(key) not in dict1:
new_constraint2 = And(new_constraint2, key)
else:
same_constraints.append(key)
merge_constraints = same_constraints + [Or(new_constraint1, new_constraint2)]
return merge_constraints, c1, c2

@ -0,0 +1,97 @@
from copy import copy
from typing import Set, List
from mythril.laser.ethereum.svm import LaserEVM
from mythril.laser.plugin.interface import LaserPlugin
from .merge_states import merge_states
from .check_mergeability import check_ws_merge_condition
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.plugin.interface import LaserPlugin
import logging
log = logging.getLogger(__name__)
class MergeAnnotation(StateAnnotation):
pass
class StateMergePluginBuilder(LaserPlugin):
plugin_default_enabled = True
enabled = True
author = "MythX Development Team"
name = "MythX State Merge"
plugin_license = "All rights reserved."
plugin_type = "Laser Plugin"
plugin_version = "0.0.1 "
plugin_description = "This plugin merges states after the end of a transaction"
def __call__(self, *args, **kwargs):
return StateMergePlugin()
class StateMergePlugin(LaserPlugin):
"""
Tries to merge states based on their similarity.
Currently it only tries to merge if everything is same
except constraints and storage. And there is some tolerance level
to the constraints.
A state can be merged only once --> avoids segfaults + better performance
"""
def initialize(self, symbolic_vm: LaserEVM):
"""Initializes the State merging plugin
Introduces hooks for stop_sym_trans function
:param symbolic_vm:
:return:
"""
@symbolic_vm.laser_hook("stop_sym_trans")
def execute_stop_sym_trans_hook():
open_states = symbolic_vm.open_states
if len(open_states) <= 1:
return
num_old_states = len(open_states)
new_states = [] # type: List[WorldState]
old_size = len(open_states)
old_states = copy(open_states)
while old_size != len(new_states):
old_size = len(new_states)
new_states = []
merged_set = set() # type: Set[int]
for i, state in enumerate(old_states):
if i in merged_set:
continue
if len(list(state.get_annotations(MergeAnnotation))) > 0:
new_states.append(state)
continue
new_states.append(self._look_for_merges(i, old_states, merged_set))
old_states = copy(new_states)
log.info(f"States reduced from {num_old_states} to {len(new_states)}")
symbolic_vm.open_states = new_states
def _look_for_merges(
self,
offset: int,
states: List[WorldState],
merged_set: Set[int],
) -> WorldState:
"""
Tries to merge states[offset] with any of the states in states[offset+1:]
:param offset: The offset of state
:param states: The List of states
:param merged_set: Set indicating which states are excluded from merging
:return: Returns a state
"""
state = states[offset]
for j in range(offset + 1, len(states)):
if j in merged_set or not check_ws_merge_condition(state, states[j]):
continue
merge_states(state, states[j])
merged_set.add(j)
state.annotations.append(MergeAnnotation())
return state
return state

@ -73,6 +73,7 @@ class MythrilAnalyzer:
args.disable_coverage_strategy = cmd_args.disable_coverage_strategy args.disable_coverage_strategy = cmd_args.disable_coverage_strategy
args.disable_mutation_pruner = cmd_args.disable_mutation_pruner args.disable_mutation_pruner = cmd_args.disable_mutation_pruner
args.enable_summaries = cmd_args.enable_summaries args.enable_summaries = cmd_args.enable_summaries
args.enable_state_merge = cmd_args.enable_state_merging
if args.pruning_factor is None: if args.pruning_factor is None:
if self.execution_timeout > LARGE_TIME: if self.execution_timeout > LARGE_TIME:

@ -24,6 +24,7 @@ class Args(object, metaclass=Singleton):
self.disable_mutation_pruner = False self.disable_mutation_pruner = False
self.incremental_txs = True self.incremental_txs = True
self.enable_summaries = False self.enable_summaries = False
self.enable_state_merge = False
args = Args() args = Args()

@ -35,6 +35,7 @@ def test_generate_graph():
disable_coverage_strategy=False, disable_coverage_strategy=False,
disable_mutation_pruner=False, disable_mutation_pruner=False,
enable_summaries=False, enable_summaries=False,
enable_state_merging=False,
) )
analyzer = MythrilAnalyzer( analyzer = MythrilAnalyzer(
disassembler=disassembler, disassembler=disassembler,

@ -0,0 +1,34 @@
import pytest
import os
import subprocess
from tests import PROJECT_DIR, TESTDATA
MYTH = str(PROJECT_DIR / "myth")
def output_with_stderr(command):
return subprocess.run(
command.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
testfile_path = os.path.split(__file__)[0]
"""
calls.bin is the bytecode of
https://github.com/ConsenSys/mythril/blob/develop/solidity_examples/calls.sol
"""
swc_test_data = [
("114", f"{TESTDATA}/inputs/calls.sol.o", (9, 5)),
]
@pytest.mark.parametrize("swc, code, states_reduction", swc_test_data)
def test_merge(swc, code, states_reduction):
output = output_with_stderr(
f"{MYTH} -v4 a -f {code} -t 1 --solver-timeout 500 -mUncheckedRetval --enable-state-merging"
)
output_str = f"States reduced from {states_reduction[0]} to {states_reduction[1]}"
assert output_str in output.stderr.decode("utf-8")

@ -4,23 +4,11 @@ import sys
import os import os
from tests import PROJECT_DIR, TESTDATA from tests import PROJECT_DIR, TESTDATA
from subprocess import check_output, CalledProcessError from utils import output_of
MYTH = str(PROJECT_DIR / "myth") MYTH = str(PROJECT_DIR / "myth")
def output_of(command):
"""
:param command:
:return:
"""
try:
return check_output(command, shell=True).decode("UTF-8")
except CalledProcessError as exc:
return exc.output.decode("UTF-8")
test_data = ( test_data = (
# TODO: The commented tests should be sped up! # TODO: The commented tests should be sped up!
# ( # (

@ -42,6 +42,7 @@ def test_fire_lasers(mock_sym, mock_fire_lasers, mock_code_info):
disable_coverage_strategy=False, disable_coverage_strategy=False,
disable_mutation_pruner=False, disable_mutation_pruner=False,
enable_summaries=False, enable_summaries=False,
enable_state_merging=False,
) )
analyzer = MythrilAnalyzer(disassembler, cmd_args=args) analyzer = MythrilAnalyzer(disassembler, cmd_args=args)

@ -32,6 +32,7 @@ def test_statespace_dump():
disable_coverage_strategy=False, disable_coverage_strategy=False,
disable_mutation_pruner=False, disable_mutation_pruner=False,
enable_summaries=False, enable_summaries=False,
enable_state_merging=False,
) )
analyzer = MythrilAnalyzer( analyzer = MythrilAnalyzer(
disassembler=disassembler, disassembler=disassembler,

Loading…
Cancel
Save