Merge branch 'develop' into fix/symbolic_dos

pull/1148/head
Bernhard Mueller 5 years ago committed by GitHub
commit ad4cad6b1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 102
      mythril/laser/ethereum/plugins/implementations/dependency_pruner.py

@ -7,7 +7,6 @@ from mythril.laser.ethereum.transaction.transaction_models import (
ContractCreationTransaction, ContractCreationTransaction,
) )
from mythril.exceptions import UnsatError from mythril.exceptions import UnsatError
from z3.z3types import Z3Exception
from mythril.analysis import solver from mythril.analysis import solver
from typing import cast, List, Dict, Set from typing import cast, List, Dict, Set
from copy import copy from copy import copy
@ -26,15 +25,15 @@ class DependencyAnnotation(StateAnnotation):
def __init__(self): def __init__(self):
self.storage_loaded = [] # type: List self.storage_loaded = [] # type: List
self.storage_written = {} # type: Dict[int, List] self.storage_written = {} # type: Dict[int, List]
self.has_call = False
self.path = [0] # type: List self.path = [0] # type: List
self.blocks_seen = set() # type: Set[int]
def __copy__(self): def __copy__(self):
result = DependencyAnnotation() result = DependencyAnnotation()
result.storage_loaded = copy(self.storage_loaded) result.storage_loaded = copy(self.storage_loaded)
result.storage_written = copy(self.storage_written) result.storage_written = copy(self.storage_written)
result.path = copy(self.path) result.path = copy(self.path)
result.has_call = self.has_call result.blocks_seen = copy(self.blocks_seen)
return result return result
def get_storage_write_cache(self, iteration: int): def get_storage_write_cache(self, iteration: int):
@ -44,15 +43,10 @@ class DependencyAnnotation(StateAnnotation):
return self.storage_written[iteration] return self.storage_written[iteration]
def extend_storage_write_cache(self, iteration: int, value: object): def extend_storage_write_cache(self, iteration: int, value: object):
try:
if iteration not in self.storage_written: if iteration not in self.storage_written:
self.storage_written[iteration] = [value] self.storage_written[iteration] = [value]
else: elif value not in self.storage_written[iteration]:
if value not in self.storage_written[iteration]:
self.storage_written[iteration].append(value) self.storage_written[iteration].append(value)
except Z3Exception as e:
# FIXME: This should not happen unless there's a bug in laser such as a BitVec512 being generated
log.debug("Error updating storage write cache: {}".format(e))
class WSDependencyAnnotation(StateAnnotation): class WSDependencyAnnotation(StateAnnotation):
@ -140,41 +134,37 @@ class DependencyPruner(LaserPlugin):
def _reset(self): def _reset(self):
self.iteration = 0 self.iteration = 0
self.dependency_map = {} # type: Dict[int, List[object]] self.sloads_on_path = {} # type: Dict[int, List[object]]
self.protected_addresses = set() # type: Set[int] self.sstores_on_path = {} # type: Dict[int, List[object]]
self.storage_accessed_global = set() # type: Set
def update_dependency_map(self, path: List[int], target_location: object) -> None: def update_sloads(self, path: List[int], target_location: object) -> None:
"""Update the dependency map for the block offsets on the given path. """Update the dependency map for the block offsets on the given path.
:param path :param path
:param target_location :param target_location
""" """
log.debug(
"Updating dependency map for path: {} with location: {}".format(
path, target_location
)
)
try:
for address in path: for address in path:
if address in self.dependency_map: if address in self.sloads_on_path:
if target_location not in self.dependency_map[address]: if target_location not in self.sloads_on_path[address]:
self.dependency_map[address].append(target_location) self.sloads_on_path[address].append(target_location)
else: else:
self.dependency_map[address] = [target_location] self.sloads_on_path[address] = [target_location]
except Z3Exception as e:
# FIXME: This should not happen unless there's a bug in laser such as a BitVec512 being generated
log.debug("Error updating dependency map: {}".format(e))
def protect_path(self, path: List[int]) -> None: def update_sstores(self, path: List[int], target_location: object) -> None:
"""Prevent an execution path of being pruned. """Update the dependency map for the block offsets on the given path.
:param path :param path
:param target_location
""" """
for address in path: for address in path:
self.protected_addresses.add(address) if address in self.sstores_on_path:
if target_location not in self.sstores_on_path[address]:
self.sstores_on_path[address].append(target_location)
else:
self.sstores_on_path[address] = [target_location]
def wanna_execute(self, address: int, annotation: DependencyAnnotation) -> bool: def wanna_execute(self, address: int, annotation: DependencyAnnotation) -> bool:
"""Decide whether the basic block starting at 'address' should be executed. """Decide whether the basic block starting at 'address' should be executed.
@ -185,14 +175,25 @@ class DependencyPruner(LaserPlugin):
storage_write_cache = annotation.get_storage_write_cache(self.iteration - 1) storage_write_cache = annotation.get_storage_write_cache(self.iteration - 1)
# Execute the block if it's marked as "protected" or doesn't yet have an entry in the dependency map. # Skip "pure" paths that don't have any dependencies.
if address in self.protected_addresses or address not in self.dependency_map: if address not in self.sloads_on_path:
return False
# Execute the path if there are state modifications along it that *could* be relevant
if address in self.storage_accessed_global:
for location in self.sstores_on_path:
try:
solver.get_model((location == address,))
return True return True
dependencies = self.dependency_map[address] except UnsatError:
continue
# Return if *any* dependency is found dependencies = self.sloads_on_path[address]
# Execute the path if there's any dependency on state modified in the previous transaction
for location in storage_write_cache: for location in storage_write_cache:
for dependency in dependencies: for dependency in dependencies:
@ -228,13 +229,6 @@ class DependencyPruner(LaserPlugin):
def start_sym_trans_hook(): def start_sym_trans_hook():
self.iteration += 1 self.iteration += 1
@symbolic_vm.post_hook("CALL")
def call_hook(state: GlobalState):
annotation = get_dependency_annotation(state)
annotation.has_call = True
self.protect_path(annotation.path)
@symbolic_vm.post_hook("JUMP") @symbolic_vm.post_hook("JUMP")
def jump_hook(state: GlobalState): def jump_hook(state: GlobalState):
address = state.get_current_instruction()["address"] address = state.get_current_instruction()["address"]
@ -257,9 +251,10 @@ class DependencyPruner(LaserPlugin):
def sstore_hook(state: GlobalState): def sstore_hook(state: GlobalState):
annotation = get_dependency_annotation(state) annotation = get_dependency_annotation(state)
annotation.extend_storage_write_cache( location = state.mstate.stack[-1]
self.iteration, state.mstate.stack[-1]
) self.update_sstores(annotation.path, location)
annotation.extend_storage_write_cache(self.iteration, location)
@symbolic_vm.pre_hook("SLOAD") @symbolic_vm.pre_hook("SLOAD")
def sload_hook(state: GlobalState): def sload_hook(state: GlobalState):
@ -272,7 +267,8 @@ class DependencyPruner(LaserPlugin):
# We backwards-annotate the path here as sometimes execution never reaches a stop or return # We backwards-annotate the path here as sometimes execution never reaches a stop or return
# (and this may change in a future transaction). # (and this may change in a future transaction).
self.update_dependency_map(annotation.path, location) self.update_sloads(annotation.path, location)
self.storage_accessed_global.add(location)
@symbolic_vm.pre_hook("STOP") @symbolic_vm.pre_hook("STOP")
def stop_hook(state: GlobalState): def stop_hook(state: GlobalState):
@ -291,11 +287,11 @@ class DependencyPruner(LaserPlugin):
annotation = get_dependency_annotation(state) annotation = get_dependency_annotation(state)
if annotation.has_call:
self.protect_path(annotation.path)
for index in annotation.storage_loaded: for index in annotation.storage_loaded:
self.update_dependency_map(annotation.path, index) self.update_sloads(annotation.path, index)
for index in annotation.storage_written:
self.update_sstores(annotation.path, index)
def _check_basic_block(address: int, annotation: DependencyAnnotation): def _check_basic_block(address: int, annotation: DependencyAnnotation):
"""This method is where the actual pruning happens. """This method is where the actual pruning happens.
@ -305,7 +301,12 @@ class DependencyPruner(LaserPlugin):
""" """
# Don't skip any blocks in the contract creation transaction # Don't skip any blocks in the contract creation transaction
if self.iteration < 2: if self.iteration < 1:
return
# Don't skip newly discovered blocks
if address not in annotation.blocks_seen:
annotation.blocks_seen.add(address)
return return
if self.wanna_execute(address, annotation): if self.wanna_execute(address, annotation):
@ -335,7 +336,6 @@ class DependencyPruner(LaserPlugin):
annotation.path = [0] annotation.path = [0]
annotation.storage_loaded = [] annotation.storage_loaded = []
annotation.has_call = False
world_state_annotation.annotations_stack.append(annotation) world_state_annotation.annotations_stack.append(annotation)
@ -344,7 +344,7 @@ class DependencyPruner(LaserPlugin):
self.iteration, self.iteration,
state.get_current_instruction()["address"], state.get_current_instruction()["address"],
state.node.function_name, state.node.function_name,
self.dependency_map, self.sloads_on_path,
annotation.storage_written[self.iteration], annotation.storage_written[self.iteration],
) )
) )

Loading…
Cancel
Save