Fix up Keccak Tree for concretisation

feature/concretise_storage
Nikhil 5 years ago
parent 91101dd41e
commit 0c1453c502
  1. 3
      mythril/laser/ethereum/instructions.py
  2. 73
      mythril/laser/ethereum/state/account.py
  3. 1
      mythril/laser/smt/bitvec.py
  4. 28
      mythril/laser/smt/bitvec_helper.py
  5. 3
      mythril/laser/smt/bitvecfunc.py
  6. 4
      mythril/laser/smt/model.py

@ -1370,6 +1370,8 @@ class Instruction:
index = state.stack.pop() index = state.stack.pop()
state.stack.append(global_state.environment.active_account.storage[index]) state.stack.append(global_state.environment.active_account.storage[index])
if global_state.get_current_instruction()["address"] == 418:
print(state.stack[-1])
return [global_state] return [global_state]
@StateTransition() @StateTransition()
@ -1437,7 +1439,6 @@ class Instruction:
states = [] states = []
op0, condition = state.stack.pop(), state.stack.pop() op0, condition = state.stack.pop(), state.stack.pop()
try: try:
jump_addr = util.get_concrete_int(op0) jump_addr = util.get_concrete_int(op0)
except TypeError: except TypeError:

@ -49,6 +49,8 @@ class ArrayStorageRegion(StorageRegion):
@staticmethod @staticmethod
def _sanitize(input_: BitVec) -> BitVec: def _sanitize(input_: BitVec) -> BitVec:
if input_.potential_value:
input_ = input_.potential_value
if input_.size() == 512: if input_.size() == 512:
return input_ return input_
if input_.size() > 512: if input_.size() > 512:
@ -118,6 +120,7 @@ class ArrayStorageRegion(StorageRegion):
storage, is_keccak_storage = self._get_corresponding_storage(key) storage, is_keccak_storage = self._get_corresponding_storage(key)
if is_keccak_storage: if is_keccak_storage:
key = self._sanitize(key.input_) key = self._sanitize(key.input_)
storage[key] = value storage[key] = value
def __deepcopy__(self, memodict=dict()): def __deepcopy__(self, memodict=dict()):
@ -218,27 +221,26 @@ class Storage:
key.input_, BitVecFunc key.input_, BitVecFunc
): ):
continue continue
new_constraints, key = self._traverse_concretise(key, model) new_constraints, key_concrete = self._traverse_concretise(key, model)
key.potential_value = key_concrete
constraints += new_constraints constraints += new_constraints
self._array_region[key] = value self._array_region[key] = value
self._ite_region.itelist = [] self._ite_region.itelist = []
return constraints return constraints
def calc_sha3(self, val): def calc_sha3(self, val, size):
try: try:
val = int(sha3_256(str(val.as_long()).encode("utf-8")).hexdigest(), 16) val = int(sha3_256(str(val.value).encode("utf-8")).hexdigest(), 16)
except AttributeError: except AttributeError:
val = int( val = int(
sha3_256( sha3_256(str(randint(0, 2 ** size - 1)).encode("utf-8")).hexdigest(), 16
str(randint(0, 2 ** val.input_ - 1)).encode("utf-8")
).hexdigest(),
16,
) )
return symbol_factory.BitVecVal(val, 256) return symbol_factory.BitVecVal(val, 256)
def _find_value(self, symbol, model): def _find_value(self, symbol, model):
modify = symbol modify = symbol
size = min(symbol.size(), 256)
if symbol.size() > 256: if symbol.size() > 256:
index = simplify(Extract(255, 0, symbol)) index = simplify(Extract(255, 0, symbol))
else: else:
@ -250,9 +252,10 @@ class Storage:
modify = modify.as_long() modify = modify.as_long()
except AttributeError: except AttributeError:
modify = randint(0, 2 ** modify.size() - 1) modify = randint(0, 2 ** modify.size() - 1)
modify = symbol_factory.BitVecVal(modify, 256) modify = symbol_factory.BitVecVal(modify, size)
if index and not index.symbolic: if index and not index.symbolic:
modify = Concat(modify, index) modify = Concat(modify, index)
assert modify.size() == symbol.size()
return modify return modify
def _traverse_concretise(self, key, model): def _traverse_concretise(self, key, model):
@ -262,37 +265,18 @@ class Storage:
:param model: :param model:
:return: :return:
""" """
print(simplify(key))
constraints = [] constraints = []
if not isinstance(key, BitVecFunc): if not isinstance(key, BitVecFunc):
concrete_value = self._find_value(key, model) concrete_value = self._find_value(key, model)
constraints.append(concrete_value == key) key.potential_value = concrete_value
return constraints, concrete_value return constraints, concrete_value
if key.size() != 512 and str(key.input_) == "": if key.size() == 512:
print("SHIT") val = simplify(Extract(511, 256, key))
for arg in key.concat_args: new_const, concrete_val = self._traverse_concretise(val, model)
new_const, val = self._traverse_concretise(arg, model) constraints += new_const
constraints += new_const new_const, val2 = self._traverse_concretise(Extract(255, 0, key), model)
else: key.potential_value = Concat(concrete_val, val2)
cnt = 0 constraints += new_const
val = None
i = 0
while cnt != 256 and i < len(key.concat_args):
if val is None:
val = key.concat_args[i]
else:
val = Concat(val, key.concat_args[i])
cnt += key.concat_args[i].size()
i += 1
if val is not None:
val.simplify()
print(key.concat_args[i:], val, "CONCAT")
new_const, concrete_val = self._traverse_concretise(val, model)
constraints += new_const
if i < len(key.concat_args):
for arg in key.concat_args[i:]:
new_const, val = self._traverse_concretise(arg, model)
constraints += new_const
if isinstance(key.input_, BitVec) or ( if isinstance(key.input_, BitVec) or (
isinstance(key.input_, BitVecFunc) and key.input_.func_name == "sha3" isinstance(key.input_, BitVecFunc) and key.input_.func_name == "sha3"
@ -300,10 +284,21 @@ class Storage:
new_const, _ = self._traverse_concretise(key.input_, model) new_const, _ = self._traverse_concretise(key.input_, model)
constraints += new_const constraints += new_const
if key.func_name == "sha3": if isinstance(key, BitVecFunc):
key.potential_value = self.calc_sha3(key.input_) if key.size() == 512:
p1 = Extract(511, 256, key)
return constraints, key p1 = self.calc_sha3(p1.input_.potential_value, p1.input_.size())
key.potential_value = Concat(p1, Extract(255, 0, key))
if str(key.raw) != "":
constraints += [key.raw == key.potential_value.value]
else:
key.potential_value = self.calc_sha3(
key.input_.potential_value, key.input_.size()
)
if str(key.raw) != "":
constraints += [key.raw == key.potential_value.value]
assert key.size() == key.potential_value.size()
return constraints, key.potential_value
class Account: class Account:

@ -31,6 +31,7 @@ class BitVec(Expression[z3.BitVecRef]):
:param raw: :param raw:
:param annotations: :param annotations:
""" """
self.potential_value = None
super().__init__(raw, annotations) super().__init__(raw, annotations)
def size(self) -> int: def size(self) -> int:

@ -147,14 +147,14 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
if isinstance(bv, BitVecFuncExtract): if isinstance(bv, BitVecFuncExtract):
if parent is None: if parent is None:
parent = bv.parent parent = bv.parent
if parent != bv.parent: if hash(parent.raw) != hash(bv.parent.raw):
continue continue
bfne_cnt += 1 bfne_cnt += 1
if bfne_cnt == len(bvs): if bfne_cnt == len(bvs):
# check for continuity # check for continuity
fail = True fail = True
if bvs[0].high == bvs[0].parent.size() - 1 and bvs[-1].low == 0: if bvs[-1].low == 0:
fail = False fail = False
for index, bv in enumerate(bvs): for index, bv in enumerate(bvs):
if index == 0: if index == 0:
@ -162,10 +162,26 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
if bv.high + 1 != bvs[index - 1].low: if bv.high + 1 != bvs[index - 1].low:
fail = True fail = True
break break
if fail is False: if fail is False:
return bvs[0].parent if bvs[0].high == bvs[0].parent.size() - 1:
return bvs[0].parent
else:
return BitVecFuncExtract(
raw=nraw,
func_name=bvs[0].func_name,
input_=bvs[0].input_,
nested_functions=nested_functions,
concat_args=concat_list,
low=bvs[-1].low,
high=bvs[0].high,
parent=bvs[0].parent,
)
if nested_functions: if nested_functions:
for bv in bvs:
bv.simplify()
return BitVecFunc( return BitVecFunc(
raw=nraw, raw=nraw,
func_name="Hybrid", func_name="Hybrid",
@ -185,6 +201,7 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
:param bv: :param bv:
:return: :return:
""" """
raw = z3.Extract(high, low, bv.raw) raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc): if isinstance(bv, BitVecFunc):
count = 0 count = 0
@ -199,9 +216,9 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
small_bv.size() - (high - low + 1), small_bv.size() - (high - low + 1),
small_bv, small_bv,
) )
if high < count: elif high < count:
break break
if low < count: elif low < count:
if low + small_bv.size() <= high: if low + small_bv.size() <= high:
val = Concat(small_bv, val) val = Concat(small_bv, val)
else: else:
@ -219,6 +236,7 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
val.raw == val.parent.raw val.raw == val.parent.raw
): ):
val = val.parent val = val.parent
val.simplify()
return val return val
input_string = "" input_string = ""
# Is there a better value to set func_name and input to in this case? # Is there a better value to set func_name and input to in this case?

@ -113,8 +113,6 @@ def _comparison_helper(
Bool(condition) if b.nested_functions else Bool(True), Bool(condition) if b.nested_functions else Bool(True),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
) )
if a.potential_value is not None:
return Or(comparision, b == a.potential_value)
return comparision return comparision
@ -142,7 +140,6 @@ class BitVecFunc(BitVec):
self.input_ = input_ self.input_ = input_
self.nested_functions = nested_functions or [] self.nested_functions = nested_functions or []
self.nested_functions = list(dict.fromkeys(self.nested_functions)) self.nested_functions = list(dict.fromkeys(self.nested_functions))
self.potential_value = None
self.concat_args = concat_args or [] self.concat_args = concat_args or []
if isinstance(input_, BitVecFunc): if isinstance(input_, BitVecFunc):
self.nested_functions.extend(input_.nested_functions) self.nested_functions.extend(input_.nested_functions)

@ -2,6 +2,10 @@ import z3
from typing import Union, List from typing import Union, List
z3.set_option(
max_args=10000000, max_lines=10000000, max_depth=10000000, max_visited=1000000
)
class Model: class Model:
""" The model class wraps a z3 model """ The model class wraps a z3 model

Loading…
Cancel
Save