diff --git a/mythril/analysis/modules/ether_thief.py b/mythril/analysis/modules/ether_thief.py index 293b4f69..16b3bdb9 100644 --- a/mythril/analysis/modules/ether_thief.py +++ b/mythril/analysis/modules/ether_thief.py @@ -7,14 +7,16 @@ from copy import copy from mythril.analysis import solver from mythril.analysis.modules.base import DetectionModule from mythril.analysis.report import Issue -from mythril.laser.ethereum.transaction.symbolic import ATTACKER_ADDRESS -from mythril.laser.ethereum.transaction.transaction_models import ( - ContractCreationTransaction, +from mythril.laser.ethereum.transaction.symbolic import ( + ATTACKER_ADDRESS, + CREATOR_ADDRESS, ) from mythril.analysis.swc_data import UNPROTECTED_ETHER_WITHDRAWAL from mythril.exceptions import UnsatError +from mythril.laser.ethereum.transaction import ContractCreationTransaction from mythril.laser.ethereum.state.global_state import GlobalState -from mythril.laser.smt import UGT, BVAddNoOverflow, Sum, symbol_factory +from mythril.laser.smt import UGT, Sum, symbol_factory, BVAddNoOverflow +from mythril.laser.smt.bitvec import If log = logging.getLogger(__name__) @@ -77,21 +79,44 @@ class EtherThief(DetectionModule): address = instruction["address"] if self._cache_addresses.get(address, False): return [] - call_value = state.mstate.stack[-3] + value = state.mstate.stack[-3] target = state.mstate.stack[-2] - eth_sent_total = symbol_factory.BitVecVal(0, 256) + eth_sent_by_attacker = symbol_factory.BitVecVal(0, 256) constraints = copy(state.mstate.constraints) for tx in state.world_state.transaction_sequence: - constraints += [BVAddNoOverflow(eth_sent_total, tx.call_value, False)] - eth_sent_total = Sum(eth_sent_total, tx.call_value) + """ + Constraint: The call value must be greater than the sum of Ether sent by the attacker over all + transactions. This prevents false positives caused by legitimate refund functions. + Also constrain the addition from overflowing (otherwise the solver produces solutions with + ridiculously high call values). + """ + constraints += [BVAddNoOverflow(eth_sent_by_attacker, tx.call_value, False)] + eth_sent_by_attacker = Sum( + eth_sent_by_attacker, + tx.call_value * If(tx.caller == ATTACKER_ADDRESS, 1, 0), + ) + + """ + Constraint: All transactions must originate from regular users (not the creator/owner). + This prevents false positives where the owner willingly transfers ownership to another address. + """ if not isinstance(tx, ContractCreationTransaction): - constraints.append(tx.caller == ATTACKER_ADDRESS) + constraints += [tx.caller != CREATOR_ADDRESS] + + """ + Require that the current transaction is sent by the attacker and + that the Ether is sent to the attacker's address. + """ - constraints += [UGT(call_value, eth_sent_total), target == ATTACKER_ADDRESS] + constraints += [ + UGT(value, eth_sent_by_attacker), + target == ATTACKER_ADDRESS, + state.current_transaction.caller == ATTACKER_ADDRESS, + ] try: