state_merge
Nikhil Parasaram 5 years ago
commit f74f7602cd
  1. 81
      mythril/analysis/modules/arbitrary_jump.py
  2. 7
      mythril/analysis/modules/dependence_on_predictable_vars.py
  3. 137
      mythril/analysis/modules/dos.py
  4. 2
      mythril/analysis/modules/ether_thief.py
  5. 2
      mythril/analysis/modules/exceptions.py
  6. 6
      mythril/analysis/modules/external_calls.py
  7. 6
      mythril/analysis/modules/integer.py
  8. 2
      mythril/analysis/modules/multiple_sends.py
  9. 6
      mythril/analysis/modules/state_change_external_calls.py
  10. 6
      mythril/analysis/modules/suicide.py
  11. 2
      mythril/analysis/modules/unchecked_retval.py
  12. 2
      mythril/analysis/potential_issues.py
  13. 4
      mythril/analysis/symbolic.py
  14. 2
      mythril/laser/ethereum/call.py
  15. 28
      mythril/laser/ethereum/instructions.py
  16. 2
      mythril/laser/ethereum/plugins/implementations/mutation_pruner.py
  17. 69
      mythril/laser/ethereum/state/account.py
  18. 3
      mythril/laser/ethereum/state/machine_state.py
  19. 24
      mythril/laser/ethereum/state/world_state.py
  20. 21
      mythril/laser/ethereum/svm.py
  21. 3
      mythril/laser/ethereum/transaction/concolic.py
  22. 6
      mythril/laser/ethereum/transaction/symbolic.py
  23. 2
      mythril/laser/ethereum/transaction/transaction_models.py
  24. 63
      mythril/laser/smt/__init__.py
  25. 28
      mythril/laser/smt/bitvec.py
  26. 84
      mythril/laser/smt/bitvec_helper.py
  27. 297
      mythril/laser/smt/bitvecfunc.py
  28. 2
      tests/instructions/static_call_test.py
  29. 237
      tests/laser/smt/bitvecfunc_test.py

@ -0,0 +1,81 @@
"""This module contains the detection code for Arbitrary jumps."""
import logging
from mythril.analysis.solver import get_transaction_sequence, UnsatError
from mythril.analysis.modules.base import DetectionModule, Issue
from mythril.analysis.swc_data import ARBITRARY_JUMP
from mythril.laser.ethereum.state.global_state import GlobalState
log = logging.getLogger(__name__)
DESCRIPTION = """
Search for any writes to an arbitrary storage slot
"""
class ArbitraryJump(DetectionModule):
"""This module searches for JUMPs to an arbitrary instruction."""
def __init__(self):
""""""
super().__init__(
name="Jump to an arbitrary line",
swc_id=ARBITRARY_JUMP,
description=DESCRIPTION,
entrypoint="callback",
pre_hooks=["JUMP", "JUMPI"],
)
def reset_module(self):
"""
Resets the module by clearing everything
:return:
"""
super().reset_module()
def _execute(self, state: GlobalState) -> None:
"""
:param state:
:return:
"""
if state.get_current_instruction()["address"] in self.cache:
return
self.issues.extend(self._analyze_state(state))
@staticmethod
def _analyze_state(state):
"""
:param state:
:return:
"""
jump_dest = state.mstate.stack[-1]
if jump_dest.symbolic is False:
return []
# Most probably the jump destination can have multiple locations in these circumstances
try:
transaction_sequence = get_transaction_sequence(
state, state.mstate.constraints
)
except UnsatError:
return []
issue = Issue(
contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name,
address=state.get_current_instruction()["address"],
swc_id=ARBITRARY_JUMP,
title="Jump to an arbitrary instruction",
severity="Medium",
bytecode=state.environment.code.bytecode,
description_head="The caller can jump to any point in the code.",
description_tail="This can lead to unintended consequences."
"Please avoid using low level code as much as possible",
gas_used=(state.mstate.min_gas_used, state.mstate.max_gas_used),
transaction_sequence=transaction_sequence,
)
return [issue]
detector = ArbitraryJump()

@ -102,10 +102,11 @@ class PredictableDependenceModule(DetectionModule):
if isinstance(annotation, PredictablePathAnnotation): if isinstance(annotation, PredictablePathAnnotation):
if annotation.add_constraints: if annotation.add_constraints:
constraints = ( constraints = (
state.mstate.constraints + annotation.add_constraints state.world_state.constraints
+ annotation.add_constraints
) )
else: else:
constraints = copy(state.mstate.constraints) constraints = copy(state.world_state.constraints)
try: try:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, constraints state, constraints
@ -192,7 +193,7 @@ class PredictableDependenceModule(DetectionModule):
# Why the second constraint? Because without it Z3 returns a solution where param overflows. # Why the second constraint? Because without it Z3 returns a solution where param overflows.
solver.get_model(state.mstate.constraints + constraint) solver.get_model(state.world_state.constraints + constraint) # type: ignore
state.annotate(OldBlockNumberUsedAnnotation(constraint)) state.annotate(OldBlockNumberUsedAnnotation(constraint))
except UnsatError: except UnsatError:

@ -1,137 +0,0 @@
"""This module contains the detection code SWC-128 - DOS with block gas limit."""
import logging
from typing import Dict, cast, List
from mythril.analysis.swc_data import DOS_WITH_BLOCK_GAS_LIMIT
from mythril.analysis.report import Issue
from mythril.analysis.modules.base import DetectionModule
from mythril.analysis.solver import get_transaction_sequence, UnsatError
from mythril.analysis.analysis_args import analysis_args
from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.ethereum import util
from copy import copy
log = logging.getLogger(__name__)
class VisitsAnnotation(StateAnnotation):
"""State annotation that stores the addresses of state-modifying operations"""
def __init__(self) -> None:
self.loop_start = None # type: int
self.jump_targets = {} # type: Dict[int, int]
def __copy__(self):
result = VisitsAnnotation()
result.loop_start = self.loop_start
result.jump_targets = copy(self.jump_targets)
return result
class DosModule(DetectionModule):
"""This module consists of a makeshift loop detector that annotates the state with
a list of byte ranges likely to be loops. If a CALL or SSTORE detection is found in
one of the ranges it creates a low-severity issue. This is not super precise but
good enough to identify places that warrant a closer look. Checking the loop condition
would be a possible improvement.
"""
def __init__(self) -> None:
""""""
super().__init__(
name="DOS",
swc_id=DOS_WITH_BLOCK_GAS_LIMIT,
description="Check for DOS",
entrypoint="callback",
pre_hooks=["JUMP", "JUMPI", "CALL", "SSTORE"],
)
def _execute(self, state: GlobalState) -> None:
"""
:param state:
:return:
"""
issues = self._analyze_state(state)
self.issues.extend(issues)
def _analyze_state(self, state: GlobalState) -> List[Issue]:
"""
:param state: the current state
:return: returns the issues for that corresponding state
"""
opcode = state.get_current_instruction()["opcode"]
address = state.get_current_instruction()["address"]
annotations = cast(
List[VisitsAnnotation], list(state.get_annotations(VisitsAnnotation))
)
if len(annotations) == 0:
annotation = VisitsAnnotation()
state.annotate(annotation)
else:
annotation = annotations[0]
if opcode in ["JUMP", "JUMPI"]:
if annotation.loop_start is not None:
return []
try:
target = util.get_concrete_int(state.mstate.stack[-1])
except TypeError:
log.debug("Symbolic target encountered in dos module")
return []
if target in annotation.jump_targets:
annotation.jump_targets[target] += 1
else:
annotation.jump_targets[target] = 1
if annotation.jump_targets[target] > min(2, analysis_args.loop_bound - 1):
annotation.loop_start = address
elif annotation.loop_start is not None:
if opcode == "CALL":
operation = "A message call"
else:
operation = "A storage modification"
description_head = (
"Potential denial-of-service if block gas limit is reached."
)
description_tail = "{} is executed in a loop. Be aware that the transaction may fail to execute if the loop is unbounded and the necessary gas exceeds the block gas limit.".format(
operation
)
try:
transaction_sequence = get_transaction_sequence(
state, state.mstate.constraints
)
except UnsatError:
return []
issue = Issue(
contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name,
address=annotation.loop_start,
swc_id=DOS_WITH_BLOCK_GAS_LIMIT,
bytecode=state.environment.code.bytecode,
title="Potential denial-of-service if block gas limit is reached",
severity="Low",
description_head=description_head,
description_tail=description_tail,
gas_used=(state.mstate.min_gas_used, state.mstate.max_gas_used),
transaction_sequence=transaction_sequence,
)
return [issue]
return []
detector = DosModule()

@ -77,7 +77,7 @@ class EtherThief(DetectionModule):
value = state.mstate.stack[-3] value = state.mstate.stack[-3]
target = state.mstate.stack[-2] target = state.mstate.stack[-2]
constraints = copy(state.mstate.constraints) constraints = copy(state.world_state.constraints)
""" """
Require that the current transaction is sent by the attacker and Require that the current transaction is sent by the attacker and

@ -58,7 +58,7 @@ class ReachableExceptionsModule(DetectionModule):
"Use `require()` for regular input checking." "Use `require()` for regular input checking."
) )
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state.mstate.constraints state, state.world_state.constraints
) )
issue = Issue( issue = Issue(
contract=state.environment.active_account.contract_name, contract=state.environment.active_account.contract_name,

@ -33,7 +33,7 @@ an informational issue.
def _is_precompile_call(global_state: GlobalState): def _is_precompile_call(global_state: GlobalState):
to = global_state.mstate.stack[-2] # type: BitVec to = global_state.mstate.stack[-2] # type: BitVec
constraints = copy(global_state.mstate.constraints) constraints = copy(global_state.world_state.constraints)
constraints += [ constraints += [
Or( Or(
to < symbol_factory.BitVecVal(1, 256), to < symbol_factory.BitVecVal(1, 256),
@ -88,7 +88,7 @@ class ExternalCalls(DetectionModule):
constraints = Constraints([UGT(gas, symbol_factory.BitVecVal(2300, 256))]) constraints = Constraints([UGT(gas, symbol_factory.BitVecVal(2300, 256))])
solver.get_transaction_sequence( solver.get_transaction_sequence(
state, constraints + state.mstate.constraints state, constraints + state.world_state.constraints
) )
# Check whether we can also set the callee address # Check whether we can also set the callee address
@ -101,7 +101,7 @@ class ExternalCalls(DetectionModule):
constraints.append(tx.caller == ACTORS.attacker) constraints.append(tx.caller == ACTORS.attacker)
solver.get_transaction_sequence( solver.get_transaction_sequence(
state, constraints + state.mstate.constraints state, constraints + state.world_state.constraints
) )
description_head = "A call to a user-supplied address is executed." description_head = "A call to a user-supplied address is executed."

@ -291,7 +291,9 @@ class IntegerOverflowUnderflowModule(DetectionModule):
if ostate not in self._ostates_satisfiable: if ostate not in self._ostates_satisfiable:
try: try:
constraints = ostate.mstate.constraints + [annotation.constraint] constraints = ostate.world_state.constraints + [
annotation.constraint
]
solver.get_model(constraints) solver.get_model(constraints)
self._ostates_satisfiable.add(ostate) self._ostates_satisfiable.add(ostate)
except: except:
@ -308,7 +310,7 @@ class IntegerOverflowUnderflowModule(DetectionModule):
try: try:
constraints = state.mstate.constraints + [annotation.constraint] constraints = state.world_state.constraints + [annotation.constraint]
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, constraints state, constraints
) )

@ -81,7 +81,7 @@ class MultipleSendsModule(DetectionModule):
for offset in call_offsets[1:]: for offset in call_offsets[1:]:
try: try:
transaction_sequence = get_transaction_sequence( transaction_sequence = get_transaction_sequence(
state, state.mstate.constraints state, state.world_state.constraints
) )
except UnsatError: except UnsatError:
continue continue

@ -56,7 +56,7 @@ class StateChangeCallsAnnotation(StateAnnotation):
try: try:
solver.get_transaction_sequence( solver.get_transaction_sequence(
global_state, constraints + global_state.mstate.constraints global_state, constraints + global_state.world_state.constraints
) )
except UnsatError: except UnsatError:
return None return None
@ -124,7 +124,7 @@ class StateChange(DetectionModule):
gas = global_state.mstate.stack[-1] gas = global_state.mstate.stack[-1]
to = global_state.mstate.stack[-2] to = global_state.mstate.stack[-2]
try: try:
constraints = copy(global_state.mstate.constraints) constraints = copy(global_state.world_state.constraints)
solver.get_model( solver.get_model(
constraints constraints
+ [ + [
@ -190,7 +190,7 @@ class StateChange(DetectionModule):
return value.value > 0 return value.value > 0
else: else:
constraints = copy(global_state.mstate.constraints) constraints = copy(global_state.world_state.constraints)
try: try:
solver.get_model( solver.get_model(

@ -73,7 +73,9 @@ class SuicideModule(DetectionModule):
try: try:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state,
state.mstate.constraints + constraints + [to == ACTORS.attacker], state.world_state.constraints
+ constraints
+ [to == ACTORS.attacker],
) )
description_tail = ( description_tail = (
"Anyone can kill this contract and withdraw its balance to an arbitrary " "Anyone can kill this contract and withdraw its balance to an arbitrary "
@ -81,7 +83,7 @@ class SuicideModule(DetectionModule):
) )
except UnsatError: except UnsatError:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state.mstate.constraints + constraints state, state.world_state.constraints + constraints
) )
description_tail = "Arbitrary senders can kill this contract." description_tail = "Arbitrary senders can kill this contract."
issue = Issue( issue = Issue(

@ -83,7 +83,7 @@ class UncheckedRetvalModule(DetectionModule):
for retval in retvals: for retval in retvals:
try: try:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state.mstate.constraints + [retval["retval"] == 0] state, state.world_state.constraints + [retval["retval"] == 0]
) )
except UnsatError: except UnsatError:
continue continue

@ -84,7 +84,7 @@ def check_potential_issues(state: GlobalState) -> None:
for potential_issue in annotation.potential_issues: for potential_issue in annotation.potential_issues:
try: try:
transaction_sequence = get_transaction_sequence( transaction_sequence = get_transaction_sequence(
state, state.mstate.constraints + potential_issue.constraints state, state.world_state.constraints + potential_issue.constraints
) )
except UnsatError: except UnsatError:
continue continue

@ -133,8 +133,8 @@ class SymExecWrapper:
plugin_loader.load(PluginFactory.build_state_merge_plugin()) plugin_loader.load(PluginFactory.build_state_merge_plugin())
plugin_loader.load(instruction_laser_plugin) plugin_loader.load(instruction_laser_plugin)
# if not disable_dependency_pruning: if not disable_dependency_pruning:
# plugin_loader.load(PluginFactory.build_dependency_pruner_plugin()) plugin_loader.load(PluginFactory.build_dependency_pruner_plugin())
world_state = WorldState() world_state = WorldState()
for account in self.accounts.values(): for account in self.accounts.values():

@ -270,5 +270,5 @@ def native_call(
"retval_" + str(global_state.get_current_instruction()["address"]), 256 "retval_" + str(global_state.get_current_instruction()["address"]), 256
) )
global_state.mstate.stack.append(retval) global_state.mstate.stack.append(retval)
global_state.node.constraints.append(retval == 1) global_state.world_state.constraints.append(retval == 1)
return [global_state] return [global_state]

@ -80,7 +80,7 @@ def transfer_ether(
""" """
value = value if isinstance(value, BitVec) else symbol_factory.BitVecVal(value, 256) value = value if isinstance(value, BitVec) else symbol_factory.BitVecVal(value, 256)
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
UGE(global_state.world_state.balances[sender], value) UGE(global_state.world_state.balances[sender], value)
) )
global_state.world_state.balances[receiver] += value global_state.world_state.balances[receiver] += value
@ -948,7 +948,7 @@ class Instruction:
no_of_bytes += calldata.size no_of_bytes += calldata.size
else: else:
no_of_bytes += 0x200 # space for 16 32-byte arguments no_of_bytes += 0x200 # space for 16 32-byte arguments
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
global_state.environment.calldata.size == no_of_bytes global_state.environment.calldata.size == no_of_bytes
) )
@ -1010,7 +1010,7 @@ class Instruction:
result, condition = keccak_function_manager.create_keccak(data) result, condition = keccak_function_manager.create_keccak(data)
state.stack.append(result) state.stack.append(result)
state.constraints.append(condition) global_state.world_state.constraints.append(condition)
return [global_state] return [global_state]
@ -1563,7 +1563,7 @@ class Instruction:
# manually increment PC # manually increment PC
new_state.mstate.depth += 1 new_state.mstate.depth += 1
new_state.mstate.pc += 1 new_state.mstate.pc += 1
new_state.mstate.constraints.append(negated) new_state.world_state.constraints.append(negated)
states.append(new_state) states.append(new_state)
else: else:
log.debug("Pruned unreachable states.") log.debug("Pruned unreachable states.")
@ -1589,7 +1589,7 @@ class Instruction:
# manually set PC to destination # manually set PC to destination
new_state.mstate.pc = index new_state.mstate.pc = index
new_state.mstate.depth += 1 new_state.mstate.depth += 1
new_state.mstate.constraints.append(condi) new_state.world_state.constraints.append(condi)
states.append(new_state) states.append(new_state)
else: else:
log.debug("Pruned unreachable states.") log.debug("Pruned unreachable states.")
@ -1648,7 +1648,7 @@ class Instruction:
def _create_transaction_helper( def _create_transaction_helper(
self, global_state, call_value, mem_offset, mem_size, create2_salt=None self, global_state, call_value, mem_offset, mem_size, create2_salt=None
): ) -> List[GlobalState]:
mstate = global_state.mstate mstate = global_state.mstate
environment = global_state.environment environment = global_state.environment
world_state = global_state.world_state world_state = global_state.world_state
@ -1673,7 +1673,7 @@ class Instruction:
if len(code_raw) < 1: if len(code_raw) < 1:
global_state.mstate.stack.append(1) global_state.mstate.stack.append(1)
log.debug("No code found for trying to execute a create type instruction.") log.debug("No code found for trying to execute a create type instruction.")
return global_state return [global_state]
code_str = bytes.hex(bytes(code_raw)) code_str = bytes.hex(bytes(code_raw))
@ -1695,7 +1695,7 @@ class Instruction:
addr = hex(caller.value)[2:] addr = hex(caller.value)[2:]
addr = "0" * (40 - len(addr)) + addr addr = "0" * (40 - len(addr)) + addr
Instruction._sha3_gas_helper(global_state, len(code_str[2:] // 2)) Instruction._sha3_gas_helper(global_state, len(code_str[2:]) // 2)
contract_address = int( contract_address = int(
get_code_hash("0xff" + addr + salt + get_code_hash(code_str)[2:])[26:], get_code_hash("0xff" + addr + salt + get_code_hash(code_str)[2:])[26:],
@ -1898,7 +1898,7 @@ class Instruction:
) )
if isinstance(value, BitVec): if isinstance(value, BitVec):
if value.symbolic: if value.symbolic:
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
value == symbol_factory.BitVecVal(0, 256) value == symbol_factory.BitVecVal(0, 256)
) )
elif value.value > 0: elif value.value > 0:
@ -2026,7 +2026,7 @@ class Instruction:
"retval_" + str(instr["address"]), 256 "retval_" + str(instr["address"]), 256
) )
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 0) global_state.world_state.constraints.append(return_value == 0)
return [global_state] return [global_state]
try: try:
@ -2058,7 +2058,7 @@ class Instruction:
# Put return value on stack # Put return value on stack
return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256)
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 1) global_state.world_state.constraints.append(return_value == 1)
return [global_state] return [global_state]
@StateTransition() @StateTransition()
@ -2154,7 +2154,7 @@ class Instruction:
"retval_" + str(instr["address"]), 256 "retval_" + str(instr["address"]), 256
) )
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 0) global_state.world_state.constraints.append(return_value == 0)
return [global_state] return [global_state]
try: try:
@ -2186,7 +2186,7 @@ class Instruction:
# Put return value on stack # Put return value on stack
return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256)
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 1) global_state.world_state.constraints.append(return_value == 1)
return [global_state] return [global_state]
@StateTransition() @StateTransition()
@ -2318,6 +2318,6 @@ class Instruction:
# Put return value on stack # Put return value on stack
return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256)
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 1) global_state.world_state.constraints.append(return_value == 1)
return [global_state] return [global_state]

@ -45,7 +45,7 @@ class MutationPruner(LaserPlugin):
@symbolic_vm.laser_hook("add_world_state") @symbolic_vm.laser_hook("add_world_state")
def world_state_filter_hook(global_state: GlobalState): def world_state_filter_hook(global_state: GlobalState):
if And( if And(
*global_state.mstate.constraints[:] *global_state.world_state.constraints[:]
+ [ + [
global_state.environment.callvalue global_state.environment.callvalue
> symbol_factory.BitVecVal(0, 256) > symbol_factory.BitVecVal(0, 256)

@ -4,7 +4,7 @@ This includes classes representing accounts and their storage.
""" """
import logging import logging
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import Any, Dict, Union, Tuple, Set, cast from typing import Any, Dict, Union, Set
from mythril.laser.smt import ( from mythril.laser.smt import (
@ -13,8 +13,6 @@ from mythril.laser.smt import (
BitVec, BitVec,
Bool, Bool,
simplify, simplify,
BitVecFunc,
Extract,
BaseArray, BaseArray,
Concat, Concat,
If, If,
@ -25,26 +23,6 @@ from mythril.laser.smt import symbol_factory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class StorageRegion:
def __getitem__(self, item):
raise NotImplementedError
def __setitem__(self, key, value):
raise NotImplementedError
class ArrayStorageRegion(StorageRegion):
""" An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions"""
pass
class IteStorageRegion(StorageRegion):
""" An IteStorageRegion is a storage region that uses Ite statements to implement a storage"""
pass
class Storage: class Storage:
"""Storage class represents the storage of an Account.""" """Storage class represents the storage of an Account."""
@ -64,18 +42,8 @@ class Storage:
self.storage_keys_loaded = set() # type: Set[int] self.storage_keys_loaded = set() # type: Set[int]
self.address = address self.address = address
@staticmethod
def _sanitize(input_: BitVec) -> BitVec:
if input_.size() == 512:
return input_
if input_.size() > 512:
return Extract(511, 0, input_)
else:
return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_)
def __getitem__(self, item: BitVec) -> BitVec: def __getitem__(self, item: BitVec) -> BitVec:
storage = self._standard_storage storage = self._standard_storage
sanitized_item = item
if ( if (
self.address self.address
and self.address.value != 0 and self.address.value != 0
@ -84,7 +52,7 @@ class Storage:
and (self.dynld and self.dynld.storage_loading) and (self.dynld and self.dynld.storage_loading)
): ):
try: try:
storage[sanitized_item] = symbol_factory.BitVecVal( storage[item] = symbol_factory.BitVecVal(
int( int(
self.dynld.read_storage( self.dynld.read_storage(
contract_address="0x{:040X}".format(self.address.value), contract_address="0x{:040X}".format(self.address.value),
@ -95,29 +63,14 @@ class Storage:
256, 256,
) )
self.storage_keys_loaded.add(int(item.value)) self.storage_keys_loaded.add(int(item.value))
self.printable_storage[item] = storage[sanitized_item] self.printable_storage[item] = storage[item]
except ValueError as e: except ValueError as e:
log.debug("Couldn't read storage at %s: %s", item, e) log.debug("Couldn't read storage at %s: %s", item, e)
return simplify(storage[sanitized_item]) return simplify(storage[item])
@staticmethod
def get_map_index(key: BitVec) -> BitVec:
if (
not isinstance(key, BitVecFunc)
or key.func_name != "keccak256"
or key.input_ is None
):
return None
index = Extract(255, 0, key.input_)
return simplify(index)
def _get_corresponding_storage(self, key: BitVec) -> BaseArray:
return self._standard_storage
def __setitem__(self, key, value: Any) -> None: def __setitem__(self, key, value: Any) -> None:
storage = self._get_corresponding_storage(key)
self.printable_storage[key] = value self.printable_storage[key] = value
storage[key] = value self._standard_storage[key] = value
if key.symbolic is False: if key.symbolic is False:
self.storage_keys_loaded.add(int(key.value)) self.storage_keys_loaded.add(int(key.value))
@ -135,6 +88,13 @@ class Storage:
# TODO: Do something better here # TODO: Do something better here
return str(self.printable_storage) return str(self.printable_storage)
def merge_storage(self, storage: "Storage", path_condition: Bool):
Lambda([x], If(And(lo <= x, x <= hi), y, Select(m, x)))
self._standard_storage = If(path_condition, self._standard_storage, storage._standard_storage)
class Account: class Account:
"""Account class representing ethereum accounts.""" """Account class representing ethereum accounts."""
@ -205,9 +165,8 @@ class Account:
self._balances[self.address] += balance self._balances[self.address] += balance
def merge_accounts(self, account: "Account", path_condition: Bool): def merge_accounts(self, account: "Account", path_condition: Bool):
self.nonce = If(path_condition, self.nonce, account.nonce) assert self.nonce == account.nonce
# self.storage.merge_storage(account.storage) self.storage.merge_storage(account.storage, path_condition)
## Merge Storage
@property @property
def as_dict(self) -> Dict: def as_dict(self) -> Dict:

@ -11,7 +11,6 @@ from mythril.laser.ethereum.evm_exceptions import (
StackUnderflowException, StackUnderflowException,
OutOfGasException, OutOfGasException,
) )
from mythril.laser.ethereum.state.constraints import Constraints
from mythril.laser.ethereum.state.memory import Memory from mythril.laser.ethereum.state.memory import Memory
@ -115,7 +114,6 @@ class MachineState:
self.gas_limit = gas_limit self.gas_limit = gas_limit
self.min_gas_used = min_gas_used # lower gas usage bound self.min_gas_used = min_gas_used # lower gas usage bound
self.max_gas_used = max_gas_used # upper gas usage bound self.max_gas_used = max_gas_used # upper gas usage bound
self.constraints = constraints or Constraints()
self.depth = depth self.depth = depth
self.prev_pc = prev_pc # holds context of current pc self.prev_pc = prev_pc # holds context of current pc
@ -216,7 +214,6 @@ class MachineState:
pc=self._pc, pc=self._pc,
stack=copy(self.stack), stack=copy(self.stack),
memory=copy(self.memory), memory=copy(self.memory),
constraints=copy(self.constraints),
depth=self.depth, depth=self.depth,
prev_pc=self.prev_pc, prev_pc=self.prev_pc,
) )

@ -9,6 +9,7 @@ from ethereum.utils import mk_contract_address
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.constraints import Constraints from mythril.laser.ethereum.state.constraints import Constraints
from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.ethereum.state.constraints import Constraints
if TYPE_CHECKING: if TYPE_CHECKING:
from mythril.laser.ethereum.cfg import Node from mythril.laser.ethereum.cfg import Node
@ -27,14 +28,19 @@ class Balances:
balance = self.balance[item] balance = self.balance[item]
for bs, pc in self.balance_merge_list: for bs, pc in self.balance_merge_list:
balance = If(pc, bs[item], balance) balance = If(pc, bs[item], balance)
return balance
def __setitem__(self, key, value):
pass
class WorldState: class WorldState:
"""The WorldState class represents the world state as described in the """The WorldState class represents the world state as described in the
yellow paper.""" yellow paper."""
def __init__( def __init__(
self, transaction_sequence=None, annotations: List[StateAnnotation] = None self,
transaction_sequence=None,
annotations: List[StateAnnotation] = None,
constraints: Constraints = None,
) -> None: ) -> None:
"""Constructor for the world state. Initializes the accounts record. """Constructor for the world state. Initializes the accounts record.
@ -44,17 +50,22 @@ class WorldState:
self._accounts = {} # type: Dict[int, Account] self._accounts = {} # type: Dict[int, Account]
self.balances = Array("balance", 256, 256) self.balances = Array("balance", 256, 256)
self.starting_balances = copy(self.balances) self.starting_balances = copy(self.balances)
self.constraints = constraints or Constraints()
self.node = None # type: Optional['Node'] self.node = None # type: Optional['Node']
self.transaction_sequence = transaction_sequence or [] self.transaction_sequence = transaction_sequence or []
self._annotations = annotations or [] self._annotations = annotations or []
def merge_states(self, state: "WorldState"): def merge_states(self, state: "WorldState"):
# combine annotations
self._annotations += state._annotations self._annotations += state._annotations
c1 = self.node.constraints.compress()
c2 = state.node.constraints.compress()
self.node.constraints = Constraints([Or(c1, c2)])
# Merge constraints
c1 = self.constraints.compress()
c2 = state.constraints.compress()
self.constraints = Constraints([Or(c1, c2)])
# Merge accounts
for address, account in state.accounts.items(): for address, account in state.accounts.items():
if address not in self._accounts: if address not in self._accounts:
self.put_account(account) self.put_account(account)
@ -63,6 +74,8 @@ class WorldState:
## Merge balances ## Merge balances
@property @property
def accounts(self): def accounts(self):
return self._accounts return self._accounts
@ -95,6 +108,7 @@ class WorldState:
for account in self._accounts.values(): for account in self._accounts.values():
new_world_state.put_account(copy(account)) new_world_state.put_account(copy(account))
new_world_state.node = self.node new_world_state.node = self.node
new_world_state.constraints = copy(self.constraints)
return new_world_state return new_world_state
def accounts_exist_or_load(self, addr: str, dynamic_loader: DynLoader) -> str: def accounts_exist_or_load(self, addr: str, dynamic_loader: DynLoader) -> str:

@ -249,7 +249,9 @@ class LaserEVM:
log.debug("Encountered unimplemented instruction") log.debug("Encountered unimplemented instruction")
continue continue
new_states = [ new_states = [
state for state in new_states if state.mstate.constraints.is_possible state
for state in new_states
if state.world_state.constraints.is_possible
] ]
self.manage_cfg(op_code, new_states) # TODO: What about op_code is None? self.manage_cfg(op_code, new_states) # TODO: What about op_code is None?
@ -345,8 +347,8 @@ class LaserEVM:
global_state.transaction_stack global_state.transaction_stack
) + [(start_signal.transaction, global_state)] ) + [(start_signal.transaction, global_state)]
new_global_state.node = global_state.node new_global_state.node = global_state.node
new_global_state.mstate.constraints = ( new_global_state.world_state.constraints = (
start_signal.global_state.mstate.constraints start_signal.global_state.world_state.constraints
) )
log.debug("Starting new transaction %s", start_signal.transaction) log.debug("Starting new transaction %s", start_signal.transaction)
@ -367,9 +369,6 @@ class LaserEVM:
) and not end_signal.revert: ) and not end_signal.revert:
check_potential_issues(global_state) check_potential_issues(global_state)
end_signal.global_state.world_state.node = global_state.node end_signal.global_state.world_state.node = global_state.node
end_signal.global_state.world_state.node.constraints += (
end_signal.global_state.mstate.constraints
)
self._add_world_state(end_signal.global_state) self._add_world_state(end_signal.global_state)
new_global_states = [] new_global_states = []
@ -417,7 +416,9 @@ class LaserEVM:
:return: :return:
""" """
return_global_state.mstate.constraints += global_state.mstate.constraints return_global_state.world_state.constraints += (
global_state.world_state.constraints
)
# Resume execution of the transaction initializing instruction # Resume execution of the transaction initializing instruction
op_code = return_global_state.environment.code.instruction_list[ op_code = return_global_state.environment.code.instruction_list[
return_global_state.mstate.pc return_global_state.mstate.pc
@ -465,12 +466,12 @@ class LaserEVM:
assert len(new_states) <= 2 assert len(new_states) <= 2
for state in new_states: for state in new_states:
self._new_node_state( self._new_node_state(
state, JumpType.CONDITIONAL, state.mstate.constraints[-1] state, JumpType.CONDITIONAL, state.world_state.constraints[-1]
) )
elif opcode in ("SLOAD", "SSTORE") and len(new_states) > 1: elif opcode in ("SLOAD", "SSTORE") and len(new_states) > 1:
for state in new_states: for state in new_states:
self._new_node_state( self._new_node_state(
state, JumpType.CONDITIONAL, state.mstate.constraints[-1] state, JumpType.CONDITIONAL, state.world_state.constraints[-1]
) )
elif opcode == "RETURN": elif opcode == "RETURN":
for state in new_states: for state in new_states:
@ -491,7 +492,7 @@ class LaserEVM:
new_node = Node(state.environment.active_account.contract_name) new_node = Node(state.environment.active_account.contract_name)
old_node = state.node old_node = state.node
state.node = new_node state.node = new_node
new_node.constraints = state.mstate.constraints new_node.constraints = state.world_state.constraints
if self.requires_statespace: if self.requires_statespace:
self.nodes[new_node.uid] = new_node self.nodes[new_node.uid] = new_node
self.edges.append( self.edges.append(

@ -88,8 +88,7 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None:
condition=None, condition=None,
) )
) )
global_state.mstate.constraints += transaction.world_state.node.constraints new_node.constraints = global_state.world_state.constraints
new_node.constraints = global_state.mstate.constraints
global_state.world_state.transaction_sequence.append(transaction) global_state.world_state.transaction_sequence.append(transaction)
global_state.node = new_node global_state.node = new_node

@ -162,7 +162,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) -
global_state = transaction.initial_global_state() global_state = transaction.initial_global_state()
global_state.transaction_stack.append((transaction, None)) global_state.transaction_stack.append((transaction, None))
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
Or(*[transaction.caller == actor for actor in ACTORS.addresses.values()]) Or(*[transaction.caller == actor for actor in ACTORS.addresses.values()])
) )
@ -183,9 +183,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) -
condition=None, condition=None,
) )
) )
new_node.constraints = global_state.world_state.constraints
global_state.mstate.constraints += transaction.world_state.node.constraints
new_node.constraints = global_state.mstate.constraints
global_state.world_state.transaction_sequence.append(transaction) global_state.world_state.transaction_sequence.append(transaction)
global_state.node = new_node global_state.node = new_node

@ -126,7 +126,7 @@ class BaseTransaction:
else symbol_factory.BitVecVal(environment.callvalue, 256) else symbol_factory.BitVecVal(environment.callvalue, 256)
) )
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
UGE(global_state.world_state.balances[sender], value) UGE(global_state.world_state.balances[sender], value)
) )
global_state.world_state.balances[receiver] += value global_state.world_state.balances[receiver] += value

@ -18,7 +18,6 @@ from mythril.laser.smt.bitvec_helper import (
LShR, LShR,
) )
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.array import K, Array, BaseArray
@ -80,44 +79,6 @@ class SymbolFactory(Generic[T, U]):
""" """
raise NotImplementedError() raise NotImplementedError()
@staticmethod
def BitVecFuncVal(
value: int,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value.
:param value: The concrete value to set the bit vector to
:param func_name: The name of the bit vector function
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:param input_: The input to the bit vector function
:return: The freshly created bit vector function
"""
raise NotImplementedError()
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value.
:param name: The name of the symbolic bit vector
:param func_name: The name of the bit vector function
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:param input_: The input to the bit vector function
:return: The freshly created bit vector function
"""
raise NotImplementedError()
class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
""" """
@ -159,30 +120,6 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
raw = z3.BitVec(name, size) raw = z3.BitVec(name, size)
return BitVec(raw, annotations) return BitVec(raw, annotations)
@staticmethod
def BitVecFuncVal(
value: int,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a concrete value."""
raw = z3.BitVecVal(value, size)
return BitVecFunc(raw, func_name, input_, annotations)
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value."""
raw = z3.BitVec(name, size)
return BitVecFunc(raw, func_name, input_, annotations)
class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]):
""" """

@ -66,8 +66,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other + self
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw + other, annotations=self.annotations) return BitVec(self.raw + other, annotations=self.annotations)
@ -80,8 +78,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other - self
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw - other, annotations=self.annotations) return BitVec(self.raw - other, annotations=self.annotations)
@ -94,8 +90,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other * self
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
return BitVec(self.raw * other.raw, annotations=union) return BitVec(self.raw * other.raw, annotations=union)
@ -105,8 +99,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other / self
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
return BitVec(self.raw / other.raw, annotations=union) return BitVec(self.raw / other.raw, annotations=union)
@ -116,8 +108,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other & self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -129,8 +119,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other | self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -142,8 +130,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other ^ self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -155,8 +141,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other > self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -168,8 +152,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other < self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -204,8 +186,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other == self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return Bool( return Bool(
cast(z3.BoolRef, self.raw == other), annotations=self.annotations cast(z3.BoolRef, self.raw == other), annotations=self.annotations
@ -224,8 +204,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other != self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return Bool( return Bool(
cast(z3.BoolRef, self.raw != other), annotations=self.annotations cast(z3.BoolRef, self.raw != other), annotations=self.annotations
@ -244,8 +222,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param operator: The shift operator :param operator: The shift operator
:return: the resulting output :return: the resulting output
""" """
if isinstance(other, BitVecFunc):
return operator(other, self)
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return BitVec( return BitVec(
operator(self.raw, other), annotations=self.annotations operator(self.raw, other), annotations=self.annotations
@ -275,7 +251,3 @@ class BitVec(Expression[z3.BitVecRef]):
:return: :return:
""" """
return self.raw.__hash__() return self.raw.__hash__()
# TODO: Fix circular import issues
from mythril.laser.smt.bitvecfunc import BitVecFunc

@ -3,31 +3,18 @@ import z3
from mythril.laser.smt.bool import Bool, Or from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper
from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper
Annotations = Set[Any] Annotations = Set[Any]
def _comparison_helper( def _comparison_helper(a: BitVec, b: BitVec, operation: Callable) -> Bool:
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool:
annotations = a.annotations.union(b.annotations) annotations = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
return _func_comparison_helper(a, b, operation, default_value, inputs_equal)
return Bool(operation(a.raw, b.raw), annotations) return Bool(operation(a.raw, b.raw), annotations)
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw) raw = operation(a.raw, b.raw)
union = a.annotations.union(b.annotations) union = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
return _func_arithmetic_helper(a, b, operation)
elif isinstance(b, BitVecFunc):
return _func_arithmetic_helper(b, a, operation)
return BitVec(raw, annotations=union) return BitVec(raw, annotations=union)
@ -43,8 +30,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
:param c: :param c:
:return: :return:
""" """
# TODO: Handle BitVecFunc
if not isinstance(a, Bool): if not isinstance(a, Bool):
a = Bool(z3.BoolVal(a)) a = Bool(z3.BoolVal(a))
if not isinstance(b, BitVec): if not isinstance(b, BitVec):
@ -52,19 +37,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
if not isinstance(c, BitVec): if not isinstance(c, BitVec):
c = BitVec(z3.BitVecVal(c, 256)) c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations.union(b.annotations).union(c.annotations) union = a.annotations.union(b.annotations).union(c.annotations)
bvf = [] # type: List[BitVecFunc]
if isinstance(a, BitVecFunc):
bvf += [a]
if isinstance(b, BitVecFunc):
bvf += [b]
if isinstance(c, BitVecFunc):
bvf += [c]
if bvf:
raw = z3.If(a.raw, b.raw, c.raw)
nested_functions = [nf for func in bvf for nf in func.nested_functions] + bvf
return BitVecFunc(raw, func_name="Hybrid", nested_functions=nested_functions)
return BitVec(z3.If(a.raw, b.raw, c.raw), union) return BitVec(z3.If(a.raw, b.raw, c.raw), union)
@ -75,7 +47,7 @@ def UGT(a: BitVec, b: BitVec) -> Bool:
:param b: :param b:
:return: :return:
""" """
return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) return _comparison_helper(a, b, z3.UGT)
def UGE(a: BitVec, b: BitVec) -> Bool: def UGE(a: BitVec, b: BitVec) -> Bool:
@ -95,7 +67,7 @@ def ULT(a: BitVec, b: BitVec) -> Bool:
:param b: :param b:
:return: :return:
""" """
return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) return _comparison_helper(a, b, z3.ULT)
def ULE(a: BitVec, b: BitVec) -> Bool: def ULE(a: BitVec, b: BitVec) -> Bool:
@ -133,21 +105,8 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
nraw = z3.Concat([a.raw for a in bvs]) nraw = z3.Concat([a.raw for a in bvs])
annotations = set() # type: Annotations annotations = set() # type: Annotations
nested_functions = [] # type: List[BitVecFunc]
for bv in bvs: for bv in bvs:
annotations = annotations.union(bv.annotations) annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
nested_functions += bv.nested_functions
nested_functions += [bv]
if nested_functions:
return BitVecFunc(
raw=nraw,
func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=annotations),
nested_functions=nested_functions,
)
return BitVec(nraw, annotations) return BitVec(nraw, annotations)
@ -160,16 +119,6 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
:return: :return:
""" """
raw = z3.Extract(high, low, bv.raw) raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
input_string = ""
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations),
nested_functions=bv.nested_functions + [bv],
)
return BitVec(raw, annotations=bv.annotations) return BitVec(raw, annotations=bv.annotations)
@ -210,34 +159,9 @@ def Sum(*args: BitVec) -> BitVec:
""" """
raw = z3.Sum([a.raw for a in args]) raw = z3.Sum([a.raw for a in args])
annotations = set() # type: Annotations annotations = set() # type: Annotations
bitvecfuncs = []
for bv in args: for bv in args:
annotations = annotations.union(bv.annotations) annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
bitvecfuncs.append(bv)
nested_functions = [
nf for func in bitvecfuncs for nf in func.nested_functions
] + bitvecfuncs
if len(bitvecfuncs) >= 2:
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=None,
annotations=annotations,
nested_functions=nested_functions,
)
elif len(bitvecfuncs) == 1:
return BitVecFunc(
raw=raw,
func_name=bitvecfuncs[0].func_name,
input_=bitvecfuncs[0].input_,
annotations=annotations,
nested_functions=nested_functions,
)
return BitVec(raw, annotations) return BitVec(raw, annotations)
@ -288,3 +212,5 @@ def BVSubNoUnderflow(
b = BitVec(z3.BitVecVal(b, 256)) b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed)) return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed))
def Lambda(var: BitVec, )

@ -1,297 +0,0 @@
import operator
from itertools import product
from typing import Optional, Union, cast, Callable, List
import z3
from mythril.laser.smt.bitvec import BitVec, Annotations, _padded_operation
from mythril.laser.smt.bool import Or, Bool, And
def _arithmetic_helper(
a: "BitVecFunc", b: Union[BitVec, int], operation: Callable
) -> "BitVecFunc":
"""
Helper function for arithmetic operations on BitVecFuncs.
:param a: The BitVecFunc to perform the operation on.
:param b: A BitVec or int to perform the operation on.
:param operation: The arithmetic operation to perform.
:return: The resulting BitVecFunc
"""
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
raw = operation(a.raw, b.raw)
union = a.annotations.union(b.annotations)
if isinstance(b, BitVecFunc):
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=union),
nested_functions=a.nested_functions + b.nested_functions + [a, b],
)
return BitVecFunc(
raw=raw,
func_name=a.func_name,
input_=a.input_,
annotations=union,
nested_functions=a.nested_functions + [a],
)
def _comparison_helper(
a: "BitVecFunc",
b: Union[BitVec, int],
operation: Callable,
default_value: bool,
inputs_equal: bool,
) -> Bool:
"""
Helper function for comparison operations with BitVecFuncs.
:param a: The BitVecFunc to compare.
:param b: A BitVec or int to compare to.
:param operation: The comparison operation to perform.
:return: The resulting Bool
"""
# Is there some hack for gt/lt comparisons?
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
union = a.annotations.union(b.annotations)
if not a.symbolic and not b.symbolic:
if operation == z3.UGT:
operation = operator.gt
if operation == z3.ULT:
operation = operator.lt
return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
or str(operation) not in ("<built-in function eq>", "<built-in function ne>")
):
return Bool(z3.BoolVal(default_value), annotations=union)
condition = True
for a_nest, b_nest in product(a.nested_functions, b.nested_functions):
if a_nest.func_name != b_nest.func_name:
continue
if a_nest.func_name == "Hybrid":
continue
# a.input (eq/neq) b.input ==> a == b
if inputs_equal:
condition = z3.And(
condition,
z3.Or(
z3.Not((a_nest.input_ == b_nest.input_).raw),
(a_nest.raw == b_nest.raw),
),
z3.Or(
z3.Not((a_nest.raw == b_nest.raw)),
(a_nest.input_ == b_nest.input_).raw,
),
)
else:
condition = z3.And(
condition,
z3.Or(
z3.Not((a_nest.input_ != b_nest.input_).raw),
(a_nest.raw == b_nest.raw),
),
z3.Or(
z3.Not((a_nest.raw == b_nest.raw)),
(a_nest.input_ != b_nest.input_).raw,
),
)
return And(
Bool(
cast(z3.BoolRef, _padded_operation(a.raw, b.raw, operation)),
annotations=union,
),
Bool(condition) if b.nested_functions else Bool(True),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
)
class BitVecFunc(BitVec):
"""A bit vector function symbol. Used in place of functions like sha3."""
def __init__(
self,
raw: z3.BitVecRef,
func_name: Optional[str],
input_: "BitVec" = None,
annotations: Optional[Annotations] = None,
nested_functions: Optional[List["BitVecFunc"]] = None,
):
"""
:param raw: The raw bit vector symbol
:param func_name: The function name. e.g. sha3
:param input: The input to the functions
:param annotations: The annotations the BitVecFunc should start with
"""
self.func_name = func_name
self.input_ = input_
self.nested_functions = nested_functions or []
self.nested_functions = list(dict.fromkeys(self.nested_functions))
if isinstance(input_, BitVecFunc):
self.nested_functions.extend(input_.nested_functions)
super().__init__(raw, annotations)
def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an addition expression.
:param other: The int or BitVec to add to this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.add)
def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a subtraction expression.
:param other: The int or BitVec to subtract from this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.sub)
def __mul__(self, other: "BitVec") -> "BitVecFunc":
"""Create a multiplication expression.
:param other: The int or BitVec to multiply to this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.mul)
def __truediv__(self, other: "BitVec") -> "BitVecFunc":
"""Create a signed division expression.
:param other: The int or BitVec to divide this BitVecFunc by
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.truediv)
def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an and expression.
:param other: The int or BitVec to and with this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.and_)
def __or__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an or expression.
:param other: The int or BitVec to or with this BitVecFunc
:return: The resulting BitVecFunc
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _arithmetic_helper(self, other, operator.or_)
def __xor__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a xor expression.
:param other: The int or BitVec to xor with this BitVecFunc
:return: The resulting BitVecFunc
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _arithmetic_helper(self, other, operator.xor)
def __lt__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed less than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.lt, default_value=False, inputs_equal=False
)
def __gt__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed greater than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.gt, default_value=False, inputs_equal=False
)
def __le__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed less than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return Or(self < other, self == other)
def __ge__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed greater than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return Or(self > other, self == other)
# MYPY: fix complains about overriding __eq__
def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an equality expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.eq, default_value=False, inputs_equal=True
)
# MYPY: fix complains about overriding __ne__
def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an inequality expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.ne, default_value=True, inputs_equal=False
)
def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec":
"""
Left shift operation
:param other: The int or BitVec to shift on
:return The resulting left shifted output
"""
return _arithmetic_helper(self, other, operator.lshift)
def __rshift__(self, other: Union[int, "BitVec"]) -> "BitVec":
"""
Right shift operation
:param other: The int or BitVec to shift on
:return The resulting right shifted output:
"""
return _arithmetic_helper(self, other, operator.rshift)
def __hash__(self) -> int:
return self.raw.__hash__()

@ -121,4 +121,4 @@ def test_staticness_call_symbolic(f1):
instruction.evaluate(state) instruction.evaluate(state)
assert ts.value.transaction.static assert ts.value.transaction.static
assert ts.value.global_state.mstate.constraints[-1] == (call_value == 0) assert ts.value.global_state.world_state.constraints[-1] == (call_value == 0)

@ -1,237 +0,0 @@
from mythril.laser.smt import Solver, symbol_factory, UGT, UGE, ULT, ULE
import z3
import pytest
import operator
@pytest.mark.parametrize(
"operation,expected",
[
(operator.add, z3.unsat),
(operator.sub, z3.unsat),
(operator.and_, z3.sat),
(operator.or_, z3.sat),
(operator.xor, z3.unsat),
],
)
def test_bitvecfunc_arithmetic(operation, expected):
# Arrange
s = Solver()
input_ = symbol_factory.BitVecVal(1, 8)
bvf = symbol_factory.BitVecFuncSym("bvf", "sha3", 256, input_=input_)
x = symbol_factory.BitVecSym("x", 256)
y = symbol_factory.BitVecSym("y", 256)
# Act
s.add(x != y)
s.add(operation(bvf, x) == operation(y, bvf))
# Assert
assert s.check() == expected
@pytest.mark.parametrize(
"operation,expected",
[
(operator.eq, z3.sat),
(operator.ne, z3.unsat),
(operator.lt, z3.unsat),
(operator.le, z3.sat),
(operator.gt, z3.unsat),
(operator.ge, z3.sat),
(UGT, z3.unsat),
(UGE, z3.sat),
(ULT, z3.unsat),
(ULE, z3.sat),
],
)
def test_bitvecfunc_bitvecfunc_comparison(operation, expected):
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=input2)
# Act
s.add(operation(bvf1, bvf2))
s.add(input1 == input2)
# Assert
assert s.check() == expected
def test_bitvecfunc_bitvecfuncval_comparison():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecVal(1337, 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncVal(12345678910, "sha3", 256, input_=input2)
# Act
s.add(bvf1 == bvf2)
# Assert
assert s.check() == z3.sat
assert s.model().eval(input2.raw) == 1337
def test_bitvecfunc_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 == input2)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
def test_bitvecfunc_unequal_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 != input2)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_ext_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 == input2)
s.add(input3 == input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
def test_bitvecfunc_ext_unequal_nested_comparison():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 == input2)
s.add(input3 != input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_ext_unequal_nested_comparison_f():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 != input2)
s.add(input3 == input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_find_input():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
# Act
s.add(input1 == symbol_factory.BitVecVal(1, 256))
s.add(bvf1 == bvf2)
# Assert
assert s.check() == z3.sat
assert s.model()[input2.raw] == 1
def test_bitvecfunc_nested_find_input():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 == symbol_factory.BitVecVal(123, 256))
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
assert s.model()[input2.raw] == 123
Loading…
Cancel
Save