Enhance the previous cases

extract_concat_invariance
Nikhil 5 years ago
parent bebfa1000f
commit 8322e08eeb
  1. 35
      mythril/laser/smt/bitvec.py
  2. 141
      mythril/laser/smt/bitvec_helper.py
  3. 5
      mythril/laser/smt/bitvecfunc.py

@ -1,7 +1,7 @@
"""This module provides classes for an SMT abstraction of bit vectors."""
from operator import lshift, rshift, ne, eq
from typing import Union, Set, cast, Any, Optional, Callable
from typing import Union, Set, cast, Any, Optional, Callable, List
import z3
@ -25,12 +25,14 @@ def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator):
class BitVec(Expression[z3.BitVecRef]):
"""A bit vector symbol."""
def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None):
def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None, concat_args=None):
"""
:param raw:
:param annotations:
"""
self.potential_inputs = [] # type: List["BitVec"]
self.concat_args = concat_args or []
super().__init__(raw, annotations)
def size(self) -> int:
@ -277,5 +279,34 @@ class BitVec(Expression[z3.BitVecRef]):
return self.raw.__hash__()
class BitVecExtract(BitVec):
"""A bit vector function wrapper, useful for preserving Extract() and Concat() operations"""
def __init__(
self,
raw: z3.BitVecRef,
annotations: Optional[Annotations] = None,
concat_args: List = None,
low=None,
high=None,
parent=None,
):
"""
:param raw: The raw bit vector symbol
:param func_name: The function name. e.g. sha3
:param input: The input to the functions
:param annotations: The annotations the BitVecFunc should start with
"""
super().__init__(
raw=raw,
annotations=annotations,
concat_args=concat_args,
)
self.low = low
self.high = high
self.parent = parent
# TODO: Fix circular import issues
from mythril.laser.smt.bitvecfunc import BitVecFunc

@ -1,8 +1,8 @@
from typing import Union, overload, List, Set, cast, Any, Optional, Callable
from typing import Union, overload, List, Set, cast, Any, Callable
import z3
from mythril.laser.smt.bool import Bool, And, Or
from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec, BitVecExtract
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.bitvecfunc import BitVecFuncExtract
from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper
@ -109,6 +109,10 @@ def ULE(a: BitVec, b: BitVec) -> Bool:
return Or(ULT(a, b), a == b)
def check_extracted_var(bv: BitVec):
return isinstance(bv, BitVecFuncExtract) or isinstance(bv, BitVecExtract)
@overload
def Concat(*args: List[BitVec]) -> BitVec:
...
@ -131,65 +135,43 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
else:
bvs = cast(List[BitVec], args)
concat_list = bvs
nraw = z3.Concat([a.raw for a in bvs])
annotations = set() # type: Annotations
nested_functions = [] # type: List[BitVecFunc]
bfne_cnt = 0
parent = None
for bv in bvs:
annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
nested_functions += bv.nested_functions
nested_functions += [bv]
if isinstance(bv, BitVecFuncExtract):
if parent is None:
parent = bv.parent
if hash(parent.raw) != hash(bv.parent.raw):
continue
bfne_cnt += 1
if bfne_cnt == len(bvs):
# check for continuity
fail = True
if bvs[-1].low == 0:
fail = False
for index, bv in enumerate(bvs):
if index == 0:
continue
if bv.high + 1 != bvs[index - 1].low:
fail = True
break
if fail is False:
if bvs[0].high == bvs[0].parent.size() - 1:
return bvs[0].parent
new_bvs = []
prev_bv = bvs[0]
# casting everywhere in if's looks quite messy, so I am type ignoring
for bv in bvs[1:]:
if (
not check_extracted_var(bv)
or bv.high + 1 != prev_bv.low # type: ignore
or bv.parent != prev_bv.parent # type: ignore
):
if check_extracted_var(prev_bv) and hash(prev_bv) == hash(
prev_bv.parent
): # type: ignore
new_bvs.append(prev_bv.parent) # type: ignore
else:
return BitVecFuncExtract(
raw=nraw,
func_name=bvs[0].func_name,
input_=bvs[0].input_,
nested_functions=nested_functions,
concat_args=concat_list,
low=bvs[-1].low,
high=bvs[0].high,
parent=bvs[0].parent,
)
new_bvs.append(prev_bv)
prev_bv = bv
continue
prev_bv = Concat(prev_bv, bv)
new_bvs.append(prev_bv)
if nested_functions:
for bv in bvs:
bv.simplify()
return BitVecFunc(
raw=nraw,
func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=annotations),
nested_functions=nested_functions,
concat_args=concat_list,
concat_args=new_bvs,
)
return BitVec(nraw, annotations)
return BitVec(raw=nraw, annotations=annotations, concat_args=new_bvs)
def Extract(high: int, low: int, bv: BitVec) -> BitVec:
@ -202,43 +184,40 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
"""
raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
count = 0
val = None
for small_bv in bv.concat_args[::-1]:
if low == count:
if low + small_bv.size() <= high:
val = small_bv
else:
val = Extract(
count = 0
val = None
for small_bv in bv.concat_args[::-1]:
if low == count:
if low + small_bv.size() <= high:
val = small_bv
else:
val = Extract(
small_bv.size() - 1, small_bv.size() - (high - low + 1), small_bv
)
elif high < count:
break
elif low < count:
if low + small_bv.size() <= high:
val = Concat(small_bv, val)
else:
val = Concat(
Extract(
small_bv.size() - 1,
small_bv.size() - (high - low + 1),
small_bv,
)
elif high < count:
break
elif low < count:
if low + small_bv.size() <= high:
val = Concat(small_bv, val)
else:
val = Concat(
Extract(
small_bv.size() - 1,
small_bv.size() - (high - low + 1),
small_bv,
),
val,
)
count += small_bv.size()
if val is not None:
if isinstance(val, BitVecFuncExtract) and z3.simplify(
val.raw == val.parent.raw
):
val = val.parent
val.simplify()
return val
input_string = ""
# Is there a better value to set func_name and input to in this case?
),
val,
)
count += small_bv.size()
if val is not None:
val.simplify()
bv.simplify()
if check_extracted_var(val) and hash(val.raw) == hash(val.parent.raw):
val = val.parent
return val
input_string = ""
bv.simplify()
if isinstance(bv, BitVecFunc):
return BitVecFuncExtract(
raw=raw,
func_name="Hybrid",
@ -248,8 +227,8 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
high=high,
parent=bv,
)
return BitVec(raw, annotations=bv.annotations)
else:
return BitVecExtract(raw=raw, low=low, high=high, parent=bv)
def URem(a: BitVec, b: BitVec) -> BitVec:

@ -3,7 +3,7 @@ from itertools import product
from typing import Optional, Union, cast, Callable, List
import z3
from mythril.laser.smt.bitvec import BitVec, Annotations
from mythril.laser.smt.bitvec import BitVec, Annotations, BitVecExtract
from mythril.laser.smt.bool import Or, Bool, And
@ -139,10 +139,9 @@ class BitVecFunc(BitVec):
self.input_ = input_
self.nested_functions = nested_functions or []
self.nested_functions = list(dict.fromkeys(self.nested_functions))
self.concat_args = concat_args or []
if isinstance(input_, BitVecFunc):
self.nested_functions.extend(input_.nested_functions)
super().__init__(raw, annotations)
super().__init__(raw, annotations, concat_args=concat_args)
def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an addition expression.

Loading…
Cancel
Save