diff --git a/mythril/analysis/module/modules/arbitrary_jump.py b/mythril/analysis/module/modules/arbitrary_jump.py index bb4da658..35154f0f 100644 --- a/mythril/analysis/module/modules/arbitrary_jump.py +++ b/mythril/analysis/module/modules/arbitrary_jump.py @@ -1,11 +1,16 @@ """This module contains the detection code for Arbitrary jumps.""" import logging + +from mythril.exceptions import UnsatError + from mythril.analysis.solver import get_transaction_sequence, UnsatError from mythril.analysis.issue_annotation import IssueAnnotation from mythril.analysis.module.base import DetectionModule, Issue, EntryPoint from mythril.analysis.swc_data import ARBITRARY_JUMP + from mythril.laser.ethereum.state.global_state import GlobalState -from mythril.laser.smt import And +from mythril.laser.smt import And, BitVec, symbol_factory +from mythril.support.model import get_model log = logging.getLogger(__name__) @@ -15,6 +20,27 @@ Search for jumps to arbitrary locations in the bytecode """ +def is_unique_jumpdest(jump_dest: BitVec, state: GlobalState) -> bool: + """ + Handles cases where jump_dest evaluates to a single concrete value + """ + + try: + model = get_model(state.world_state.constraints) + except UnsatError: + return True + concrete_jump_dest = model.eval(jump_dest.raw, model_completion=True) + + try: + model = get_model( + state.world_state.constraints + + [symbol_factory.BitVecVal(concrete_jump_dest.as_long(), 256) != jump_dest] + ) + except UnsatError: + return True + return False + + class ArbitraryJump(DetectionModule): """This module searches for JUMPs to a user-specified location.""" @@ -49,7 +75,10 @@ class ArbitraryJump(DetectionModule): jump_dest = state.mstate.stack[-1] if jump_dest.symbolic is False: return [] - # Most probably the jump destination can have multiple locations in these circumstances + + if is_unique_jumpdest(jump_dest, state) is True: + return [] + try: transaction_sequence = get_transaction_sequence( state, state.world_state.constraints diff --git a/tests/analysis_tests/arbitrary_jump_test.py b/tests/analysis_tests/arbitrary_jump_test.py new file mode 100644 index 00000000..f2d0f233 --- /dev/null +++ b/tests/analysis_tests/arbitrary_jump_test.py @@ -0,0 +1,100 @@ +import pytest +from mock import patch + +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.smt import symbol_factory, BitVec +from mythril.laser.ethereum.state.environment import Environment +from mythril.laser.ethereum.state.account import Account +from mythril.laser.ethereum.state.machine_state import MachineState +from mythril.laser.ethereum.state.constraints import Constraints +from mythril.laser.ethereum.state.global_state import GlobalState +from mythril.laser.ethereum.state.world_state import WorldState +from mythril.laser.ethereum.instructions import Instruction +from mythril.laser.ethereum.transaction.symbolic import ACTORS +from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction +from mythril.laser.ethereum.call import SymbolicCalldata +from mythril.laser.ethereum.transaction import TransactionStartSignal +from mythril.analysis.module.modules.arbitrary_jump import ( + is_unique_jumpdest, + ArbitraryJump, +) +from mythril.laser.ethereum.time_handler import time_handler + + +def get_global_state(constraints): + """Constructs an arbitrary global state + + Args: + constraints (List[BitVec]): Constraints list for the global state + + Returns: + [GlobalState]: An arbitrary global state + """ + active_account = Account("0x0", code=Disassembly("60606040")) + environment = Environment( + active_account, None, SymbolicCalldata("2"), None, None, None, None + ) + world_state = WorldState() + world_state.put_account(active_account) + state = GlobalState(world_state, environment, None, MachineState(gas_limit=8000000)) + print(world_state.balances) + state.world_state.transaction_sequence = [ + MessageCallTransaction( + world_state=world_state, + gas_limit=8000000, + init_call_data=True, + call_value=symbol_factory.BitVecSym("call_value", 256), + caller=ACTORS.attacker, + callee_account=active_account, + ) + ] + state.transaction_stack.append( + ( + MessageCallTransaction( + world_state=world_state, gas_limit=8000000, init_call_data=True + ), + None, + ) + ) + print(state.world_state.transaction_sequence[0].call_data.calldatasize) + state.mstate.stack = [symbol_factory.BitVecSym("jump_dest", 256)] + + state.world_state.constraints = Constraints(constraints) + return state + + +test_data = ( + ( + get_global_state([symbol_factory.BitVecSym("jump_dest", 256) == 222]), + True, + ), + ( + get_global_state([symbol_factory.BitVecSym("jump_dest", 256) > 222]), + False, + ), +) + + +@pytest.mark.parametrize("global_state, unique", test_data) +def test_unique_jumpdest(global_state, unique): + time_handler.start_execution(10) + assert is_unique_jumpdest(global_state.mstate.stack[-1], global_state) == unique + + +test_data = ( + ( + get_global_state([symbol_factory.BitVecSym("jump_dest", 256) == 222]), + False, + ), + ( + get_global_state([symbol_factory.BitVecSym("jump_dest", 256) > 222]), + True, + ), +) + + +@pytest.mark.parametrize("global_state, has_issue", test_data) +def test_module(global_state, has_issue): + time_handler.start_execution(10) + module = ArbitraryJump() + assert (len(module._analyze_state(global_state)) > 0) == has_issue