diff --git a/mythril/laser/ethereum/evm_exceptions.py b/mythril/laser/ethereum/evm_exceptions.py index fd99aed5..7e2a7b8f 100644 --- a/mythril/laser/ethereum/evm_exceptions.py +++ b/mythril/laser/ethereum/evm_exceptions.py @@ -37,6 +37,12 @@ class OutOfGasException(VmException): pass +class WriteProtection(VmException): + """A VM exception denoting that a write operation is executed on a write protected environment""" + + pass + + class ProgramCounterException(VmException): """A VM exception denoting an invalid PC value (No stop instruction is reached).""" diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 593be817..4bd8436f 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -38,6 +38,7 @@ from mythril.laser.ethereum.evm_exceptions import ( InvalidJumpDestination, InvalidInstruction, OutOfGasException, + WriteProtection, ) from mythril.laser.ethereum.gas import OPCODE_GAS from mythril.laser.ethereum.state.global_state import GlobalState @@ -88,14 +89,18 @@ class StateTransition(object): if `increment_pc=True`. """ - def __init__(self, increment_pc=True, enable_gas=True): + def __init__( + self, increment_pc=True, enable_gas=True, is_state_mutation_instruction=False + ): """ :param increment_pc: :param enable_gas: + :param is_state_mutation_instruction: The function mutates state """ self.increment_pc = increment_pc self.enable_gas = enable_gas + self.is_state_mutation_instruction = is_state_mutation_instruction @staticmethod def call_on_state_copy(func: Callable, func_obj: "Instruction", state: GlobalState): @@ -165,6 +170,13 @@ class StateTransition(object): :param global_state: :return: """ + if self.is_state_mutation_instruction and global_state.environment.static: + raise WriteProtection( + "The function {} cannot be executed in a static call".format( + func.__name__[:-1] + ) + ) + new_global_states = self.call_on_state_copy(func, func_obj, global_state) new_global_states = [ self.accumulate_gas(state) for state in new_global_states @@ -1435,7 +1447,7 @@ class Instruction: state.stack.append(global_state.environment.active_account.storage[index]) return [global_state] - @StateTransition() + @StateTransition(is_state_mutation_instruction=True) def sstore_(self, global_state: GlobalState) -> List[GlobalState]: """ @@ -1602,7 +1614,7 @@ class Instruction: global_state.mstate.stack.append(global_state.new_bitvec("gas", 256)) return [global_state] - @StateTransition() + @StateTransition(is_state_mutation_instruction=True) def log_(self, global_state: GlobalState) -> List[GlobalState]: """ @@ -1617,7 +1629,7 @@ class Instruction: # Not supported return [global_state] - @StateTransition() + @StateTransition(is_state_mutation_instruction=True) def create_(self, global_state: GlobalState) -> List[GlobalState]: """ @@ -1625,13 +1637,14 @@ class Instruction: :return: """ # TODO: implement me + state = global_state.mstate state.stack.pop(), state.stack.pop(), state.stack.pop() # Not supported state.stack.append(0) return [global_state] - @StateTransition() + @StateTransition(is_state_mutation_instruction=True) def create2_(self, global_state: GlobalState) -> List[GlobalState]: """ @@ -1667,7 +1680,7 @@ class Instruction: return_data = state.memory[offset : offset + length] global_state.current_transaction.end(global_state, return_data) - @StateTransition() + @StateTransition(is_state_mutation_instruction=True) def suicide_(self, global_state: GlobalState): """ @@ -1772,6 +1785,21 @@ class Instruction: ) return [global_state] + if environment.static: + if isinstance(value, int) and value > 0: + raise WriteProtection( + "Cannot call with non zero value in a static call" + ) + if isinstance(value, BitVec): + if value.symbolic: + global_state.mstate.constraints.append( + value == symbol_factory.BitVecVal(0, 256) + ) + elif value.value > 0: + raise WriteProtection( + "Cannot call with non zero value in a static call" + ) + native_result = native_call( global_state, callee_address, call_data, memory_out_offset, memory_out_size ) @@ -1786,8 +1814,9 @@ class Instruction: callee_account=callee_account, call_data=call_data, call_value=value, + static=environment.static, ) - raise TransactionStartSignal(transaction, self.op_code) + raise TransactionStartSignal(transaction, self.op_code, global_state) @StateTransition() def call_post(self, global_state: GlobalState) -> List[GlobalState]: @@ -1796,65 +1825,8 @@ class Instruction: :param global_state: :return: """ - instr = global_state.get_current_instruction() - - try: - callee_address, callee_account, call_data, value, gas, memory_out_offset, memory_out_size = get_call_parameters( - global_state, self.dynamic_loader, True - ) - except ValueError as e: - log.debug( - "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format( - e - ) - ) - global_state.mstate.stack.append( - global_state.new_bitvec("retval_" + str(instr["address"]), 256) - ) - return [global_state] - - if global_state.last_return_data is None: - # Put return value on stack - return_value = global_state.new_bitvec( - "retval_" + str(instr["address"]), 256 - ) - global_state.mstate.stack.append(return_value) - global_state.mstate.constraints.append(return_value == 0) - return [global_state] - - try: - memory_out_offset = ( - util.get_concrete_int(memory_out_offset) - if isinstance(memory_out_offset, Expression) - else memory_out_offset - ) - memory_out_size = ( - util.get_concrete_int(memory_out_size) - if isinstance(memory_out_size, Expression) - else memory_out_size - ) - except TypeError: - global_state.mstate.stack.append( - global_state.new_bitvec("retval_" + str(instr["address"]), 256) - ) - return [global_state] - - # Copy memory - global_state.mstate.mem_extend( - memory_out_offset, min(memory_out_size, len(global_state.last_return_data)) - ) - for i in range(min(memory_out_size, len(global_state.last_return_data))): - global_state.mstate.memory[ - i + memory_out_offset - ] = global_state.last_return_data[i] - - # Put return value on stack - return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) - global_state.mstate.stack.append(return_value) - global_state.mstate.constraints.append(return_value == 1) - - return [global_state] + return self.post_handler(global_state, function_name="call") @StateTransition() def callcode_(self, global_state: GlobalState) -> List[GlobalState]: @@ -1870,7 +1842,6 @@ class Instruction: callee_address, callee_account, call_data, value, gas, _, _ = get_call_parameters( global_state, self.dynamic_loader, True ) - if callee_account is not None and callee_account.code.bytecode == "": log.debug("The call is related to ether transfer between accounts") sender = global_state.environment.active_account.address @@ -1903,8 +1874,9 @@ class Instruction: callee_account=environment.active_account, call_data=call_data, call_value=value, + static=environment.static, ) - raise TransactionStartSignal(transaction, self.op_code) + raise TransactionStartSignal(transaction, self.op_code, global_state) @StateTransition() def callcode_post(self, global_state: GlobalState) -> List[GlobalState]: @@ -1988,7 +1960,7 @@ class Instruction: if callee_account is not None and callee_account.code.bytecode == "": log.debug("The call is related to ether transfer between accounts") - sender = environment.active_account.address + sender = global_state.environment.active_account.address receiver = callee_account.address transfer_ether(global_state, sender, receiver, value) @@ -2018,8 +1990,9 @@ class Instruction: callee_account=environment.active_account, call_data=call_data, call_value=environment.callvalue, + static=environment.static, ) - raise TransactionStartSignal(transaction, self.op_code) + raise TransactionStartSignal(transaction, self.op_code, global_state) @StateTransition() def delegatecall_post(self, global_state: GlobalState) -> List[GlobalState]: @@ -2051,6 +2024,7 @@ class Instruction: "retval_" + str(instr["address"]), 256 ) global_state.mstate.stack.append(return_value) + global_state.mstate.constraints.append(return_value == 0) return [global_state] try: @@ -2092,8 +2066,8 @@ class Instruction: :param global_state: :return: """ - # TODO: implement me instr = global_state.get_current_instruction() + environment = global_state.environment try: callee_address, callee_account, call_data, value, gas, memory_out_offset, memory_out_size = get_call_parameters( global_state, self.dynamic_loader @@ -2101,7 +2075,7 @@ class Instruction: if callee_account is not None and callee_account.code.bytecode == "": log.debug("The call is related to ether transfer between accounts") - sender = global_state.environment.active_account.address + sender = environment.active_account.address receiver = callee_account.address transfer_ether(global_state, sender, receiver, value) @@ -2124,10 +2098,82 @@ class Instruction: native_result = native_call( global_state, callee_address, call_data, memory_out_offset, memory_out_size ) + if native_result: return native_result - global_state.mstate.stack.append( - global_state.new_bitvec("retval_" + str(instr["address"]), 256) + transaction = MessageCallTransaction( + world_state=global_state.world_state, + gas_price=environment.gasprice, + gas_limit=gas, + origin=environment.origin, + code=callee_account.code, + caller=environment.address, + callee_account=environment.active_account, + call_data=call_data, + call_value=value, + static=True, ) + raise TransactionStartSignal(transaction, self.op_code, global_state) + + def staticcall_post(self, global_state: GlobalState) -> List[GlobalState]: + return self.post_handler(global_state, function_name="staticcall") + + def post_handler(self, global_state, function_name: str): + instr = global_state.get_current_instruction() + + try: + callee_address, callee_account, call_data, value, gas, memory_out_offset, memory_out_size = get_call_parameters( + global_state, self.dynamic_loader, True + ) + except ValueError as e: + log.debug( + "Could not determine required parameters for {}, putting fresh symbol on the stack. \n{}".format( + function_name, e + ) + ) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) + ) + return [global_state] + + if global_state.last_return_data is None: + # Put return value on stack + return_value = global_state.new_bitvec( + "retval_" + str(instr["address"]), 256 + ) + global_state.mstate.stack.append(return_value) + return [global_state] + + try: + memory_out_offset = ( + util.get_concrete_int(memory_out_offset) + if isinstance(memory_out_offset, Expression) + else memory_out_offset + ) + memory_out_size = ( + util.get_concrete_int(memory_out_size) + if isinstance(memory_out_size, Expression) + else memory_out_size + ) + except TypeError: + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) + ) + return [global_state] + + # Copy memory + global_state.mstate.mem_extend( + memory_out_offset, min(memory_out_size, len(global_state.last_return_data)) + ) + for i in range(min(memory_out_size, len(global_state.last_return_data))): + global_state.mstate.memory[ + i + memory_out_offset + ] = global_state.last_return_data[i] + + # Put return value on stack + return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) + global_state.mstate.stack.append(return_value) + global_state.mstate.constraints.append(return_value == 1) + return [global_state] diff --git a/mythril/laser/ethereum/plugins/implementations/mutation_pruner.py b/mythril/laser/ethereum/plugins/implementations/mutation_pruner.py index 86755610..fa6aef58 100644 --- a/mythril/laser/ethereum/plugins/implementations/mutation_pruner.py +++ b/mythril/laser/ethereum/plugins/implementations/mutation_pruner.py @@ -15,7 +15,10 @@ class MutationAnnotation(StateAnnotation): This is the annotation used by the MutationPruner plugin to record mutations """ - pass + @property + def persist_to_world_state(self): + # This should persist among calls, and be but as a world state annotation. + return True class MutationPruner(LaserPlugin): @@ -44,10 +47,16 @@ class MutationPruner(LaserPlugin): @symbolic_vm.pre_hook("SSTORE") def sstore_mutator_hook(global_state: GlobalState): global_state.annotate(MutationAnnotation()) + assert len( + list(global_state.world_state.get_annotations(MutationAnnotation)) + ) @symbolic_vm.pre_hook("CALL") def call_mutator_hook(global_state: GlobalState): global_state.annotate(MutationAnnotation()) + assert len( + list(global_state.world_state.get_annotations(MutationAnnotation)) + ) @symbolic_vm.laser_hook("add_world_state") def world_state_filter_hook(global_state: GlobalState): @@ -63,5 +72,8 @@ class MutationPruner(LaserPlugin): global_state.current_transaction, ContractCreationTransaction ): return - if len(list(global_state.get_annotations(MutationAnnotation))) == 0: + if ( + len(list(global_state.world_state.get_annotations(MutationAnnotation))) + == 0 + ): raise PluginSkipWorldState diff --git a/mythril/laser/ethereum/state/annotation.py b/mythril/laser/ethereum/state/annotation.py index 6a321776..0f25a311 100644 --- a/mythril/laser/ethereum/state/annotation.py +++ b/mythril/laser/ethereum/state/annotation.py @@ -12,6 +12,8 @@ class StateAnnotation: traverse the state space themselves. """ + # TODO: Remove this? It seems to be used only in the MutationPruner, and + # we could simply use world state annotations if we want them to be persisted. @property def persist_to_world_state(self) -> bool: """If this function returns true then laser will also annotate the diff --git a/mythril/laser/ethereum/state/environment.py b/mythril/laser/ethereum/state/environment.py index a3300f9a..4cfb51e0 100644 --- a/mythril/laser/ethereum/state/environment.py +++ b/mythril/laser/ethereum/state/environment.py @@ -22,6 +22,7 @@ class Environment: callvalue: ExprRef, origin: ExprRef, code=None, + static=False, ) -> None: """ @@ -32,7 +33,7 @@ class Environment: :param callvalue: :param origin: :param code: - :param calldata_type: + :param static: Makes the environment static. """ # Metadata @@ -50,6 +51,7 @@ class Environment: self.gasprice = gasprice self.origin = origin self.callvalue = callvalue + self.static = static def __str__(self) -> str: """ diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 88b362c8..8ef03531 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -324,7 +324,9 @@ class LaserEVM: global_state.transaction_stack ) + [(start_signal.transaction, global_state)] new_global_state.node = global_state.node - new_global_state.mstate.constraints = global_state.mstate.constraints + new_global_state.mstate.constraints = ( + start_signal.global_state.mstate.constraints + ) log.debug("Starting new transaction %s", start_signal.transaction) diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index 97f2f694..31ce0271 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -45,9 +45,11 @@ class TransactionStartSignal(Exception): self, transaction: Union["MessageCallTransaction", "ContractCreationTransaction"], op_code: str, + global_state: GlobalState, ) -> None: self.transaction = transaction self.op_code = op_code + self.global_state = global_state class BaseTransaction: @@ -66,6 +68,7 @@ class BaseTransaction: code=None, call_value=None, init_call_data=True, + static=False, ) -> None: assert isinstance(world_state, WorldState) self.world_state = world_state @@ -101,7 +104,7 @@ class BaseTransaction: if call_value is not None else symbol_factory.BitVecSym("callvalue{}".format(identifier), 256) ) - + self.static = static self.return_data = None # type: str def initial_global_state_from_environment(self, environment, active_function): @@ -159,6 +162,7 @@ class MessageCallTransaction(BaseTransaction): self.call_value, self.origin, code=self.code or self.callee_account.code, + static=self.static, ) return super().initial_global_state_from_environment( environment, active_function="fallback" diff --git a/tests/instructions/static_call_test.py b/tests/instructions/static_call_test.py new file mode 100644 index 00000000..b9bcc26c --- /dev/null +++ b/tests/instructions/static_call_test.py @@ -0,0 +1,124 @@ +import pytest +from mock import patch + +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.smt import symbol_factory +from mythril.laser.ethereum.state.environment import Environment +from mythril.laser.ethereum.state.account import Account +from mythril.laser.ethereum.state.machine_state import MachineState +from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.instructions import Instruction +from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction +from mythril.laser.ethereum.call import SymbolicCalldata +from mythril.laser.ethereum.transaction import TransactionStartSignal + +from mythril.laser.ethereum.evm_exceptions import WriteProtection + + +def get_global_state(): + active_account = Account("0x0", code=Disassembly("60606040")) + environment = Environment( + active_account, None, SymbolicCalldata("2"), None, None, None + ) + world_state = WorldState() + world_state.put_account(active_account) + state = GlobalState(world_state, environment, None, MachineState(gas_limit=8000000)) + state.transaction_stack.append( + (MessageCallTransaction(world_state=world_state, gas_limit=8000000), None) + ) + return state + + +@patch( + "mythril.laser.ethereum.instructions.get_call_parameters", + return_value=( + "0", + Account(code=Disassembly(code="0x00"), address="0x19"), + 0, + 0, + 0, + 0, + 0, + ), +) +def test_staticcall(f1): + # Arrange + state = get_global_state() + state.mstate.stack = [10, 10, 10, 10, 10, 10, 10, 10, 0] + instruction = Instruction("staticcall", dynamic_loader=None) + + # Act and Assert + with pytest.raises(TransactionStartSignal) as ts: + instruction.evaluate(state) + assert ts.value.transaction.static + assert ts.value.transaction.initial_global_state().environment.static + + +test_data = ( + "suicide", + "create", + "create2", + "log0", + "log1", + "log2", + "log3", + "log4", + "sstore", +) + + +@pytest.mark.parametrize("input", test_data) +def test_staticness(input): + # Arrange + state = get_global_state() + state.environment.static = True + state.mstate.stack = [] + instruction = Instruction(input, dynamic_loader=None) + + # Act and Assert + with pytest.raises(WriteProtection): + instruction.evaluate(state) + + +test_data_call = ((0, True), (100, False)) + + +@pytest.mark.parametrize("input, success", test_data_call) +@patch("mythril.laser.ethereum.instructions.get_call_parameters") +def test_staticness_call_concrete(f1, input, success): + # Arrange + state = get_global_state() + state.environment.static = True + state.mstate.stack = [] + code = Disassembly(code="616263") + f1.return_value = ("0", Account(code=code, address="0x19"), 0, input, 0, 0, 0) + instruction = Instruction("call", dynamic_loader=None) + + # Act and Assert + if success: + with pytest.raises(TransactionStartSignal) as ts: + instruction.evaluate(state) + assert ts.value.transaction.static + else: + with pytest.raises(WriteProtection): + instruction.evaluate(state) + + +@patch("mythril.laser.ethereum.instructions.get_call_parameters") +def test_staticness_call_symbolic(f1): + # Arrange + state = get_global_state() + state.environment.static = True + state.mstate.stack = [] + call_value = symbol_factory.BitVecSym("x", 256) + code = Disassembly(code="616263") + f1.return_value = ("0", Account(code=code, address="0x19"), 0, call_value, 0, 0, 0) + instruction = Instruction("call", dynamic_loader=None) + + # Act and Assert + with pytest.raises(TransactionStartSignal) as ts: + instruction.evaluate(state) + + assert ts.value.transaction.static + assert ts.value.global_state.mstate.constraints[-1] == (call_value == 0)