Introduce call depth limit

call_depth
Bernhard Mueller 5 years ago
parent 698b1d423f
commit 861b2ecfb3
  1. 107
      mythril/analysis/module/modules/user_assertions.py
  2. 45
      mythril/laser/ethereum/instructions.py
  3. 4
      mythril/laser/ethereum/svm.py
  4. 8
      mythril/laser/ethereum/transaction/transaction_models.py

@ -1,13 +1,12 @@
"""This module contains the detection code for potentially insecure low-level """This module contains the detection code for potentially insecure low-level
calls.""" calls."""
from mythril.analysis.potential_issues import ( from mythril.analysis import solver
PotentialIssue, from mythril.analysis.potential_issues import Issue
get_potential_issues_annotation,
)
from mythril.analysis.swc_data import ASSERT_VIOLATION from mythril.analysis.swc_data import ASSERT_VIOLATION
from mythril.analysis.module.base import DetectionModule, EntryPoint from mythril.analysis.module.base import DetectionModule, EntryPoint
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.exceptions import UnsatError
import logging import logging
import eth_abi import eth_abi
@ -40,10 +39,12 @@ class UserAssertions(DetectionModule):
:param state: :param state:
:return: :return:
""" """
potential_issues = self._analyze_state(state)
annotation = get_potential_issues_annotation(state) issues = self._analyze_state(state)
annotation.potential_issues.extend(potential_issues) for issue in issues:
self.cache.add(issue.address)
self.issues.extend(issues)
def _analyze_state(self, state: GlobalState): def _analyze_state(self, state: GlobalState):
""" """
@ -51,46 +52,60 @@ class UserAssertions(DetectionModule):
:param state: :param state:
:return: :return:
""" """
topic, size, mem_start = state.mstate.stack[-3:]
try:
if topic.symbolic or topic.value != assertion_failed_hash:
return [] transaction_sequence = solver.get_transaction_sequence(
state, state.world_state.constraints
message = None
if not mem_start.symbolic and not size.symbolic:
message = eth_abi.decode_single(
"string",
bytes(
state.mstate.memory[
mem_start.value + 32 : mem_start.value + size.value
]
),
).decode("utf8")
description_head = "A user-provided assertion failed."
if message:
description_tail = "A user-provided assertion failed with the message '{}'".format(
message
) )
else:
description_tail = "A user-provided assertion failed." topic, size, mem_start = state.mstate.stack[-3:]
address = state.get_current_instruction()["address"] if topic.symbolic or topic.value != assertion_failed_hash:
issue = PotentialIssue( return []
contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name, message = None
address=address, if not mem_start.symbolic and not size.symbolic:
swc_id=ASSERT_VIOLATION, message = eth_abi.decode_single(
title="Assertion Failed", "string",
bytecode=state.environment.code.bytecode, bytes(
severity="Medium", state.mstate.memory[
description_head=description_head, mem_start.value + 32 : mem_start.value + size.value
description_tail=description_tail, ]
constraints=[], ),
detector=self, ).decode("utf8")
)
if message:
return [issue] description_tail = "A user-provided assertion failed with the message '{}'".format(
message
)
log.info("MythX assertion emitted: {}".format(message))
else:
description_tail = "A user-provided assertion failed."
address = state.get_current_instruction()["address"]
issue = Issue(
contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name,
address=address,
swc_id=ASSERT_VIOLATION,
title="Exception State",
severity="Medium",
description_head="A user-provided assertion failed.",
description_tail=description_tail,
bytecode=state.environment.code.bytecode,
transaction_sequence=transaction_sequence,
gas_used=(state.mstate.min_gas_used, state.mstate.max_gas_used),
)
return [issue]
except UnsatError:
log.debug("no model found")
return []
detector = UserAssertions() detector = UserAssertions()

@ -62,6 +62,7 @@ log = logging.getLogger(__name__)
TT256 = 2 ** 256 TT256 = 2 ** 256
TT256M1 = 2 ** 256 - 1 TT256M1 = 2 ** 256 - 1
MAX_CALL_DEPTH = 5
def transfer_ether( def transfer_ether(
global_state: GlobalState, global_state: GlobalState,
@ -1909,6 +1910,21 @@ class Instruction:
environment = global_state.environment environment = global_state.environment
memory_out_size, memory_out_offset = global_state.mstate.stack[-7:-5] memory_out_size, memory_out_offset = global_state.mstate.stack[-7:-5]
if(len(global_state.transaction_stack) >= MAX_CALL_DEPTH):
log.info(
"Max call depth reached"
)
self._write_symbolic_returndata(
global_state, memory_out_offset, memory_out_size
)
global_state.mstate.stack.append(
global_state.new_bitvec("retval_" + str(instr["address"]), 256)
)
return [global_state]
try: try:
( (
callee_address, callee_address,
@ -2137,6 +2153,20 @@ class Instruction:
environment = global_state.environment environment = global_state.environment
memory_out_size, memory_out_offset = global_state.mstate.stack[-6:-4] memory_out_size, memory_out_offset = global_state.mstate.stack[-6:-4]
if(len(global_state.transaction_stack) >= MAX_CALL_DEPTH):
log.info(
"Max call depth reached"
)
self._write_symbolic_returndata(
global_state, memory_out_offset, memory_out_size
)
global_state.mstate.stack.append(
global_state.new_bitvec("retval_" + str(instr["address"]), 256)
)
return [global_state]
try: try:
( (
callee_address, callee_address,
@ -2272,6 +2302,21 @@ class Instruction:
instr = global_state.get_current_instruction() instr = global_state.get_current_instruction()
environment = global_state.environment environment = global_state.environment
memory_out_size, memory_out_offset = global_state.mstate.stack[-6:-4] memory_out_size, memory_out_offset = global_state.mstate.stack[-6:-4]
if(len(global_state.transaction_stack) >= MAX_CALL_DEPTH):
log.info(
"Max call depth reached"
)
self._write_symbolic_returndata(
global_state, memory_out_offset, memory_out_size
)
global_state.mstate.stack.append(
global_state.new_bitvec("retval_" + str(instr["address"]), 256)
)
return [global_state]
try: try:
( (
callee_address, callee_address,

@ -144,7 +144,7 @@ class LaserEVM:
if pre_configuration_mode: if pre_configuration_mode:
self.open_states = [world_state] self.open_states = [world_state]
log.info("Starting message call transaction to {}".format(target_address)) log.info("Starting message call transaction to {:#42x}".format(target_address))
self._execute_transactions(symbol_factory.BitVecVal(target_address, 256)) self._execute_transactions(symbol_factory.BitVecVal(target_address, 256))
elif scratch_mode: elif scratch_mode:
@ -356,7 +356,7 @@ class LaserEVM:
start_signal.transaction.call_value, start_signal.transaction.call_value,
) )
log.debug("Starting new transaction %s", start_signal.transaction) log.debug("Starting new transaction %s, call stack size %d", start_signal.transaction, len(new_global_state.transaction_stack))
return [new_global_state], op_code return [new_global_state], op_code

@ -138,10 +138,16 @@ class BaseTransaction:
raise NotImplementedError raise NotImplementedError
def __str__(self) -> str: def __str__(self) -> str:
try:
_caller = "{:#42x}".format(int(str(self.caller)))
except:
_caller = str(self.caller)
return "{} {} from {} to {:#42x}".format( return "{} {} from {} to {:#42x}".format(
self.__class__.__name__, self.__class__.__name__,
self.id, self.id,
self.caller, _caller,
int(str(self.callee_account.address)) if self.callee_account else -1, int(str(self.callee_account.address)) if self.callee_account else -1,
) )

Loading…
Cancel
Save