diff --git a/mythril/analysis/call_helpers.py b/mythril/analysis/call_helpers.py index 270ff5af..cb6f55d7 100644 --- a/mythril/analysis/call_helpers.py +++ b/mythril/analysis/call_helpers.py @@ -18,7 +18,7 @@ def get_call_from_state(state: GlobalState) -> Union[Call, None]: op = instruction["opcode"] stack = state.mstate.stack - if op in ("CALL", "CALLCODE", "STATICCALL"): + if op in ("CALL", "CALLCODE"): gas, to, value, meminstart, meminsz, memoutstart, memoutsz = ( get_variable(stack[-1]), get_variable(stack[-2]), diff --git a/mythril/analysis/modules/delegatecall.py b/mythril/analysis/modules/delegatecall.py index 4b97228b..aa0cae6a 100644 --- a/mythril/analysis/modules/delegatecall.py +++ b/mythril/analysis/modules/delegatecall.py @@ -1,19 +1,18 @@ """This module contains the detection code for insecure delegate call usage.""" -import json import logging -from copy import copy -from typing import List, cast, Dict +from typing import List -from mythril.analysis import solver +from mythril.analysis.potential_issues import ( + get_potential_issues_annotation, + PotentialIssue, +) from mythril.analysis.swc_data import DELEGATECALL_TO_UNTRUSTED_CONTRACT from mythril.laser.ethereum.transaction.symbolic import ATTACKER_ADDRESS from mythril.laser.ethereum.transaction.transaction_models import ( ContractCreationTransaction, ) -from mythril.analysis.report import Issue from mythril.analysis.modules.base import DetectionModule from mythril.exceptions import UnsatError -from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.smt import symbol_factory, UGT @@ -41,19 +40,16 @@ class DelegateCallModule(DetectionModule): """ if state.get_current_instruction()["address"] in self.cache: return - issues = self._analyze_state(state) - for issue in issues: - self.cache.add(issue.address) - self.issues.extend(issues) + potential_issues = self._analyze_state(state) + + annotation = get_potential_issues_annotation(state) + annotation.potential_issues.extend(potential_issues) - @staticmethod - def _analyze_state(state: GlobalState) -> List[Issue]: + def _analyze_state(self, state: GlobalState) -> List[PotentialIssue]: """ :param state: the current state :return: returns the issues for that corresponding state """ - op_code = state.get_current_instruction()["opcode"] - gas = state.mstate.stack[-1] to = state.mstate.stack[-2] @@ -67,16 +63,14 @@ class DelegateCallModule(DetectionModule): constraints.append(tx.caller == ATTACKER_ADDRESS) try: - transaction_sequence = solver.get_transaction_sequence( - state, state.mstate.constraints + constraints - ) - address = state.get_current_instruction()["address"] + logging.debug( - "[DELEGATECALL] Detected delegatecall to a user-supplied address : {}".format( + "[DELEGATECALL] Detected potential delegatecall to a user-supplied address : {}".format( address ) ) + description_head = "The contract delegates execution to another contract with a user-supplied address." description_tail = ( "The smart contract delegates execution to a user-supplied address. Note that callers " @@ -85,7 +79,7 @@ class DelegateCallModule(DetectionModule): ) return [ - Issue( + PotentialIssue( contract=state.environment.active_account.contract_name, function_name=state.environment.active_function_name, address=address, @@ -95,8 +89,8 @@ class DelegateCallModule(DetectionModule): severity="Medium", description_head=description_head, description_tail=description_tail, - transaction_sequence=transaction_sequence, - gas_used=(state.mstate.min_gas_used, state.mstate.max_gas_used), + constraints=constraints, + detector=self, ) ] diff --git a/mythril/laser/ethereum/gas.py b/mythril/laser/ethereum/gas.py index 7e283d15..0a0c0a38 100644 --- a/mythril/laser/ethereum/gas.py +++ b/mythril/laser/ethereum/gas.py @@ -180,10 +180,7 @@ OPCODE_GAS = { "LOG3": (4 * 375, 4 * 375 + 8 * 32), "LOG4": (5 * 375, 5 * 375 + 8 * 32), "CREATE": (32000, 32000), - "CREATE2": ( - 32000, - 32000, - ), # TODO: The gas value is dynamic, to be done while implementing create2 + "CREATE2": (32000, 32000), # TODO: Make the gas values dynamic "CALL": (700, 700 + 9000 + 25000), "NATIVE_COST": calculate_native_gas, "CALLCODE": (700, 700 + 9000 + 25000), diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 4bd8436f..2f1c7f03 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -6,7 +6,6 @@ import logging from copy import copy, deepcopy from typing import cast, Callable, List, Set, Union, Tuple, Any from datetime import datetime -from math import ceil from ethereum import utils from mythril.laser.smt import ( @@ -29,9 +28,13 @@ from mythril.laser.smt import ( ) from mythril.laser.smt import symbol_factory +from mythril.disassembler.disassembly import Disassembly + +from mythril.laser.ethereum.state.calldata import ConcreteCalldata, SymbolicCalldata + import mythril.laser.ethereum.util as helper from mythril.laser.ethereum import util -from mythril.laser.ethereum.call import get_call_parameters, native_call +from mythril.laser.ethereum.call import get_call_parameters, native_call, get_call_data from mythril.laser.ethereum.evm_exceptions import ( VmException, StackUnderflowException, @@ -46,8 +49,11 @@ from mythril.laser.ethereum.transaction import ( MessageCallTransaction, TransactionStartSignal, ContractCreationTransaction, + get_next_transaction_id, ) +from mythril.support.support_utils import get_code_hash + from mythril.support.loader import DynLoader log = logging.getLogger(__name__) @@ -780,7 +786,7 @@ class Instruction: return [global_state] @staticmethod - def _calldata_copy_helper(global_state, state, mstart, dstart, size): + def _calldata_copy_helper(global_state, mstate, mstart, dstart, size): environment = global_state.environment try: @@ -804,11 +810,11 @@ class Instruction: size = cast(int, size) if size > 0: try: - state.mem_extend(mstart, size) + mstate.mem_extend(mstart, size) except TypeError as e: log.debug("Memory allocation error: {}".format(e)) - state.mem_extend(mstart, 1) - state.memory[mstart] = global_state.new_bitvec( + mstate.mem_extend(mstart, 1) + mstate.memory[mstart] = global_state.new_bitvec( "calldata_" + str(environment.active_account.contract_name) + "[" @@ -833,12 +839,12 @@ class Instruction: else simplify(cast(BitVec, i_data) + 1) ) for i in range(len(new_memory)): - state.memory[i + mstart] = new_memory[i] + mstate.memory[i + mstart] = new_memory[i] except IndexError: log.debug("Exception copying calldata to memory") - state.memory[mstart] = global_state.new_bitvec( + mstate.memory[mstart] = global_state.new_bitvec( "calldata_" + str(environment.active_account.contract_name) + "[" @@ -929,19 +935,31 @@ class Instruction: state = global_state.mstate environment = global_state.environment disassembly = environment.code + calldata = global_state.environment.calldata if isinstance(global_state.current_transaction, ContractCreationTransaction): # Hacky way to ensure constructor arguments work - Pick some reasonably large size. - no_of_bytes = ( - len(disassembly.bytecode) // 2 + 0x200 - ) # space for 16 32-byte arguments - global_state.mstate.constraints.append( - global_state.environment.calldata.size == no_of_bytes - ) + no_of_bytes = len(disassembly.bytecode) // 2 + if isinstance(calldata, ConcreteCalldata): + no_of_bytes += calldata.size + else: + no_of_bytes += 0x200 # space for 16 32-byte arguments + global_state.mstate.constraints.append( + global_state.environment.calldata.size == no_of_bytes + ) + else: no_of_bytes = len(disassembly.bytecode) // 2 state.stack.append(no_of_bytes) return [global_state] + @staticmethod + def _sha3_gas_helper(global_state, length): + min_gas, max_gas = cast(Callable, OPCODE_GAS["SHA3_FUNC"])(length) + global_state.mstate.min_gas_used += min_gas + global_state.mstate.max_gas_used += max_gas + StateTransition.check_gas_usage_limit(global_state) + return global_state + @StateTransition(enable_gas=False) def sha3_(self, global_state: GlobalState) -> List[GlobalState]: """ @@ -967,10 +985,7 @@ class Instruction: state.max_gas_used += gas_tuple[1] return [global_state] - min_gas, max_gas = cast(Callable, OPCODE_GAS["SHA3_FUNC"])(length) - state.min_gas_used += min_gas - state.max_gas_used += max_gas - StateTransition.check_gas_usage_limit(global_state) + Instruction._sha3_gas_helper(global_state, length) state.mem_extend(index, length) data_list = [ @@ -1023,34 +1038,6 @@ class Instruction: global_state.mstate.stack.append(global_state.environment.gasprice) return [global_state] - @staticmethod - def _handle_symbolic_args( - global_state: GlobalState, concrete_memory_offset: int - ) -> None: - """ - In contract creation transaction with dynamic arguments(like arrays, maps) solidity will try to - execute CODECOPY with code size as len(with_args) - len(without_args) which in our case - would be 0, hence we are writing 10 symbol words onto the memory on the assumption that - no one would use 10 array/map arguments for constructor. - :param global_state: The global state - :param concrete_memory_offset: The memory offset on which symbols should be written - """ - no_of_words = ceil( - min(len(global_state.environment.code.bytecode) / 2, 320) / 32 - ) - global_state.mstate.mem_extend(concrete_memory_offset, 32 * no_of_words) - for i in range(no_of_words): - global_state.mstate.memory.write_word_at( - concrete_memory_offset + i * 32, - global_state.new_bitvec( - "code_{}({})".format( - concrete_memory_offset + i * 32, - global_state.environment.active_account.contract_name, - ), - 256, - ), - ) - @StateTransition() def codecopy_(self, global_state: GlobalState) -> List[GlobalState]: """ @@ -1063,28 +1050,71 @@ class Instruction: global_state.mstate.stack.pop(), global_state.mstate.stack.pop(), ) + code = global_state.environment.code.bytecode + if code[0:2] == "0x": + code = code[2:] + code_size = len(code) // 2 - if ( - isinstance(global_state.current_transaction, ContractCreationTransaction) - and code_offset >= len(global_state.environment.code.bytecode) // 2 - ): + if isinstance(global_state.current_transaction, ContractCreationTransaction): # Treat creation code after the expected disassembly as calldata. # This is a slightly hacky way to ensure that symbolic constructor # arguments work correctly. - offset = code_offset - len(global_state.environment.code.bytecode) // 2 - log.warning("Doing hacky thing offset: {} size: {}".format(offset, size)) - return self._calldata_copy_helper( - global_state, global_state.mstate, memory_offset, offset, size - ) - else: - return self._code_copy_helper( - code=global_state.environment.code.bytecode, - memory_offset=memory_offset, - code_offset=code_offset, - size=size, - op="CODECOPY", - global_state=global_state, - ) + mstate = global_state.mstate + offset = code_offset - code_size + log.debug("Copying from code offset: {} with size: {}".format(offset, size)) + + if isinstance(global_state.environment.calldata, SymbolicCalldata): + if code_offset >= code_size: + return self._calldata_copy_helper( + global_state, mstate, memory_offset, offset, size + ) + else: + # Copy from both code and calldata appropriately. + concrete_code_offset = helper.get_concrete_int(code_offset) + concrete_size = helper.get_concrete_int(size) + + code_copy_offset = concrete_code_offset + code_copy_size = ( + concrete_size + if concrete_code_offset + concrete_size <= code_size + else code_size - concrete_code_offset + ) + code_copy_size = code_copy_size if code_copy_size >= 0 else 0 + + calldata_copy_offset = ( + concrete_code_offset - code_size + if concrete_code_offset - code_size > 0 + else 0 + ) + calldata_copy_size = concrete_code_offset + concrete_size - code_size + calldata_copy_size = ( + calldata_copy_size if calldata_copy_size >= 0 else 0 + ) + + [global_state] = self._code_copy_helper( + code=global_state.environment.code.bytecode, + memory_offset=memory_offset, + code_offset=code_copy_offset, + size=code_copy_size, + op="CODECOPY", + global_state=global_state, + ) + return self._calldata_copy_helper( + global_state=global_state, + mstate=mstate, + mstart=memory_offset + code_copy_size, + dstart=calldata_copy_offset, + size=calldata_copy_size, + ) + + return self._code_copy_helper( + code=global_state.environment.code.bytecode, + memory_offset=memory_offset, + code_offset=code_offset, + size=size, + op="CODECOPY", + global_state=global_state, + ) @StateTransition() def extcodesize_(self, global_state: GlobalState) -> List[GlobalState]: @@ -1117,9 +1147,9 @@ class Instruction: @staticmethod def _code_copy_helper( code: str, - memory_offset: BitVec, - code_offset: BitVec, - size: BitVec, + memory_offset: Union[int, BitVec], + code_offset: Union[int, BitVec], + size: Union[int, BitVec], op: str, global_state: GlobalState, ) -> List[GlobalState]: @@ -1165,13 +1195,6 @@ class Instruction: if code[0:2] == "0x": code = code[2:] - if concrete_size == 0 and isinstance( - global_state.current_transaction, ContractCreationTransaction - ): - if concrete_code_offset >= len(code) // 2: - Instruction._handle_symbolic_args(global_state, concrete_memory_offset) - return [global_state] - for i in range(concrete_size): if 2 * (concrete_code_offset + i + 1) <= len(code): global_state.mstate.memory[concrete_memory_offset + i] = int( @@ -1231,18 +1254,26 @@ class Instruction: global_state=global_state, ) - @StateTransition + @StateTransition() def extcodehash_(self, global_state: GlobalState) -> List[GlobalState]: """ :param global_state: :return: List of global states possible, list of size 1 in this case """ - # TODO: To be implemented - address = global_state.mstate.stack.pop() - global_state.mstate.stack.append( - global_state.new_bitvec("extcodehash_{}".format(str(address)), 256) - ) + world_state = global_state.world_state + stack = global_state.mstate.stack + address = Extract(159, 0, stack.pop()) + + if address.symbolic: + code_hash = symbol_factory.BitVecVal(int(get_code_hash(""), 16), 256) + elif address.value not in world_state.accounts: + code_hash = symbol_factory.BitVecVal(0, 256) + else: + addr = "0" * (40 - len(hex(address.value)[2:])) + hex(address.value)[2:] + code = world_state.accounts_exist_or_load(addr, self.dynamic_loader) + code_hash = symbol_factory.BitVecVal(int(get_code_hash(code), 16), 256) + stack.append(code_hash) return [global_state] @StateTransition() @@ -1629,6 +1660,62 @@ class Instruction: # Not supported return [global_state] + def _create_transaction_helper( + self, global_state, call_value, mem_offset, mem_size, create2_salt=None + ): + mstate = global_state.mstate + environment = global_state.environment + world_state = global_state.world_state + + call_data = get_call_data(global_state, mem_offset, mem_offset + mem_size) + + code_raw = [] + code_end = call_data.size + for i in range(call_data.size): + if call_data[i].symbolic: + code_end = i + break + code_raw.append(call_data[i].value) + + code_str = bytes.hex(bytes(code_raw)) + + next_transaction_id = get_next_transaction_id() + constructor_arguments = ConcreteCalldata( + next_transaction_id, call_data[code_end:] + ) + code = Disassembly(code_str) + + caller = environment.active_account.address + gas_price = environment.gasprice + origin = environment.origin + + contract_address = None + if create2_salt: + salt = hex(create2_salt)[2:] + salt = "0" * (64 - len(salt)) + salt + + addr = hex(caller.value)[2:] + addr = "0" * (40 - len(addr)) + addr + + Instruction._sha3_gas_helper(global_state, len(code_str[2:] // 2)) + + contract_address = int( + get_code_hash("0xff" + addr + salt + get_code_hash(code_str)[2:])[26:], + 16, + ) + transaction = ContractCreationTransaction( + world_state=world_state, + caller=caller, + code=code, + call_data=constructor_arguments, + gas_price=gas_price, + gas_limit=mstate.gas_limit, + origin=origin, + call_value=call_value, + contract_address=contract_address, + ) + raise TransactionStartSignal(transaction, self.op_code, global_state) + @StateTransition(is_state_mutation_instruction=True) def create_(self, global_state: GlobalState) -> List[GlobalState]: """ @@ -1636,12 +1723,23 @@ class Instruction: :param global_state: :return: """ - # TODO: implement me + call_value, mem_offset, mem_size = global_state.mstate.pop(3) - state = global_state.mstate - state.stack.pop(), state.stack.pop(), state.stack.pop() - # Not supported - state.stack.append(0) + return self._create_transaction_helper( + global_state, call_value, mem_offset, mem_size + ) + + @StateTransition() + def create_post(self, global_state: GlobalState) -> List[GlobalState]: + call_value, mem_offset, mem_size = global_state.mstate.pop(3) + call_data = get_call_data(global_state, mem_offset, mem_offset + mem_size) + if global_state.last_return_data: + return_val = symbol_factory.BitVecVal( + int(global_state.last_return_data, 16), 256 + ) + else: + return_val = symbol_factory.BitVecVal(0, 256) + global_state.mstate.stack.append(return_val) return [global_state] @StateTransition(is_state_mutation_instruction=True) @@ -1651,16 +1749,23 @@ class Instruction: :param global_state: :return: """ - # TODO: implement me - state = global_state.mstate - endowment, memory_start, memory_length, salt = ( - state.stack.pop(), - state.stack.pop(), - state.stack.pop(), - state.stack.pop(), + call_value, mem_offset, mem_size, salt = global_state.mstate.pop(4) + + return self._create_transaction_helper( + global_state, call_value, mem_offset, mem_size, salt ) - # Not supported - state.stack.append(0) + + @StateTransition() + def create2_post(self, global_state: GlobalState) -> List[GlobalState]: + call_value, mem_offset, mem_size, salt = global_state.mstate.pop(4) + call_data = get_call_data(global_state, mem_offset, mem_offset + mem_size) + if global_state.last_return_data: + return_val = symbol_factory.BitVecVal( + int(global_state.last_return_data), 256 + ) + else: + return_val = symbol_factory.BitVecVal(0, 256) + global_state.mstate.stack.append(return_val) return [global_state] @StateTransition() @@ -2109,13 +2214,14 @@ class Instruction: origin=environment.origin, code=callee_account.code, caller=environment.address, - callee_account=environment.active_account, + callee_account=callee_account, call_data=call_data, call_value=value, static=True, ) raise TransactionStartSignal(transaction, self.op_code, global_state) + @StateTransition() def staticcall_post(self, global_state: GlobalState) -> List[GlobalState]: return self.post_handler(global_state, function_name="staticcall") @@ -2123,8 +2229,9 @@ class Instruction: instr = global_state.get_current_instruction() try: + with_value = function_name is not "staticcall" callee_address, callee_account, call_data, value, gas, memory_out_offset, memory_out_size = get_call_parameters( - global_state, self.dynamic_loader, True + global_state, self.dynamic_loader, with_value ) except ValueError as e: log.debug( diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 822d1c55..af52c12b 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -396,6 +396,7 @@ class LaserEVM: :param return_data: :return: """ + return_global_state.mstate.constraints += global_state.mstate.constraints # Resume execution of the transaction initializing instruction op_code = return_global_state.environment.code.instruction_list[ @@ -409,6 +410,15 @@ class LaserEVM: return_global_state.environment.active_account = global_state.accounts[ return_global_state.environment.active_account.address.value ] + if isinstance( + global_state.current_transaction, ContractCreationTransaction + ): + return_global_state.mstate.min_gas_used += ( + global_state.mstate.min_gas_used + ) + return_global_state.mstate.max_gas_used += ( + global_state.mstate.max_gas_used + ) # Execute the post instruction handler new_global_states = Instruction( diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index 31ce0271..b36cc61d 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -195,12 +195,16 @@ class ContractCreationTransaction(BaseTransaction): code=None, call_value=None, contract_name=None, + contract_address=None, ) -> None: self.prev_world_state = deepcopy(world_state) + contract_address = ( + contract_address if isinstance(contract_address, int) else None + ) callee_account = world_state.create_account( - 0, concrete_storage=True, creator=caller.value + 0, concrete_storage=True, creator=caller.value, address=contract_address ) - callee_account.contract_name = contract_name + callee_account.contract_name = contract_name or callee_account.contract_name # init_call_data "should" be false, but it is easier to model the calldata symbolically # and add logic in codecopy/codesize/calldatacopy/calldatasize than to model code "correctly" super().__init__( diff --git a/tests/instructions/create_test.py b/tests/instructions/create_test.py new file mode 100644 index 00000000..0a97565e --- /dev/null +++ b/tests/instructions/create_test.py @@ -0,0 +1,72 @@ +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.ethereum.state.environment import Environment +from mythril.laser.ethereum.state.account import Account +from mythril.laser.ethereum.state.machine_state import MachineState +from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.instructions import Instruction +from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction +from mythril.laser.ethereum.state.calldata import ConcreteCalldata +from mythril.laser.ethereum.svm import LaserEVM +from mythril.laser.smt import symbol_factory + +# contract A { +# uint256 public val = 10; +# } +contract_init_code = "6080604052600a600055348015601457600080fd5b506084806100236000396000f3fe6080604052348015600f57600080fd5b506004361060285760003560e01c80633c6bb43614602d575b600080fd5b60336049565b6040518082815260200191505060405180910390f35b6000548156fea265627a7a72315820d3cfe7a909450a953cbd7e47d8c17477f2609baa5208d043e965efec69d1ed9864736f6c634300050b0032" +contract_runtime_code = "6080604052348015600f57600080fd5b506004361060285760003560e01c80633c6bb43614602d575b600080fd5b60336049565b6040518082815260200191505060405180910390f35b6000548156fea265627a7a72315820d3cfe7a909450a953cbd7e47d8c17477f2609baa5208d043e965efec69d1ed9864736f6c634300050b0032" + +last_state = None +created_contract_account = None + + +def execute_create(): + global last_state + global created_contract_account + if not last_state and not created_contract_account: + code_raw = [] + for i in range(len(contract_init_code) // 2): + code_raw.append(int(contract_init_code[2 * i : 2 * (i + 1)], 16)) + calldata = ConcreteCalldata(0, code_raw) + + world_state = WorldState() + account = world_state.create_account(balance=1000000, address=101) + account.code = Disassembly("60a760006000f000") + environment = Environment(account, None, calldata, None, None, None) + og_state = GlobalState( + world_state, environment, None, MachineState(gas_limit=8000000) + ) + og_state.transaction_stack.append( + (MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None) + ) + + laser = LaserEVM() + states = [og_state] + last_state = og_state + for state in states: + new_states, op_code = laser.execute_state(state) + last_state = state + if op_code == "STOP": + break + states.extend(new_states) + + created_contract_address = last_state.mstate.stack[-1].value + created_contract_account = last_state.world_state.accounts[ + created_contract_address + ] + + return last_state, created_contract_account + + +def test_create_has_code(): + last_state, created_contract_account = execute_create() + assert created_contract_account.code.bytecode == contract_runtime_code + + +def test_create_has_storage(): + last_state, created_contract_account = execute_create() + storage = created_contract_account.storage + # From contract, val = 10. + assert storage[symbol_factory.BitVecVal(0, 256)] == symbol_factory.BitVecVal( + 10, 256 + ) diff --git a/tests/instructions/extcodehash_test.py b/tests/instructions/extcodehash_test.py new file mode 100644 index 00000000..14f2ad65 --- /dev/null +++ b/tests/instructions/extcodehash_test.py @@ -0,0 +1,48 @@ +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.ethereum.state.environment import Environment +from mythril.laser.ethereum.state.account import Account +from mythril.laser.ethereum.state.machine_state import MachineState +from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.instructions import Instruction +from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction + +from mythril.support.support_utils import get_code_hash + +from mythril.laser.smt import symbol_factory + +# Arrange +world_state = WorldState() +account = world_state.create_account(balance=10, address=101) +account.code = Disassembly("60606040") +world_state.create_account(balance=10, address=1000) +environment = Environment(account, None, None, None, None, None) +og_state = GlobalState(world_state, environment, None, MachineState(gas_limit=8000000)) +og_state.transaction_stack.append( + (MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None) +) + +instruction = Instruction("extcodehash", dynamic_loader=None) + + +def test_extcodehash_no_account(): + + # If account does not exist, return 0 + og_state.mstate.stack = [symbol_factory.BitVecVal(1, 256)] + new_state = instruction.evaluate(og_state)[0] + assert new_state.mstate.stack[-1] == 0 + + +def test_extcodehash_no_code(): + + # If account code does not exist, return hash of empty set. + og_state.mstate.stack = [symbol_factory.BitVecVal(1000, 256)] + new_state = instruction.evaluate(og_state)[0] + assert hex(new_state.mstate.stack[-1].value) == get_code_hash("") + + +def test_extcodehash_return_hash(): + # If account code exists, return hash of the code. + og_state.mstate.stack = [symbol_factory.BitVecVal(101, 256)] + new_state = instruction.evaluate(og_state)[0] + assert hex(new_state.mstate.stack[-1].value) == get_code_hash("60606040")