Fix up Keccak Tree for concretisation

feature/concretise_storage
Nikhil 5 years ago
parent 91101dd41e
commit 0c1453c502
  1. 3
      mythril/laser/ethereum/instructions.py
  2. 67
      mythril/laser/ethereum/state/account.py
  3. 1
      mythril/laser/smt/bitvec.py
  4. 26
      mythril/laser/smt/bitvec_helper.py
  5. 3
      mythril/laser/smt/bitvecfunc.py
  6. 4
      mythril/laser/smt/model.py

@ -1370,6 +1370,8 @@ class Instruction:
index = state.stack.pop()
state.stack.append(global_state.environment.active_account.storage[index])
if global_state.get_current_instruction()["address"] == 418:
print(state.stack[-1])
return [global_state]
@StateTransition()
@ -1437,7 +1439,6 @@ class Instruction:
states = []
op0, condition = state.stack.pop(), state.stack.pop()
try:
jump_addr = util.get_concrete_int(op0)
except TypeError:

@ -49,6 +49,8 @@ class ArrayStorageRegion(StorageRegion):
@staticmethod
def _sanitize(input_: BitVec) -> BitVec:
if input_.potential_value:
input_ = input_.potential_value
if input_.size() == 512:
return input_
if input_.size() > 512:
@ -118,6 +120,7 @@ class ArrayStorageRegion(StorageRegion):
storage, is_keccak_storage = self._get_corresponding_storage(key)
if is_keccak_storage:
key = self._sanitize(key.input_)
storage[key] = value
def __deepcopy__(self, memodict=dict()):
@ -218,27 +221,26 @@ class Storage:
key.input_, BitVecFunc
):
continue
new_constraints, key = self._traverse_concretise(key, model)
new_constraints, key_concrete = self._traverse_concretise(key, model)
key.potential_value = key_concrete
constraints += new_constraints
self._array_region[key] = value
self._ite_region.itelist = []
return constraints
def calc_sha3(self, val):
def calc_sha3(self, val, size):
try:
val = int(sha3_256(str(val.as_long()).encode("utf-8")).hexdigest(), 16)
val = int(sha3_256(str(val.value).encode("utf-8")).hexdigest(), 16)
except AttributeError:
val = int(
sha3_256(
str(randint(0, 2 ** val.input_ - 1)).encode("utf-8")
).hexdigest(),
16,
sha3_256(str(randint(0, 2 ** size - 1)).encode("utf-8")).hexdigest(), 16
)
return symbol_factory.BitVecVal(val, 256)
def _find_value(self, symbol, model):
modify = symbol
size = min(symbol.size(), 256)
if symbol.size() > 256:
index = simplify(Extract(255, 0, symbol))
else:
@ -250,9 +252,10 @@ class Storage:
modify = modify.as_long()
except AttributeError:
modify = randint(0, 2 ** modify.size() - 1)
modify = symbol_factory.BitVecVal(modify, 256)
modify = symbol_factory.BitVecVal(modify, size)
if index and not index.symbolic:
modify = Concat(modify, index)
assert modify.size() == symbol.size()
return modify
def _traverse_concretise(self, key, model):
@ -262,36 +265,17 @@ class Storage:
:param model:
:return:
"""
print(simplify(key))
constraints = []
if not isinstance(key, BitVecFunc):
concrete_value = self._find_value(key, model)
constraints.append(concrete_value == key)
key.potential_value = concrete_value
return constraints, concrete_value
if key.size() != 512 and str(key.input_) == "":
print("SHIT")
for arg in key.concat_args:
new_const, val = self._traverse_concretise(arg, model)
constraints += new_const
else:
cnt = 0
val = None
i = 0
while cnt != 256 and i < len(key.concat_args):
if val is None:
val = key.concat_args[i]
else:
val = Concat(val, key.concat_args[i])
cnt += key.concat_args[i].size()
i += 1
if val is not None:
val.simplify()
print(key.concat_args[i:], val, "CONCAT")
if key.size() == 512:
val = simplify(Extract(511, 256, key))
new_const, concrete_val = self._traverse_concretise(val, model)
constraints += new_const
if i < len(key.concat_args):
for arg in key.concat_args[i:]:
new_const, val = self._traverse_concretise(arg, model)
new_const, val2 = self._traverse_concretise(Extract(255, 0, key), model)
key.potential_value = Concat(concrete_val, val2)
constraints += new_const
if isinstance(key.input_, BitVec) or (
@ -300,10 +284,21 @@ class Storage:
new_const, _ = self._traverse_concretise(key.input_, model)
constraints += new_const
if key.func_name == "sha3":
key.potential_value = self.calc_sha3(key.input_)
return constraints, key
if isinstance(key, BitVecFunc):
if key.size() == 512:
p1 = Extract(511, 256, key)
p1 = self.calc_sha3(p1.input_.potential_value, p1.input_.size())
key.potential_value = Concat(p1, Extract(255, 0, key))
if str(key.raw) != "":
constraints += [key.raw == key.potential_value.value]
else:
key.potential_value = self.calc_sha3(
key.input_.potential_value, key.input_.size()
)
if str(key.raw) != "":
constraints += [key.raw == key.potential_value.value]
assert key.size() == key.potential_value.size()
return constraints, key.potential_value
class Account:

@ -31,6 +31,7 @@ class BitVec(Expression[z3.BitVecRef]):
:param raw:
:param annotations:
"""
self.potential_value = None
super().__init__(raw, annotations)
def size(self) -> int:

@ -147,14 +147,14 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
if isinstance(bv, BitVecFuncExtract):
if parent is None:
parent = bv.parent
if parent != bv.parent:
if hash(parent.raw) != hash(bv.parent.raw):
continue
bfne_cnt += 1
if bfne_cnt == len(bvs):
# check for continuity
fail = True
if bvs[0].high == bvs[0].parent.size() - 1 and bvs[-1].low == 0:
if bvs[-1].low == 0:
fail = False
for index, bv in enumerate(bvs):
if index == 0:
@ -162,10 +162,26 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
if bv.high + 1 != bvs[index - 1].low:
fail = True
break
if fail is False:
if bvs[0].high == bvs[0].parent.size() - 1:
return bvs[0].parent
else:
return BitVecFuncExtract(
raw=nraw,
func_name=bvs[0].func_name,
input_=bvs[0].input_,
nested_functions=nested_functions,
concat_args=concat_list,
low=bvs[-1].low,
high=bvs[0].high,
parent=bvs[0].parent,
)
if nested_functions:
for bv in bvs:
bv.simplify()
return BitVecFunc(
raw=nraw,
func_name="Hybrid",
@ -185,6 +201,7 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
:param bv:
:return:
"""
raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
count = 0
@ -199,9 +216,9 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
small_bv.size() - (high - low + 1),
small_bv,
)
if high < count:
elif high < count:
break
if low < count:
elif low < count:
if low + small_bv.size() <= high:
val = Concat(small_bv, val)
else:
@ -219,6 +236,7 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
val.raw == val.parent.raw
):
val = val.parent
val.simplify()
return val
input_string = ""
# Is there a better value to set func_name and input to in this case?

@ -113,8 +113,6 @@ def _comparison_helper(
Bool(condition) if b.nested_functions else Bool(True),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
)
if a.potential_value is not None:
return Or(comparision, b == a.potential_value)
return comparision
@ -142,7 +140,6 @@ class BitVecFunc(BitVec):
self.input_ = input_
self.nested_functions = nested_functions or []
self.nested_functions = list(dict.fromkeys(self.nested_functions))
self.potential_value = None
self.concat_args = concat_args or []
if isinstance(input_, BitVecFunc):
self.nested_functions.extend(input_.nested_functions)

@ -2,6 +2,10 @@ import z3
from typing import Union, List
z3.set_option(
max_args=10000000, max_lines=10000000, max_depth=10000000, max_visited=1000000
)
class Model:
""" The model class wraps a z3 model

Loading…
Cancel
Save