Enhance the previous cases

extract_concat_invariance
Nikhil 5 years ago
parent bebfa1000f
commit 8322e08eeb
  1. 35
      mythril/laser/smt/bitvec.py
  2. 141
      mythril/laser/smt/bitvec_helper.py
  3. 5
      mythril/laser/smt/bitvecfunc.py

@ -1,7 +1,7 @@
"""This module provides classes for an SMT abstraction of bit vectors.""" """This module provides classes for an SMT abstraction of bit vectors."""
from operator import lshift, rshift, ne, eq from operator import lshift, rshift, ne, eq
from typing import Union, Set, cast, Any, Optional, Callable from typing import Union, Set, cast, Any, Optional, Callable, List
import z3 import z3
@ -25,12 +25,14 @@ def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator):
class BitVec(Expression[z3.BitVecRef]): class BitVec(Expression[z3.BitVecRef]):
"""A bit vector symbol.""" """A bit vector symbol."""
def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None): def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None, concat_args=None):
""" """
:param raw: :param raw:
:param annotations: :param annotations:
""" """
self.potential_inputs = [] # type: List["BitVec"]
self.concat_args = concat_args or []
super().__init__(raw, annotations) super().__init__(raw, annotations)
def size(self) -> int: def size(self) -> int:
@ -277,5 +279,34 @@ class BitVec(Expression[z3.BitVecRef]):
return self.raw.__hash__() return self.raw.__hash__()
class BitVecExtract(BitVec):
"""A bit vector function wrapper, useful for preserving Extract() and Concat() operations"""
def __init__(
self,
raw: z3.BitVecRef,
annotations: Optional[Annotations] = None,
concat_args: List = None,
low=None,
high=None,
parent=None,
):
"""
:param raw: The raw bit vector symbol
:param func_name: The function name. e.g. sha3
:param input: The input to the functions
:param annotations: The annotations the BitVecFunc should start with
"""
super().__init__(
raw=raw,
annotations=annotations,
concat_args=concat_args,
)
self.low = low
self.high = high
self.parent = parent
# TODO: Fix circular import issues # TODO: Fix circular import issues
from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import BitVecFunc

@ -1,8 +1,8 @@
from typing import Union, overload, List, Set, cast, Any, Optional, Callable from typing import Union, overload, List, Set, cast, Any, Callable
import z3 import z3
from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec from mythril.laser.smt.bitvec import BitVec, BitVecExtract
from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.bitvecfunc import BitVecFuncExtract from mythril.laser.smt.bitvecfunc import BitVecFuncExtract
from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper
@ -109,6 +109,10 @@ def ULE(a: BitVec, b: BitVec) -> Bool:
return Or(ULT(a, b), a == b) return Or(ULT(a, b), a == b)
def check_extracted_var(bv: BitVec):
return isinstance(bv, BitVecFuncExtract) or isinstance(bv, BitVecExtract)
@overload @overload
def Concat(*args: List[BitVec]) -> BitVec: def Concat(*args: List[BitVec]) -> BitVec:
... ...
@ -131,65 +135,43 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
else: else:
bvs = cast(List[BitVec], args) bvs = cast(List[BitVec], args)
concat_list = bvs
nraw = z3.Concat([a.raw for a in bvs]) nraw = z3.Concat([a.raw for a in bvs])
annotations = set() # type: Annotations annotations = set() # type: Annotations
nested_functions = [] # type: List[BitVecFunc] nested_functions = [] # type: List[BitVecFunc]
bfne_cnt = 0
parent = None
for bv in bvs: for bv in bvs:
annotations = annotations.union(bv.annotations) annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc): if isinstance(bv, BitVecFunc):
nested_functions += bv.nested_functions nested_functions += bv.nested_functions
nested_functions += [bv] nested_functions += [bv]
if isinstance(bv, BitVecFuncExtract): new_bvs = []
if parent is None: prev_bv = bvs[0]
parent = bv.parent # casting everywhere in if's looks quite messy, so I am type ignoring
if hash(parent.raw) != hash(bv.parent.raw): for bv in bvs[1:]:
continue if (
bfne_cnt += 1 not check_extracted_var(bv)
or bv.high + 1 != prev_bv.low # type: ignore
if bfne_cnt == len(bvs): or bv.parent != prev_bv.parent # type: ignore
# check for continuity ):
fail = True if check_extracted_var(prev_bv) and hash(prev_bv) == hash(
if bvs[-1].low == 0: prev_bv.parent
fail = False ): # type: ignore
for index, bv in enumerate(bvs): new_bvs.append(prev_bv.parent) # type: ignore
if index == 0:
continue
if bv.high + 1 != bvs[index - 1].low:
fail = True
break
if fail is False:
if bvs[0].high == bvs[0].parent.size() - 1:
return bvs[0].parent
else: else:
return BitVecFuncExtract( new_bvs.append(prev_bv)
raw=nraw, prev_bv = bv
func_name=bvs[0].func_name, continue
input_=bvs[0].input_, prev_bv = Concat(prev_bv, bv)
nested_functions=nested_functions, new_bvs.append(prev_bv)
concat_args=concat_list,
low=bvs[-1].low,
high=bvs[0].high,
parent=bvs[0].parent,
)
if nested_functions: if nested_functions:
for bv in bvs:
bv.simplify()
return BitVecFunc( return BitVecFunc(
raw=nraw, raw=nraw,
func_name="Hybrid", func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=annotations), input_=BitVec(z3.BitVec("", 256), annotations=annotations),
nested_functions=nested_functions, nested_functions=nested_functions,
concat_args=concat_list, concat_args=new_bvs,
) )
return BitVec(raw=nraw, annotations=annotations, concat_args=new_bvs)
return BitVec(nraw, annotations)
def Extract(high: int, low: int, bv: BitVec) -> BitVec: def Extract(high: int, low: int, bv: BitVec) -> BitVec:
@ -202,43 +184,40 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
""" """
raw = z3.Extract(high, low, bv.raw) raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc): count = 0
count = 0 val = None
val = None for small_bv in bv.concat_args[::-1]:
for small_bv in bv.concat_args[::-1]: if low == count:
if low == count: if low + small_bv.size() <= high:
if low + small_bv.size() <= high: val = small_bv
val = small_bv else:
else: val = Extract(
val = Extract( small_bv.size() - 1, small_bv.size() - (high - low + 1), small_bv
)
elif high < count:
break
elif low < count:
if low + small_bv.size() <= high:
val = Concat(small_bv, val)
else:
val = Concat(
Extract(
small_bv.size() - 1, small_bv.size() - 1,
small_bv.size() - (high - low + 1), small_bv.size() - (high - low + 1),
small_bv, small_bv,
) ),
elif high < count: val,
break )
elif low < count: count += small_bv.size()
if low + small_bv.size() <= high: if val is not None:
val = Concat(small_bv, val) val.simplify()
else: bv.simplify()
val = Concat( if check_extracted_var(val) and hash(val.raw) == hash(val.parent.raw):
Extract( val = val.parent
small_bv.size() - 1, return val
small_bv.size() - (high - low + 1), input_string = ""
small_bv, bv.simplify()
), if isinstance(bv, BitVecFunc):
val,
)
count += small_bv.size()
if val is not None:
if isinstance(val, BitVecFuncExtract) and z3.simplify(
val.raw == val.parent.raw
):
val = val.parent
val.simplify()
return val
input_string = ""
# Is there a better value to set func_name and input to in this case?
return BitVecFuncExtract( return BitVecFuncExtract(
raw=raw, raw=raw,
func_name="Hybrid", func_name="Hybrid",
@ -248,8 +227,8 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
high=high, high=high,
parent=bv, parent=bv,
) )
else:
return BitVec(raw, annotations=bv.annotations) return BitVecExtract(raw=raw, low=low, high=high, parent=bv)
def URem(a: BitVec, b: BitVec) -> BitVec: def URem(a: BitVec, b: BitVec) -> BitVec:

@ -3,7 +3,7 @@ from itertools import product
from typing import Optional, Union, cast, Callable, List from typing import Optional, Union, cast, Callable, List
import z3 import z3
from mythril.laser.smt.bitvec import BitVec, Annotations from mythril.laser.smt.bitvec import BitVec, Annotations, BitVecExtract
from mythril.laser.smt.bool import Or, Bool, And from mythril.laser.smt.bool import Or, Bool, And
@ -139,10 +139,9 @@ class BitVecFunc(BitVec):
self.input_ = input_ self.input_ = input_
self.nested_functions = nested_functions or [] self.nested_functions = nested_functions or []
self.nested_functions = list(dict.fromkeys(self.nested_functions)) self.nested_functions = list(dict.fromkeys(self.nested_functions))
self.concat_args = concat_args or []
if isinstance(input_, BitVecFunc): if isinstance(input_, BitVecFunc):
self.nested_functions.extend(input_.nested_functions) self.nested_functions.extend(input_.nested_functions)
super().__init__(raw, annotations) super().__init__(raw, annotations, concat_args=concat_args)
def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an addition expression. """Create an addition expression.

Loading…
Cancel
Save