diff --git a/mythril/analysis/modules/arbitrary_jump.py b/mythril/analysis/modules/arbitrary_jump.py new file mode 100644 index 00000000..bff27aa5 --- /dev/null +++ b/mythril/analysis/modules/arbitrary_jump.py @@ -0,0 +1,81 @@ +"""This module contains the detection code for Arbitrary jumps.""" +import logging +from mythril.analysis.solver import get_transaction_sequence, UnsatError +from mythril.analysis.modules.base import DetectionModule, Issue +from mythril.analysis.swc_data import ARBITRARY_JUMP +from mythril.laser.ethereum.state.global_state import GlobalState + +log = logging.getLogger(__name__) + +DESCRIPTION = """ + +Search for any writes to an arbitrary storage slot +""" + + +class ArbitraryJump(DetectionModule): + """This module searches for JUMPs to an arbitrary instruction.""" + + def __init__(self): + """""" + super().__init__( + name="Jump to an arbitrary line", + swc_id=ARBITRARY_JUMP, + description=DESCRIPTION, + entrypoint="callback", + pre_hooks=["JUMP", "JUMPI"], + ) + + def reset_module(self): + """ + Resets the module by clearing everything + :return: + """ + super().reset_module() + + def _execute(self, state: GlobalState) -> None: + """ + + :param state: + :return: + """ + if state.get_current_instruction()["address"] in self.cache: + return + self.issues.extend(self._analyze_state(state)) + + @staticmethod + def _analyze_state(state): + """ + + :param state: + :return: + """ + + jump_dest = state.mstate.stack[-1] + if jump_dest.symbolic is False: + return [] + # Most probably the jump destination can have multiple locations in these circumstances + try: + transaction_sequence = get_transaction_sequence( + state, state.mstate.constraints + ) + except UnsatError: + return [] + issue = Issue( + contract=state.environment.active_account.contract_name, + function_name=state.environment.active_function_name, + address=state.get_current_instruction()["address"], + swc_id=ARBITRARY_JUMP, + title="Jump to an arbitrary instruction", + severity="Medium", + bytecode=state.environment.code.bytecode, + description_head="The caller can jump to any point in the code.", + description_tail="This can lead to unintended consequences." + "Please avoid using low level code as much as possible", + gas_used=(state.mstate.min_gas_used, state.mstate.max_gas_used), + transaction_sequence=transaction_sequence, + ) + return [issue] + + +detector = ArbitraryJump() diff --git a/mythril/analysis/modules/dependence_on_predictable_vars.py b/mythril/analysis/modules/dependence_on_predictable_vars.py index a9c730fa..b3379263 100644 --- a/mythril/analysis/modules/dependence_on_predictable_vars.py +++ b/mythril/analysis/modules/dependence_on_predictable_vars.py @@ -102,10 +102,11 @@ class PredictableDependenceModule(DetectionModule): if isinstance(annotation, PredictablePathAnnotation): if annotation.add_constraints: constraints = ( - state.mstate.constraints + annotation.add_constraints + state.world_state.constraints + + annotation.add_constraints ) else: - constraints = copy(state.mstate.constraints) + constraints = copy(state.world_state.constraints) try: transaction_sequence = solver.get_transaction_sequence( state, constraints @@ -192,7 +193,7 @@ class PredictableDependenceModule(DetectionModule): # Why the second constraint? Because without it Z3 returns a solution where param overflows. - solver.get_model(state.mstate.constraints + constraint) + solver.get_model(state.world_state.constraints + constraint) # type: ignore state.annotate(OldBlockNumberUsedAnnotation(constraint)) except UnsatError: diff --git a/mythril/analysis/modules/dos.py b/mythril/analysis/modules/dos.py deleted file mode 100644 index cb9ac2f3..00000000 --- a/mythril/analysis/modules/dos.py +++ /dev/null @@ -1,137 +0,0 @@ -"""This module contains the detection code SWC-128 - DOS with block gas limit.""" - -import logging -from typing import Dict, cast, List - -from mythril.analysis.swc_data import DOS_WITH_BLOCK_GAS_LIMIT -from mythril.analysis.report import Issue -from mythril.analysis.modules.base import DetectionModule -from mythril.analysis.solver import get_transaction_sequence, UnsatError -from mythril.analysis.analysis_args import analysis_args -from mythril.laser.ethereum.state.global_state import GlobalState -from mythril.laser.ethereum.state.annotation import StateAnnotation -from mythril.laser.ethereum import util -from copy import copy - -log = logging.getLogger(__name__) - - -class VisitsAnnotation(StateAnnotation): - """State annotation that stores the addresses of state-modifying operations""" - - def __init__(self) -> None: - self.loop_start = None # type: int - self.jump_targets = {} # type: Dict[int, int] - - def __copy__(self): - result = VisitsAnnotation() - - result.loop_start = self.loop_start - result.jump_targets = copy(self.jump_targets) - return result - - -class DosModule(DetectionModule): - """This module consists of a makeshift loop detector that annotates the state with - a list of byte ranges likely to be loops. If a CALL or SSTORE detection is found in - one of the ranges it creates a low-severity issue. This is not super precise but - good enough to identify places that warrant a closer look. Checking the loop condition - would be a possible improvement. - """ - - def __init__(self) -> None: - """""" - super().__init__( - name="DOS", - swc_id=DOS_WITH_BLOCK_GAS_LIMIT, - description="Check for DOS", - entrypoint="callback", - pre_hooks=["JUMP", "JUMPI", "CALL", "SSTORE"], - ) - - def _execute(self, state: GlobalState) -> None: - - """ - :param state: - :return: - """ - issues = self._analyze_state(state) - self.issues.extend(issues) - - def _analyze_state(self, state: GlobalState) -> List[Issue]: - """ - :param state: the current state - :return: returns the issues for that corresponding state - """ - - opcode = state.get_current_instruction()["opcode"] - address = state.get_current_instruction()["address"] - - annotations = cast( - List[VisitsAnnotation], list(state.get_annotations(VisitsAnnotation)) - ) - - if len(annotations) == 0: - annotation = VisitsAnnotation() - state.annotate(annotation) - else: - annotation = annotations[0] - - if opcode in ["JUMP", "JUMPI"]: - - if annotation.loop_start is not None: - return [] - try: - target = util.get_concrete_int(state.mstate.stack[-1]) - except TypeError: - log.debug("Symbolic target encountered in dos module") - return [] - if target in annotation.jump_targets: - annotation.jump_targets[target] += 1 - else: - annotation.jump_targets[target] = 1 - - if annotation.jump_targets[target] > min(2, analysis_args.loop_bound - 1): - annotation.loop_start = address - - elif annotation.loop_start is not None: - - if opcode == "CALL": - operation = "A message call" - else: - operation = "A storage modification" - - description_head = ( - "Potential denial-of-service if block gas limit is reached." - ) - description_tail = "{} is executed in a loop. Be aware that the transaction may fail to execute if the loop is unbounded and the necessary gas exceeds the block gas limit.".format( - operation - ) - - try: - transaction_sequence = get_transaction_sequence( - state, state.mstate.constraints - ) - except UnsatError: - return [] - - issue = Issue( - contract=state.environment.active_account.contract_name, - function_name=state.environment.active_function_name, - address=annotation.loop_start, - swc_id=DOS_WITH_BLOCK_GAS_LIMIT, - bytecode=state.environment.code.bytecode, - title="Potential denial-of-service if block gas limit is reached", - severity="Low", - description_head=description_head, - description_tail=description_tail, - gas_used=(state.mstate.min_gas_used, state.mstate.max_gas_used), - transaction_sequence=transaction_sequence, - ) - - return [issue] - - return [] - - -detector = DosModule() diff --git a/mythril/analysis/modules/ether_thief.py b/mythril/analysis/modules/ether_thief.py index 6a1e30cc..c0648167 100644 --- a/mythril/analysis/modules/ether_thief.py +++ b/mythril/analysis/modules/ether_thief.py @@ -77,7 +77,7 @@ class EtherThief(DetectionModule): value = state.mstate.stack[-3] target = state.mstate.stack[-2] - constraints = copy(state.mstate.constraints) + constraints = copy(state.world_state.constraints) """ Require that the current transaction is sent by the attacker and diff --git a/mythril/analysis/modules/exceptions.py b/mythril/analysis/modules/exceptions.py index 1ac639eb..c2c8c61e 100644 --- a/mythril/analysis/modules/exceptions.py +++ b/mythril/analysis/modules/exceptions.py @@ -58,7 +58,7 @@ class ReachableExceptionsModule(DetectionModule): "Use `require()` for regular input checking." ) transaction_sequence = solver.get_transaction_sequence( - state, state.mstate.constraints + state, state.world_state.constraints ) issue = Issue( contract=state.environment.active_account.contract_name, diff --git a/mythril/analysis/modules/external_calls.py b/mythril/analysis/modules/external_calls.py index f227f7b2..a09cf517 100644 --- a/mythril/analysis/modules/external_calls.py +++ b/mythril/analysis/modules/external_calls.py @@ -33,7 +33,7 @@ an informational issue. def _is_precompile_call(global_state: GlobalState): to = global_state.mstate.stack[-2] # type: BitVec - constraints = copy(global_state.mstate.constraints) + constraints = copy(global_state.world_state.constraints) constraints += [ Or( to < symbol_factory.BitVecVal(1, 256), @@ -88,7 +88,7 @@ class ExternalCalls(DetectionModule): constraints = Constraints([UGT(gas, symbol_factory.BitVecVal(2300, 256))]) solver.get_transaction_sequence( - state, constraints + state.mstate.constraints + state, constraints + state.world_state.constraints ) # Check whether we can also set the callee address @@ -101,7 +101,7 @@ class ExternalCalls(DetectionModule): constraints.append(tx.caller == ACTORS.attacker) solver.get_transaction_sequence( - state, constraints + state.mstate.constraints + state, constraints + state.world_state.constraints ) description_head = "A call to a user-supplied address is executed." diff --git a/mythril/analysis/modules/integer.py b/mythril/analysis/modules/integer.py index 2879a331..3aa6ab50 100644 --- a/mythril/analysis/modules/integer.py +++ b/mythril/analysis/modules/integer.py @@ -291,7 +291,9 @@ class IntegerOverflowUnderflowModule(DetectionModule): if ostate not in self._ostates_satisfiable: try: - constraints = ostate.mstate.constraints + [annotation.constraint] + constraints = ostate.world_state.constraints + [ + annotation.constraint + ] solver.get_model(constraints) self._ostates_satisfiable.add(ostate) except: @@ -308,7 +310,7 @@ class IntegerOverflowUnderflowModule(DetectionModule): try: - constraints = state.mstate.constraints + [annotation.constraint] + constraints = state.world_state.constraints + [annotation.constraint] transaction_sequence = solver.get_transaction_sequence( state, constraints ) diff --git a/mythril/analysis/modules/multiple_sends.py b/mythril/analysis/modules/multiple_sends.py index e6b95085..67051080 100644 --- a/mythril/analysis/modules/multiple_sends.py +++ b/mythril/analysis/modules/multiple_sends.py @@ -81,7 +81,7 @@ class MultipleSendsModule(DetectionModule): for offset in call_offsets[1:]: try: transaction_sequence = get_transaction_sequence( - state, state.mstate.constraints + state, state.world_state.constraints ) except UnsatError: continue diff --git a/mythril/analysis/modules/state_change_external_calls.py b/mythril/analysis/modules/state_change_external_calls.py index 55b1d488..c3819c51 100644 --- a/mythril/analysis/modules/state_change_external_calls.py +++ b/mythril/analysis/modules/state_change_external_calls.py @@ -56,7 +56,7 @@ class StateChangeCallsAnnotation(StateAnnotation): try: solver.get_transaction_sequence( - global_state, constraints + global_state.mstate.constraints + global_state, constraints + global_state.world_state.constraints ) except UnsatError: return None @@ -124,7 +124,7 @@ class StateChange(DetectionModule): gas = global_state.mstate.stack[-1] to = global_state.mstate.stack[-2] try: - constraints = copy(global_state.mstate.constraints) + constraints = copy(global_state.world_state.constraints) solver.get_model( constraints + [ @@ -190,7 +190,7 @@ class StateChange(DetectionModule): return value.value > 0 else: - constraints = copy(global_state.mstate.constraints) + constraints = copy(global_state.world_state.constraints) try: solver.get_model( diff --git a/mythril/analysis/modules/suicide.py b/mythril/analysis/modules/suicide.py index e3e77360..6ef87d8e 100644 --- a/mythril/analysis/modules/suicide.py +++ b/mythril/analysis/modules/suicide.py @@ -73,7 +73,9 @@ class SuicideModule(DetectionModule): try: transaction_sequence = solver.get_transaction_sequence( state, - state.mstate.constraints + constraints + [to == ACTORS.attacker], + state.world_state.constraints + + constraints + + [to == ACTORS.attacker], ) description_tail = ( "Anyone can kill this contract and withdraw its balance to an arbitrary " @@ -81,7 +83,7 @@ class SuicideModule(DetectionModule): ) except UnsatError: transaction_sequence = solver.get_transaction_sequence( - state, state.mstate.constraints + constraints + state, state.world_state.constraints + constraints ) description_tail = "Arbitrary senders can kill this contract." issue = Issue( diff --git a/mythril/analysis/modules/unchecked_retval.py b/mythril/analysis/modules/unchecked_retval.py index 5aa8f005..8b5d1293 100644 --- a/mythril/analysis/modules/unchecked_retval.py +++ b/mythril/analysis/modules/unchecked_retval.py @@ -83,7 +83,7 @@ class UncheckedRetvalModule(DetectionModule): for retval in retvals: try: transaction_sequence = solver.get_transaction_sequence( - state, state.mstate.constraints + [retval["retval"] == 0] + state, state.world_state.constraints + [retval["retval"] == 0] ) except UnsatError: continue diff --git a/mythril/analysis/potential_issues.py b/mythril/analysis/potential_issues.py index 99340330..6d9159b8 100644 --- a/mythril/analysis/potential_issues.py +++ b/mythril/analysis/potential_issues.py @@ -84,7 +84,7 @@ def check_potential_issues(state: GlobalState) -> None: for potential_issue in annotation.potential_issues: try: transaction_sequence = get_transaction_sequence( - state, state.mstate.constraints + potential_issue.constraints + state, state.world_state.constraints + potential_issue.constraints ) except UnsatError: continue diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index ddded1b0..8064d8fe 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -133,8 +133,8 @@ class SymExecWrapper: plugin_loader.load(PluginFactory.build_state_merge_plugin()) plugin_loader.load(instruction_laser_plugin) - # if not disable_dependency_pruning: - # plugin_loader.load(PluginFactory.build_dependency_pruner_plugin()) + if not disable_dependency_pruning: + plugin_loader.load(PluginFactory.build_dependency_pruner_plugin()) world_state = WorldState() for account in self.accounts.values(): diff --git a/mythril/laser/ethereum/call.py b/mythril/laser/ethereum/call.py index 5d0672c9..52b35711 100644 --- a/mythril/laser/ethereum/call.py +++ b/mythril/laser/ethereum/call.py @@ -270,5 +270,5 @@ def native_call( "retval_" + str(global_state.get_current_instruction()["address"]), 256 ) global_state.mstate.stack.append(retval) - global_state.node.constraints.append(retval == 1) + global_state.world_state.constraints.append(retval == 1) return [global_state] diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index a91acd8e..21f5f4b5 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -80,7 +80,7 @@ def transfer_ether( """ value = value if isinstance(value, BitVec) else symbol_factory.BitVecVal(value, 256) - global_state.mstate.constraints.append( + global_state.world_state.constraints.append( UGE(global_state.world_state.balances[sender], value) ) global_state.world_state.balances[receiver] += value @@ -948,7 +948,7 @@ class Instruction: no_of_bytes += calldata.size else: no_of_bytes += 0x200 # space for 16 32-byte arguments - global_state.mstate.constraints.append( + global_state.world_state.constraints.append( global_state.environment.calldata.size == no_of_bytes ) @@ -1010,7 +1010,7 @@ class Instruction: result, condition = keccak_function_manager.create_keccak(data) state.stack.append(result) - state.constraints.append(condition) + global_state.world_state.constraints.append(condition) return [global_state] @@ -1563,7 +1563,7 @@ class Instruction: # manually increment PC new_state.mstate.depth += 1 new_state.mstate.pc += 1 - new_state.mstate.constraints.append(negated) + new_state.world_state.constraints.append(negated) states.append(new_state) else: log.debug("Pruned unreachable states.") @@ -1589,7 +1589,7 @@ class Instruction: # manually set PC to destination new_state.mstate.pc = index new_state.mstate.depth += 1 - new_state.mstate.constraints.append(condi) + new_state.world_state.constraints.append(condi) states.append(new_state) else: log.debug("Pruned unreachable states.") @@ -1648,7 +1648,7 @@ class Instruction: def _create_transaction_helper( self, global_state, call_value, mem_offset, mem_size, create2_salt=None - ): + ) -> List[GlobalState]: mstate = global_state.mstate environment = global_state.environment world_state = global_state.world_state @@ -1673,7 +1673,7 @@ class Instruction: if len(code_raw) < 1: global_state.mstate.stack.append(1) log.debug("No code found for trying to execute a create type instruction.") - return global_state + return [global_state] code_str = bytes.hex(bytes(code_raw)) @@ -1695,7 +1695,7 @@ class Instruction: addr = hex(caller.value)[2:] addr = "0" * (40 - len(addr)) + addr - Instruction._sha3_gas_helper(global_state, len(code_str[2:] // 2)) + Instruction._sha3_gas_helper(global_state, len(code_str[2:]) // 2) contract_address = int( get_code_hash("0xff" + addr + salt + get_code_hash(code_str)[2:])[26:], @@ -1898,7 +1898,7 @@ class Instruction: ) if isinstance(value, BitVec): if value.symbolic: - global_state.mstate.constraints.append( + global_state.world_state.constraints.append( value == symbol_factory.BitVecVal(0, 256) ) elif value.value > 0: @@ -2026,7 +2026,7 @@ class Instruction: "retval_" + str(instr["address"]), 256 ) global_state.mstate.stack.append(return_value) - global_state.mstate.constraints.append(return_value == 0) + global_state.world_state.constraints.append(return_value == 0) return [global_state] try: @@ -2058,7 +2058,7 @@ class Instruction: # Put return value on stack return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) global_state.mstate.stack.append(return_value) - global_state.mstate.constraints.append(return_value == 1) + global_state.world_state.constraints.append(return_value == 1) return [global_state] @StateTransition() @@ -2154,7 +2154,7 @@ class Instruction: "retval_" + str(instr["address"]), 256 ) global_state.mstate.stack.append(return_value) - global_state.mstate.constraints.append(return_value == 0) + global_state.world_state.constraints.append(return_value == 0) return [global_state] try: @@ -2186,7 +2186,7 @@ class Instruction: # Put return value on stack return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) global_state.mstate.stack.append(return_value) - global_state.mstate.constraints.append(return_value == 1) + global_state.world_state.constraints.append(return_value == 1) return [global_state] @StateTransition() @@ -2318,6 +2318,6 @@ class Instruction: # Put return value on stack return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) global_state.mstate.stack.append(return_value) - global_state.mstate.constraints.append(return_value == 1) + global_state.world_state.constraints.append(return_value == 1) return [global_state] diff --git a/mythril/laser/ethereum/plugins/implementations/mutation_pruner.py b/mythril/laser/ethereum/plugins/implementations/mutation_pruner.py index 6bc13d9e..fe8c53ac 100644 --- a/mythril/laser/ethereum/plugins/implementations/mutation_pruner.py +++ b/mythril/laser/ethereum/plugins/implementations/mutation_pruner.py @@ -45,7 +45,7 @@ class MutationPruner(LaserPlugin): @symbolic_vm.laser_hook("add_world_state") def world_state_filter_hook(global_state: GlobalState): if And( - *global_state.mstate.constraints[:] + *global_state.world_state.constraints[:] + [ global_state.environment.callvalue > symbol_factory.BitVecVal(0, 256) diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 1365c671..18ceb9b2 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -4,7 +4,7 @@ This includes classes representing accounts and their storage. """ import logging from copy import copy, deepcopy -from typing import Any, Dict, Union, Tuple, Set, cast +from typing import Any, Dict, Union, Set from mythril.laser.smt import ( @@ -13,8 +13,6 @@ from mythril.laser.smt import ( BitVec, Bool, simplify, - BitVecFunc, - Extract, BaseArray, Concat, If, @@ -25,26 +23,6 @@ from mythril.laser.smt import symbol_factory log = logging.getLogger(__name__) -class StorageRegion: - def __getitem__(self, item): - raise NotImplementedError - - def __setitem__(self, key, value): - raise NotImplementedError - - -class ArrayStorageRegion(StorageRegion): - """ An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions""" - - pass - - -class IteStorageRegion(StorageRegion): - """ An IteStorageRegion is a storage region that uses Ite statements to implement a storage""" - - pass - - class Storage: """Storage class represents the storage of an Account.""" @@ -64,18 +42,8 @@ class Storage: self.storage_keys_loaded = set() # type: Set[int] self.address = address - @staticmethod - def _sanitize(input_: BitVec) -> BitVec: - if input_.size() == 512: - return input_ - if input_.size() > 512: - return Extract(511, 0, input_) - else: - return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_) - def __getitem__(self, item: BitVec) -> BitVec: storage = self._standard_storage - sanitized_item = item if ( self.address and self.address.value != 0 @@ -84,7 +52,7 @@ class Storage: and (self.dynld and self.dynld.storage_loading) ): try: - storage[sanitized_item] = symbol_factory.BitVecVal( + storage[item] = symbol_factory.BitVecVal( int( self.dynld.read_storage( contract_address="0x{:040X}".format(self.address.value), @@ -95,29 +63,14 @@ class Storage: 256, ) self.storage_keys_loaded.add(int(item.value)) - self.printable_storage[item] = storage[sanitized_item] + self.printable_storage[item] = storage[item] except ValueError as e: log.debug("Couldn't read storage at %s: %s", item, e) - return simplify(storage[sanitized_item]) - - @staticmethod - def get_map_index(key: BitVec) -> BitVec: - if ( - not isinstance(key, BitVecFunc) - or key.func_name != "keccak256" - or key.input_ is None - ): - return None - index = Extract(255, 0, key.input_) - return simplify(index) - - def _get_corresponding_storage(self, key: BitVec) -> BaseArray: - return self._standard_storage + return simplify(storage[item]) def __setitem__(self, key, value: Any) -> None: - storage = self._get_corresponding_storage(key) self.printable_storage[key] = value - storage[key] = value + self._standard_storage[key] = value if key.symbolic is False: self.storage_keys_loaded.add(int(key.value)) @@ -135,6 +88,13 @@ class Storage: # TODO: Do something better here return str(self.printable_storage) + def merge_storage(self, storage: "Storage", path_condition: Bool): + Lambda([x], If(And(lo <= x, x <= hi), y, Select(m, x))) + + self._standard_storage = If(path_condition, self._standard_storage, storage._standard_storage) + + + class Account: """Account class representing ethereum accounts.""" @@ -205,9 +165,8 @@ class Account: self._balances[self.address] += balance def merge_accounts(self, account: "Account", path_condition: Bool): - self.nonce = If(path_condition, self.nonce, account.nonce) - # self.storage.merge_storage(account.storage) - ## Merge Storage + assert self.nonce == account.nonce + self.storage.merge_storage(account.storage, path_condition) @property def as_dict(self) -> Dict: diff --git a/mythril/laser/ethereum/state/machine_state.py b/mythril/laser/ethereum/state/machine_state.py index a30c8afa..4fc734ea 100644 --- a/mythril/laser/ethereum/state/machine_state.py +++ b/mythril/laser/ethereum/state/machine_state.py @@ -11,7 +11,6 @@ from mythril.laser.ethereum.evm_exceptions import ( StackUnderflowException, OutOfGasException, ) -from mythril.laser.ethereum.state.constraints import Constraints from mythril.laser.ethereum.state.memory import Memory @@ -115,7 +114,6 @@ class MachineState: self.gas_limit = gas_limit self.min_gas_used = min_gas_used # lower gas usage bound self.max_gas_used = max_gas_used # upper gas usage bound - self.constraints = constraints or Constraints() self.depth = depth self.prev_pc = prev_pc # holds context of current pc @@ -216,7 +214,6 @@ class MachineState: pc=self._pc, stack=copy(self.stack), memory=copy(self.memory), - constraints=copy(self.constraints), depth=self.depth, prev_pc=self.prev_pc, ) diff --git a/mythril/laser/ethereum/state/world_state.py b/mythril/laser/ethereum/state/world_state.py index f7aeda86..4d3f6212 100644 --- a/mythril/laser/ethereum/state/world_state.py +++ b/mythril/laser/ethereum/state/world_state.py @@ -9,6 +9,7 @@ from ethereum.utils import mk_contract_address from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.constraints import Constraints from mythril.laser.ethereum.state.annotation import StateAnnotation +from mythril.laser.ethereum.state.constraints import Constraints if TYPE_CHECKING: from mythril.laser.ethereum.cfg import Node @@ -27,14 +28,19 @@ class Balances: balance = self.balance[item] for bs, pc in self.balance_merge_list: balance = If(pc, bs[item], balance) - + return balance + def __setitem__(self, key, value): + pass class WorldState: """The WorldState class represents the world state as described in the yellow paper.""" def __init__( - self, transaction_sequence=None, annotations: List[StateAnnotation] = None + self, + transaction_sequence=None, + annotations: List[StateAnnotation] = None, + constraints: Constraints = None, ) -> None: """Constructor for the world state. Initializes the accounts record. @@ -44,17 +50,22 @@ class WorldState: self._accounts = {} # type: Dict[int, Account] self.balances = Array("balance", 256, 256) self.starting_balances = copy(self.balances) + self.constraints = constraints or Constraints() self.node = None # type: Optional['Node'] self.transaction_sequence = transaction_sequence or [] self._annotations = annotations or [] def merge_states(self, state: "WorldState"): + # combine annotations self._annotations += state._annotations - c1 = self.node.constraints.compress() - c2 = state.node.constraints.compress() - self.node.constraints = Constraints([Or(c1, c2)]) + # Merge constraints + c1 = self.constraints.compress() + c2 = state.constraints.compress() + self.constraints = Constraints([Or(c1, c2)]) + + # Merge accounts for address, account in state.accounts.items(): if address not in self._accounts: self.put_account(account) @@ -63,6 +74,8 @@ class WorldState: ## Merge balances + + @property def accounts(self): return self._accounts @@ -95,6 +108,7 @@ class WorldState: for account in self._accounts.values(): new_world_state.put_account(copy(account)) new_world_state.node = self.node + new_world_state.constraints = copy(self.constraints) return new_world_state def accounts_exist_or_load(self, addr: str, dynamic_loader: DynLoader) -> str: diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 9e3a9fdc..f3977997 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -249,7 +249,9 @@ class LaserEVM: log.debug("Encountered unimplemented instruction") continue new_states = [ - state for state in new_states if state.mstate.constraints.is_possible + state + for state in new_states + if state.world_state.constraints.is_possible ] self.manage_cfg(op_code, new_states) # TODO: What about op_code is None? @@ -345,8 +347,8 @@ class LaserEVM: global_state.transaction_stack ) + [(start_signal.transaction, global_state)] new_global_state.node = global_state.node - new_global_state.mstate.constraints = ( - start_signal.global_state.mstate.constraints + new_global_state.world_state.constraints = ( + start_signal.global_state.world_state.constraints ) log.debug("Starting new transaction %s", start_signal.transaction) @@ -367,9 +369,6 @@ class LaserEVM: ) and not end_signal.revert: check_potential_issues(global_state) end_signal.global_state.world_state.node = global_state.node - end_signal.global_state.world_state.node.constraints += ( - end_signal.global_state.mstate.constraints - ) self._add_world_state(end_signal.global_state) new_global_states = [] @@ -417,7 +416,9 @@ class LaserEVM: :return: """ - return_global_state.mstate.constraints += global_state.mstate.constraints + return_global_state.world_state.constraints += ( + global_state.world_state.constraints + ) # Resume execution of the transaction initializing instruction op_code = return_global_state.environment.code.instruction_list[ return_global_state.mstate.pc @@ -465,12 +466,12 @@ class LaserEVM: assert len(new_states) <= 2 for state in new_states: self._new_node_state( - state, JumpType.CONDITIONAL, state.mstate.constraints[-1] + state, JumpType.CONDITIONAL, state.world_state.constraints[-1] ) elif opcode in ("SLOAD", "SSTORE") and len(new_states) > 1: for state in new_states: self._new_node_state( - state, JumpType.CONDITIONAL, state.mstate.constraints[-1] + state, JumpType.CONDITIONAL, state.world_state.constraints[-1] ) elif opcode == "RETURN": for state in new_states: @@ -491,7 +492,7 @@ class LaserEVM: new_node = Node(state.environment.active_account.contract_name) old_node = state.node state.node = new_node - new_node.constraints = state.mstate.constraints + new_node.constraints = state.world_state.constraints if self.requires_statespace: self.nodes[new_node.uid] = new_node self.edges.append( diff --git a/mythril/laser/ethereum/transaction/concolic.py b/mythril/laser/ethereum/transaction/concolic.py index f995a584..07d97816 100644 --- a/mythril/laser/ethereum/transaction/concolic.py +++ b/mythril/laser/ethereum/transaction/concolic.py @@ -88,8 +88,7 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None: condition=None, ) ) - global_state.mstate.constraints += transaction.world_state.node.constraints - new_node.constraints = global_state.mstate.constraints + new_node.constraints = global_state.world_state.constraints global_state.world_state.transaction_sequence.append(transaction) global_state.node = new_node diff --git a/mythril/laser/ethereum/transaction/symbolic.py b/mythril/laser/ethereum/transaction/symbolic.py index 09ee0173..1a5178e3 100644 --- a/mythril/laser/ethereum/transaction/symbolic.py +++ b/mythril/laser/ethereum/transaction/symbolic.py @@ -162,7 +162,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) - global_state = transaction.initial_global_state() global_state.transaction_stack.append((transaction, None)) - global_state.mstate.constraints.append( + global_state.world_state.constraints.append( Or(*[transaction.caller == actor for actor in ACTORS.addresses.values()]) ) @@ -183,9 +183,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) - condition=None, ) ) - - global_state.mstate.constraints += transaction.world_state.node.constraints - new_node.constraints = global_state.mstate.constraints + new_node.constraints = global_state.world_state.constraints global_state.world_state.transaction_sequence.append(transaction) global_state.node = new_node diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index b36cc61d..71630a07 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -126,7 +126,7 @@ class BaseTransaction: else symbol_factory.BitVecVal(environment.callvalue, 256) ) - global_state.mstate.constraints.append( + global_state.world_state.constraints.append( UGE(global_state.world_state.balances[sender], value) ) global_state.world_state.balances[receiver] += value diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 86ded2ed..4dba9e6f 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -18,7 +18,6 @@ from mythril.laser.smt.bitvec_helper import ( LShR, ) -from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray @@ -80,44 +79,6 @@ class SymbolFactory(Generic[T, U]): """ raise NotImplementedError() - @staticmethod - def BitVecFuncVal( - value: int, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value. - - :param value: The concrete value to set the bit vector to - :param func_name: The name of the bit vector function - :param size: The size of the bit vector - :param annotations: The annotations to initialize the bit vector with - :param input_: The input to the bit vector function - :return: The freshly created bit vector function - """ - raise NotImplementedError() - - @staticmethod - def BitVecFuncSym( - name: str, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value. - - :param name: The name of the symbolic bit vector - :param func_name: The name of the bit vector function - :param size: The size of the bit vector - :param annotations: The annotations to initialize the bit vector with - :param input_: The input to the bit vector function - :return: The freshly created bit vector function - """ - raise NotImplementedError() - class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): """ @@ -159,30 +120,6 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): raw = z3.BitVec(name, size) return BitVec(raw, annotations) - @staticmethod - def BitVecFuncVal( - value: int, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a concrete value.""" - raw = z3.BitVecVal(value, size) - return BitVecFunc(raw, func_name, input_, annotations) - - @staticmethod - def BitVecFuncSym( - name: str, - func_name: str, - size: int, - annotations: Annotations = None, - input_: "BitVec" = None, - ) -> BitVecFunc: - """Creates a new bit vector function with a symbolic value.""" - raw = z3.BitVec(name, size) - return BitVecFunc(raw, func_name, input_, annotations) - class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): """ diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index b308e863..22acc1c3 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -66,8 +66,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other + self if isinstance(other, int): return BitVec(self.raw + other, annotations=self.annotations) @@ -80,8 +78,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other - self if isinstance(other, int): return BitVec(self.raw - other, annotations=self.annotations) @@ -94,8 +90,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other * self union = self.annotations.union(other.annotations) return BitVec(self.raw * other.raw, annotations=union) @@ -105,8 +99,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other / self union = self.annotations.union(other.annotations) return BitVec(self.raw / other.raw, annotations=union) @@ -116,8 +108,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other & self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -129,8 +119,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other | self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -142,8 +130,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other ^ self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -155,8 +141,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other > self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -168,8 +152,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other < self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations.union(other.annotations) @@ -204,8 +186,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other == self if not isinstance(other, BitVec): return Bool( cast(z3.BoolRef, self.raw == other), annotations=self.annotations @@ -224,8 +204,6 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - if isinstance(other, BitVecFunc): - return other != self if not isinstance(other, BitVec): return Bool( cast(z3.BoolRef, self.raw != other), annotations=self.annotations @@ -244,8 +222,6 @@ class BitVec(Expression[z3.BitVecRef]): :param operator: The shift operator :return: the resulting output """ - if isinstance(other, BitVecFunc): - return operator(other, self) if not isinstance(other, BitVec): return BitVec( operator(self.raw, other), annotations=self.annotations @@ -275,7 +251,3 @@ class BitVec(Expression[z3.BitVecRef]): :return: """ return self.raw.__hash__() - - -# TODO: Fix circular import issues -from mythril.laser.smt.bitvecfunc import BitVecFunc diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index c1f60607..3c09a62f 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -3,31 +3,18 @@ import z3 from mythril.laser.smt.bool import Bool, Or from mythril.laser.smt.bitvec import BitVec -from mythril.laser.smt.bitvecfunc import BitVecFunc -from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper -from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper Annotations = Set[Any] -def _comparison_helper( - a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool -) -> Bool: +def _comparison_helper(a: BitVec, b: BitVec, operation: Callable) -> Bool: annotations = a.annotations.union(b.annotations) - if isinstance(a, BitVecFunc): - return _func_comparison_helper(a, b, operation, default_value, inputs_equal) return Bool(operation(a.raw, b.raw), annotations) def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: raw = operation(a.raw, b.raw) union = a.annotations.union(b.annotations) - - if isinstance(a, BitVecFunc): - return _func_arithmetic_helper(a, b, operation) - elif isinstance(b, BitVecFunc): - return _func_arithmetic_helper(b, a, operation) - return BitVec(raw, annotations=union) @@ -43,8 +30,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi :param c: :return: """ - # TODO: Handle BitVecFunc - if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) if not isinstance(b, BitVec): @@ -52,19 +37,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi if not isinstance(c, BitVec): c = BitVec(z3.BitVecVal(c, 256)) union = a.annotations.union(b.annotations).union(c.annotations) - - bvf = [] # type: List[BitVecFunc] - if isinstance(a, BitVecFunc): - bvf += [a] - if isinstance(b, BitVecFunc): - bvf += [b] - if isinstance(c, BitVecFunc): - bvf += [c] - if bvf: - raw = z3.If(a.raw, b.raw, c.raw) - nested_functions = [nf for func in bvf for nf in func.nested_functions] + bvf - return BitVecFunc(raw, func_name="Hybrid", nested_functions=nested_functions) - return BitVec(z3.If(a.raw, b.raw, c.raw), union) @@ -75,7 +47,7 @@ def UGT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) + return _comparison_helper(a, b, z3.UGT) def UGE(a: BitVec, b: BitVec) -> Bool: @@ -95,7 +67,7 @@ def ULT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) + return _comparison_helper(a, b, z3.ULT) def ULE(a: BitVec, b: BitVec) -> Bool: @@ -133,21 +105,8 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: nraw = z3.Concat([a.raw for a in bvs]) annotations = set() # type: Annotations - nested_functions = [] # type: List[BitVecFunc] for bv in bvs: annotations = annotations.union(bv.annotations) - if isinstance(bv, BitVecFunc): - nested_functions += bv.nested_functions - nested_functions += [bv] - - if nested_functions: - return BitVecFunc( - raw=nraw, - func_name="Hybrid", - input_=BitVec(z3.BitVec("", 256), annotations=annotations), - nested_functions=nested_functions, - ) - return BitVec(nraw, annotations) @@ -160,16 +119,6 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :return: """ raw = z3.Extract(high, low, bv.raw) - if isinstance(bv, BitVecFunc): - input_string = "" - # Is there a better value to set func_name and input to in this case? - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations), - nested_functions=bv.nested_functions + [bv], - ) - return BitVec(raw, annotations=bv.annotations) @@ -210,34 +159,9 @@ def Sum(*args: BitVec) -> BitVec: """ raw = z3.Sum([a.raw for a in args]) annotations = set() # type: Annotations - bitvecfuncs = [] for bv in args: annotations = annotations.union(bv.annotations) - if isinstance(bv, BitVecFunc): - bitvecfuncs.append(bv) - - nested_functions = [ - nf for func in bitvecfuncs for nf in func.nested_functions - ] + bitvecfuncs - - if len(bitvecfuncs) >= 2: - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=None, - annotations=annotations, - nested_functions=nested_functions, - ) - elif len(bitvecfuncs) == 1: - return BitVecFunc( - raw=raw, - func_name=bitvecfuncs[0].func_name, - input_=bitvecfuncs[0].input_, - annotations=annotations, - nested_functions=nested_functions, - ) - return BitVec(raw, annotations) @@ -288,3 +212,5 @@ def BVSubNoUnderflow( b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed)) + +def Lambda(var: BitVec, ) diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py deleted file mode 100644 index e5bdfec4..00000000 --- a/mythril/laser/smt/bitvecfunc.py +++ /dev/null @@ -1,297 +0,0 @@ -import operator -from itertools import product -from typing import Optional, Union, cast, Callable, List -import z3 - -from mythril.laser.smt.bitvec import BitVec, Annotations, _padded_operation -from mythril.laser.smt.bool import Or, Bool, And - - -def _arithmetic_helper( - a: "BitVecFunc", b: Union[BitVec, int], operation: Callable -) -> "BitVecFunc": - """ - Helper function for arithmetic operations on BitVecFuncs. - - :param a: The BitVecFunc to perform the operation on. - :param b: A BitVec or int to perform the operation on. - :param operation: The arithmetic operation to perform. - :return: The resulting BitVecFunc - """ - if isinstance(b, int): - b = BitVec(z3.BitVecVal(b, a.size())) - - raw = operation(a.raw, b.raw) - union = a.annotations.union(b.annotations) - - if isinstance(b, BitVecFunc): - return BitVecFunc( - raw=raw, - func_name="Hybrid", - input_=BitVec(z3.BitVec("", 256), annotations=union), - nested_functions=a.nested_functions + b.nested_functions + [a, b], - ) - - return BitVecFunc( - raw=raw, - func_name=a.func_name, - input_=a.input_, - annotations=union, - nested_functions=a.nested_functions + [a], - ) - - -def _comparison_helper( - a: "BitVecFunc", - b: Union[BitVec, int], - operation: Callable, - default_value: bool, - inputs_equal: bool, -) -> Bool: - """ - Helper function for comparison operations with BitVecFuncs. - - :param a: The BitVecFunc to compare. - :param b: A BitVec or int to compare to. - :param operation: The comparison operation to perform. - :return: The resulting Bool - """ - # Is there some hack for gt/lt comparisons? - if isinstance(b, int): - b = BitVec(z3.BitVecVal(b, a.size())) - union = a.annotations.union(b.annotations) - - if not a.symbolic and not b.symbolic: - if operation == z3.UGT: - operation = operator.gt - if operation == z3.ULT: - operation = operator.lt - return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) - if ( - not isinstance(b, BitVecFunc) - or not a.func_name - or not a.input_ - or not a.func_name == b.func_name - or str(operation) not in ("", "") - ): - return Bool(z3.BoolVal(default_value), annotations=union) - - condition = True - for a_nest, b_nest in product(a.nested_functions, b.nested_functions): - if a_nest.func_name != b_nest.func_name: - continue - if a_nest.func_name == "Hybrid": - continue - # a.input (eq/neq) b.input ==> a == b - if inputs_equal: - condition = z3.And( - condition, - z3.Or( - z3.Not((a_nest.input_ == b_nest.input_).raw), - (a_nest.raw == b_nest.raw), - ), - z3.Or( - z3.Not((a_nest.raw == b_nest.raw)), - (a_nest.input_ == b_nest.input_).raw, - ), - ) - else: - condition = z3.And( - condition, - z3.Or( - z3.Not((a_nest.input_ != b_nest.input_).raw), - (a_nest.raw == b_nest.raw), - ), - z3.Or( - z3.Not((a_nest.raw == b_nest.raw)), - (a_nest.input_ != b_nest.input_).raw, - ), - ) - - return And( - Bool( - cast(z3.BoolRef, _padded_operation(a.raw, b.raw, operation)), - annotations=union, - ), - Bool(condition) if b.nested_functions else Bool(True), - a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, - ) - - -class BitVecFunc(BitVec): - """A bit vector function symbol. Used in place of functions like sha3.""" - - def __init__( - self, - raw: z3.BitVecRef, - func_name: Optional[str], - input_: "BitVec" = None, - annotations: Optional[Annotations] = None, - nested_functions: Optional[List["BitVecFunc"]] = None, - ): - """ - - :param raw: The raw bit vector symbol - :param func_name: The function name. e.g. sha3 - :param input: The input to the functions - :param annotations: The annotations the BitVecFunc should start with - """ - - self.func_name = func_name - self.input_ = input_ - self.nested_functions = nested_functions or [] - self.nested_functions = list(dict.fromkeys(self.nested_functions)) - if isinstance(input_, BitVecFunc): - self.nested_functions.extend(input_.nested_functions) - super().__init__(raw, annotations) - - def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an addition expression. - - :param other: The int or BitVec to add to this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.add) - - def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create a subtraction expression. - - :param other: The int or BitVec to subtract from this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.sub) - - def __mul__(self, other: "BitVec") -> "BitVecFunc": - """Create a multiplication expression. - - :param other: The int or BitVec to multiply to this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.mul) - - def __truediv__(self, other: "BitVec") -> "BitVecFunc": - """Create a signed division expression. - - :param other: The int or BitVec to divide this BitVecFunc by - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.truediv) - - def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an and expression. - - :param other: The int or BitVec to and with this BitVecFunc - :return: The resulting BitVecFunc - """ - return _arithmetic_helper(self, other, operator.and_) - - def __or__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an or expression. - - :param other: The int or BitVec to or with this BitVecFunc - :return: The resulting BitVecFunc - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _arithmetic_helper(self, other, operator.or_) - - def __xor__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create a xor expression. - - :param other: The int or BitVec to xor with this BitVecFunc - :return: The resulting BitVecFunc - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _arithmetic_helper(self, other, operator.xor) - - def __lt__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed less than expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.lt, default_value=False, inputs_equal=False - ) - - def __gt__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed greater than expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.gt, default_value=False, inputs_equal=False - ) - - def __le__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed less than or equal to expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return Or(self < other, self == other) - - def __ge__(self, other: Union[int, "BitVec"]) -> Bool: - """Create a signed greater than or equal to expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return Or(self > other, self == other) - - # MYPY: fix complains about overriding __eq__ - def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore - """Create an equality expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - return _comparison_helper( - self, other, operator.eq, default_value=False, inputs_equal=True - ) - - # MYPY: fix complains about overriding __ne__ - def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore - """Create an inequality expression. - - :param other: The int or BitVec to compare to this BitVecFunc - :return: The resulting Bool - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - return _comparison_helper( - self, other, operator.ne, default_value=True, inputs_equal=False - ) - - def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec": - """ - Left shift operation - :param other: The int or BitVec to shift on - :return The resulting left shifted output - """ - return _arithmetic_helper(self, other, operator.lshift) - - def __rshift__(self, other: Union[int, "BitVec"]) -> "BitVec": - """ - Right shift operation - :param other: The int or BitVec to shift on - :return The resulting right shifted output: - """ - return _arithmetic_helper(self, other, operator.rshift) - - def __hash__(self) -> int: - return self.raw.__hash__() diff --git a/tests/instructions/static_call_test.py b/tests/instructions/static_call_test.py index b9bcc26c..70cbccdc 100644 --- a/tests/instructions/static_call_test.py +++ b/tests/instructions/static_call_test.py @@ -121,4 +121,4 @@ def test_staticness_call_symbolic(f1): instruction.evaluate(state) assert ts.value.transaction.static - assert ts.value.global_state.mstate.constraints[-1] == (call_value == 0) + assert ts.value.global_state.world_state.constraints[-1] == (call_value == 0) diff --git a/tests/laser/smt/bitvecfunc_test.py b/tests/laser/smt/bitvecfunc_test.py deleted file mode 100644 index 37217c73..00000000 --- a/tests/laser/smt/bitvecfunc_test.py +++ /dev/null @@ -1,237 +0,0 @@ -from mythril.laser.smt import Solver, symbol_factory, UGT, UGE, ULT, ULE -import z3 -import pytest - -import operator - - -@pytest.mark.parametrize( - "operation,expected", - [ - (operator.add, z3.unsat), - (operator.sub, z3.unsat), - (operator.and_, z3.sat), - (operator.or_, z3.sat), - (operator.xor, z3.unsat), - ], -) -def test_bitvecfunc_arithmetic(operation, expected): - # Arrange - s = Solver() - - input_ = symbol_factory.BitVecVal(1, 8) - bvf = symbol_factory.BitVecFuncSym("bvf", "sha3", 256, input_=input_) - - x = symbol_factory.BitVecSym("x", 256) - y = symbol_factory.BitVecSym("y", 256) - - # Act - s.add(x != y) - s.add(operation(bvf, x) == operation(y, bvf)) - - # Assert - assert s.check() == expected - - -@pytest.mark.parametrize( - "operation,expected", - [ - (operator.eq, z3.sat), - (operator.ne, z3.unsat), - (operator.lt, z3.unsat), - (operator.le, z3.sat), - (operator.gt, z3.unsat), - (operator.ge, z3.sat), - (UGT, z3.unsat), - (UGE, z3.sat), - (ULT, z3.unsat), - (ULE, z3.sat), - ], -) -def test_bitvecfunc_bitvecfunc_comparison(operation, expected): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=input2) - - # Act - s.add(operation(bvf1, bvf2)) - s.add(input1 == input2) - - # Assert - assert s.check() == expected - - -def test_bitvecfunc_bitvecfuncval_comparison(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecVal(1337, 256) - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncVal(12345678910, "sha3", 256, input_=input2) - - # Act - s.add(bvf1 == bvf2) - - # Assert - assert s.check() == z3.sat - assert s.model().eval(input2.raw) == 1337 - - -def test_bitvecfunc_nested_comparison(): - # arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) - - # Act - s.add(input1 == input2) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.sat - - -def test_bitvecfunc_unequal_nested_comparison(): - # arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) - - # Act - s.add(input1 != input2) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.unsat - - -def test_bitvecfunc_ext_nested_comparison(): - # arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - input3 = symbol_factory.BitVecSym("input3", 256) - input4 = symbol_factory.BitVecSym("input4", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) - - # Act - s.add(input1 == input2) - s.add(input3 == input4) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.sat - - -def test_bitvecfunc_ext_unequal_nested_comparison(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - input3 = symbol_factory.BitVecSym("input3", 256) - input4 = symbol_factory.BitVecSym("input4", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) - - # Act - s.add(input1 == input2) - s.add(input3 != input4) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.unsat - - -def test_bitvecfunc_ext_unequal_nested_comparison_f(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - input3 = symbol_factory.BitVecSym("input3", 256) - input4 = symbol_factory.BitVecSym("input4", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) - - # Act - s.add(input1 != input2) - s.add(input3 == input4) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.unsat - - -def test_bitvecfunc_find_input(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - - # Act - s.add(input1 == symbol_factory.BitVecVal(1, 256)) - s.add(bvf1 == bvf2) - - # Assert - assert s.check() == z3.sat - assert s.model()[input2.raw] == 1 - - -def test_bitvecfunc_nested_find_input(): - # Arrange - s = Solver() - - input1 = symbol_factory.BitVecSym("input1", 256) - input2 = symbol_factory.BitVecSym("input2", 256) - - bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) - bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) - - bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) - bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) - - # Act - s.add(input1 == symbol_factory.BitVecVal(123, 256)) - s.add(bvf2 == bvf4) - - # Assert - assert s.check() == z3.sat - assert s.model()[input2.raw] == 123