Merge pull request #761 from JoranHonig/features/no_more_arrays

Refactor of calldata
pull/767/head
Bernhard Mueller 6 years ago committed by GitHub
commit 3e5f5cbb30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 18
      mythril/analysis/modules/integer.py
  2. 2
      mythril/analysis/solver.py
  3. 10
      mythril/laser/ethereum/call.py
  4. 8
      mythril/laser/ethereum/instructions.py
  5. 4
      mythril/laser/ethereum/natives.py
  6. 179
      mythril/laser/ethereum/state/calldata.py
  7. 4
      mythril/laser/ethereum/state/environment.py
  8. 4
      mythril/laser/ethereum/transaction/concolic.py
  9. 8
      mythril/laser/ethereum/transaction/symbolic.py
  10. 10
      mythril/laser/ethereum/transaction/transaction_models.py
  11. 80
      tests/laser/state/calldata_test.py
  12. 2
      tests/testdata/input_contracts/overflow.sol

@ -167,33 +167,19 @@ class IntegerOverflowUnderflowModule(DetectionModule):
constraints = copy.deepcopy(node.constraints) constraints = copy.deepcopy(node.constraints)
# Filter for patterns that indicate benign underflows
# Pattern 1: (96 + calldatasize_MAIN) - (96), where (96 + calldatasize_MAIN) would underflow if calldatasize is very large.
# Pattern 2: (256*If(1 & storage_0 == 0, 1, 0)) - 1, this would underlow if storage_0 = 0
if type(op0) == int and type(op1) == int: if type(op0) == int and type(op1) == int:
return [] return []
if re.search(r"calldatasize_", str(op0)):
return []
if re.search(r"256\*.*If\(1", str(op0), re.DOTALL) or re.search(
r"256\*.*If\(1", str(op1), re.DOTALL
):
return []
if re.search(r"32 \+.*calldata", str(op0), re.DOTALL) or re.search(
r"32 \+.*calldata", str(op1), re.DOTALL
):
return []
logging.debug( logging.debug(
"[INTEGER_UNDERFLOW] Checking SUB {0}, {1} at address {2}".format( "[INTEGER_UNDERFLOW] Checking SUB {0}, {1} at address {2}".format(
str(op0), str(op1), str(instruction["address"]) str(op0), str(op1), str(instruction["address"])
) )
) )
allowed_types = [int, BitVecRef, BitVecNumRef] allowed_types = [int, BitVecRef, BitVecNumRef]
if type(op0) in allowed_types and type(op1) in allowed_types: if type(op0) in allowed_types and type(op1) in allowed_types:
constraints.append(UGT(op1, op0)) constraints.append(Not(BVSubNoUnderflow(op0, op1, signed=False)))
try: try:
model = solver.get_model(constraints) model = solver.get_model(constraints)

@ -103,7 +103,7 @@ def get_transaction_sequence(global_state, constraints):
concrete_transactions[tx_id]["calldata"] = "0x" + "".join( concrete_transactions[tx_id]["calldata"] = "0x" + "".join(
[ [
hex(b)[2:] if len(hex(b)) % 2 == 0 else "0" + hex(b)[2:] hex(b)[2:] if len(hex(b)) % 2 == 0 else "0" + hex(b)[2:]
for b in transaction.call_data.concretized(model) for b in transaction.call_data.concrete(model)
] ]
) )

@ -3,7 +3,11 @@ from typing import Union
from z3 import simplify, ExprRef, Extract from z3 import simplify, ExprRef, Extract
import mythril.laser.ethereum.util as util import mythril.laser.ethereum.util as util
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.calldata import CalldataType, Calldata from mythril.laser.ethereum.state.calldata import (
CalldataType,
SymbolicCalldata,
ConcreteCalldata,
)
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.support.loader import DynLoader from mythril.support.loader import DynLoader
import re import re
@ -174,12 +178,12 @@ def get_call_data(
starting_calldata.append(Extract(j + 7, j, elem)) starting_calldata.append(Extract(j + 7, j, elem))
i += 1 i += 1
call_data = Calldata(transaction_id, starting_calldata) call_data = ConcreteCalldata(transaction_id, starting_calldata)
call_data_type = CalldataType.CONCRETE call_data_type = CalldataType.CONCRETE
logging.debug("Calldata: " + str(call_data)) logging.debug("Calldata: " + str(call_data))
except TypeError: except TypeError:
logging.debug("Unsupported symbolic calldata offset") logging.debug("Unsupported symbolic calldata offset")
call_data_type = CalldataType.SYMBOLIC call_data_type = CalldataType.SYMBOLIC
call_data = Calldata("{}_internalcall".format(transaction_id)) call_data = SymbolicCalldata("{}_internalcall".format(transaction_id))
return call_data, call_data_type return call_data, call_data_type

@ -42,7 +42,7 @@ from mythril.laser.ethereum.evm_exceptions import (
) )
from mythril.laser.ethereum.gas import OPCODE_GAS from mythril.laser.ethereum.gas import OPCODE_GAS
from mythril.laser.ethereum.keccak import KeccakFunctionManager from mythril.laser.ethereum.keccak import KeccakFunctionManager
from mythril.laser.ethereum.state.calldata import CalldataType, Calldata from mythril.laser.ethereum.state.calldata import CalldataType
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.transaction import ( from mythril.laser.ethereum.transaction import (
MessageCallTransaction, MessageCallTransaction,
@ -458,10 +458,9 @@ class Instruction:
environment = global_state.environment environment = global_state.environment
op0 = state.stack.pop() op0 = state.stack.pop()
value, constraints = environment.calldata.get_word_at(op0) value = environment.calldata.get_word_at(op0)
state.stack.append(value) state.stack.append(value)
state.constraints.extend(constraints)
return [global_state] return [global_state]
@ -541,9 +540,8 @@ class Instruction:
i_data = dstart i_data = dstart
new_memory = [] new_memory = []
for i in range(size): for i in range(size):
value, constraints = environment.calldata[i_data] value = environment.calldata[i_data]
new_memory.append(value) new_memory.append(value)
state.constraints.extend(constraints)
i_data = ( i_data = (
i_data + 1 if isinstance(i_data, int) else simplify(i_data + 1) i_data + 1 if isinstance(i_data, int) else simplify(i_data + 1)

@ -8,7 +8,7 @@ from ethereum.utils import ecrecover_to_pub
from py_ecc.secp256k1 import N as secp256k1n from py_ecc.secp256k1 import N as secp256k1n
from rlp.utils import ALL_BYTES from rlp.utils import ALL_BYTES
from mythril.laser.ethereum.state.calldata import Calldata from mythril.laser.ethereum.state.calldata import BaseCalldata
from mythril.laser.ethereum.util import bytearray_to_int, sha3, get_concrete_int from mythril.laser.ethereum.util import bytearray_to_int, sha3, get_concrete_int
from z3 import Concat, simplify from z3 import Concat, simplify
@ -88,7 +88,7 @@ def identity(data: Union[bytes, str, List[int]]) -> bytes:
return result return result
def native_contracts(address: int, data: Calldata): def native_contracts(address: int, data: BaseCalldata):
""" """
takes integer address 1, 2, 3, 4 takes integer address 1, 2, 3, 4
""" """

@ -1,17 +1,7 @@
from enum import Enum from enum import Enum
from typing import Union, Any from typing import Union, Any
from z3 import ( from z3 import BitVecVal, BitVecRef, BitVec, simplify, Concat, If, ExprRef
BitVecVal, from z3.z3types import Z3Exception, Model
BitVecRef,
BitVecSort,
BitVec,
Implies,
simplify,
Concat,
UGT,
Array,
)
from z3.z3types import Z3Exception
from mythril.laser.ethereum.util import get_concrete_int from mythril.laser.ethereum.util import get_concrete_int
@ -21,84 +11,133 @@ class CalldataType(Enum):
SYMBOLIC = 2 SYMBOLIC = 2
class Calldata: class BaseCalldata:
""" """
Calldata class representing the calldata of a transaction Base calldata class
This represents the calldata provided when sending a transaction to a contract
""" """
def __init__(self, tx_id, starting_calldata=None): def __init__(self, tx_id):
"""
Constructor for Calldata
:param tx_id: unique value representing the transaction the calldata is for
:param starting_calldata: byte array representing the concrete calldata of a transaction
"""
self.tx_id = tx_id self.tx_id = tx_id
if starting_calldata is not None:
self._calldata = []
self.calldatasize = BitVecVal(len(starting_calldata), 256)
self.concrete = True
else:
self._calldata = Array(
"{}_calldata".format(self.tx_id), BitVecSort(256), BitVecSort(8)
)
self.calldatasize = BitVec("{}_calldatasize".format(self.tx_id), 256)
self.concrete = False
if self.concrete:
for calldata_byte in starting_calldata:
if type(calldata_byte) == int:
self._calldata.append(BitVecVal(calldata_byte, 8))
else:
self._calldata.append(calldata_byte)
def concretized(self, model):
result = []
for i in range(
get_concrete_int(model.eval(self.calldatasize, model_completion=True))
):
result.append(
get_concrete_int(model.eval(self._calldata[i], model_completion=True))
)
@property
def calldatasize(self) -> ExprRef:
"""
:return: Calldata size for this calldata object
"""
result = self.size
if isinstance(result, int):
return BitVecVal(result, 256)
return result return result
def get_word_at(self, index: int): def get_word_at(self, offset: int) -> ExprRef:
return self[index : index + 32] """ Gets word at offset"""
return self[offset : offset + 32]
def __getitem__(self, item: Union[int, slice]) -> Any: def __getitem__(self, item: Union[int, slice]) -> Any:
if isinstance(item, int) or isinstance(item, ExprRef):
return self._load(item)
if isinstance(item, slice): if isinstance(item, slice):
start, step, stop = item.start, item.step, item.stop start = 0 if item.start is None else item.start
step = 1 if item.step is None else item.step
stop = self.size if item.stop is None else item.stop
try: try:
if start is None:
start = 0
if step is None:
step = 1
if stop is None:
stop = self.calldatasize
current_index = ( current_index = (
start if isinstance(start, BitVecRef) else BitVecVal(start, 256) start if isinstance(start, BitVecRef) else BitVecVal(start, 256)
) )
dataparts = [] parts = []
while simplify(current_index != stop): while simplify(current_index != stop):
dataparts.append(self[current_index]) parts.append(self._load(current_index))
current_index = simplify(current_index + step) current_index = simplify(current_index + step)
except Z3Exception: except Z3Exception:
raise IndexError("Invalid Calldata Slice") raise IndexError("Invalid Calldata Slice")
values, constraints = zip(*dataparts) return simplify(Concat(parts))
result_constraints = []
for c in constraints: raise ValueError
result_constraints.extend(c)
return simplify(Concat(values)), result_constraints def _load(self, item: Union[int, ExprRef]) -> Any:
raise NotImplementedError()
@property
def size(self) -> Union[ExprRef, int]:
""" Returns the exact size of this calldata, this is not normalized"""
raise NotImplementedError()
def concrete(self, model: Model) -> list:
""" Returns a concrete version of the calldata using the provided model"""
raise NotImplementedError
class ConcreteCalldata(BaseCalldata):
def __init__(self, tx_id: int, calldata: list):
"""
Initializes the ConcreteCalldata object
:param tx_id: Id of the transaction that the calldata is for.
:param calldata: The concrete calldata content
"""
self._calldata = calldata
super().__init__(tx_id)
if self.concrete: def _load(self, item: Union[int, ExprRef]) -> Any:
if isinstance(item, int):
try: try:
return self._calldata[get_concrete_int(item)], () return self._calldata[item]
except IndexError: except IndexError:
return BitVecVal(0, 8), () return 0
else:
constraints = [ value = BitVecVal(0x0, 8)
Implies(self._calldata[item] != 0, UGT(self.calldatasize, item)) for i in range(self.size):
] value = If(item == i, self._calldata[i], value)
return value
def concrete(self, model: Model) -> list:
return self._calldata
@property
def size(self) -> int:
return len(self._calldata)
class SymbolicCalldata(BaseCalldata):
def __init__(self, tx_id: int):
"""
Initializes the SymbolicCalldata object
:param tx_id: Id of the transaction that the calldata is for.
"""
self._reads = []
self._size = BitVec("calldatasize", 256)
super().__init__(tx_id)
def _load(self, item: Union[int, ExprRef], clean=False) -> Any:
x = BitVecVal(item, 256) if isinstance(item, int) else item
symbolic_base_value = If(
x > self._size,
BitVecVal(0, 8),
BitVec("{}_calldata_{}".format(self.tx_id, str(item)), 8),
)
return_value = symbolic_base_value
for r_index, r_value in self._reads:
return_value = If(r_index == item, r_value, return_value)
if not clean:
self._reads.append((item, symbolic_base_value))
return simplify(return_value)
def concrete(self, model: Model) -> list:
concrete_length = get_concrete_int(model.eval(self.size, model_completion=True))
result = []
for i in range(concrete_length):
value = self._load(i, clean=True)
c_value = get_concrete_int(model.eval(value, model_completion=True))
result.append(c_value)
return result
return self._calldata[item], constraints @property
def size(self) -> ExprRef:
return self._size

@ -3,7 +3,7 @@ from typing import Dict
from z3 import ExprRef, BitVecVal from z3 import ExprRef, BitVecVal
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.calldata import Calldata, CalldataType from mythril.laser.ethereum.state.calldata import CalldataType, BaseCalldata
class Environment: class Environment:
@ -15,7 +15,7 @@ class Environment:
self, self,
active_account: Account, active_account: Account,
sender: ExprRef, sender: ExprRef,
calldata: Calldata, calldata: BaseCalldata,
gasprice: ExprRef, gasprice: ExprRef,
callvalue: ExprRef, callvalue: ExprRef,
origin: ExprRef, origin: ExprRef,

@ -6,7 +6,7 @@ from mythril.laser.ethereum.transaction.transaction_models import (
) )
from z3 import BitVec from z3 import BitVec
from mythril.laser.ethereum.state.environment import Environment from mythril.laser.ethereum.state.environment import Environment
from mythril.laser.ethereum.state.calldata import Calldata, CalldataType from mythril.laser.ethereum.state.calldata import CalldataType, ConcreteCalldata
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
@ -42,7 +42,7 @@ def execute_message_call(
code=Disassembly(code), code=Disassembly(code),
caller=caller_address, caller=caller_address,
callee_account=open_world_state[callee_address], callee_account=open_world_state[callee_address],
call_data=Calldata(next_transaction_id, data), call_data=ConcreteCalldata(next_transaction_id, data),
call_data_type=CalldataType.SYMBOLIC, call_data_type=CalldataType.SYMBOLIC,
call_value=value, call_value=value,
) )

@ -3,7 +3,11 @@ from logging import debug
from mythril.disassembler.disassembly import Disassembly from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.cfg import Node, Edge, JumpType from mythril.laser.ethereum.cfg import Node, Edge, JumpType
from mythril.laser.ethereum.state.calldata import CalldataType, Calldata from mythril.laser.ethereum.state.calldata import (
CalldataType,
BaseCalldata,
SymbolicCalldata,
)
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.transaction.transaction_models import ( from mythril.laser.ethereum.transaction.transaction_models import (
MessageCallTransaction, MessageCallTransaction,
@ -32,7 +36,7 @@ def execute_message_call(laser_evm, callee_address: str) -> None:
origin=BitVec("origin{}".format(next_transaction_id), 256), origin=BitVec("origin{}".format(next_transaction_id), 256),
caller=BitVec("caller{}".format(next_transaction_id), 256), caller=BitVec("caller{}".format(next_transaction_id), 256),
callee_account=open_world_state[callee_address], callee_account=open_world_state[callee_address],
call_data=Calldata(next_transaction_id), call_data=SymbolicCalldata(next_transaction_id),
call_data_type=CalldataType.SYMBOLIC, call_data_type=CalldataType.SYMBOLIC,
call_value=BitVec("call_value{}".format(next_transaction_id), 256), call_value=BitVec("call_value{}".format(next_transaction_id), 256),
) )

@ -2,7 +2,11 @@ import logging
from typing import Union from typing import Union
from mythril.disassembler.disassembly import Disassembly from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.state.environment import Environment from mythril.laser.ethereum.state.environment import Environment
from mythril.laser.ethereum.state.calldata import Calldata from mythril.laser.ethereum.state.calldata import (
BaseCalldata,
ConcreteCalldata,
SymbolicCalldata,
)
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
@ -75,9 +79,9 @@ class BaseTransaction:
self.caller = caller self.caller = caller
self.callee_account = callee_account self.callee_account = callee_account
if call_data is None and init_call_data: if call_data is None and init_call_data:
self.call_data = Calldata(self.id, call_data) self.call_data = ConcreteCalldata(self.id, call_data)
else: else:
self.call_data = call_data if isinstance(call_data, Calldata) else None self.call_data = call_data if isinstance(call_data, BaseCalldata) else None
self.call_data_type = ( self.call_data_type = (
call_data_type call_data_type
if call_data_type is not None if call_data_type is not None

@ -1,6 +1,6 @@
import pytest import pytest
from mythril.laser.ethereum.state.calldata import Calldata from mythril.laser.ethereum.state.calldata import ConcreteCalldata, SymbolicCalldata
from z3 import Solver, simplify from z3 import Solver, simplify, BitVec, sat, unsat
from z3.z3types import Z3Exception from z3.z3types import Z3Exception
from mock import MagicMock from mock import MagicMock
@ -13,21 +13,11 @@ uninitialized_test_data = [
@pytest.mark.parametrize("starting_calldata", uninitialized_test_data) @pytest.mark.parametrize("starting_calldata", uninitialized_test_data)
def test_concrete_calldata_uninitialized_index(starting_calldata): def test_concrete_calldata_uninitialized_index(starting_calldata):
# Arrange # Arrange
calldata = Calldata(0, starting_calldata) calldata = ConcreteCalldata(0, starting_calldata)
solver = Solver()
# Act # Act
value, constraint1 = calldata[100] value = calldata[100]
value2, constraint2 = calldata.get_word_at(200) value2 = calldata.get_word_at(200)
solver.add(constraint1)
solver.add(constraint2)
solver.check()
model = solver.model()
value = model.eval(value)
value2 = model.eval(value2)
# Assert # Assert
assert value == 0 assert value == 0
@ -36,73 +26,65 @@ def test_concrete_calldata_uninitialized_index(starting_calldata):
def test_concrete_calldata_calldatasize(): def test_concrete_calldata_calldatasize():
# Arrange # Arrange
calldata = Calldata(0, [1, 4, 7, 3, 7, 2, 9]) calldata = ConcreteCalldata(0, [1, 4, 7, 3, 7, 2, 9])
solver = Solver() solver = Solver()
# Act # Act
solver.check() solver.check()
model = solver.model() model = solver.model()
result = model.eval(calldata.calldatasize) result = model.eval(calldata.calldatasize)
# Assert # Assert
assert result == 7 assert result == 7
def test_symbolic_calldata_constrain_index(): def test_concrete_calldata_constrain_index():
# Arrange # Arrange
calldata = Calldata(0) calldata = ConcreteCalldata(0, [1, 4, 7, 3, 7, 2, 9])
solver = Solver() solver = Solver()
# Act # Act
value, calldata_constraints = calldata[100] value = calldata[2]
constraint = value == 50 constraint = value == 3
solver.add([constraint] + calldata_constraints)
solver.check()
model = solver.model()
value = model.eval(value) solver.add([constraint])
calldatasize = model.eval(calldata.calldatasize) result = solver.check()
# Assert # Assert
assert value == 50 assert str(result) == "unsat"
assert simplify(calldatasize >= 100)
def test_concrete_calldata_constrain_index(): def test_symbolic_calldata_constrain_index():
# Arrange # Arrange
calldata = Calldata(0, [1, 4, 7, 3, 7, 2, 9]) calldata = SymbolicCalldata(0)
solver = Solver() solver = Solver()
# Act # Act
value, calldata_constraints = calldata[2] value = calldata[51]
constraint = value == 3
constraints = [value == 1, calldata.calldatasize == 50]
solver.add(constraints)
solver.add([constraint] + calldata_constraints)
result = solver.check() result = solver.check()
# Assert # Assert
assert str(result) == "unsat" assert str(result) == "unsat"
def test_concrete_calldata_constrain_index(): def test_symbolic_calldata_equal_indices():
# Arrange calldata = SymbolicCalldata(0)
calldata = Calldata(0)
mstate = MagicMock()
mstate.constraints = []
solver = Solver()
# Act index_a = BitVec("index_a", 256)
constraints = [] index_b = BitVec("index_b", 256)
value, calldata_constraints = calldata[51]
constraints.append(value == 1)
constraints.append(calldata.calldatasize == 50)
solver.add(constraints + calldata_constraints) # Act
a = calldata[index_a]
b = calldata[index_b]
result = solver.check() s = Solver()
s.append(index_a == index_b)
s.append(a != b)
# Assert # Assert
assert str(result) == "unsat" assert unsat == s.check()

@ -11,7 +11,7 @@ contract Over {
} }
function sendeth(address _to, uint _value) public returns (bool) { function sendeth(address _to, uint _value) public returns (bool) {
require(balances[msg.sender] - _value >= 0); // require(balances[msg.sender] - _value >= 0);
balances[msg.sender] -= _value; balances[msg.sender] -= _value;
balances[_to] += _value; balances[_to] += _value;
return true; return true;

Loading…
Cancel
Save