diff --git a/mythril/analysis/modules/integer.py b/mythril/analysis/modules/integer.py index 9bd354a7..515a0e85 100644 --- a/mythril/analysis/modules/integer.py +++ b/mythril/analysis/modules/integer.py @@ -167,33 +167,19 @@ class IntegerOverflowUnderflowModule(DetectionModule): constraints = copy.deepcopy(node.constraints) - # Filter for patterns that indicate benign underflows - - # Pattern 1: (96 + calldatasize_MAIN) - (96), where (96 + calldatasize_MAIN) would underflow if calldatasize is very large. - # Pattern 2: (256*If(1 & storage_0 == 0, 1, 0)) - 1, this would underlow if storage_0 = 0 if type(op0) == int and type(op1) == int: return [] - if re.search(r"calldatasize_", str(op0)): - return [] - if re.search(r"256\*.*If\(1", str(op0), re.DOTALL) or re.search( - r"256\*.*If\(1", str(op1), re.DOTALL - ): - return [] - if re.search(r"32 \+.*calldata", str(op0), re.DOTALL) or re.search( - r"32 \+.*calldata", str(op1), re.DOTALL - ): - return [] logging.debug( "[INTEGER_UNDERFLOW] Checking SUB {0}, {1} at address {2}".format( str(op0), str(op1), str(instruction["address"]) ) ) + allowed_types = [int, BitVecRef, BitVecNumRef] if type(op0) in allowed_types and type(op1) in allowed_types: - constraints.append(UGT(op1, op0)) - + constraints.append(Not(BVSubNoUnderflow(op0, op1, signed=False))) try: model = solver.get_model(constraints) diff --git a/mythril/analysis/solver.py b/mythril/analysis/solver.py index 2848f205..b91cc498 100644 --- a/mythril/analysis/solver.py +++ b/mythril/analysis/solver.py @@ -103,7 +103,7 @@ def get_transaction_sequence(global_state, constraints): concrete_transactions[tx_id]["calldata"] = "0x" + "".join( [ hex(b)[2:] if len(hex(b)) % 2 == 0 else "0" + hex(b)[2:] - for b in transaction.call_data.concretized(model) + for b in transaction.call_data.concrete(model) ] ) diff --git a/mythril/laser/ethereum/call.py b/mythril/laser/ethereum/call.py index cc567b89..de14d42e 100644 --- a/mythril/laser/ethereum/call.py +++ b/mythril/laser/ethereum/call.py @@ -3,7 +3,11 @@ from typing import Union from z3 import simplify, ExprRef, Extract import mythril.laser.ethereum.util as util from mythril.laser.ethereum.state.account import Account -from mythril.laser.ethereum.state.calldata import CalldataType, Calldata +from mythril.laser.ethereum.state.calldata import ( + CalldataType, + SymbolicCalldata, + ConcreteCalldata, +) from mythril.laser.ethereum.state.global_state import GlobalState from mythril.support.loader import DynLoader import re @@ -174,12 +178,12 @@ def get_call_data( starting_calldata.append(Extract(j + 7, j, elem)) i += 1 - call_data = Calldata(transaction_id, starting_calldata) + call_data = ConcreteCalldata(transaction_id, starting_calldata) call_data_type = CalldataType.CONCRETE logging.debug("Calldata: " + str(call_data)) except TypeError: logging.debug("Unsupported symbolic calldata offset") call_data_type = CalldataType.SYMBOLIC - call_data = Calldata("{}_internalcall".format(transaction_id)) + call_data = SymbolicCalldata("{}_internalcall".format(transaction_id)) return call_data, call_data_type diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 7ade087d..005244e2 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -42,7 +42,7 @@ from mythril.laser.ethereum.evm_exceptions import ( ) from mythril.laser.ethereum.gas import OPCODE_GAS from mythril.laser.ethereum.keccak import KeccakFunctionManager -from mythril.laser.ethereum.state.calldata import CalldataType, Calldata +from mythril.laser.ethereum.state.calldata import CalldataType from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.transaction import ( MessageCallTransaction, @@ -458,10 +458,9 @@ class Instruction: environment = global_state.environment op0 = state.stack.pop() - value, constraints = environment.calldata.get_word_at(op0) + value = environment.calldata.get_word_at(op0) state.stack.append(value) - state.constraints.extend(constraints) return [global_state] @@ -541,9 +540,8 @@ class Instruction: i_data = dstart new_memory = [] for i in range(size): - value, constraints = environment.calldata[i_data] + value = environment.calldata[i_data] new_memory.append(value) - state.constraints.extend(constraints) i_data = ( i_data + 1 if isinstance(i_data, int) else simplify(i_data + 1) diff --git a/mythril/laser/ethereum/natives.py b/mythril/laser/ethereum/natives.py index ed8a0c4c..f6293831 100644 --- a/mythril/laser/ethereum/natives.py +++ b/mythril/laser/ethereum/natives.py @@ -8,7 +8,7 @@ from ethereum.utils import ecrecover_to_pub from py_ecc.secp256k1 import N as secp256k1n from rlp.utils import ALL_BYTES -from mythril.laser.ethereum.state.calldata import Calldata +from mythril.laser.ethereum.state.calldata import BaseCalldata from mythril.laser.ethereum.util import bytearray_to_int, sha3, get_concrete_int from z3 import Concat, simplify @@ -88,7 +88,7 @@ def identity(data: Union[bytes, str, List[int]]) -> bytes: return result -def native_contracts(address: int, data: Calldata): +def native_contracts(address: int, data: BaseCalldata): """ takes integer address 1, 2, 3, 4 """ diff --git a/mythril/laser/ethereum/state/calldata.py b/mythril/laser/ethereum/state/calldata.py index 1fd83316..3cbbf834 100644 --- a/mythril/laser/ethereum/state/calldata.py +++ b/mythril/laser/ethereum/state/calldata.py @@ -1,17 +1,7 @@ from enum import Enum from typing import Union, Any -from z3 import ( - BitVecVal, - BitVecRef, - BitVecSort, - BitVec, - Implies, - simplify, - Concat, - UGT, - Array, -) -from z3.z3types import Z3Exception +from z3 import BitVecVal, BitVecRef, BitVec, simplify, Concat, If, ExprRef +from z3.z3types import Z3Exception, Model from mythril.laser.ethereum.util import get_concrete_int @@ -21,84 +11,133 @@ class CalldataType(Enum): SYMBOLIC = 2 -class Calldata: +class BaseCalldata: """ - Calldata class representing the calldata of a transaction + Base calldata class + This represents the calldata provided when sending a transaction to a contract """ - def __init__(self, tx_id, starting_calldata=None): - """ - Constructor for Calldata - :param tx_id: unique value representing the transaction the calldata is for - :param starting_calldata: byte array representing the concrete calldata of a transaction - """ + def __init__(self, tx_id): self.tx_id = tx_id - if starting_calldata is not None: - self._calldata = [] - self.calldatasize = BitVecVal(len(starting_calldata), 256) - self.concrete = True - else: - self._calldata = Array( - "{}_calldata".format(self.tx_id), BitVecSort(256), BitVecSort(8) - ) - self.calldatasize = BitVec("{}_calldatasize".format(self.tx_id), 256) - self.concrete = False - - if self.concrete: - for calldata_byte in starting_calldata: - if type(calldata_byte) == int: - self._calldata.append(BitVecVal(calldata_byte, 8)) - else: - self._calldata.append(calldata_byte) - - def concretized(self, model): - result = [] - for i in range( - get_concrete_int(model.eval(self.calldatasize, model_completion=True)) - ): - result.append( - get_concrete_int(model.eval(self._calldata[i], model_completion=True)) - ) + @property + def calldatasize(self) -> ExprRef: + """ + :return: Calldata size for this calldata object + """ + result = self.size + if isinstance(result, int): + return BitVecVal(result, 256) return result - def get_word_at(self, index: int): - return self[index : index + 32] + def get_word_at(self, offset: int) -> ExprRef: + """ Gets word at offset""" + return self[offset : offset + 32] def __getitem__(self, item: Union[int, slice]) -> Any: + if isinstance(item, int) or isinstance(item, ExprRef): + return self._load(item) + if isinstance(item, slice): - start, step, stop = item.start, item.step, item.stop + start = 0 if item.start is None else item.start + step = 1 if item.step is None else item.step + stop = self.size if item.stop is None else item.stop + try: - if start is None: - start = 0 - if step is None: - step = 1 - if stop is None: - stop = self.calldatasize current_index = ( start if isinstance(start, BitVecRef) else BitVecVal(start, 256) ) - dataparts = [] + parts = [] while simplify(current_index != stop): - dataparts.append(self[current_index]) + parts.append(self._load(current_index)) current_index = simplify(current_index + step) except Z3Exception: raise IndexError("Invalid Calldata Slice") - values, constraints = zip(*dataparts) - result_constraints = [] - for c in constraints: - result_constraints.extend(c) - return simplify(Concat(values)), result_constraints + return simplify(Concat(parts)) + + raise ValueError + + def _load(self, item: Union[int, ExprRef]) -> Any: + raise NotImplementedError() + + @property + def size(self) -> Union[ExprRef, int]: + """ Returns the exact size of this calldata, this is not normalized""" + raise NotImplementedError() + + def concrete(self, model: Model) -> list: + """ Returns a concrete version of the calldata using the provided model""" + raise NotImplementedError + + +class ConcreteCalldata(BaseCalldata): + def __init__(self, tx_id: int, calldata: list): + """ + Initializes the ConcreteCalldata object + :param tx_id: Id of the transaction that the calldata is for. + :param calldata: The concrete calldata content + """ + self._calldata = calldata + super().__init__(tx_id) - if self.concrete: + def _load(self, item: Union[int, ExprRef]) -> Any: + if isinstance(item, int): try: - return self._calldata[get_concrete_int(item)], () + return self._calldata[item] except IndexError: - return BitVecVal(0, 8), () - else: - constraints = [ - Implies(self._calldata[item] != 0, UGT(self.calldatasize, item)) - ] + return 0 + + value = BitVecVal(0x0, 8) + for i in range(self.size): + value = If(item == i, self._calldata[i], value) + return value + + def concrete(self, model: Model) -> list: + return self._calldata + + @property + def size(self) -> int: + return len(self._calldata) + + +class SymbolicCalldata(BaseCalldata): + def __init__(self, tx_id: int): + """ + Initializes the SymbolicCalldata object + :param tx_id: Id of the transaction that the calldata is for. + """ + self._reads = [] + self._size = BitVec("calldatasize", 256) + super().__init__(tx_id) + + def _load(self, item: Union[int, ExprRef], clean=False) -> Any: + x = BitVecVal(item, 256) if isinstance(item, int) else item + + symbolic_base_value = If( + x > self._size, + BitVecVal(0, 8), + BitVec("{}_calldata_{}".format(self.tx_id, str(item)), 8), + ) + + return_value = symbolic_base_value + for r_index, r_value in self._reads: + return_value = If(r_index == item, r_value, return_value) + + if not clean: + self._reads.append((item, symbolic_base_value)) + return simplify(return_value) + + def concrete(self, model: Model) -> list: + concrete_length = get_concrete_int(model.eval(self.size, model_completion=True)) + result = [] + for i in range(concrete_length): + value = self._load(i, clean=True) + c_value = get_concrete_int(model.eval(value, model_completion=True)) + result.append(c_value) + + return result - return self._calldata[item], constraints + @property + def size(self) -> ExprRef: + return self._size diff --git a/mythril/laser/ethereum/state/environment.py b/mythril/laser/ethereum/state/environment.py index 91c67d5a..a54e5c42 100644 --- a/mythril/laser/ethereum/state/environment.py +++ b/mythril/laser/ethereum/state/environment.py @@ -3,7 +3,7 @@ from typing import Dict from z3 import ExprRef, BitVecVal from mythril.laser.ethereum.state.account import Account -from mythril.laser.ethereum.state.calldata import Calldata, CalldataType +from mythril.laser.ethereum.state.calldata import CalldataType, BaseCalldata class Environment: @@ -15,7 +15,7 @@ class Environment: self, active_account: Account, sender: ExprRef, - calldata: Calldata, + calldata: BaseCalldata, gasprice: ExprRef, callvalue: ExprRef, origin: ExprRef, diff --git a/mythril/laser/ethereum/transaction/concolic.py b/mythril/laser/ethereum/transaction/concolic.py index 0ec88794..6c7a406d 100644 --- a/mythril/laser/ethereum/transaction/concolic.py +++ b/mythril/laser/ethereum/transaction/concolic.py @@ -6,7 +6,7 @@ from mythril.laser.ethereum.transaction.transaction_models import ( ) from z3 import BitVec from mythril.laser.ethereum.state.environment import Environment -from mythril.laser.ethereum.state.calldata import Calldata, CalldataType +from mythril.laser.ethereum.state.calldata import CalldataType, ConcreteCalldata from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.state.global_state import GlobalState @@ -42,7 +42,7 @@ def execute_message_call( code=Disassembly(code), caller=caller_address, callee_account=open_world_state[callee_address], - call_data=Calldata(next_transaction_id, data), + call_data=ConcreteCalldata(next_transaction_id, data), call_data_type=CalldataType.SYMBOLIC, call_value=value, ) diff --git a/mythril/laser/ethereum/transaction/symbolic.py b/mythril/laser/ethereum/transaction/symbolic.py index 88e16780..be1c0ec6 100644 --- a/mythril/laser/ethereum/transaction/symbolic.py +++ b/mythril/laser/ethereum/transaction/symbolic.py @@ -3,7 +3,11 @@ from logging import debug from mythril.disassembler.disassembly import Disassembly from mythril.laser.ethereum.cfg import Node, Edge, JumpType -from mythril.laser.ethereum.state.calldata import CalldataType, Calldata +from mythril.laser.ethereum.state.calldata import ( + CalldataType, + BaseCalldata, + SymbolicCalldata, +) from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.transaction.transaction_models import ( MessageCallTransaction, @@ -32,7 +36,7 @@ def execute_message_call(laser_evm, callee_address: str) -> None: origin=BitVec("origin{}".format(next_transaction_id), 256), caller=BitVec("caller{}".format(next_transaction_id), 256), callee_account=open_world_state[callee_address], - call_data=Calldata(next_transaction_id), + call_data=SymbolicCalldata(next_transaction_id), call_data_type=CalldataType.SYMBOLIC, call_value=BitVec("call_value{}".format(next_transaction_id), 256), ) diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index bef83a2a..7fe73328 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -2,7 +2,11 @@ import logging from typing import Union from mythril.disassembler.disassembly import Disassembly from mythril.laser.ethereum.state.environment import Environment -from mythril.laser.ethereum.state.calldata import Calldata +from mythril.laser.ethereum.state.calldata import ( + BaseCalldata, + ConcreteCalldata, + SymbolicCalldata, +) from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.state.global_state import GlobalState @@ -75,9 +79,9 @@ class BaseTransaction: self.caller = caller self.callee_account = callee_account if call_data is None and init_call_data: - self.call_data = Calldata(self.id, call_data) + self.call_data = ConcreteCalldata(self.id, call_data) else: - self.call_data = call_data if isinstance(call_data, Calldata) else None + self.call_data = call_data if isinstance(call_data, BaseCalldata) else None self.call_data_type = ( call_data_type if call_data_type is not None diff --git a/tests/laser/state/calldata_test.py b/tests/laser/state/calldata_test.py index 9bce4a50..3c013f76 100644 --- a/tests/laser/state/calldata_test.py +++ b/tests/laser/state/calldata_test.py @@ -1,6 +1,6 @@ import pytest -from mythril.laser.ethereum.state.calldata import Calldata -from z3 import Solver, simplify +from mythril.laser.ethereum.state.calldata import ConcreteCalldata, SymbolicCalldata +from z3 import Solver, simplify, BitVec, sat, unsat from z3.z3types import Z3Exception from mock import MagicMock @@ -13,21 +13,11 @@ uninitialized_test_data = [ @pytest.mark.parametrize("starting_calldata", uninitialized_test_data) def test_concrete_calldata_uninitialized_index(starting_calldata): # Arrange - calldata = Calldata(0, starting_calldata) - solver = Solver() + calldata = ConcreteCalldata(0, starting_calldata) # Act - value, constraint1 = calldata[100] - value2, constraint2 = calldata.get_word_at(200) - - solver.add(constraint1) - solver.add(constraint2) - - solver.check() - model = solver.model() - - value = model.eval(value) - value2 = model.eval(value2) + value = calldata[100] + value2 = calldata.get_word_at(200) # Assert assert value == 0 @@ -36,73 +26,65 @@ def test_concrete_calldata_uninitialized_index(starting_calldata): def test_concrete_calldata_calldatasize(): # Arrange - calldata = Calldata(0, [1, 4, 7, 3, 7, 2, 9]) + calldata = ConcreteCalldata(0, [1, 4, 7, 3, 7, 2, 9]) solver = Solver() # Act solver.check() model = solver.model() - result = model.eval(calldata.calldatasize) # Assert assert result == 7 -def test_symbolic_calldata_constrain_index(): +def test_concrete_calldata_constrain_index(): # Arrange - calldata = Calldata(0) + calldata = ConcreteCalldata(0, [1, 4, 7, 3, 7, 2, 9]) solver = Solver() # Act - value, calldata_constraints = calldata[100] - constraint = value == 50 - - solver.add([constraint] + calldata_constraints) - - solver.check() - model = solver.model() + value = calldata[2] + constraint = value == 3 - value = model.eval(value) - calldatasize = model.eval(calldata.calldatasize) + solver.add([constraint]) + result = solver.check() # Assert - assert value == 50 - assert simplify(calldatasize >= 100) + assert str(result) == "unsat" -def test_concrete_calldata_constrain_index(): +def test_symbolic_calldata_constrain_index(): # Arrange - calldata = Calldata(0, [1, 4, 7, 3, 7, 2, 9]) + calldata = SymbolicCalldata(0) solver = Solver() # Act - value, calldata_constraints = calldata[2] - constraint = value == 3 + value = calldata[51] + + constraints = [value == 1, calldata.calldatasize == 50] + + solver.add(constraints) - solver.add([constraint] + calldata_constraints) result = solver.check() # Assert assert str(result) == "unsat" -def test_concrete_calldata_constrain_index(): - # Arrange - calldata = Calldata(0) - mstate = MagicMock() - mstate.constraints = [] - solver = Solver() +def test_symbolic_calldata_equal_indices(): + calldata = SymbolicCalldata(0) - # Act - constraints = [] - value, calldata_constraints = calldata[51] - constraints.append(value == 1) - constraints.append(calldata.calldatasize == 50) + index_a = BitVec("index_a", 256) + index_b = BitVec("index_b", 256) - solver.add(constraints + calldata_constraints) + # Act + a = calldata[index_a] + b = calldata[index_b] - result = solver.check() + s = Solver() + s.append(index_a == index_b) + s.append(a != b) # Assert - assert str(result) == "unsat" + assert unsat == s.check() diff --git a/tests/testdata/input_contracts/overflow.sol b/tests/testdata/input_contracts/overflow.sol index d1d3d875..d434b917 100644 --- a/tests/testdata/input_contracts/overflow.sol +++ b/tests/testdata/input_contracts/overflow.sol @@ -11,7 +11,7 @@ contract Over { } function sendeth(address _to, uint _value) public returns (bool) { - require(balances[msg.sender] - _value >= 0); + // require(balances[msg.sender] - _value >= 0); balances[msg.sender] -= _value; balances[_to] += _value; return true;