diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index 1b5ce691..deaddb21 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -113,6 +113,41 @@ def check_extracted_var(bv: BitVec) -> bool: return isinstance(bv, BitVecFuncExtract) or isinstance(bv, BitVecExtract) +def concat_helper(bvs: List[BitVec]) -> List[BitVec]: + """ + Automatically concatenat the adjacent symbols which are + concatenate-able(like Concatenating Extract(10, 8, a) and Extract(7, 0, a) to Extract(10, 0, a)) + :param bvs: List of Bitvecs + :return: List of Bitvecs with some concats + """ + prev_bv = copy(bvs[0]) + new_bvs = [] + # casting everywhere in "if's" will look quite messy, so I am type ignoring them. + for bv in bvs[1:]: + if ( + not (check_extracted_var(bv) and check_extracted_var(prev_bv)) + or bv.high + 1 != prev_bv.low # type: ignore + or z3.simplify(bv.parent.raw != prev_bv.parent.raw) # type: ignore + ): + + if check_extracted_var(prev_bv) and hash(prev_bv) == hash( + prev_bv.parent # type: ignore + ): + new_bvs.append(prev_bv.parent) # type: ignore + else: + new_bvs.append(prev_bv) + prev_bv = copy(bv) + continue + prev_bv.raw = z3.Concat(prev_bv.raw, bv.raw) + prev_bv.low = bv.low # type: ignore + if check_extracted_var(prev_bv) and hash(z3.simplify(prev_bv.raw)) == hash( + prev_bv.parent # type: ignore + ): + prev_bv = prev_bv.parent # type: ignore + new_bvs.append(prev_bv) + return new_bvs + + @overload def Concat(*args: List[BitVec]) -> BitVec: ... @@ -144,34 +179,9 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: if isinstance(bv, BitVecFunc): nested_functions += bv.nested_functions nested_functions += [bv] - new_bvs = [] - prev_bv = copy(bvs[0]) - - # casting everywhere in "if's" will look quite messy, so I am type ignoring them. - for bv in bvs[1:]: - if ( - not (check_extracted_var(bv) and check_extracted_var(prev_bv)) - or bv.high + 1 != prev_bv.low # type: ignore - or z3.simplify(bv.parent.raw != prev_bv.parent.raw) # type: ignore - ): - - if check_extracted_var(prev_bv) and hash(prev_bv) == hash( - prev_bv.parent # type: ignore - ): - new_bvs.append(prev_bv.parent) # type: ignore - else: - new_bvs.append(prev_bv) - prev_bv = copy(bv) - continue - prev_bv.raw = z3.Concat(prev_bv.raw, bv.raw) - prev_bv.low = bv.low # type: ignore - if check_extracted_var(prev_bv) and hash(z3.simplify(prev_bv.raw)) == hash( - prev_bv.parent # type: ignore - ): - prev_bv = prev_bv.parent # type: ignore - new_bvs.append(prev_bv) + new_bvs = concat_helper(bvs) if len(new_bvs) == 1: - return prev_bv + return new_bvs[0] if nested_functions: return BitVecFunc( raw=nraw, @@ -184,16 +194,7 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: return BitVec(raw=nraw, annotations=annotations, concat_args=new_bvs) -def Extract(high: int, low: int, bv: BitVec) -> BitVec: - """Create an extract expression. - - :param high: - :param low: - :param bv: - :return: - """ - - raw = z3.Extract(high, low, bv.raw) +def extract_helper(high: int, low: int, bv: BitVec) -> BitVec: count = 0 val = None for small_bv in bv.concat_args[::-1]: @@ -219,11 +220,27 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: val, ) count += small_bv.size() + return val + + +def Extract(high: int, low: int, bv: BitVec) -> BitVec: + """Create an extract expression. + + :param high: + :param low: + :param bv: + :return: + """ + + raw = z3.Extract(high, low, bv.raw) + val = extract_helper(high, low, bv) if val is not None: val.simplify() bv.simplify() - if check_extracted_var(val) and hash(val.raw) == hash(val.parent.raw): - val = val.parent + if check_extracted_var(val) and hash(val.raw) == hash( + val.parent.raw # type: ignore + ): + val = val.parent # type: ignore assert val.size() == high - low + 1 return val input_string = ""