Move constraints to world state (#1276)

* Add constraints to worldstate

* Remove propogation of constraints from Node

* Remove a usage of state.constraints()

* Ignore mypy error

* Add a missing file
pull/1283/head
Nikhil Parasaram 5 years ago committed by GitHub
parent 634d59caa5
commit eeeb9bf639
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      mythril/analysis/modules/dependence_on_predictable_vars.py
  2. 2
      mythril/analysis/modules/ether_thief.py
  3. 2
      mythril/analysis/modules/exceptions.py
  4. 6
      mythril/analysis/modules/external_calls.py
  5. 6
      mythril/analysis/modules/integer.py
  6. 2
      mythril/analysis/modules/multiple_sends.py
  7. 6
      mythril/analysis/modules/state_change_external_calls.py
  8. 6
      mythril/analysis/modules/suicide.py
  9. 2
      mythril/analysis/modules/unchecked_retval.py
  10. 2
      mythril/analysis/potential_issues.py
  11. 2
      mythril/laser/ethereum/call.py
  12. 22
      mythril/laser/ethereum/instructions.py
  13. 2
      mythril/laser/ethereum/plugins/implementations/mutation_pruner.py
  14. 3
      mythril/laser/ethereum/state/machine_state.py
  15. 8
      mythril/laser/ethereum/state/world_state.py
  16. 21
      mythril/laser/ethereum/svm.py
  17. 3
      mythril/laser/ethereum/transaction/concolic.py
  18. 6
      mythril/laser/ethereum/transaction/symbolic.py
  19. 2
      mythril/laser/ethereum/transaction/transaction_models.py
  20. 2
      tests/instructions/static_call_test.py

@ -102,10 +102,11 @@ class PredictableDependenceModule(DetectionModule):
if isinstance(annotation, PredictablePathAnnotation): if isinstance(annotation, PredictablePathAnnotation):
if annotation.add_constraints: if annotation.add_constraints:
constraints = ( constraints = (
state.mstate.constraints + annotation.add_constraints state.world_state.constraints
+ annotation.add_constraints
) )
else: else:
constraints = copy(state.mstate.constraints) constraints = copy(state.world_state.constraints)
try: try:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, constraints state, constraints
@ -192,7 +193,7 @@ class PredictableDependenceModule(DetectionModule):
# Why the second constraint? Because without it Z3 returns a solution where param overflows. # Why the second constraint? Because without it Z3 returns a solution where param overflows.
solver.get_model(state.mstate.constraints + constraint) solver.get_model(state.world_state.constraints + constraint) # type: ignore
state.annotate(OldBlockNumberUsedAnnotation(constraint)) state.annotate(OldBlockNumberUsedAnnotation(constraint))
except UnsatError: except UnsatError:

@ -77,7 +77,7 @@ class EtherThief(DetectionModule):
value = state.mstate.stack[-3] value = state.mstate.stack[-3]
target = state.mstate.stack[-2] target = state.mstate.stack[-2]
constraints = copy(state.mstate.constraints) constraints = copy(state.world_state.constraints)
""" """
Require that the current transaction is sent by the attacker and Require that the current transaction is sent by the attacker and

@ -58,7 +58,7 @@ class ReachableExceptionsModule(DetectionModule):
"Use `require()` for regular input checking." "Use `require()` for regular input checking."
) )
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state.mstate.constraints state, state.world_state.constraints
) )
issue = Issue( issue = Issue(
contract=state.environment.active_account.contract_name, contract=state.environment.active_account.contract_name,

@ -33,7 +33,7 @@ an informational issue.
def _is_precompile_call(global_state: GlobalState): def _is_precompile_call(global_state: GlobalState):
to = global_state.mstate.stack[-2] # type: BitVec to = global_state.mstate.stack[-2] # type: BitVec
constraints = copy(global_state.mstate.constraints) constraints = copy(global_state.world_state.constraints)
constraints += [ constraints += [
Or( Or(
to < symbol_factory.BitVecVal(1, 256), to < symbol_factory.BitVecVal(1, 256),
@ -88,7 +88,7 @@ class ExternalCalls(DetectionModule):
constraints = Constraints([UGT(gas, symbol_factory.BitVecVal(2300, 256))]) constraints = Constraints([UGT(gas, symbol_factory.BitVecVal(2300, 256))])
solver.get_transaction_sequence( solver.get_transaction_sequence(
state, constraints + state.mstate.constraints state, constraints + state.world_state.constraints
) )
# Check whether we can also set the callee address # Check whether we can also set the callee address
@ -101,7 +101,7 @@ class ExternalCalls(DetectionModule):
constraints.append(tx.caller == ACTORS.attacker) constraints.append(tx.caller == ACTORS.attacker)
solver.get_transaction_sequence( solver.get_transaction_sequence(
state, constraints + state.mstate.constraints state, constraints + state.world_state.constraints
) )
description_head = "A call to a user-supplied address is executed." description_head = "A call to a user-supplied address is executed."

@ -291,7 +291,9 @@ class IntegerOverflowUnderflowModule(DetectionModule):
if ostate not in self._ostates_satisfiable: if ostate not in self._ostates_satisfiable:
try: try:
constraints = ostate.mstate.constraints + [annotation.constraint] constraints = ostate.world_state.constraints + [
annotation.constraint
]
solver.get_model(constraints) solver.get_model(constraints)
self._ostates_satisfiable.add(ostate) self._ostates_satisfiable.add(ostate)
except: except:
@ -308,7 +310,7 @@ class IntegerOverflowUnderflowModule(DetectionModule):
try: try:
constraints = state.mstate.constraints + [annotation.constraint] constraints = state.world_state.constraints + [annotation.constraint]
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, constraints state, constraints
) )

@ -81,7 +81,7 @@ class MultipleSendsModule(DetectionModule):
for offset in call_offsets[1:]: for offset in call_offsets[1:]:
try: try:
transaction_sequence = get_transaction_sequence( transaction_sequence = get_transaction_sequence(
state, state.mstate.constraints state, state.world_state.constraints
) )
except UnsatError: except UnsatError:
continue continue

@ -56,7 +56,7 @@ class StateChangeCallsAnnotation(StateAnnotation):
try: try:
solver.get_transaction_sequence( solver.get_transaction_sequence(
global_state, constraints + global_state.mstate.constraints global_state, constraints + global_state.world_state.constraints
) )
except UnsatError: except UnsatError:
return None return None
@ -124,7 +124,7 @@ class StateChange(DetectionModule):
gas = global_state.mstate.stack[-1] gas = global_state.mstate.stack[-1]
to = global_state.mstate.stack[-2] to = global_state.mstate.stack[-2]
try: try:
constraints = copy(global_state.mstate.constraints) constraints = copy(global_state.world_state.constraints)
solver.get_model( solver.get_model(
constraints constraints
+ [ + [
@ -190,7 +190,7 @@ class StateChange(DetectionModule):
return value.value > 0 return value.value > 0
else: else:
constraints = copy(global_state.mstate.constraints) constraints = copy(global_state.world_state.constraints)
try: try:
solver.get_model( solver.get_model(

@ -73,7 +73,9 @@ class SuicideModule(DetectionModule):
try: try:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state,
state.mstate.constraints + constraints + [to == ACTORS.attacker], state.world_state.constraints
+ constraints
+ [to == ACTORS.attacker],
) )
description_tail = ( description_tail = (
"Anyone can kill this contract and withdraw its balance to an arbitrary " "Anyone can kill this contract and withdraw its balance to an arbitrary "
@ -81,7 +83,7 @@ class SuicideModule(DetectionModule):
) )
except UnsatError: except UnsatError:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state.mstate.constraints + constraints state, state.world_state.constraints + constraints
) )
description_tail = "Arbitrary senders can kill this contract." description_tail = "Arbitrary senders can kill this contract."
issue = Issue( issue = Issue(

@ -83,7 +83,7 @@ class UncheckedRetvalModule(DetectionModule):
for retval in retvals: for retval in retvals:
try: try:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state.mstate.constraints + [retval["retval"] == 0] state, state.world_state.constraints + [retval["retval"] == 0]
) )
except UnsatError: except UnsatError:
continue continue

@ -84,7 +84,7 @@ def check_potential_issues(state: GlobalState) -> None:
for potential_issue in annotation.potential_issues: for potential_issue in annotation.potential_issues:
try: try:
transaction_sequence = get_transaction_sequence( transaction_sequence = get_transaction_sequence(
state, state.mstate.constraints + potential_issue.constraints state, state.world_state.constraints + potential_issue.constraints
) )
except UnsatError: except UnsatError:
continue continue

@ -270,5 +270,5 @@ def native_call(
"retval_" + str(global_state.get_current_instruction()["address"]), 256 "retval_" + str(global_state.get_current_instruction()["address"]), 256
) )
global_state.mstate.stack.append(retval) global_state.mstate.stack.append(retval)
global_state.node.constraints.append(retval == 1) global_state.world_state.constraints.append(retval == 1)
return [global_state] return [global_state]

@ -80,7 +80,7 @@ def transfer_ether(
""" """
value = value if isinstance(value, BitVec) else symbol_factory.BitVecVal(value, 256) value = value if isinstance(value, BitVec) else symbol_factory.BitVecVal(value, 256)
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
UGE(global_state.world_state.balances[sender], value) UGE(global_state.world_state.balances[sender], value)
) )
global_state.world_state.balances[receiver] += value global_state.world_state.balances[receiver] += value
@ -948,7 +948,7 @@ class Instruction:
no_of_bytes += calldata.size no_of_bytes += calldata.size
else: else:
no_of_bytes += 0x200 # space for 16 32-byte arguments no_of_bytes += 0x200 # space for 16 32-byte arguments
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
global_state.environment.calldata.size == no_of_bytes global_state.environment.calldata.size == no_of_bytes
) )
@ -1010,7 +1010,7 @@ class Instruction:
result, condition = keccak_function_manager.create_keccak(data) result, condition = keccak_function_manager.create_keccak(data)
state.stack.append(result) state.stack.append(result)
state.constraints.append(condition) global_state.world_state.constraints.append(condition)
return [global_state] return [global_state]
@ -1563,7 +1563,7 @@ class Instruction:
# manually increment PC # manually increment PC
new_state.mstate.depth += 1 new_state.mstate.depth += 1
new_state.mstate.pc += 1 new_state.mstate.pc += 1
new_state.mstate.constraints.append(negated) new_state.world_state.constraints.append(negated)
states.append(new_state) states.append(new_state)
else: else:
log.debug("Pruned unreachable states.") log.debug("Pruned unreachable states.")
@ -1589,7 +1589,7 @@ class Instruction:
# manually set PC to destination # manually set PC to destination
new_state.mstate.pc = index new_state.mstate.pc = index
new_state.mstate.depth += 1 new_state.mstate.depth += 1
new_state.mstate.constraints.append(condi) new_state.world_state.constraints.append(condi)
states.append(new_state) states.append(new_state)
else: else:
log.debug("Pruned unreachable states.") log.debug("Pruned unreachable states.")
@ -1898,7 +1898,7 @@ class Instruction:
) )
if isinstance(value, BitVec): if isinstance(value, BitVec):
if value.symbolic: if value.symbolic:
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
value == symbol_factory.BitVecVal(0, 256) value == symbol_factory.BitVecVal(0, 256)
) )
elif value.value > 0: elif value.value > 0:
@ -2026,7 +2026,7 @@ class Instruction:
"retval_" + str(instr["address"]), 256 "retval_" + str(instr["address"]), 256
) )
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 0) global_state.world_state.constraints.append(return_value == 0)
return [global_state] return [global_state]
try: try:
@ -2058,7 +2058,7 @@ class Instruction:
# Put return value on stack # Put return value on stack
return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256)
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 1) global_state.world_state.constraints.append(return_value == 1)
return [global_state] return [global_state]
@StateTransition() @StateTransition()
@ -2154,7 +2154,7 @@ class Instruction:
"retval_" + str(instr["address"]), 256 "retval_" + str(instr["address"]), 256
) )
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 0) global_state.world_state.constraints.append(return_value == 0)
return [global_state] return [global_state]
try: try:
@ -2186,7 +2186,7 @@ class Instruction:
# Put return value on stack # Put return value on stack
return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256)
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 1) global_state.world_state.constraints.append(return_value == 1)
return [global_state] return [global_state]
@StateTransition() @StateTransition()
@ -2318,6 +2318,6 @@ class Instruction:
# Put return value on stack # Put return value on stack
return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256)
global_state.mstate.stack.append(return_value) global_state.mstate.stack.append(return_value)
global_state.mstate.constraints.append(return_value == 1) global_state.world_state.constraints.append(return_value == 1)
return [global_state] return [global_state]

@ -45,7 +45,7 @@ class MutationPruner(LaserPlugin):
@symbolic_vm.laser_hook("add_world_state") @symbolic_vm.laser_hook("add_world_state")
def world_state_filter_hook(global_state: GlobalState): def world_state_filter_hook(global_state: GlobalState):
if And( if And(
*global_state.mstate.constraints[:] *global_state.world_state.constraints[:]
+ [ + [
global_state.environment.callvalue global_state.environment.callvalue
> symbol_factory.BitVecVal(0, 256) > symbol_factory.BitVecVal(0, 256)

@ -11,7 +11,6 @@ from mythril.laser.ethereum.evm_exceptions import (
StackUnderflowException, StackUnderflowException,
OutOfGasException, OutOfGasException,
) )
from mythril.laser.ethereum.state.constraints import Constraints
from mythril.laser.ethereum.state.memory import Memory from mythril.laser.ethereum.state.memory import Memory
@ -115,7 +114,6 @@ class MachineState:
self.gas_limit = gas_limit self.gas_limit = gas_limit
self.min_gas_used = min_gas_used # lower gas usage bound self.min_gas_used = min_gas_used # lower gas usage bound
self.max_gas_used = max_gas_used # upper gas usage bound self.max_gas_used = max_gas_used # upper gas usage bound
self.constraints = constraints or Constraints()
self.depth = depth self.depth = depth
self.prev_pc = prev_pc # holds context of current pc self.prev_pc = prev_pc # holds context of current pc
@ -216,7 +214,6 @@ class MachineState:
pc=self._pc, pc=self._pc,
stack=copy(self.stack), stack=copy(self.stack),
memory=copy(self.memory), memory=copy(self.memory),
constraints=copy(self.constraints),
depth=self.depth, depth=self.depth,
prev_pc=self.prev_pc, prev_pc=self.prev_pc,
) )

@ -8,6 +8,7 @@ from mythril.laser.smt import symbol_factory, Array, BitVec
from ethereum.utils import mk_contract_address from ethereum.utils import mk_contract_address
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.ethereum.state.constraints import Constraints
if TYPE_CHECKING: if TYPE_CHECKING:
from mythril.laser.ethereum.cfg import Node from mythril.laser.ethereum.cfg import Node
@ -18,7 +19,10 @@ class WorldState:
yellow paper.""" yellow paper."""
def __init__( def __init__(
self, transaction_sequence=None, annotations: List[StateAnnotation] = None self,
transaction_sequence=None,
annotations: List[StateAnnotation] = None,
constraints: Constraints = None,
) -> None: ) -> None:
"""Constructor for the world state. Initializes the accounts record. """Constructor for the world state. Initializes the accounts record.
@ -28,6 +32,7 @@ class WorldState:
self._accounts = {} # type: Dict[int, Account] self._accounts = {} # type: Dict[int, Account]
self.balances = Array("balance", 256, 256) self.balances = Array("balance", 256, 256)
self.starting_balances = copy(self.balances) self.starting_balances = copy(self.balances)
self.constraints = constraints or Constraints()
self.node = None # type: Optional['Node'] self.node = None # type: Optional['Node']
self.transaction_sequence = transaction_sequence or [] self.transaction_sequence = transaction_sequence or []
@ -65,6 +70,7 @@ class WorldState:
for account in self._accounts.values(): for account in self._accounts.values():
new_world_state.put_account(copy(account)) new_world_state.put_account(copy(account))
new_world_state.node = self.node new_world_state.node = self.node
new_world_state.constraints = copy(self.constraints)
return new_world_state return new_world_state
def accounts_exist_or_load(self, addr: str, dynamic_loader: DynLoader) -> str: def accounts_exist_or_load(self, addr: str, dynamic_loader: DynLoader) -> str:

@ -249,7 +249,9 @@ class LaserEVM:
log.debug("Encountered unimplemented instruction") log.debug("Encountered unimplemented instruction")
continue continue
new_states = [ new_states = [
state for state in new_states if state.mstate.constraints.is_possible state
for state in new_states
if state.world_state.constraints.is_possible
] ]
self.manage_cfg(op_code, new_states) # TODO: What about op_code is None? self.manage_cfg(op_code, new_states) # TODO: What about op_code is None?
@ -345,8 +347,8 @@ class LaserEVM:
global_state.transaction_stack global_state.transaction_stack
) + [(start_signal.transaction, global_state)] ) + [(start_signal.transaction, global_state)]
new_global_state.node = global_state.node new_global_state.node = global_state.node
new_global_state.mstate.constraints = ( new_global_state.world_state.constraints = (
start_signal.global_state.mstate.constraints start_signal.global_state.world_state.constraints
) )
log.debug("Starting new transaction %s", start_signal.transaction) log.debug("Starting new transaction %s", start_signal.transaction)
@ -367,9 +369,6 @@ class LaserEVM:
) and not end_signal.revert: ) and not end_signal.revert:
check_potential_issues(global_state) check_potential_issues(global_state)
end_signal.global_state.world_state.node = global_state.node end_signal.global_state.world_state.node = global_state.node
end_signal.global_state.world_state.node.constraints += (
end_signal.global_state.mstate.constraints
)
self._add_world_state(end_signal.global_state) self._add_world_state(end_signal.global_state)
new_global_states = [] new_global_states = []
@ -417,7 +416,9 @@ class LaserEVM:
:return: :return:
""" """
return_global_state.mstate.constraints += global_state.mstate.constraints return_global_state.world_state.constraints += (
global_state.world_state.constraints
)
# Resume execution of the transaction initializing instruction # Resume execution of the transaction initializing instruction
op_code = return_global_state.environment.code.instruction_list[ op_code = return_global_state.environment.code.instruction_list[
return_global_state.mstate.pc return_global_state.mstate.pc
@ -465,12 +466,12 @@ class LaserEVM:
assert len(new_states) <= 2 assert len(new_states) <= 2
for state in new_states: for state in new_states:
self._new_node_state( self._new_node_state(
state, JumpType.CONDITIONAL, state.mstate.constraints[-1] state, JumpType.CONDITIONAL, state.world_state.constraints[-1]
) )
elif opcode in ("SLOAD", "SSTORE") and len(new_states) > 1: elif opcode in ("SLOAD", "SSTORE") and len(new_states) > 1:
for state in new_states: for state in new_states:
self._new_node_state( self._new_node_state(
state, JumpType.CONDITIONAL, state.mstate.constraints[-1] state, JumpType.CONDITIONAL, state.world_state.constraints[-1]
) )
elif opcode == "RETURN": elif opcode == "RETURN":
for state in new_states: for state in new_states:
@ -491,7 +492,7 @@ class LaserEVM:
new_node = Node(state.environment.active_account.contract_name) new_node = Node(state.environment.active_account.contract_name)
old_node = state.node old_node = state.node
state.node = new_node state.node = new_node
new_node.constraints = state.mstate.constraints new_node.constraints = state.world_state.constraints
if self.requires_statespace: if self.requires_statespace:
self.nodes[new_node.uid] = new_node self.nodes[new_node.uid] = new_node
self.edges.append( self.edges.append(

@ -88,8 +88,7 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None:
condition=None, condition=None,
) )
) )
global_state.mstate.constraints += transaction.world_state.node.constraints new_node.constraints = global_state.world_state.constraints
new_node.constraints = global_state.mstate.constraints
global_state.world_state.transaction_sequence.append(transaction) global_state.world_state.transaction_sequence.append(transaction)
global_state.node = new_node global_state.node = new_node

@ -162,7 +162,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) -
global_state = transaction.initial_global_state() global_state = transaction.initial_global_state()
global_state.transaction_stack.append((transaction, None)) global_state.transaction_stack.append((transaction, None))
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
Or(*[transaction.caller == actor for actor in ACTORS.addresses.values()]) Or(*[transaction.caller == actor for actor in ACTORS.addresses.values()])
) )
@ -183,9 +183,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) -
condition=None, condition=None,
) )
) )
new_node.constraints = global_state.world_state.constraints
global_state.mstate.constraints += transaction.world_state.node.constraints
new_node.constraints = global_state.mstate.constraints
global_state.world_state.transaction_sequence.append(transaction) global_state.world_state.transaction_sequence.append(transaction)
global_state.node = new_node global_state.node = new_node

@ -126,7 +126,7 @@ class BaseTransaction:
else symbol_factory.BitVecVal(environment.callvalue, 256) else symbol_factory.BitVecVal(environment.callvalue, 256)
) )
global_state.mstate.constraints.append( global_state.world_state.constraints.append(
UGE(global_state.world_state.balances[sender], value) UGE(global_state.world_state.balances[sender], value)
) )
global_state.world_state.balances[receiver] += value global_state.world_state.balances[receiver] += value

@ -121,4 +121,4 @@ def test_staticness_call_symbolic(f1):
instruction.evaluate(state) instruction.evaluate(state)
assert ts.value.transaction.static assert ts.value.transaction.static
assert ts.value.global_state.mstate.constraints[-1] == (call_value == 0) assert ts.value.global_state.world_state.constraints[-1] == (call_value == 0)

Loading…
Cancel
Save