Bye Bye BitVecFuncs

remove/dos
Nikhil Parasaram 5 years ago
parent d20e601a6b
commit 76a6cbe042
  1. 1
      mythril/laser/smt/__init__.py
  2. 28
      mythril/laser/smt/bitvec.py
  3. 74
      mythril/laser/smt/bitvec_helper.py
  4. 297
      mythril/laser/smt/bitvecfunc.py

@ -18,7 +18,6 @@ from mythril.laser.smt.bitvec_helper import (
LShR, LShR,
) )
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.array import K, Array, BaseArray

@ -66,8 +66,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other + self
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw + other, annotations=self.annotations) return BitVec(self.raw + other, annotations=self.annotations)
@ -80,8 +78,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other - self
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw - other, annotations=self.annotations) return BitVec(self.raw - other, annotations=self.annotations)
@ -94,8 +90,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other * self
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
return BitVec(self.raw * other.raw, annotations=union) return BitVec(self.raw * other.raw, annotations=union)
@ -105,8 +99,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other / self
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
return BitVec(self.raw / other.raw, annotations=union) return BitVec(self.raw / other.raw, annotations=union)
@ -116,8 +108,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other & self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -129,8 +119,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other | self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -142,8 +130,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other ^ self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -155,8 +141,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other > self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -168,8 +152,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other < self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -204,8 +186,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other == self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return Bool( return Bool(
cast(z3.BoolRef, self.raw == other), annotations=self.annotations cast(z3.BoolRef, self.raw == other), annotations=self.annotations
@ -224,8 +204,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other != self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return Bool( return Bool(
cast(z3.BoolRef, self.raw != other), annotations=self.annotations cast(z3.BoolRef, self.raw != other), annotations=self.annotations
@ -244,8 +222,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param operator: The shift operator :param operator: The shift operator
:return: the resulting output :return: the resulting output
""" """
if isinstance(other, BitVecFunc):
return operator(other, self)
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return BitVec( return BitVec(
operator(self.raw, other), annotations=self.annotations operator(self.raw, other), annotations=self.annotations
@ -275,7 +251,3 @@ class BitVec(Expression[z3.BitVecRef]):
:return: :return:
""" """
return self.raw.__hash__() return self.raw.__hash__()
# TODO: Fix circular import issues
from mythril.laser.smt.bitvecfunc import BitVecFunc

@ -3,9 +3,6 @@ import z3
from mythril.laser.smt.bool import Bool, Or from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper
from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper
Annotations = Set[Any] Annotations = Set[Any]
@ -14,20 +11,12 @@ def _comparison_helper(
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool: ) -> Bool:
annotations = a.annotations.union(b.annotations) annotations = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
return _func_comparison_helper(a, b, operation, default_value, inputs_equal)
return Bool(operation(a.raw, b.raw), annotations) return Bool(operation(a.raw, b.raw), annotations)
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw) raw = operation(a.raw, b.raw)
union = a.annotations.union(b.annotations) union = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
return _func_arithmetic_helper(a, b, operation)
elif isinstance(b, BitVecFunc):
return _func_arithmetic_helper(b, a, operation)
return BitVec(raw, annotations=union) return BitVec(raw, annotations=union)
@ -43,8 +32,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
:param c: :param c:
:return: :return:
""" """
# TODO: Handle BitVecFunc
if not isinstance(a, Bool): if not isinstance(a, Bool):
a = Bool(z3.BoolVal(a)) a = Bool(z3.BoolVal(a))
if not isinstance(b, BitVec): if not isinstance(b, BitVec):
@ -52,19 +39,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
if not isinstance(c, BitVec): if not isinstance(c, BitVec):
c = BitVec(z3.BitVecVal(c, 256)) c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations.union(b.annotations).union(c.annotations) union = a.annotations.union(b.annotations).union(c.annotations)
bvf = [] # type: List[BitVecFunc]
if isinstance(a, BitVecFunc):
bvf += [a]
if isinstance(b, BitVecFunc):
bvf += [b]
if isinstance(c, BitVecFunc):
bvf += [c]
if bvf:
raw = z3.If(a.raw, b.raw, c.raw)
nested_functions = [nf for func in bvf for nf in func.nested_functions] + bvf
return BitVecFunc(raw, func_name="Hybrid", nested_functions=nested_functions)
return BitVec(z3.If(a.raw, b.raw, c.raw), union) return BitVec(z3.If(a.raw, b.raw, c.raw), union)
@ -133,21 +107,8 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
nraw = z3.Concat([a.raw for a in bvs]) nraw = z3.Concat([a.raw for a in bvs])
annotations = set() # type: Annotations annotations = set() # type: Annotations
nested_functions = [] # type: List[BitVecFunc]
for bv in bvs: for bv in bvs:
annotations = annotations.union(bv.annotations) annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
nested_functions += bv.nested_functions
nested_functions += [bv]
if nested_functions:
return BitVecFunc(
raw=nraw,
func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=annotations),
nested_functions=nested_functions,
)
return BitVec(nraw, annotations) return BitVec(nraw, annotations)
@ -160,16 +121,6 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
:return: :return:
""" """
raw = z3.Extract(high, low, bv.raw) raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
input_string = ""
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations),
nested_functions=bv.nested_functions + [bv],
)
return BitVec(raw, annotations=bv.annotations) return BitVec(raw, annotations=bv.annotations)
@ -210,34 +161,9 @@ def Sum(*args: BitVec) -> BitVec:
""" """
raw = z3.Sum([a.raw for a in args]) raw = z3.Sum([a.raw for a in args])
annotations = set() # type: Annotations annotations = set() # type: Annotations
bitvecfuncs = []
for bv in args: for bv in args:
annotations = annotations.union(bv.annotations) annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
bitvecfuncs.append(bv)
nested_functions = [
nf for func in bitvecfuncs for nf in func.nested_functions
] + bitvecfuncs
if len(bitvecfuncs) >= 2:
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=None,
annotations=annotations,
nested_functions=nested_functions,
)
elif len(bitvecfuncs) == 1:
return BitVecFunc(
raw=raw,
func_name=bitvecfuncs[0].func_name,
input_=bitvecfuncs[0].input_,
annotations=annotations,
nested_functions=nested_functions,
)
return BitVec(raw, annotations) return BitVec(raw, annotations)

@ -1,297 +0,0 @@
import operator
from itertools import product
from typing import Optional, Union, cast, Callable, List
import z3
from mythril.laser.smt.bitvec import BitVec, Annotations, _padded_operation
from mythril.laser.smt.bool import Or, Bool, And
def _arithmetic_helper(
a: "BitVecFunc", b: Union[BitVec, int], operation: Callable
) -> "BitVecFunc":
"""
Helper function for arithmetic operations on BitVecFuncs.
:param a: The BitVecFunc to perform the operation on.
:param b: A BitVec or int to perform the operation on.
:param operation: The arithmetic operation to perform.
:return: The resulting BitVecFunc
"""
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
raw = operation(a.raw, b.raw)
union = a.annotations.union(b.annotations)
if isinstance(b, BitVecFunc):
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=union),
nested_functions=a.nested_functions + b.nested_functions + [a, b],
)
return BitVecFunc(
raw=raw,
func_name=a.func_name,
input_=a.input_,
annotations=union,
nested_functions=a.nested_functions + [a],
)
def _comparison_helper(
a: "BitVecFunc",
b: Union[BitVec, int],
operation: Callable,
default_value: bool,
inputs_equal: bool,
) -> Bool:
"""
Helper function for comparison operations with BitVecFuncs.
:param a: The BitVecFunc to compare.
:param b: A BitVec or int to compare to.
:param operation: The comparison operation to perform.
:return: The resulting Bool
"""
# Is there some hack for gt/lt comparisons?
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
union = a.annotations.union(b.annotations)
if not a.symbolic and not b.symbolic:
if operation == z3.UGT:
operation = operator.gt
if operation == z3.ULT:
operation = operator.lt
return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
or str(operation) not in ("<built-in function eq>", "<built-in function ne>")
):
return Bool(z3.BoolVal(default_value), annotations=union)
condition = True
for a_nest, b_nest in product(a.nested_functions, b.nested_functions):
if a_nest.func_name != b_nest.func_name:
continue
if a_nest.func_name == "Hybrid":
continue
# a.input (eq/neq) b.input ==> a == b
if inputs_equal:
condition = z3.And(
condition,
z3.Or(
z3.Not((a_nest.input_ == b_nest.input_).raw),
(a_nest.raw == b_nest.raw),
),
z3.Or(
z3.Not((a_nest.raw == b_nest.raw)),
(a_nest.input_ == b_nest.input_).raw,
),
)
else:
condition = z3.And(
condition,
z3.Or(
z3.Not((a_nest.input_ != b_nest.input_).raw),
(a_nest.raw == b_nest.raw),
),
z3.Or(
z3.Not((a_nest.raw == b_nest.raw)),
(a_nest.input_ != b_nest.input_).raw,
),
)
return And(
Bool(
cast(z3.BoolRef, _padded_operation(a.raw, b.raw, operation)),
annotations=union,
),
Bool(condition) if b.nested_functions else Bool(True),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
)
class BitVecFunc(BitVec):
"""A bit vector function symbol. Used in place of functions like sha3."""
def __init__(
self,
raw: z3.BitVecRef,
func_name: Optional[str],
input_: "BitVec" = None,
annotations: Optional[Annotations] = None,
nested_functions: Optional[List["BitVecFunc"]] = None,
):
"""
:param raw: The raw bit vector symbol
:param func_name: The function name. e.g. sha3
:param input: The input to the functions
:param annotations: The annotations the BitVecFunc should start with
"""
self.func_name = func_name
self.input_ = input_
self.nested_functions = nested_functions or []
self.nested_functions = list(dict.fromkeys(self.nested_functions))
if isinstance(input_, BitVecFunc):
self.nested_functions.extend(input_.nested_functions)
super().__init__(raw, annotations)
def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an addition expression.
:param other: The int or BitVec to add to this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.add)
def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a subtraction expression.
:param other: The int or BitVec to subtract from this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.sub)
def __mul__(self, other: "BitVec") -> "BitVecFunc":
"""Create a multiplication expression.
:param other: The int or BitVec to multiply to this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.mul)
def __truediv__(self, other: "BitVec") -> "BitVecFunc":
"""Create a signed division expression.
:param other: The int or BitVec to divide this BitVecFunc by
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.truediv)
def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an and expression.
:param other: The int or BitVec to and with this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.and_)
def __or__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an or expression.
:param other: The int or BitVec to or with this BitVecFunc
:return: The resulting BitVecFunc
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _arithmetic_helper(self, other, operator.or_)
def __xor__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a xor expression.
:param other: The int or BitVec to xor with this BitVecFunc
:return: The resulting BitVecFunc
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _arithmetic_helper(self, other, operator.xor)
def __lt__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed less than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.lt, default_value=False, inputs_equal=False
)
def __gt__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed greater than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.gt, default_value=False, inputs_equal=False
)
def __le__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed less than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return Or(self < other, self == other)
def __ge__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed greater than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return Or(self > other, self == other)
# MYPY: fix complains about overriding __eq__
def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an equality expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.eq, default_value=False, inputs_equal=True
)
# MYPY: fix complains about overriding __ne__
def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an inequality expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.ne, default_value=True, inputs_equal=False
)
def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec":
"""
Left shift operation
:param other: The int or BitVec to shift on
:return The resulting left shifted output
"""
return _arithmetic_helper(self, other, operator.lshift)
def __rshift__(self, other: Union[int, "BitVec"]) -> "BitVec":
"""
Right shift operation
:param other: The int or BitVec to shift on
:return The resulting right shifted output:
"""
return _arithmetic_helper(self, other, operator.rshift)
def __hash__(self) -> int:
return self.raw.__hash__()
Loading…
Cancel
Save