diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index b308e863..ad0d3549 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -1,7 +1,7 @@ """This module provides classes for an SMT abstraction of bit vectors.""" from operator import lshift, rshift, ne, eq -from typing import Union, Set, cast, Any, Optional, Callable +from typing import Union, Set, cast, Any, Optional, Callable, List import z3 @@ -25,12 +25,14 @@ def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator): class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" - def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None): + def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None, concat_args=None): """ :param raw: :param annotations: """ + self.potential_inputs = [] # type: List["BitVec"] + self.concat_args = concat_args or [] super().__init__(raw, annotations) def size(self) -> int: @@ -277,5 +279,34 @@ class BitVec(Expression[z3.BitVecRef]): return self.raw.__hash__() +class BitVecExtract(BitVec): + """A bit vector function wrapper, useful for preserving Extract() and Concat() operations""" + + def __init__( + self, + raw: z3.BitVecRef, + annotations: Optional[Annotations] = None, + concat_args: List = None, + low=None, + high=None, + parent=None, + ): + """ + + :param raw: The raw bit vector symbol + :param func_name: The function name. e.g. sha3 + :param input: The input to the functions + :param annotations: The annotations the BitVecFunc should start with + """ + super().__init__( + raw=raw, + annotations=annotations, + concat_args=concat_args, + ) + self.low = low + self.high = high + self.parent = parent + + # TODO: Fix circular import issues from mythril.laser.smt.bitvecfunc import BitVecFunc diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index 02faeab2..68efe609 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -1,8 +1,8 @@ -from typing import Union, overload, List, Set, cast, Any, Optional, Callable +from typing import Union, overload, List, Set, cast, Any, Callable import z3 -from mythril.laser.smt.bool import Bool, And, Or -from mythril.laser.smt.bitvec import BitVec +from mythril.laser.smt.bool import Bool, Or +from mythril.laser.smt.bitvec import BitVec, BitVecExtract from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import BitVecFuncExtract from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper @@ -109,6 +109,10 @@ def ULE(a: BitVec, b: BitVec) -> Bool: return Or(ULT(a, b), a == b) +def check_extracted_var(bv: BitVec): + return isinstance(bv, BitVecFuncExtract) or isinstance(bv, BitVecExtract) + + @overload def Concat(*args: List[BitVec]) -> BitVec: ... @@ -131,65 +135,43 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: else: bvs = cast(List[BitVec], args) - concat_list = bvs nraw = z3.Concat([a.raw for a in bvs]) annotations = set() # type: Annotations nested_functions = [] # type: List[BitVecFunc] - bfne_cnt = 0 - parent = None for bv in bvs: annotations = annotations.union(bv.annotations) if isinstance(bv, BitVecFunc): nested_functions += bv.nested_functions nested_functions += [bv] - if isinstance(bv, BitVecFuncExtract): - if parent is None: - parent = bv.parent - if hash(parent.raw) != hash(bv.parent.raw): - continue - bfne_cnt += 1 - - if bfne_cnt == len(bvs): - # check for continuity - fail = True - if bvs[-1].low == 0: - fail = False - for index, bv in enumerate(bvs): - if index == 0: - continue - if bv.high + 1 != bvs[index - 1].low: - fail = True - break - - if fail is False: - if bvs[0].high == bvs[0].parent.size() - 1: - return bvs[0].parent + new_bvs = [] + prev_bv = bvs[0] + # casting everywhere in if's looks quite messy, so I am type ignoring + for bv in bvs[1:]: + if ( + not check_extracted_var(bv) + or bv.high + 1 != prev_bv.low # type: ignore + or bv.parent != prev_bv.parent # type: ignore + ): + if check_extracted_var(prev_bv) and hash(prev_bv) == hash( + prev_bv.parent + ): # type: ignore + new_bvs.append(prev_bv.parent) # type: ignore else: - return BitVecFuncExtract( - raw=nraw, - func_name=bvs[0].func_name, - input_=bvs[0].input_, - nested_functions=nested_functions, - concat_args=concat_list, - low=bvs[-1].low, - high=bvs[0].high, - parent=bvs[0].parent, - ) - + new_bvs.append(prev_bv) + prev_bv = bv + continue + prev_bv = Concat(prev_bv, bv) + new_bvs.append(prev_bv) if nested_functions: - for bv in bvs: - bv.simplify() - return BitVecFunc( raw=nraw, func_name="Hybrid", input_=BitVec(z3.BitVec("", 256), annotations=annotations), nested_functions=nested_functions, - concat_args=concat_list, + concat_args=new_bvs, ) - - return BitVec(nraw, annotations) + return BitVec(raw=nraw, annotations=annotations, concat_args=new_bvs) def Extract(high: int, low: int, bv: BitVec) -> BitVec: @@ -202,43 +184,40 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: """ raw = z3.Extract(high, low, bv.raw) - if isinstance(bv, BitVecFunc): - count = 0 - val = None - for small_bv in bv.concat_args[::-1]: - if low == count: - if low + small_bv.size() <= high: - val = small_bv - else: - val = Extract( + count = 0 + val = None + for small_bv in bv.concat_args[::-1]: + if low == count: + if low + small_bv.size() <= high: + val = small_bv + else: + val = Extract( + small_bv.size() - 1, small_bv.size() - (high - low + 1), small_bv + ) + elif high < count: + break + elif low < count: + if low + small_bv.size() <= high: + val = Concat(small_bv, val) + else: + val = Concat( + Extract( small_bv.size() - 1, small_bv.size() - (high - low + 1), small_bv, - ) - elif high < count: - break - elif low < count: - if low + small_bv.size() <= high: - val = Concat(small_bv, val) - else: - val = Concat( - Extract( - small_bv.size() - 1, - small_bv.size() - (high - low + 1), - small_bv, - ), - val, - ) - count += small_bv.size() - if val is not None: - if isinstance(val, BitVecFuncExtract) and z3.simplify( - val.raw == val.parent.raw - ): - val = val.parent - val.simplify() - return val - input_string = "" - # Is there a better value to set func_name and input to in this case? + ), + val, + ) + count += small_bv.size() + if val is not None: + val.simplify() + bv.simplify() + if check_extracted_var(val) and hash(val.raw) == hash(val.parent.raw): + val = val.parent + return val + input_string = "" + bv.simplify() + if isinstance(bv, BitVecFunc): return BitVecFuncExtract( raw=raw, func_name="Hybrid", @@ -248,8 +227,8 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: high=high, parent=bv, ) - - return BitVec(raw, annotations=bv.annotations) + else: + return BitVecExtract(raw=raw, low=low, high=high, parent=bv) def URem(a: BitVec, b: BitVec) -> BitVec: diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index aba81587..3c3c17a2 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -3,7 +3,7 @@ from itertools import product from typing import Optional, Union, cast, Callable, List import z3 -from mythril.laser.smt.bitvec import BitVec, Annotations +from mythril.laser.smt.bitvec import BitVec, Annotations, BitVecExtract from mythril.laser.smt.bool import Or, Bool, And @@ -139,10 +139,9 @@ class BitVecFunc(BitVec): self.input_ = input_ self.nested_functions = nested_functions or [] self.nested_functions = list(dict.fromkeys(self.nested_functions)) - self.concat_args = concat_args or [] if isinstance(input_, BitVecFunc): self.nested_functions.extend(input_.nested_functions) - super().__init__(raw, annotations) + super().__init__(raw, annotations, concat_args=concat_args) def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": """Create an addition expression.