diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index c2c01b4d..d1abbf35 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -1003,7 +1003,14 @@ class Instruction: @StateTransition() def revert_(self, global_state): - return [] + state = global_state.mstate + offset, length = state.stack.pop(), state.stack.pop() + return_data = [global_state.new_bitvec("return_data", 256)] + try: + return_data = state.memory[util.get_concrete_int(offset):util.get_concrete_int(offset + length)] + except AttributeError: + logging.debug("Return with symbolic length or offset. Not supported") + global_state.current_transaction.end(global_state, return_data=return_data, revert=True) @StateTransition() def assert_fail_(self, global_state): diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 6c48dabe..3a173515 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -149,30 +149,31 @@ class LaserEVM: new_global_states = self._end_message_call(return_global_state, global_state, revert_changes=True, return_data=None) - except TransactionStartSignal as e: + except TransactionStartSignal as start_signal: # Setup new global state - new_global_state = e.transaction.initial_global_state() + new_global_state = start_signal.transaction.initial_global_state() - new_global_state.transaction_stack = copy(global_state.transaction_stack) + [(e.transaction, global_state)] + new_global_state.transaction_stack = copy(global_state.transaction_stack) + [(start_signal.transaction, global_state)] new_global_state.node = global_state.node new_global_state.mstate.constraints = global_state.mstate.constraints return [new_global_state], op_code - except TransactionEndSignal as e: - transaction, return_global_state = e.global_state.transaction_stack.pop() + except TransactionEndSignal as end_signal: + transaction, return_global_state = end_signal.global_state.transaction_stack.pop() if return_global_state is None: - if not isinstance(transaction, ContractCreationTransaction) or transaction.return_data: - e.global_state.world_state.node = global_state.node - self.open_states.append(e.global_state.world_state) + if (not isinstance(transaction, ContractCreationTransaction) or transaction.return_data) and not end_signal.revert: + end_signal.global_state.world_state.node = global_state.node + self.open_states.append(end_signal.global_state.world_state) new_global_states = [] else: # First execute the post hook for the transaction ending instruction - self._execute_post_hook(op_code, [e.global_state]) + self._execute_post_hook(op_code, [end_signal.global_state]) new_global_states = self._end_message_call(return_global_state, global_state, - revert_changes=False, return_data=transaction.return_data) + revert_changes=False or end_signal.revert, + return_data=transaction.return_data) self._execute_post_hook(op_code, new_global_states) diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index fa60599d..35826bcd 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -12,10 +12,12 @@ def get_next_transaction_id(): _next_transaction_id += 1 return _next_transaction_id + class TransactionEndSignal(Exception): """ Exception raised when a transaction is finalized""" - def __init__(self, global_state): + def __init__(self, global_state, revert=False): self.global_state = global_state + self.revert = revert class TransactionStartSignal(Exception): @@ -70,9 +72,9 @@ class MessageCallTransaction: return global_state - def end(self, global_state, return_data=None): + def end(self, global_state, return_data=None, revert=False): self.return_data = return_data - raise TransactionEndSignal(global_state) + raise TransactionEndSignal(global_state, revert) class ContractCreationTransaction: @@ -125,7 +127,7 @@ class ContractCreationTransaction: return global_state - def end(self, global_state, return_data=None): + def end(self, global_state, return_data=None, revert=False): if not all([isinstance(element, int) for element in return_data]): self.return_data = None @@ -136,4 +138,6 @@ class ContractCreationTransaction: global_state.environment.active_account.code = Disassembly(contract_code) self.return_data = global_state.environment.active_account.address - raise TransactionEndSignal(global_state) + raise TransactionEndSignal(global_state, revert=revert) + +