Refactor code

extract_concat_invariance
Nikhil 5 years ago
parent 55ac452784
commit cfdffffaa0
  1. 95
      mythril/laser/smt/bitvec_helper.py

@ -113,6 +113,41 @@ def check_extracted_var(bv: BitVec) -> bool:
return isinstance(bv, BitVecFuncExtract) or isinstance(bv, BitVecExtract)
def concat_helper(bvs: List[BitVec]) -> List[BitVec]:
"""
Automatically concatenat the adjacent symbols which are
concatenate-able(like Concatenating Extract(10, 8, a) and Extract(7, 0, a) to Extract(10, 0, a))
:param bvs: List of Bitvecs
:return: List of Bitvecs with some concats
"""
prev_bv = copy(bvs[0])
new_bvs = []
# casting everywhere in "if's" will look quite messy, so I am type ignoring them.
for bv in bvs[1:]:
if (
not (check_extracted_var(bv) and check_extracted_var(prev_bv))
or bv.high + 1 != prev_bv.low # type: ignore
or z3.simplify(bv.parent.raw != prev_bv.parent.raw) # type: ignore
):
if check_extracted_var(prev_bv) and hash(prev_bv) == hash(
prev_bv.parent # type: ignore
):
new_bvs.append(prev_bv.parent) # type: ignore
else:
new_bvs.append(prev_bv)
prev_bv = copy(bv)
continue
prev_bv.raw = z3.Concat(prev_bv.raw, bv.raw)
prev_bv.low = bv.low # type: ignore
if check_extracted_var(prev_bv) and hash(z3.simplify(prev_bv.raw)) == hash(
prev_bv.parent # type: ignore
):
prev_bv = prev_bv.parent # type: ignore
new_bvs.append(prev_bv)
return new_bvs
@overload
def Concat(*args: List[BitVec]) -> BitVec:
...
@ -144,34 +179,9 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
if isinstance(bv, BitVecFunc):
nested_functions += bv.nested_functions
nested_functions += [bv]
new_bvs = []
prev_bv = copy(bvs[0])
# casting everywhere in "if's" will look quite messy, so I am type ignoring them.
for bv in bvs[1:]:
if (
not (check_extracted_var(bv) and check_extracted_var(prev_bv))
or bv.high + 1 != prev_bv.low # type: ignore
or z3.simplify(bv.parent.raw != prev_bv.parent.raw) # type: ignore
):
if check_extracted_var(prev_bv) and hash(prev_bv) == hash(
prev_bv.parent # type: ignore
):
new_bvs.append(prev_bv.parent) # type: ignore
else:
new_bvs.append(prev_bv)
prev_bv = copy(bv)
continue
prev_bv.raw = z3.Concat(prev_bv.raw, bv.raw)
prev_bv.low = bv.low # type: ignore
if check_extracted_var(prev_bv) and hash(z3.simplify(prev_bv.raw)) == hash(
prev_bv.parent # type: ignore
):
prev_bv = prev_bv.parent # type: ignore
new_bvs.append(prev_bv)
new_bvs = concat_helper(bvs)
if len(new_bvs) == 1:
return prev_bv
return new_bvs[0]
if nested_functions:
return BitVecFunc(
raw=nraw,
@ -184,16 +194,7 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
return BitVec(raw=nraw, annotations=annotations, concat_args=new_bvs)
def Extract(high: int, low: int, bv: BitVec) -> BitVec:
"""Create an extract expression.
:param high:
:param low:
:param bv:
:return:
"""
raw = z3.Extract(high, low, bv.raw)
def extract_helper(high: int, low: int, bv: BitVec) -> BitVec:
count = 0
val = None
for small_bv in bv.concat_args[::-1]:
@ -219,11 +220,27 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
val,
)
count += small_bv.size()
return val
def Extract(high: int, low: int, bv: BitVec) -> BitVec:
"""Create an extract expression.
:param high:
:param low:
:param bv:
:return:
"""
raw = z3.Extract(high, low, bv.raw)
val = extract_helper(high, low, bv)
if val is not None:
val.simplify()
bv.simplify()
if check_extracted_var(val) and hash(val.raw) == hash(val.parent.raw):
val = val.parent
if check_extracted_var(val) and hash(val.raw) == hash(
val.parent.raw # type: ignore
):
val = val.parent # type: ignore
assert val.size() == high - low + 1
return val
input_string = ""

Loading…
Cancel
Save