From 0c1453c502acff392256f323eed5b879c453f6bb Mon Sep 17 00:00:00 2001 From: Nikhil Date: Thu, 25 Jul 2019 10:58:46 +0530 Subject: [PATCH] Fix up Keccak Tree for concretisation --- mythril/laser/ethereum/instructions.py | 3 +- mythril/laser/ethereum/state/account.py | 73 ++++++++++++------------- mythril/laser/smt/bitvec.py | 1 + mythril/laser/smt/bitvec_helper.py | 28 ++++++++-- mythril/laser/smt/bitvecfunc.py | 3 - mythril/laser/smt/model.py | 4 ++ 6 files changed, 64 insertions(+), 48 deletions(-) diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 239f7714..97b1411a 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -1370,6 +1370,8 @@ class Instruction: index = state.stack.pop() state.stack.append(global_state.environment.active_account.storage[index]) + if global_state.get_current_instruction()["address"] == 418: + print(state.stack[-1]) return [global_state] @StateTransition() @@ -1437,7 +1439,6 @@ class Instruction: states = [] op0, condition = state.stack.pop(), state.stack.pop() - try: jump_addr = util.get_concrete_int(op0) except TypeError: diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index cbea7d2f..3263050e 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -49,6 +49,8 @@ class ArrayStorageRegion(StorageRegion): @staticmethod def _sanitize(input_: BitVec) -> BitVec: + if input_.potential_value: + input_ = input_.potential_value if input_.size() == 512: return input_ if input_.size() > 512: @@ -118,6 +120,7 @@ class ArrayStorageRegion(StorageRegion): storage, is_keccak_storage = self._get_corresponding_storage(key) if is_keccak_storage: key = self._sanitize(key.input_) + storage[key] = value def __deepcopy__(self, memodict=dict()): @@ -218,27 +221,26 @@ class Storage: key.input_, BitVecFunc ): continue - new_constraints, key = self._traverse_concretise(key, model) + new_constraints, key_concrete = self._traverse_concretise(key, model) + key.potential_value = key_concrete constraints += new_constraints self._array_region[key] = value self._ite_region.itelist = [] return constraints - def calc_sha3(self, val): + def calc_sha3(self, val, size): try: - val = int(sha3_256(str(val.as_long()).encode("utf-8")).hexdigest(), 16) + val = int(sha3_256(str(val.value).encode("utf-8")).hexdigest(), 16) except AttributeError: val = int( - sha3_256( - str(randint(0, 2 ** val.input_ - 1)).encode("utf-8") - ).hexdigest(), - 16, + sha3_256(str(randint(0, 2 ** size - 1)).encode("utf-8")).hexdigest(), 16 ) return symbol_factory.BitVecVal(val, 256) def _find_value(self, symbol, model): modify = symbol + size = min(symbol.size(), 256) if symbol.size() > 256: index = simplify(Extract(255, 0, symbol)) else: @@ -250,9 +252,10 @@ class Storage: modify = modify.as_long() except AttributeError: modify = randint(0, 2 ** modify.size() - 1) - modify = symbol_factory.BitVecVal(modify, 256) + modify = symbol_factory.BitVecVal(modify, size) if index and not index.symbolic: modify = Concat(modify, index) + assert modify.size() == symbol.size() return modify def _traverse_concretise(self, key, model): @@ -262,37 +265,18 @@ class Storage: :param model: :return: """ - print(simplify(key)) constraints = [] if not isinstance(key, BitVecFunc): concrete_value = self._find_value(key, model) - constraints.append(concrete_value == key) + key.potential_value = concrete_value return constraints, concrete_value - if key.size() != 512 and str(key.input_) == "": - print("SHIT") - for arg in key.concat_args: - new_const, val = self._traverse_concretise(arg, model) - constraints += new_const - else: - cnt = 0 - val = None - i = 0 - while cnt != 256 and i < len(key.concat_args): - if val is None: - val = key.concat_args[i] - else: - val = Concat(val, key.concat_args[i]) - cnt += key.concat_args[i].size() - i += 1 - if val is not None: - val.simplify() - print(key.concat_args[i:], val, "CONCAT") - new_const, concrete_val = self._traverse_concretise(val, model) - constraints += new_const - if i < len(key.concat_args): - for arg in key.concat_args[i:]: - new_const, val = self._traverse_concretise(arg, model) - constraints += new_const + if key.size() == 512: + val = simplify(Extract(511, 256, key)) + new_const, concrete_val = self._traverse_concretise(val, model) + constraints += new_const + new_const, val2 = self._traverse_concretise(Extract(255, 0, key), model) + key.potential_value = Concat(concrete_val, val2) + constraints += new_const if isinstance(key.input_, BitVec) or ( isinstance(key.input_, BitVecFunc) and key.input_.func_name == "sha3" @@ -300,10 +284,21 @@ class Storage: new_const, _ = self._traverse_concretise(key.input_, model) constraints += new_const - if key.func_name == "sha3": - key.potential_value = self.calc_sha3(key.input_) - - return constraints, key + if isinstance(key, BitVecFunc): + if key.size() == 512: + p1 = Extract(511, 256, key) + p1 = self.calc_sha3(p1.input_.potential_value, p1.input_.size()) + key.potential_value = Concat(p1, Extract(255, 0, key)) + if str(key.raw) != "": + constraints += [key.raw == key.potential_value.value] + else: + key.potential_value = self.calc_sha3( + key.input_.potential_value, key.input_.size() + ) + if str(key.raw) != "": + constraints += [key.raw == key.potential_value.value] + assert key.size() == key.potential_value.size() + return constraints, key.potential_value class Account: diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index b308e863..fe0a9848 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -31,6 +31,7 @@ class BitVec(Expression[z3.BitVecRef]): :param raw: :param annotations: """ + self.potential_value = None super().__init__(raw, annotations) def size(self) -> int: diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index 5555d858..6c0cd59d 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -147,14 +147,14 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: if isinstance(bv, BitVecFuncExtract): if parent is None: parent = bv.parent - if parent != bv.parent: + if hash(parent.raw) != hash(bv.parent.raw): continue bfne_cnt += 1 if bfne_cnt == len(bvs): # check for continuity fail = True - if bvs[0].high == bvs[0].parent.size() - 1 and bvs[-1].low == 0: + if bvs[-1].low == 0: fail = False for index, bv in enumerate(bvs): if index == 0: @@ -162,10 +162,26 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: if bv.high + 1 != bvs[index - 1].low: fail = True break + if fail is False: - return bvs[0].parent + if bvs[0].high == bvs[0].parent.size() - 1: + return bvs[0].parent + else: + return BitVecFuncExtract( + raw=nraw, + func_name=bvs[0].func_name, + input_=bvs[0].input_, + nested_functions=nested_functions, + concat_args=concat_list, + low=bvs[-1].low, + high=bvs[0].high, + parent=bvs[0].parent, + ) if nested_functions: + for bv in bvs: + bv.simplify() + return BitVecFunc( raw=nraw, func_name="Hybrid", @@ -185,6 +201,7 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :param bv: :return: """ + raw = z3.Extract(high, low, bv.raw) if isinstance(bv, BitVecFunc): count = 0 @@ -199,9 +216,9 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: small_bv.size() - (high - low + 1), small_bv, ) - if high < count: + elif high < count: break - if low < count: + elif low < count: if low + small_bv.size() <= high: val = Concat(small_bv, val) else: @@ -219,6 +236,7 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: val.raw == val.parent.raw ): val = val.parent + val.simplify() return val input_string = "" # Is there a better value to set func_name and input to in this case? diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index cf64fdaa..bbce70f7 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -113,8 +113,6 @@ def _comparison_helper( Bool(condition) if b.nested_functions else Bool(True), a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, ) - if a.potential_value is not None: - return Or(comparision, b == a.potential_value) return comparision @@ -142,7 +140,6 @@ class BitVecFunc(BitVec): self.input_ = input_ self.nested_functions = nested_functions or [] self.nested_functions = list(dict.fromkeys(self.nested_functions)) - self.potential_value = None self.concat_args = concat_args or [] if isinstance(input_, BitVecFunc): self.nested_functions.extend(input_.nested_functions) diff --git a/mythril/laser/smt/model.py b/mythril/laser/smt/model.py index 524683e9..e75e99b8 100644 --- a/mythril/laser/smt/model.py +++ b/mythril/laser/smt/model.py @@ -2,6 +2,10 @@ import z3 from typing import Union, List +z3.set_option( + max_args=10000000, max_lines=10000000, max_depth=10000000, max_visited=1000000 +) + class Model: """ The model class wraps a z3 model