From ce39e97c2d4cca7513f0b454df0ecd7caa6f9aa2 Mon Sep 17 00:00:00 2001 From: Nikhil Date: Thu, 1 Aug 2019 10:13:11 +0530 Subject: [PATCH] Add a few changes --- mythril/laser/ethereum/state/account.py | 51 +++--- mythril/laser/ethereum/svm.py | 53 ++++++- mythril/laser/smt/bitvec.py | 5 +- mythril/laser/smt/bitvec_helper.py | 2 +- mythril/laser/smt/bitvecfunc.py | 202 +++++++++++++++++++++++- mythril/laser/smt/bitvecfuncextract.py | 35 ---- 6 files changed, 285 insertions(+), 63 deletions(-) diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index cc873b02..f2f4d66b 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -17,7 +17,7 @@ from mythril.laser.smt import ( Concat, If, Or, - And + And, ) from mythril.disassembler.disassembly import Disassembly from mythril.laser.smt import symbol_factory @@ -179,19 +179,19 @@ class Storage: @staticmethod def _array_condition(key: BitVec): - return ( - not isinstance(key, BitVecFunc) - or ( - isinstance(key, BitVecFunc) - and key.func_name == "keccak256" - and len(key.nested_functions) <= 1 - ) + return not isinstance(key, BitVecFunc) or ( + isinstance(key, BitVecFunc) + and key.func_name == "keccak256" + and len(key.nested_functions) <= 1 ) def __getitem__(self, key: BitVec) -> BitVec: ite_get = self._ite_region[cast(BitVecFunc, key)] array_get = self._array_region[key] - return If(ite_get, ite_get, array_get) + if self._array_condition(key): + return If(ite_get, ite_get, array_get) + else: + return ite_get def __setitem__(self, key: BitVec, value: Any) -> None: self._printable_storage[key] = value @@ -226,9 +226,7 @@ class Storage: ran = hex(randint(0, 2 ** size - 1))[2:] if len(ran) % 2 != 0: ran += "0" - val = int( - keccak_256(bytes.fromhex(ran)).hexdigest(), 16 - ) + val = int(keccak_256(bytes.fromhex(ran)).hexdigest(), 16) return symbol_factory.BitVecVal(val, 256) def _find_value(self, symbol, model): @@ -263,10 +261,22 @@ class Storage: """ if not isinstance(key, BitVecFunc): concrete_values = [self._find_value(key, model[0]) for model in models] + ex_key = Extract(511, 256, key) + if ex_key.pseudo_input: + lis = [ + self._find_value(ex_key.pseudo_input, model[0]) + for model in models[4:] + ] + for val in lis: + if val is not None: + ex_key.pseudo_input = val + break potential_values = concrete_values key.potential_value = [] for i, val in enumerate(potential_values): - key.potential_value.append((val, And(models[i][1], BitVec(key.raw) == val))) + key.potential_value.append( + (val, And(models[i][1], BitVec(key.raw) == val)) + ) return key.potential_value if key.size() == 512: @@ -278,7 +288,9 @@ class Storage: for val1, val2 in zip(concrete_vals, vals2): if val2 and val1: c_val = Concat(val1[0], val2[0]) - condition = And(models[i][1], BitVec(key.raw) == c_val, val1[1], val2[1]) + condition = And( + models[i][1], BitVec(key.raw) == c_val, val1[1], val2[1] + ) key.potential_value.append((c_val, condition)) else: key.potential_value.append((None, None)) @@ -291,7 +303,10 @@ class Storage: if isinstance(key, BitVecFunc): if key.size() == 512: p1 = Extract(511, 256, key) - p1 = [(self.calc_sha3(val[0], p1.input_.size()), val[1]) for val in p1.input_.potential_value] + p1 = [ + (self.calc_sha3(val[0], p1.input_.size()), val[1]) + for val in p1.input_.potential_value + ] key.potential_value = [] for i, val in enumerate(p1): if val[0]: @@ -304,10 +319,10 @@ class Storage: key.potential_value = [] for i, val in enumerate(key.input_.potential_value): if val[0]: - concrete_val = self.calc_sha3( - val[0], key.input_.size() + concrete_val = self.calc_sha3(val[0], key.input_.size()) + condition = And( + models[i][1], val[1], BitVec(key.raw) == concrete_val ) - condition = And(models[i][1], val[1], BitVec(key.raw) == concrete_val) key.potential_value.append((concrete_val, condition)) else: key.potential_value.append((None, None)) diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 7581b299..dd9ff675 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -25,7 +25,7 @@ from mythril.laser.ethereum.transaction import ( execute_contract_creation, execute_message_call, ) -from mythril.laser.smt import symbol_factory +from mythril.laser.smt import symbol_factory, And, BitVecFunc, BitVec, Extract ACTOR_ADDRESSES = [ symbol_factory.BitVecVal(0xAFFEAFFEAFFEAFFEAFFEAFFEAFFEAFFEAFFEAFFE, 256), @@ -358,10 +358,59 @@ class LaserEVM: sat = False for actor in ACTOR_ADDRESSES: try: - models_tuple.append((get_model(constraints=global_state.mstate.constraints + [sender == actor]), sender == actor)) + models_tuple.append( + ( + get_model( + constraints=global_state.mstate.constraints + + [sender == actor] + ), + sender == actor, + ) + ) sat = True except UnsatError: models_tuple.append((None, sender == actor)) + import random, sha3 + + calldata_cond = True + for account in global_state.world_state.accounts.values(): + for key in account.storage._ite_region.itedict: + if ( + isinstance(key, BitVecFunc) + and not isinstance(key.input_, BitVecFunc) + and isinstance(key.input_, BitVec) + and key.input_.symbolic + and key.input_.size() == 512 + ): + pseudo_input = random.randint(0, 2 ** 256 - 1) + hex_v = hex(pseudo_input)[2:] + if len(hex_v) % 2 == 1: + hex_v += "0" + hash_val = symbol_factory.BitVecVal( + int(sha3.keccak_256(bytes.fromhex(hex_v)).hexdigest()[2:], 16), + 256, + ) + pseudo_input = symbol_factory.BitVecVal(pseudo_input, 256) + calldata_cond = And( + calldata_cond, + key.input_ == hash_val, + Extract(511, 256, key.input_).pseudo_input == pseudo_input, + ) + for actor in ACTOR_ADDRESSES: + try: + models_tuple.append( + ( + get_model( + constraints=global_state.mstate.constraints + + [sender == actor, calldata_cond] + ), + And(calldata_cond, sender == actor), + ) + ) + sat = True + except UnsatError: + models_tuple.append((None, And(calldata_cond, sender == actor))) + if not sat: return [False] diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index fe0a9848..4df6a379 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -21,17 +21,20 @@ def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator): b = z3.Concat(z3.BitVecVal(0, a.size() - b.size()), b) return operator(a, b) +from random import randint class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" - def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None): + def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None, has_pseudo_input=True): """ :param raw: :param annotations: """ self.potential_value = None + if has_pseudo_input: + self.pseudo_input = BitVec(z3.BitVec("{}_pseudoinput".format(randint(0, 2**100)), 256), annotations=annotations, has_pseudo_input=False) super().__init__(raw, annotations) def size(self) -> int: diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index 6c0cd59d..fd618390 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -7,7 +7,7 @@ from mythril.laser.smt.bitvec import BitVec from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper -from mythril.laser.smt.bitvecfuncextract import BitVecFuncExtract +from mythril.laser.smt.bitvecfunc import BitVecFuncExtract Annotations = Set[Any] diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index 902bee25..d34cbe06 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -7,6 +7,139 @@ from mythril.laser.smt.bitvec import BitVec, Annotations from mythril.laser.smt.bool import Or, Bool, And +def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: + """Create a concatenation expression. + + :param args: + :return: + """ + # The following statement is used if a list is provided as an argument to concat + if len(args) == 1 and isinstance(args[0], list): + bvs = args[0] # type: List[BitVec] + else: + bvs = cast(List[BitVec], args) + + concat_list = bvs + nraw = z3.Concat([a.raw for a in bvs]) + annotations = set() # type: Annotations + + nested_functions = [] # type: List[BitVecFunc] + bfne_cnt = 0 + parent = None + for bv in bvs: + annotations = annotations.union(bv.annotations) + if isinstance(bv, BitVecFunc): + nested_functions += bv.nested_functions + nested_functions += [bv] + if isinstance(bv, BitVecFuncExtract): + if parent is None: + parent = bv.parent + if hash(parent.raw) != hash(bv.parent.raw): + continue + bfne_cnt += 1 + + if bfne_cnt == len(bvs): + # check for continuity + fail = True + if bvs[-1].low == 0: + fail = False + for index, bv in enumerate(bvs): + if index == 0: + continue + if bv.high + 1 != bvs[index - 1].low: + fail = True + break + + if fail is False: + if bvs[0].high == bvs[0].parent.size() - 1: + return bvs[0].parent + else: + return BitVecFuncExtract( + raw=nraw, + func_name=bvs[0].func_name, + input_=bvs[0].input_, + nested_functions=nested_functions, + concat_args=concat_list, + low=bvs[-1].low, + high=bvs[0].high, + parent=bvs[0].parent, + ) + + if nested_functions: + for bv in bvs: + bv.simplify() + + return BitVecFunc( + raw=nraw, + func_name="Hybrid", + input_=BitVec(z3.BitVec("", 256), annotations=annotations), + nested_functions=nested_functions, + concat_args=concat_list, + ) + + return BitVec(nraw, annotations) + + +def Extract(high: int, low: int, bv: BitVec) -> BitVec: + """Create an extract expression. + + :param high: + :param low: + :param bv: + :return: + """ + + raw = z3.Extract(high, low, bv.raw) + if isinstance(bv, BitVecFunc): + count = 0 + val = None + for small_bv in bv.concat_args[::-1]: + if low == count: + if low + small_bv.size() <= high: + val = small_bv + else: + val = Extract( + small_bv.size() - 1, + small_bv.size() - (high - low + 1), + small_bv, + ) + elif high < count: + break + elif low < count: + if low + small_bv.size() <= high: + val = Concat(small_bv, val) + else: + val = Concat( + Extract( + small_bv.size() - 1, + small_bv.size() - (high - low + 1), + small_bv, + ), + val, + ) + count += small_bv.size() + if val is not None: + if isinstance(val, BitVecFuncExtract) and z3.simplify( + val.raw == val.parent.raw + ): + val = val.parent + val.simplify() + return val + input_string = "" + # Is there a better value to set func_name and input to in this case? + return BitVecFuncExtract( + raw=raw, + func_name="Hybrid", + input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations), + nested_functions=bv.nested_functions + [bv], + low=low, + high=high, + parent=bv, + ) + + return BitVec(raw, annotations=bv.annotations) + + def _arithmetic_helper( a: "BitVecFunc", b: Union[BitVec, int], operation: Callable ) -> "BitVecFunc": @@ -68,12 +201,34 @@ def _comparison_helper( operation = operator.lt return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) - if not isinstance(b, BitVecFunc) and a.potential_value: - condition = False - for value, cond in a.potential_value: - if value is not None: - condition = Or(condition, And(b == value, cond)) - return And(condition, operation(a.raw, b.raw)) + if ( + a.size() == 512 + and b.size() == 512 + and z3.is_true( + z3.simplify(z3.Extract(255, 0, a.raw) == z3.Extract(255, 0, b.raw)) + ) + ): + a = Extract(511, 256, a) + b = Extract(511, 256, b) + + if not isinstance(b, BitVecFunc): + if a.potential_value: + condition = False + for value, cond in a.potential_value: + if value is not None: + condition = Or(condition, And(operation(b, value), cond)) + return And(condition, operation(a.raw, b.raw)) + + if b.pseudo_input and b.pseudo_input.size() >= a.input_.size(): + if b.pseudo_input.size() > a.input_.size(): + padded_a = z3.Concat( + z3.BitVecVal(0, b.pseudo_input.size() - a.input_.size()), + a.input_.raw, + ) + else: + padded_a = a.input_.raw + print(b.pseudo_input.raw) + return And(operation(a.raw, b.raw), operation(padded_a, b.pseudo_input.raw)) if ( not isinstance(b, BitVecFunc) or not a.func_name @@ -307,3 +462,38 @@ class BitVecFunc(BitVec): def __hash__(self) -> int: return self.raw.__hash__() + + +class BitVecFuncExtract(BitVecFunc): + """A bit vector function wrapper, useful for preserving Extract() and Concat() operations""" + + def __init__( + self, + raw: z3.BitVecRef, + func_name: Optional[str], + input_: "BitVec" = None, + annotations: Optional[Annotations] = None, + nested_functions: Optional[List["BitVecFunc"]] = None, + concat_args: List = None, + low=None, + high=None, + parent=None, + ): + """ + + :param raw: The raw bit vector symbol + :param func_name: The function name. e.g. sha3 + :param input: The input to the functions + :param annotations: The annotations the BitVecFunc should start with + """ + super().__init__( + raw=raw, + func_name=func_name, + input_=input_, + annotations=annotations, + nested_functions=nested_functions, + concat_args=concat_args, + ) + self.low = low + self.high = high + self.parent = parent diff --git a/mythril/laser/smt/bitvecfuncextract.py b/mythril/laser/smt/bitvecfuncextract.py index 4fa320be..e8547346 100644 --- a/mythril/laser/smt/bitvecfuncextract.py +++ b/mythril/laser/smt/bitvecfuncextract.py @@ -3,38 +3,3 @@ import z3 from typing import Optional, List from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvec import Annotations, BitVec - - -class BitVecFuncExtract(BitVecFunc): - """A bit vector function wrapper, useful for preserving Extract() and Concat() operations""" - - def __init__( - self, - raw: z3.BitVecRef, - func_name: Optional[str], - input_: "BitVec" = None, - annotations: Optional[Annotations] = None, - nested_functions: Optional[List["BitVecFunc"]] = None, - concat_args: List = None, - low=None, - high=None, - parent=None, - ): - """ - - :param raw: The raw bit vector symbol - :param func_name: The function name. e.g. sha3 - :param input: The input to the functions - :param annotations: The annotations the BitVecFunc should start with - """ - super().__init__( - raw=raw, - func_name=func_name, - input_=input_, - annotations=annotations, - nested_functions=nested_functions, - concat_args=concat_args, - ) - self.low = low - self.high = high - self.parent = parent