Fix some Concat tests

extract_concat_invariance
Nikhil 5 years ago
parent c268838d60
commit f806c0abcb
  1. 29
      mythril/laser/smt/bitvec_helper.py

@ -1,6 +1,6 @@
from typing import Union, overload, List, Set, cast, Any, Callable from typing import Union, overload, List, Set, cast, Any, Callable
import z3 import z3
from copy import copy
from mythril.laser.smt.bool import Bool, Or from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec, BitVecExtract from mythril.laser.smt.bitvec import BitVec, BitVecExtract
from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import BitVecFunc
@ -83,8 +83,7 @@ def UGE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned greater or equals expression. """Create an unsigned greater or equals expression.
:param a: :param a:
:param b:/home/nikhil/Work/Mythril/mythril2/mythril/laser/smt/bitvec_helper.py", line 154, in Concat :param b:
or bv.parent
:return: :return:
""" """
return Or(UGT(a, b), a == b) return Or(UGT(a, b), a == b)
@ -146,26 +145,33 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
nested_functions += bv.nested_functions nested_functions += bv.nested_functions
nested_functions += [bv] nested_functions += [bv]
new_bvs = [] new_bvs = []
prev_bv = bvs[0] prev_bv = copy(bvs[0])
# casting everywhere in if's looks quite messy, so I am type ignoring
# casting everywhere in "if's" will look quite messy, so I am type ignoring them.
for bv in bvs[1:]: for bv in bvs[1:]:
if ( if (
not check_extracted_var(bv) not check_extracted_var(bv)
or bv.high + 1 != prev_bv.low # type: ignore or bv.high + 1 != prev_bv.low # type: ignore
or z3.simplify(bv.parent.raw != prev_bv.parent.raw) # type: ignore or z3.simplify(bv.parent.raw != prev_bv.parent.raw) # type: ignore
): ):
if check_extracted_var(prev_bv) and hash(prev_bv) == hash( if check_extracted_var(prev_bv) and hash(prev_bv) == hash(
prev_bv.parent prev_bv.parent # type: ignore
): # type: ignore ):
new_bvs.append(prev_bv.parent) # type: ignore new_bvs.append(prev_bv.parent) # type: ignore
else: else:
new_bvs.append(prev_bv) new_bvs.append(prev_bv)
prev_bv = bv prev_bv = copy(bv)
continue continue
prev_bv.raw = z3.Concat(prev_bv.raw, bv.raw) prev_bv.raw = z3.Concat(prev_bv.raw, bv.raw)
prev_bv.low = bv.low prev_bv.low = bv.low # type: ignore
if check_extracted_var(prev_bv) and hash(z3.simplify(prev_bv.raw)) == hash(
prev_bv.parent # type: ignore
):
prev_bv = prev_bv.parent # type: ignore
new_bvs.append(prev_bv) new_bvs.append(prev_bv)
if len(new_bvs) == 1:
return prev_bv
if nested_functions: if nested_functions:
return BitVecFunc( return BitVecFunc(
raw=nraw, raw=nraw,
@ -174,6 +180,7 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
nested_functions=nested_functions, nested_functions=nested_functions,
concat_args=new_bvs, concat_args=new_bvs,
) )
return BitVec(raw=nraw, annotations=annotations, concat_args=new_bvs) return BitVec(raw=nraw, annotations=annotations, concat_args=new_bvs)
@ -217,9 +224,11 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
bv.simplify() bv.simplify()
if check_extracted_var(val) and hash(val.raw) == hash(val.parent.raw): if check_extracted_var(val) and hash(val.raw) == hash(val.parent.raw):
val = val.parent val = val.parent
assert val.size() == high - low + 1
return val return val
input_string = "" input_string = ""
bv.simplify() bv.simplify()
assert raw.size() == high - low + 1
if isinstance(bv, BitVecFunc): if isinstance(bv, BitVecFunc):
return BitVecFuncExtract( return BitVecFuncExtract(
raw=raw, raw=raw,

Loading…
Cancel
Save