diff --git a/mythril/analysis/solver.py b/mythril/analysis/solver.py index 6a14fb68..2f2ef85f 100644 --- a/mythril/analysis/solver.py +++ b/mythril/analysis/solver.py @@ -133,7 +133,7 @@ def _get_concrete_state(initial_accounts: Dict, min_price_dict: Dict[str, int]): data = dict() # type: Dict[str, Union[int, str]] data["nonce"] = account.nonce data["code"] = account.code.bytecode - data["storage"] = account.storage.printable_storage + data["storage"] = str(account.storage) data["balance"] = hex(min_price_dict.get(address, 0)) accounts[hex(address)] = data return {"accounts": accounts} diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 16cb4d38..bddd1727 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -25,6 +25,7 @@ from mythril.laser.smt import ( Bool, Not, LShR, + BitVecFunc ) from mythril.laser.smt import symbol_factory @@ -939,6 +940,8 @@ class Instruction: input_=data, annotations=annotations, ) + if hash(argument_hash) == 1443016052: + print(data) log.debug("Created BitVecFunc hash.") else: diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 4c47516a..5ab770e8 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -5,7 +5,8 @@ This includes classes representing accounts and their storage. import logging from copy import copy, deepcopy from typing import Any, Dict, Union, Tuple, cast - +from sha3 import keccak_256 +from random import randint from mythril.laser.smt import ( Array, @@ -16,6 +17,8 @@ from mythril.laser.smt import ( Extract, BaseArray, Concat, + And, + If ) from mythril.disassembler.disassembly import Disassembly from mythril.laser.smt import symbol_factory @@ -32,18 +35,6 @@ class StorageRegion: class ArrayStorageRegion(StorageRegion): - """ An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions""" - - pass - - -class IteStorageRegion(StorageRegion): - """ An IteStorageRegion is a storage region that uses Ite statements to implement a storage""" - - pass - - -class Storage: """Storage class represents the storage of an Account.""" def __init__(self, concrete=False, address=None, dynamic_loader=None) -> None: @@ -57,13 +48,13 @@ class Storage: self._standard_storage = Array("Storage", 256, 256) self._map_storage = {} # type: Dict[BitVec, BaseArray] - self.printable_storage = {} # type: Dict[BitVec, BitVec] - self.dynld = dynamic_loader self.address = address @staticmethod def _sanitize(input_: BitVec) -> BitVec: + if input_.potential_values: + input_ = input_.potential_values if input_.size() == 512: return input_ if input_.size() > 512: @@ -77,7 +68,7 @@ class Storage: item = self._sanitize(cast(BitVecFunc, item).input_) value = storage[item] if ( - (value.value == 0 or value.value is None) # 0 for Array, None for K + value.value == 0 and self.address and item.symbolic is False and self.address.value != 0 @@ -94,10 +85,9 @@ class Storage: ), 256, ) - self.printable_storage[item] = storage[item] return storage[item] - except ValueError as e: - log.debug("Couldn't read storage at %s: %s", item, e) + except ValueError: + pass return simplify(storage[item]) @@ -132,24 +122,212 @@ class Storage: def __setitem__(self, key, value: Any) -> None: storage, is_keccak_storage = self._get_corresponding_storage(key) - self.printable_storage[key] = value if is_keccak_storage: key = self._sanitize(key.input_) storage[key] = value def __deepcopy__(self, memodict=dict()): concrete = isinstance(self._standard_storage, K) - storage = Storage( + storage = ArrayStorageRegion( concrete=concrete, address=self.address, dynamic_loader=self.dynld ) storage._standard_storage = deepcopy(self._standard_storage) storage._map_storage = deepcopy(self._map_storage) - storage.print_storage = copy(self.printable_storage) + + return storage + + +class IteStorageRegion(StorageRegion): + """ An IteStorageRegion is a storage region that uses Ite statements to implement a storage""" + + def __init__(self) -> None: + """Constructor for Storage. + """ + self.itedict = {} # type: Dict[Tuple[BitVecFunc, Any]] + + def __getitem__(self, item: BitVecFunc): + storage = symbol_factory.BitVecVal(0, 256) + for key, val in self.itedict.items(): + storage = If(item == key, val, storage) + return storage + + def __setitem__(self, key: BitVecFunc, value): + self.itedict[key] = value + + def __deepcopy__(self, memodict={}): + ite_copy = IteStorageRegion() + ite_copy.itedict = copy(self.itedict) + return ite_copy + + +class Storage: + """Storage class represents the storage of an Account.""" + + def __init__( + self, concrete=False, address=None, dynamic_loader=None, copy_call=False + ) -> None: + """Constructor for Storage. + + :param concrete: bool indicating whether to interpret uninitialized storage as concrete versus symbolic + """ + if copy_call: + # This is done because it was costly create these instances + self._array_region = None + self._ite_region = None + else: + self._array_region = ArrayStorageRegion(concrete, address, dynamic_loader) + self._ite_region = IteStorageRegion() + self._printable_storage = {} # type: Dict[BitVec, BitVec] + + @staticmethod + def _array_condition(key: BitVec): + return not isinstance(key, BitVecFunc) or ( + isinstance(key, BitVecFunc) + and key.func_name == "keccak256" + and len(key.nested_functions) <= 1 + ) + + def __getitem__(self, key: BitVec) -> BitVec: + ite_get = self._ite_region[cast(BitVecFunc, key)] + array_get = self._array_region[key] + if self._array_condition(key): + return If(ite_get, ite_get, array_get) + else: + return ite_get + + def __setitem__(self, key: BitVec, value: Any) -> None: + self._printable_storage[key] = value + if self._array_condition(key): + self._array_region[key] = value + + self._ite_region[cast(BitVecFunc, key)] = value + + def __deepcopy__(self, memodict=dict()): + storage = Storage(copy_call=True) + storage._array_region = deepcopy(self._array_region) + storage._ite_region = deepcopy(self._ite_region) + storage._printable_storage = copy(self._printable_storage) return storage def __str__(self) -> str: # TODO: Do something better here - return str(self.printable_storage) + return str(self._printable_storage) + + def concretize(self, models): + for key, value in self._ite_region.itedict.items(): + key_concrete = self._traverse_concretise(key, models) + key.potential_values = key_concrete + + def calc_sha3(self, val, size): + try: + hex_val = hex(val.value)[2:] + if len(hex_val) % 2 != 0: + hex_val += "0" + val = int(keccak_256(bytes.fromhex(hex_val)).hexdigest(), 16) + except (AttributeError, TypeError): + ran = hex(randint(0, 2 ** size - 1))[2:] + if len(ran) % 2 != 0: + ran += "0" + val = int(keccak_256(bytes.fromhex(ran)).hexdigest(), 16) + return symbol_factory.BitVecVal(val, 256) + + def _find_value(self, symbol, model): + if model is None: + return + modify = symbol + size = min(symbol.size(), 256) + if symbol.size() > 256: + index = simplify(Extract(255, 0, symbol)) + else: + index = None + if index and not index.symbolic: + modify = Extract(511, 256, modify) + modify = model.eval(modify.raw) + try: + modify = modify.as_long() + except AttributeError: + modify = randint(0, 2 ** modify.size() - 1) + modify = symbol_factory.BitVecVal(modify, size) + if index and not index.symbolic: + modify = Concat(modify, index) + + assert modify.size() == symbol.size() + return modify + + def _traverse_concretise(self, key, models): + """ + Breadth first Search + :param key: + :param model: + :return: + """ + if not isinstance(key, BitVecFunc): + concrete_values = [self._find_value(key, model[0]) for model in models] + if key.size() == 512: + ex_key = Extract(511, 256, key) + else: + ex_key = key + potential_values = concrete_values + key.potential_values = [] + for i, val in enumerate(potential_values): + key.potential_values.append( + (val, And(models[i][1], BitVec(key.raw) == val)) + ) + + return key.potential_values + if key.size() == 512: + val = simplify(Extract(511, 256, key)) + concrete_vals = self._traverse_concretise(val, models) + vals2 = self._traverse_concretise(Extract(255, 0, key), models) + key.potential_values = [] + i = 0 + for val1, val2 in zip(concrete_vals, vals2): + if val2 and val1: + c_val = Concat(val1[0], val2[0]) + condition = And( + models[i][1], BitVec(key.raw) == c_val, val1[1], val2[1] + ) + key.potential_values.append((c_val, condition)) + else: + key.potential_values.append((None, None)) + + if isinstance(key.input_, BitVec) or ( + isinstance(key.input_, BitVecFunc) and key.input_.func_name == "sha3" + ): + self._traverse_concretise(key.input_, models) + + if isinstance(key, BitVecFunc): + if key.size() == 512: + p1 = Extract(511, 256, key) + if not isinstance(p1, BitVecFunc): + p1 = Extract(255, 0, key) + p1 = [ + (self.calc_sha3(val[0], p1.input_.size()), val[1]) + for val in p1.input_.potential_values + ] + key.potential_values = [] + for i, val in enumerate(p1): + if val[0]: + c_val = Concat(val[0], Extract(255, 0, key)) + condition = And(models[i][1], val[1], BitVec(key.raw) == c_val) + key.potential_values.append((c_val, condition)) + else: + key.potential_values.append((None, None)) + else: + key.potential_values = [] + for i, val in enumerate(key.input_.potential_values): + if val[0]: + concrete_val = self.calc_sha3(val[0], key.input_.size()) + condition = And( + models[i][1], val[1], BitVec(key.raw) == concrete_val + ) + key.potential_values.append((concrete_val, condition)) + else: + key.potential_values.append((None, None)) + if key.potential_values[0][0] is not None: + assert key.size() == key.potential_values[0][0].size() + + return key.potential_values class Account: diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index c9fc393c..e6d9d02d 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -3,7 +3,7 @@ import logging from collections import defaultdict from copy import copy from datetime import datetime, timedelta -from typing import Callable, Dict, DefaultDict, List, Tuple, Optional +from typing import Callable, Dict, DefaultDict, List, Tuple, Union, Optional from mythril.laser.ethereum.cfg import NodeFlags, Node, Edge, JumpType from mythril.laser.ethereum.evm_exceptions import StackUnderflowException @@ -16,6 +16,8 @@ from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy from abc import ABCMeta from mythril.laser.ethereum.time_handler import time_handler +from mythril.analysis.solver import get_model, UnsatError + from mythril.laser.ethereum.transaction import ( ContractCreationTransaction, TransactionEndSignal, @@ -23,8 +25,13 @@ from mythril.laser.ethereum.transaction import ( execute_contract_creation, execute_message_call, ) -from mythril.laser.smt import symbol_factory +from mythril.laser.smt import symbol_factory, And, BitVecFunc, BitVec, Extract, simplify, is_true +ACTOR_ADDRESSES = [ + symbol_factory.BitVecVal(0xAFFEAFFEAFFEAFFEAFFEAFFEAFFEAFFEAFFEAFFE, 256), + symbol_factory.BitVecVal(0xDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEF, 256), + symbol_factory.BitVecVal(0xDEADBEEFDEADBEEFDEADBEEFDEADBEEFDEADBEEE, 256), +] log = logging.getLogger(__name__) @@ -331,7 +338,7 @@ class LaserEVM: transaction, return_global_state = end_signal.global_state.transaction_stack[ -1 ] - + self.concretize_ite_storage(end_signal.global_state) if return_global_state is None: if ( not isinstance(transaction, ContractCreationTransaction) @@ -355,6 +362,76 @@ class LaserEVM: return new_global_states, op_code + def concretize_ite_storage(self, global_state): + sender = global_state.environment.sender + models_tuple = [] + sat = False + for actor in ACTOR_ADDRESSES: + try: + models_tuple.append( + ( + get_model( + constraints=global_state.mstate.constraints + + [sender == actor] + ), + sender == actor, + ) + ) + sat = True + except UnsatError: + models_tuple.append((None, sender == actor)) + import random, sha3 + + calldata_cond = True + for account in global_state.world_state.accounts.values(): + for key in account.storage._ite_region.itedict: + if ( + isinstance(key, BitVecFunc) + and not isinstance(key.input_, BitVecFunc) + and isinstance(key.input_, BitVec) + and key.input_.symbolic + and key.input_.size() == 512 + ): + pseudo_input = random.randint(0, 2 ** 256 - 1) + hex_v = hex(pseudo_input)[2:] + if len(hex_v) % 2 == 1: + hex_v += "0" + hash_val = symbol_factory.BitVecVal( + int(sha3.keccak_256(bytes.fromhex(hex_v)).hexdigest()[2:], 16), + 256, + ) + pseudo_input = symbol_factory.BitVecVal(pseudo_input, 256) + calldata_cond = And( + calldata_cond, + Extract(511, 256, key.input_) == hash_val, + Extract(511, 256, key.input_).potential_input == pseudo_input, + ) + Extract(511, 256, key.input_).potential_input_cond = calldata_cond + if not is_true(simplify(Extract(511, 256, key.input_).potential_input_cond != calldata_cond)): + print(key.input_, Extract(511, 256, key.input_).concat_args) + assert Extract(511, 256, key.input_).potential_input_cond == calldata_cond + print(Extract(511, 256, key.input_), calldata_cond, "CONDED") + for actor in ACTOR_ADDRESSES: + try: + models_tuple.append( + ( + get_model( + constraints=global_state.mstate.constraints + + [sender == actor, calldata_cond] + ), + And(calldata_cond, sender == actor), + ) + ) + sat = True + except UnsatError: + models_tuple.append((None, And(calldata_cond, sender == actor))) + + if not sat: + return [False] + + for account in global_state.world_state.accounts.values(): + account.storage.concretize(models_tuple) + def _end_message_call( self, return_global_state: GlobalState, diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index ad0d3549..ef1a1f5b 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -25,14 +25,21 @@ def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator): class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" - def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None, concat_args=None): + def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None, concat_args=None, pot_input=True): """ :param raw: :param annotations: """ - self.potential_inputs = [] # type: List["BitVec"] self.concat_args = concat_args or [] + self.potential_values = [] + from random import randint + if pot_input: + self.potential_input = BitVec(z3.BitVec("rn{}_input_".format(randint(0, 1000000)), 256), pot_input=False) + self.potential_input_cond = True + else: + self.potential_input = None + self.potential_input_cond = False super().__init__(raw, annotations) def size(self) -> int: diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index b80b4bda..4b4b6cb6 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -145,6 +145,8 @@ def concat_helper(bvs: List[BitVec]) -> List[BitVec]: ): prev_bv = prev_bv.parent # type: ignore new_bvs.append(prev_bv) + if len(new_bvs) == 1 and check_extracted_var(new_bvs[0]) and hash(z3.simplify(new_bvs[0].raw)) == hash(z3.simplify(new_bvs[0].parent.raw)): + return new_bvs[0].parent return new_bvs @@ -206,7 +208,7 @@ def extract_helper(high: int, low: int, bv: BitVec) -> BitVec: val = None for small_bv in bv.concat_args[::-1]: if low == count: - if low + small_bv.size() <= high: + if small_bv.size() <= high - low + 1: val = small_bv else: val = Extract( @@ -238,9 +240,9 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :param bv: :return: """ - raw = z3.Extract(high, low, bv.raw) val = extract_helper(high, low, bv) + dc = copy(bv) if val is not None: val.simplify() bv.simplify() @@ -249,6 +251,9 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: ): val = val.parent # type: ignore assert val.size() == high - low + 1 + if val.size() == 256 and isinstance(val, BitVecFunc): + pass + #print(val, val.input_) return val input_string = "" bv.simplify() diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index 3c3c17a2..3120c952 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -67,6 +67,37 @@ def _comparison_helper( if operation == z3.ULT: operation = operator.lt return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) + if ( + a.size() == 512 + and b.size() == 512 + and z3.is_true( + z3.simplify(z3.Extract(255, 0, a.raw) == z3.Extract(255, 0, b.raw)) + ) + ): + from mythril.laser.smt.bitvec_helper import Extract + a = Extract(511, 256, a) + b = Extract(511, 256, b) + + if not isinstance(b, BitVecFunc): + paddded_cond = True + if b.potential_input and b.potential_input.size() >= a.input_.size(): + if b.potential_input.size() > a.input_.size(): + padded_a = z3.Concat( + z3.BitVecVal(0, b.potential_input.size() - a.input_.size()), + a.input_.raw, + ) + else: + padded_a = a.input_.raw + paddded_cond = And(operation(padded_a, b.potential_input.raw), b.potential_input_cond) + if a.potential_values: + condition = False + for value, cond in a.potential_values: + if value is not None: + condition = Or(condition, And(operation(b.raw, value.value), cond)) + ret = And(condition, operation(a.raw, b.raw), paddded_cond) + return ret + + return And(operation(a.raw, b.raw), paddded_cond) if ( not isinstance(b, BitVecFunc) or not a.func_name @@ -108,11 +139,17 @@ def _comparison_helper( ), ) - return And( + comparision = And( Bool(cast(z3.BoolRef, operation(a.raw, b.raw)), annotations=union), Bool(condition) if b.nested_functions else Bool(True), a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, ) + if a.potential_values: + for i, val in enumerate(a.potential_values): + comparision = Or(comparision, And(operation(val[0].raw, b.raw), val[1])) + + comparision.simplify() + return comparision class BitVecFunc(BitVec): @@ -134,6 +171,9 @@ class BitVecFunc(BitVec): :param input: The input to the functions :param annotations: The annotations the BitVecFunc should start with """ + if str(z3.simplify(input_.raw)) == "" and str(z3.simplify(raw)) == "KECCAC[invhash(1443016052)]": + import traceback + print(traceback.extract_stack(), z3.simplify(raw), z3.simplify(input_.raw)) self.func_name = func_name self.input_ = input_ diff --git a/tests/laser/smt/concat_extract_tests/concat_extract_assignment.py b/tests/laser/smt/concat_extract_tests/concat_extract_assignment.py index 7a446c8f..92e9a215 100644 --- a/tests/laser/smt/concat_extract_tests/concat_extract_assignment.py +++ b/tests/laser/smt/concat_extract_tests/concat_extract_assignment.py @@ -1,4 +1,4 @@ -from mythril.laser.smt import Extract, Concat, symbol_factory, simplify +from mythril.laser.smt import Extract, Concat, symbol_factory, simplify, And def test_concat_extract_assignment(): @@ -8,9 +8,23 @@ def test_concat_extract_assignment(): "Keccak[input]", size=256, func_name="keccak256", input_=Concat(inp1, inp2) ) output = Concat(output1, symbol_factory.BitVecVal(0, 256)) - Extract(511, 256, output).potential_inputs = [inp2, inp2] + cond = And(output1 == inp2, inp1 == inp2) + Extract(511, 256, output).potential_input_cond = cond + + assert Extract(511, 256, output).potential_input_cond == cond + +def test_concat_extract_input_assignment(): + inp1 = symbol_factory.BitVecSym("input1", 256) + inp2 = symbol_factory.BitVecSym("input2", 256) + output1 = symbol_factory.BitVecFuncSym( + "Keccak[input]", size=256, func_name="keccak256", input_=Concat(inp1, inp2) + ) + inp3 = Concat(inp2, inp1) + cond = And(output1 == inp2, inp1 == inp2) + Extract(511, 256, inp3).potential_input_cond = cond + + assert Extract(511, 256, inp3).potential_input_cond == cond - assert Extract(511, 256, output).potential_inputs == [inp2, inp2] def test_concat_extract_assignment_nested(): @@ -27,15 +41,12 @@ def test_concat_extract_assignment_nested(): func_name="keccak256", input_=Concat(o1, symbol_factory.BitVecVal(0, 256)), ) - - Extract(511, 256, Extract(511, 256, output1.input_).input_).potential_inputs = [ - inp1, - inp1, - ] + cond = And(output1 == o1, inp1 == inp1) + Extract(511, 256, Extract(511, 256, output1.input_).input_).potential_input_cond = cond assert Extract( 511, 256, Extract(511, 256, output1.input_).input_ - ).potential_inputs == [inp1, inp1] + ).potential_input_cond == cond def test_concat_extract_same_instance(): @@ -54,10 +65,10 @@ def test_concat_extract_same_instance(): ) id1 = id( - Extract(511, 256, Extract(511, 256, output1.input_).input_).potential_inputs + Extract(511, 256, Extract(511, 256, output1.input_).input_).potential_input_cond ) id2 = id( - Extract(511, 256, Extract(511, 256, output1.input_).input_).potential_inputs + Extract(511, 256, Extract(511, 256, output1.input_).input_).potential_input_cond ) assert id1 == id2 diff --git a/tests/native_tests.sol b/tests/native_tests.sol index f786c1bd..5832b3ea 100644 --- a/tests/native_tests.sol +++ b/tests/native_tests.sol @@ -1,4 +1,3 @@ -pragma solidity 0.5.0; contract Caller {