mirror of https://github.com/ConsenSys/mythril
parent
a4fe7b287e
commit
70d234e4f1
@ -0,0 +1,104 @@ |
|||||||
|
"""This module contains account-related functionality. |
||||||
|
|
||||||
|
This includes classes representing accounts and their storage. |
||||||
|
""" |
||||||
|
import logging |
||||||
|
from copy import copy, deepcopy |
||||||
|
from typing import Any, Dict, Union, Set |
||||||
|
|
||||||
|
|
||||||
|
from mythril.laser.smt import Array, K, BitVec, simplify, BaseArray, If, Bool |
||||||
|
from mythril.disassembler.disassembly import Disassembly |
||||||
|
from mythril.laser.smt import symbol_factory |
||||||
|
from mythril.support.support_args import args |
||||||
|
|
||||||
|
log = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
from pysmt.shortcuts import ( |
||||||
|
Symbol, |
||||||
|
Array, |
||||||
|
Store, |
||||||
|
Select, |
||||||
|
ArrayType, |
||||||
|
Solver, |
||||||
|
BV, |
||||||
|
Equals, |
||||||
|
BVZero, |
||||||
|
BVConcat, |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
class TransientStorage: |
||||||
|
def __init__(self): |
||||||
|
self.checkpoints = [ |
||||||
|
0 |
||||||
|
] # List to store lengths of the journal at each checkpoint |
||||||
|
|
||||||
|
# Define symbolic arrays for current state and journal |
||||||
|
self.current_array = Array("current", ArrayType(BV(256), BV(256))) |
||||||
|
self.journal_array = Array( |
||||||
|
"journal", ArrayType(BV(256), ArrayType(BV(256), BV(256))) |
||||||
|
) |
||||||
|
|
||||||
|
# Function to get value from current state |
||||||
|
self.get_func = Select(self.current_array, addr)[key] |
||||||
|
|
||||||
|
# Function to update current state |
||||||
|
self.update_current_func = Store( |
||||||
|
self.current_array, |
||||||
|
addr, |
||||||
|
Store(Select(self.current_array, addr), key, prevValue), |
||||||
|
) |
||||||
|
|
||||||
|
def get(self, addr, key): |
||||||
|
# Create symbolic variables for address and key |
||||||
|
addr_sym = BVConcat(BVZero(96), addr) |
||||||
|
key_sym = BVConcat(BVZero(96), key) |
||||||
|
|
||||||
|
# Call SMT solver to get value from current state |
||||||
|
solver = Solver() |
||||||
|
solver.add( |
||||||
|
Equals(self.get_func.simplify({addr: addr_sym, key: key_sym}), BVZero(256)) |
||||||
|
) |
||||||
|
if solver.solve(): |
||||||
|
model = solver.get_model() |
||||||
|
return model.get_value(self.get_func) |
||||||
|
else: |
||||||
|
return BVZero(256) # Return symbolic zero if value is not found |
||||||
|
|
||||||
|
def put(self, addr, key, value): |
||||||
|
# Create symbolic variables for address, key, and value |
||||||
|
addr_sym = BVConcat(BVZero(96), addr) |
||||||
|
key_sym = BVConcat(BVZero(96), key) |
||||||
|
value_sym = BVConcat(BVZero(96), value) |
||||||
|
|
||||||
|
# Store the journal entry |
||||||
|
self.journal.append((addr_sym, key_sym, self.get(addr, key))) |
||||||
|
|
||||||
|
# Update the current state |
||||||
|
self.current_array = self.update_current_func.simplify( |
||||||
|
{addr: addr_sym, key: key_sym, prevValue: value_sym} |
||||||
|
) |
||||||
|
|
||||||
|
def commit(self): |
||||||
|
if len(self.checkpoints) == 0: |
||||||
|
raise ValueError("Nothing to commit") |
||||||
|
self.checkpoints.pop() # The last checkpoint is discarded. |
||||||
|
|
||||||
|
def checkpoint(self): |
||||||
|
self.checkpoints.append(len(self.journal)) |
||||||
|
|
||||||
|
def revert(self): |
||||||
|
last_checkpoint = self.checkpoints.pop() |
||||||
|
if last_checkpoint is None: |
||||||
|
raise ValueError("Nothing to revert") |
||||||
|
|
||||||
|
for i in range(len(self.journal) - 1, last_checkpoint - 1, -1): |
||||||
|
(addr, key, prevValue) = self.journal[i] |
||||||
|
self.current_array = Store( |
||||||
|
self.current_array, |
||||||
|
addr, |
||||||
|
Store(Select(self.current_array, addr), key, prevValue), |
||||||
|
) |
||||||
|
self.journal = self.journal[:last_checkpoint] |
@ -0,0 +1 @@ |
|||||||
|
from .state_merge_plugin import StateMergePluginBuilder |
@ -0,0 +1,106 @@ |
|||||||
|
import logging |
||||||
|
from mythril.laser.ethereum.cfg import Node |
||||||
|
from mythril.laser.ethereum.state.world_state import WorldState |
||||||
|
from mythril.laser.ethereum.state.account import Account |
||||||
|
from mythril.laser.ethereum.state.constraints import Constraints |
||||||
|
from mythril.laser.smt import Not |
||||||
|
|
||||||
|
CONSTRAINT_DIFFERENCE_LIMIT = 15 |
||||||
|
|
||||||
|
log = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def check_node_merge_condition(node1: Node, node2: Node): |
||||||
|
""" |
||||||
|
Checks whether two nodes are merge-able |
||||||
|
:param node1: The node to be merged |
||||||
|
:param node2: The other node to be merged |
||||||
|
:return: Boolean, True if we can merge |
||||||
|
""" |
||||||
|
return all( |
||||||
|
[ |
||||||
|
node1.function_name == node2.function_name, |
||||||
|
node1.contract_name == node2.contract_name, |
||||||
|
node1.start_addr == node2.start_addr, |
||||||
|
] |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
def check_account_merge_condition(account1: Account, account2: Account): |
||||||
|
""" |
||||||
|
Checks whether we can merge accounts |
||||||
|
""" |
||||||
|
return all( |
||||||
|
[ |
||||||
|
account1.nonce == account2.nonce, |
||||||
|
account1.deleted == account2.deleted, |
||||||
|
account1.code.bytecode == account2.code.bytecode, |
||||||
|
] |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
def check_ws_merge_condition(state1: WorldState, state2: WorldState): |
||||||
|
""" |
||||||
|
Checks whether we can merge these states |
||||||
|
""" |
||||||
|
if state1.node and not check_node_merge_condition(state1.node, state2.node): |
||||||
|
return False |
||||||
|
|
||||||
|
for address, account in state2.accounts.items(): |
||||||
|
if ( |
||||||
|
address in state1._accounts |
||||||
|
and check_account_merge_condition(state1._accounts[address], account) |
||||||
|
is False |
||||||
|
): |
||||||
|
return False |
||||||
|
if not _check_merge_annotations(state1, state2): |
||||||
|
return False |
||||||
|
|
||||||
|
return True |
||||||
|
|
||||||
|
|
||||||
|
def _check_merge_annotations(state1: WorldState, state2: WorldState): |
||||||
|
""" |
||||||
|
Checks whether two annotations can be merged |
||||||
|
:param state: |
||||||
|
:return: |
||||||
|
""" |
||||||
|
if len(state2.annotations) != len(state1.annotations): |
||||||
|
return False |
||||||
|
if _check_constraint_merge(state1.constraints, state2.constraints) is False: |
||||||
|
return False |
||||||
|
for v1, v2 in zip(state2.annotations, state1.annotations): |
||||||
|
if type(v1) != type(v2): |
||||||
|
return False |
||||||
|
try: |
||||||
|
if v1.check_merge_annotation(v2) is False: # type: ignore |
||||||
|
return False |
||||||
|
except AttributeError: |
||||||
|
log.error( |
||||||
|
f"check_merge_annotation() method doesn't exist " |
||||||
|
f"for the annotation {type(v1)}. Aborting merge for the state" |
||||||
|
) |
||||||
|
return False |
||||||
|
|
||||||
|
return True |
||||||
|
|
||||||
|
|
||||||
|
def _check_constraint_merge( |
||||||
|
constraints1: Constraints, constraints2: Constraints |
||||||
|
) -> bool: |
||||||
|
""" |
||||||
|
We are merging the states which have a no more than CONSTRAINT_DIFFERENCE_LIMIT |
||||||
|
different constraints. This helps in merging states which are not too different |
||||||
|
""" |
||||||
|
dict1 = {c: True for c in constraints1} |
||||||
|
dict2 = {c: True for c in constraints2} |
||||||
|
c1, c2 = 0, 0 |
||||||
|
for key in dict1: |
||||||
|
if key not in dict2 and Not(key) not in dict2: |
||||||
|
c1 += 1 |
||||||
|
for key in dict2: |
||||||
|
if key not in dict1 and Not(key) not in dict1: |
||||||
|
c2 += 1 |
||||||
|
if c1 + c2 > CONSTRAINT_DIFFERENCE_LIMIT: |
||||||
|
return False |
||||||
|
return True |
@ -0,0 +1,162 @@ |
|||||||
|
import logging |
||||||
|
|
||||||
|
from mythril.laser.ethereum.cfg import Node |
||||||
|
from typing import Tuple, cast |
||||||
|
from mythril.laser.ethereum.state.world_state import WorldState |
||||||
|
from mythril.laser.ethereum.state.account import Account, Storage |
||||||
|
from mythril.laser.ethereum.state.constraints import Constraints |
||||||
|
from mythril.laser.smt import symbol_factory, Array, If, Or, And, Not, Bool |
||||||
|
|
||||||
|
log = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
def merge_states(state1: WorldState, state2: WorldState): |
||||||
|
""" |
||||||
|
Merge state2 into state1 |
||||||
|
:param state1: The state to be merged into |
||||||
|
:param state2: The state which is merged into state1 |
||||||
|
:return: |
||||||
|
""" |
||||||
|
|
||||||
|
# Merge constraints |
||||||
|
state1.constraints, condition1, _ = _merge_constraints( |
||||||
|
state1.constraints, state2.constraints |
||||||
|
) |
||||||
|
|
||||||
|
# Merge balances |
||||||
|
state1.balances = cast(Array, If(condition1, state1.balances, state2.balances)) |
||||||
|
state1.starting_balances = cast( |
||||||
|
Array, If(condition1, state1.starting_balances, state2.starting_balances) |
||||||
|
) |
||||||
|
|
||||||
|
# Merge accounts |
||||||
|
for address, account in state2.accounts.items(): |
||||||
|
if address not in state1._accounts: |
||||||
|
state1.put_account(account) |
||||||
|
else: |
||||||
|
merge_accounts( |
||||||
|
state1._accounts[address], account, condition1, state1.balances |
||||||
|
) |
||||||
|
|
||||||
|
# Merge annotations |
||||||
|
_merge_annotations(state1, state2) |
||||||
|
|
||||||
|
# Merge Node |
||||||
|
merge_nodes(state1.node, state2.node, state1.constraints) |
||||||
|
|
||||||
|
|
||||||
|
def merge_nodes(node1: Node, node2: Node, constraints: Constraints): |
||||||
|
""" |
||||||
|
Merges node2 into node1 |
||||||
|
:param node1: The node to be merged |
||||||
|
:param node2: The other node to be merged |
||||||
|
:param constraints: The merged constraints |
||||||
|
:return: |
||||||
|
""" |
||||||
|
node1.states += node2.states |
||||||
|
node1.uid = hash(node1) |
||||||
|
node1.flags |= node2.flags |
||||||
|
node1.constraints = constraints |
||||||
|
|
||||||
|
|
||||||
|
def merge_accounts( |
||||||
|
account1: Account, |
||||||
|
account2: Account, |
||||||
|
path_condition: Bool, |
||||||
|
merged_balance: Array, |
||||||
|
): |
||||||
|
""" |
||||||
|
Merges account2 into account1 |
||||||
|
:param account1: The account to merge with |
||||||
|
:param account2: The second account to merge |
||||||
|
:param path_condition: The constraint for this account |
||||||
|
:param merged_balance: The merged balance |
||||||
|
:return: |
||||||
|
""" |
||||||
|
if ( |
||||||
|
account1.nonce != account2.nonce |
||||||
|
or account1.deleted != account2.deleted |
||||||
|
or account1.code.bytecode != account2.code.bytecode |
||||||
|
): |
||||||
|
raise ValueError("Un-Mergeable accounts are given to be merged") |
||||||
|
|
||||||
|
account1._balances = merged_balance |
||||||
|
merge_storage(account1.storage, account2.storage, path_condition) |
||||||
|
|
||||||
|
|
||||||
|
def merge_storage(storage1: Storage, storage2: Storage, path_condition: Bool): |
||||||
|
""" |
||||||
|
Merge storage2 into storage1 |
||||||
|
:param storage1: To storage to merge into |
||||||
|
:param storage2: To storage to merge with |
||||||
|
:param path_condition: The constraint for this storage to be executed |
||||||
|
:return: |
||||||
|
""" |
||||||
|
storage1._standard_storage = If( |
||||||
|
path_condition, storage1._standard_storage, storage2._standard_storage |
||||||
|
) |
||||||
|
storage1.storage_keys_loaded = storage1.storage_keys_loaded.union( |
||||||
|
storage2.storage_keys_loaded |
||||||
|
) |
||||||
|
for key, value in storage2.printable_storage.items(): |
||||||
|
if key in storage1.printable_storage: |
||||||
|
storage1.printable_storage[key] = If( |
||||||
|
path_condition, storage1.printable_storage[key], value |
||||||
|
) |
||||||
|
else: |
||||||
|
storage1.printable_storage[key] = If(path_condition, 0, value) |
||||||
|
|
||||||
|
|
||||||
|
def _merge_annotations(state1: "WorldState", state2: "WorldState"): |
||||||
|
""" |
||||||
|
Merges the annotations of the two states into state1 |
||||||
|
:param state1: |
||||||
|
:param state2: |
||||||
|
:return: |
||||||
|
""" |
||||||
|
for v1, v2 in zip(state1.annotations, state2.annotations): |
||||||
|
try: |
||||||
|
v1.merge_annotation(v2) # type: ignore |
||||||
|
except AttributeError: |
||||||
|
log.error( |
||||||
|
f"merge_annotation() method doesn't exist for the annotation {type(v1)}. " |
||||||
|
"Aborting merge for the state" |
||||||
|
) |
||||||
|
return False |
||||||
|
|
||||||
|
|
||||||
|
def _merge_constraints( |
||||||
|
constraints1: Constraints, constraints2: Constraints |
||||||
|
) -> Tuple[Constraints, Bool, Bool]: |
||||||
|
""" |
||||||
|
Merges constraints |
||||||
|
:param constraints1: Constraint2 of state1 |
||||||
|
:param constraints2: Constraints of state2 |
||||||
|
:return: A Tuple of merged constraints, |
||||||
|
conjunction of constraints in state 1 not in state 2, conjunction of constraints |
||||||
|
in state2 not in state1 |
||||||
|
""" |
||||||
|
dict1 = {c: True for c in constraints1} |
||||||
|
dict2 = {c: True for c in constraints2} |
||||||
|
c1, c2 = symbol_factory.Bool(True), symbol_factory.Bool(True) |
||||||
|
new_constraint1, new_constraint2 = ( |
||||||
|
symbol_factory.Bool(True), |
||||||
|
symbol_factory.Bool(True), |
||||||
|
) |
||||||
|
same_constraints = Constraints() |
||||||
|
for key in dict1: |
||||||
|
if key not in dict2: |
||||||
|
c1 = And(c1, key) |
||||||
|
if Not(key) not in dict2: |
||||||
|
new_constraint1 = And(new_constraint1, key) |
||||||
|
else: |
||||||
|
same_constraints.append(key) |
||||||
|
for key in dict2: |
||||||
|
if key not in dict1: |
||||||
|
c2 = And(c2, key) |
||||||
|
if Not(key) not in dict1: |
||||||
|
new_constraint2 = And(new_constraint2, key) |
||||||
|
else: |
||||||
|
same_constraints.append(key) |
||||||
|
merge_constraints = same_constraints + [Or(new_constraint1, new_constraint2)] |
||||||
|
return merge_constraints, c1, c2 |
@ -0,0 +1,97 @@ |
|||||||
|
from copy import copy |
||||||
|
from typing import Set, List |
||||||
|
from mythril.laser.ethereum.svm import LaserEVM |
||||||
|
from mythril.laser.plugin.interface import LaserPlugin |
||||||
|
from .merge_states import merge_states |
||||||
|
from .check_mergeability import check_ws_merge_condition |
||||||
|
from mythril.laser.ethereum.state.world_state import WorldState |
||||||
|
from mythril.laser.ethereum.state.annotation import StateAnnotation |
||||||
|
from mythril.laser.plugin.interface import LaserPlugin |
||||||
|
import logging |
||||||
|
|
||||||
|
log = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
class MergeAnnotation(StateAnnotation): |
||||||
|
pass |
||||||
|
|
||||||
|
|
||||||
|
class StateMergePluginBuilder(LaserPlugin): |
||||||
|
plugin_default_enabled = True |
||||||
|
enabled = True |
||||||
|
|
||||||
|
author = "MythX Development Team" |
||||||
|
name = "MythX State Merge" |
||||||
|
plugin_license = "All rights reserved." |
||||||
|
plugin_type = "Laser Plugin" |
||||||
|
plugin_version = "0.0.1 " |
||||||
|
plugin_description = "This plugin merges states after the end of a transaction" |
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs): |
||||||
|
return StateMergePlugin() |
||||||
|
|
||||||
|
|
||||||
|
class StateMergePlugin(LaserPlugin): |
||||||
|
""" |
||||||
|
Tries to merge states based on their similarity. |
||||||
|
Currently it only tries to merge if everything is same |
||||||
|
except constraints and storage. And there is some tolerance level |
||||||
|
to the constraints. |
||||||
|
A state can be merged only once --> avoids segfaults + better performance |
||||||
|
""" |
||||||
|
|
||||||
|
def initialize(self, symbolic_vm: LaserEVM): |
||||||
|
"""Initializes the State merging plugin |
||||||
|
|
||||||
|
Introduces hooks for stop_sym_trans function |
||||||
|
:param symbolic_vm: |
||||||
|
:return: |
||||||
|
""" |
||||||
|
|
||||||
|
@symbolic_vm.laser_hook("stop_sym_trans") |
||||||
|
def execute_stop_sym_trans_hook(): |
||||||
|
open_states = symbolic_vm.open_states |
||||||
|
if len(open_states) <= 1: |
||||||
|
return |
||||||
|
num_old_states = len(open_states) |
||||||
|
new_states = [] # type: List[WorldState] |
||||||
|
old_size = len(open_states) |
||||||
|
old_states = copy(open_states) |
||||||
|
while old_size != len(new_states): |
||||||
|
old_size = len(new_states) |
||||||
|
new_states = [] |
||||||
|
merged_set = set() # type: Set[int] |
||||||
|
for i, state in enumerate(old_states): |
||||||
|
if i in merged_set: |
||||||
|
continue |
||||||
|
if len(list(state.get_annotations(MergeAnnotation))) > 0: |
||||||
|
new_states.append(state) |
||||||
|
continue |
||||||
|
new_states.append(self._look_for_merges(i, old_states, merged_set)) |
||||||
|
|
||||||
|
old_states = copy(new_states) |
||||||
|
log.info(f"States reduced from {num_old_states} to {len(new_states)}") |
||||||
|
symbolic_vm.open_states = new_states |
||||||
|
|
||||||
|
def _look_for_merges( |
||||||
|
self, |
||||||
|
offset: int, |
||||||
|
states: List[WorldState], |
||||||
|
merged_set: Set[int], |
||||||
|
) -> WorldState: |
||||||
|
""" |
||||||
|
Tries to merge states[offset] with any of the states in states[offset+1:] |
||||||
|
:param offset: The offset of state |
||||||
|
:param states: The List of states |
||||||
|
:param merged_set: Set indicating which states are excluded from merging |
||||||
|
:return: Returns a state |
||||||
|
""" |
||||||
|
state = states[offset] |
||||||
|
for j in range(offset + 1, len(states)): |
||||||
|
if j in merged_set or not check_ws_merge_condition(state, states[j]): |
||||||
|
continue |
||||||
|
merge_states(state, states[j]) |
||||||
|
merged_set.add(j) |
||||||
|
state.annotations.append(MergeAnnotation()) |
||||||
|
return state |
||||||
|
return state |
@ -0,0 +1,34 @@ |
|||||||
|
import pytest |
||||||
|
import os |
||||||
|
import subprocess |
||||||
|
|
||||||
|
from tests import PROJECT_DIR, TESTDATA |
||||||
|
|
||||||
|
|
||||||
|
MYTH = str(PROJECT_DIR / "myth") |
||||||
|
|
||||||
|
|
||||||
|
def output_with_stderr(command): |
||||||
|
return subprocess.run( |
||||||
|
command.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
testfile_path = os.path.split(__file__)[0] |
||||||
|
|
||||||
|
""" |
||||||
|
calls.bin is the bytecode of |
||||||
|
https://github.com/ConsenSys/mythril/blob/develop/solidity_examples/calls.sol |
||||||
|
""" |
||||||
|
swc_test_data = [ |
||||||
|
("114", f"{TESTDATA}/inputs/calls.sol.o", (9, 5)), |
||||||
|
] |
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("swc, code, states_reduction", swc_test_data) |
||||||
|
def test_merge(swc, code, states_reduction): |
||||||
|
output = output_with_stderr( |
||||||
|
f"{MYTH} -v4 a -f {code} -t 1 --solver-timeout 500 -mUncheckedRetval --enable-state-merging" |
||||||
|
) |
||||||
|
output_str = f"States reduced from {states_reduction[0]} to {states_reduction[1]}" |
||||||
|
assert output_str in output.stderr.decode("utf-8") |
Loading…
Reference in new issue