Remove unused code (#1263)

* Fix a regression

* Handle new black format

* Bye Bye BitVecFuncs

* Bye Bye bitvecfunc tests

* Bye Bye complex code
pull/1283/head
Nikhil Parasaram 5 years ago committed by GitHub
parent 8d5a51a619
commit 634d59caa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 58
      mythril/laser/ethereum/state/account.py
  2. 63
      mythril/laser/smt/__init__.py
  3. 28
      mythril/laser/smt/bitvec.py
  4. 82
      mythril/laser/smt/bitvec_helper.py
  5. 297
      mythril/laser/smt/bitvecfunc.py
  6. 237
      tests/laser/smt/bitvecfunc_test.py

@ -4,7 +4,7 @@ This includes classes representing accounts and their storage.
""" """
import logging import logging
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import Any, Dict, Union, Tuple, Set, cast from typing import Any, Dict, Union, Set
from mythril.laser.smt import ( from mythril.laser.smt import (
@ -12,10 +12,7 @@ from mythril.laser.smt import (
K, K,
BitVec, BitVec,
simplify, simplify,
BitVecFunc,
Extract,
BaseArray, BaseArray,
Concat,
) )
from mythril.disassembler.disassembly import Disassembly from mythril.disassembler.disassembly import Disassembly
from mythril.laser.smt import symbol_factory from mythril.laser.smt import symbol_factory
@ -23,26 +20,6 @@ from mythril.laser.smt import symbol_factory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class StorageRegion:
def __getitem__(self, item):
raise NotImplementedError
def __setitem__(self, key, value):
raise NotImplementedError
class ArrayStorageRegion(StorageRegion):
""" An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions"""
pass
class IteStorageRegion(StorageRegion):
""" An IteStorageRegion is a storage region that uses Ite statements to implement a storage"""
pass
class Storage: class Storage:
"""Storage class represents the storage of an Account.""" """Storage class represents the storage of an Account."""
@ -62,18 +39,8 @@ class Storage:
self.storage_keys_loaded = set() # type: Set[int] self.storage_keys_loaded = set() # type: Set[int]
self.address = address self.address = address
@staticmethod
def _sanitize(input_: BitVec) -> BitVec:
if input_.size() == 512:
return input_
if input_.size() > 512:
return Extract(511, 0, input_)
else:
return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_)
def __getitem__(self, item: BitVec) -> BitVec: def __getitem__(self, item: BitVec) -> BitVec:
storage = self._standard_storage storage = self._standard_storage
sanitized_item = item
if ( if (
self.address self.address
and self.address.value != 0 and self.address.value != 0
@ -82,7 +49,7 @@ class Storage:
and (self.dynld and self.dynld.storage_loading) and (self.dynld and self.dynld.storage_loading)
): ):
try: try:
storage[sanitized_item] = symbol_factory.BitVecVal( storage[item] = symbol_factory.BitVecVal(
int( int(
self.dynld.read_storage( self.dynld.read_storage(
contract_address="0x{:040X}".format(self.address.value), contract_address="0x{:040X}".format(self.address.value),
@ -93,29 +60,14 @@ class Storage:
256, 256,
) )
self.storage_keys_loaded.add(int(item.value)) self.storage_keys_loaded.add(int(item.value))
self.printable_storage[item] = storage[sanitized_item] self.printable_storage[item] = storage[item]
except ValueError as e: except ValueError as e:
log.debug("Couldn't read storage at %s: %s", item, e) log.debug("Couldn't read storage at %s: %s", item, e)
return simplify(storage[sanitized_item]) return simplify(storage[item])
@staticmethod
def get_map_index(key: BitVec) -> BitVec:
if (
not isinstance(key, BitVecFunc)
or key.func_name != "keccak256"
or key.input_ is None
):
return None
index = Extract(255, 0, key.input_)
return simplify(index)
def _get_corresponding_storage(self, key: BitVec) -> BaseArray:
return self._standard_storage
def __setitem__(self, key, value: Any) -> None: def __setitem__(self, key, value: Any) -> None:
storage = self._get_corresponding_storage(key)
self.printable_storage[key] = value self.printable_storage[key] = value
storage[key] = value self._standard_storage[key] = value
if key.symbolic is False: if key.symbolic is False:
self.storage_keys_loaded.add(int(key.value)) self.storage_keys_loaded.add(int(key.value))

@ -18,7 +18,6 @@ from mythril.laser.smt.bitvec_helper import (
LShR, LShR,
) )
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.array import K, Array, BaseArray
@ -80,44 +79,6 @@ class SymbolFactory(Generic[T, U]):
""" """
raise NotImplementedError() raise NotImplementedError()
@staticmethod
def BitVecFuncVal(
value: int,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value.
:param value: The concrete value to set the bit vector to
:param func_name: The name of the bit vector function
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:param input_: The input to the bit vector function
:return: The freshly created bit vector function
"""
raise NotImplementedError()
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value.
:param name: The name of the symbolic bit vector
:param func_name: The name of the bit vector function
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:param input_: The input to the bit vector function
:return: The freshly created bit vector function
"""
raise NotImplementedError()
class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
""" """
@ -159,30 +120,6 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
raw = z3.BitVec(name, size) raw = z3.BitVec(name, size)
return BitVec(raw, annotations) return BitVec(raw, annotations)
@staticmethod
def BitVecFuncVal(
value: int,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a concrete value."""
raw = z3.BitVecVal(value, size)
return BitVecFunc(raw, func_name, input_, annotations)
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value."""
raw = z3.BitVec(name, size)
return BitVecFunc(raw, func_name, input_, annotations)
class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]):
""" """

@ -66,8 +66,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other + self
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw + other, annotations=self.annotations) return BitVec(self.raw + other, annotations=self.annotations)
@ -80,8 +78,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other - self
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw - other, annotations=self.annotations) return BitVec(self.raw - other, annotations=self.annotations)
@ -94,8 +90,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other * self
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
return BitVec(self.raw * other.raw, annotations=union) return BitVec(self.raw * other.raw, annotations=union)
@ -105,8 +99,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other / self
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
return BitVec(self.raw / other.raw, annotations=union) return BitVec(self.raw / other.raw, annotations=union)
@ -116,8 +108,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other & self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -129,8 +119,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other | self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -142,8 +130,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other ^ self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -155,8 +141,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other > self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -168,8 +152,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other < self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size())) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
@ -204,8 +186,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other == self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return Bool( return Bool(
cast(z3.BoolRef, self.raw == other), annotations=self.annotations cast(z3.BoolRef, self.raw == other), annotations=self.annotations
@ -224,8 +204,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other != self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return Bool( return Bool(
cast(z3.BoolRef, self.raw != other), annotations=self.annotations cast(z3.BoolRef, self.raw != other), annotations=self.annotations
@ -244,8 +222,6 @@ class BitVec(Expression[z3.BitVecRef]):
:param operator: The shift operator :param operator: The shift operator
:return: the resulting output :return: the resulting output
""" """
if isinstance(other, BitVecFunc):
return operator(other, self)
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return BitVec( return BitVec(
operator(self.raw, other), annotations=self.annotations operator(self.raw, other), annotations=self.annotations
@ -275,7 +251,3 @@ class BitVec(Expression[z3.BitVecRef]):
:return: :return:
""" """
return self.raw.__hash__() return self.raw.__hash__()
# TODO: Fix circular import issues
from mythril.laser.smt.bitvecfunc import BitVecFunc

@ -3,31 +3,18 @@ import z3
from mythril.laser.smt.bool import Bool, Or from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper
from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper
Annotations = Set[Any] Annotations = Set[Any]
def _comparison_helper( def _comparison_helper(a: BitVec, b: BitVec, operation: Callable) -> Bool:
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool:
annotations = a.annotations.union(b.annotations) annotations = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
return _func_comparison_helper(a, b, operation, default_value, inputs_equal)
return Bool(operation(a.raw, b.raw), annotations) return Bool(operation(a.raw, b.raw), annotations)
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw) raw = operation(a.raw, b.raw)
union = a.annotations.union(b.annotations) union = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
return _func_arithmetic_helper(a, b, operation)
elif isinstance(b, BitVecFunc):
return _func_arithmetic_helper(b, a, operation)
return BitVec(raw, annotations=union) return BitVec(raw, annotations=union)
@ -43,8 +30,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
:param c: :param c:
:return: :return:
""" """
# TODO: Handle BitVecFunc
if not isinstance(a, Bool): if not isinstance(a, Bool):
a = Bool(z3.BoolVal(a)) a = Bool(z3.BoolVal(a))
if not isinstance(b, BitVec): if not isinstance(b, BitVec):
@ -52,19 +37,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
if not isinstance(c, BitVec): if not isinstance(c, BitVec):
c = BitVec(z3.BitVecVal(c, 256)) c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations.union(b.annotations).union(c.annotations) union = a.annotations.union(b.annotations).union(c.annotations)
bvf = [] # type: List[BitVecFunc]
if isinstance(a, BitVecFunc):
bvf += [a]
if isinstance(b, BitVecFunc):
bvf += [b]
if isinstance(c, BitVecFunc):
bvf += [c]
if bvf:
raw = z3.If(a.raw, b.raw, c.raw)
nested_functions = [nf for func in bvf for nf in func.nested_functions] + bvf
return BitVecFunc(raw, func_name="Hybrid", nested_functions=nested_functions)
return BitVec(z3.If(a.raw, b.raw, c.raw), union) return BitVec(z3.If(a.raw, b.raw, c.raw), union)
@ -75,7 +47,7 @@ def UGT(a: BitVec, b: BitVec) -> Bool:
:param b: :param b:
:return: :return:
""" """
return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) return _comparison_helper(a, b, z3.UGT)
def UGE(a: BitVec, b: BitVec) -> Bool: def UGE(a: BitVec, b: BitVec) -> Bool:
@ -95,7 +67,7 @@ def ULT(a: BitVec, b: BitVec) -> Bool:
:param b: :param b:
:return: :return:
""" """
return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) return _comparison_helper(a, b, z3.ULT)
def ULE(a: BitVec, b: BitVec) -> Bool: def ULE(a: BitVec, b: BitVec) -> Bool:
@ -133,21 +105,8 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
nraw = z3.Concat([a.raw for a in bvs]) nraw = z3.Concat([a.raw for a in bvs])
annotations = set() # type: Annotations annotations = set() # type: Annotations
nested_functions = [] # type: List[BitVecFunc]
for bv in bvs: for bv in bvs:
annotations = annotations.union(bv.annotations) annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
nested_functions += bv.nested_functions
nested_functions += [bv]
if nested_functions:
return BitVecFunc(
raw=nraw,
func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=annotations),
nested_functions=nested_functions,
)
return BitVec(nraw, annotations) return BitVec(nraw, annotations)
@ -160,16 +119,6 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
:return: :return:
""" """
raw = z3.Extract(high, low, bv.raw) raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
input_string = ""
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations),
nested_functions=bv.nested_functions + [bv],
)
return BitVec(raw, annotations=bv.annotations) return BitVec(raw, annotations=bv.annotations)
@ -210,34 +159,9 @@ def Sum(*args: BitVec) -> BitVec:
""" """
raw = z3.Sum([a.raw for a in args]) raw = z3.Sum([a.raw for a in args])
annotations = set() # type: Annotations annotations = set() # type: Annotations
bitvecfuncs = []
for bv in args: for bv in args:
annotations = annotations.union(bv.annotations) annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
bitvecfuncs.append(bv)
nested_functions = [
nf for func in bitvecfuncs for nf in func.nested_functions
] + bitvecfuncs
if len(bitvecfuncs) >= 2:
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=None,
annotations=annotations,
nested_functions=nested_functions,
)
elif len(bitvecfuncs) == 1:
return BitVecFunc(
raw=raw,
func_name=bitvecfuncs[0].func_name,
input_=bitvecfuncs[0].input_,
annotations=annotations,
nested_functions=nested_functions,
)
return BitVec(raw, annotations) return BitVec(raw, annotations)

@ -1,297 +0,0 @@
import operator
from itertools import product
from typing import Optional, Union, cast, Callable, List
import z3
from mythril.laser.smt.bitvec import BitVec, Annotations, _padded_operation
from mythril.laser.smt.bool import Or, Bool, And
def _arithmetic_helper(
a: "BitVecFunc", b: Union[BitVec, int], operation: Callable
) -> "BitVecFunc":
"""
Helper function for arithmetic operations on BitVecFuncs.
:param a: The BitVecFunc to perform the operation on.
:param b: A BitVec or int to perform the operation on.
:param operation: The arithmetic operation to perform.
:return: The resulting BitVecFunc
"""
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
raw = operation(a.raw, b.raw)
union = a.annotations.union(b.annotations)
if isinstance(b, BitVecFunc):
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=union),
nested_functions=a.nested_functions + b.nested_functions + [a, b],
)
return BitVecFunc(
raw=raw,
func_name=a.func_name,
input_=a.input_,
annotations=union,
nested_functions=a.nested_functions + [a],
)
def _comparison_helper(
a: "BitVecFunc",
b: Union[BitVec, int],
operation: Callable,
default_value: bool,
inputs_equal: bool,
) -> Bool:
"""
Helper function for comparison operations with BitVecFuncs.
:param a: The BitVecFunc to compare.
:param b: A BitVec or int to compare to.
:param operation: The comparison operation to perform.
:return: The resulting Bool
"""
# Is there some hack for gt/lt comparisons?
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
union = a.annotations.union(b.annotations)
if not a.symbolic and not b.symbolic:
if operation == z3.UGT:
operation = operator.gt
if operation == z3.ULT:
operation = operator.lt
return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
or str(operation) not in ("<built-in function eq>", "<built-in function ne>")
):
return Bool(z3.BoolVal(default_value), annotations=union)
condition = True
for a_nest, b_nest in product(a.nested_functions, b.nested_functions):
if a_nest.func_name != b_nest.func_name:
continue
if a_nest.func_name == "Hybrid":
continue
# a.input (eq/neq) b.input ==> a == b
if inputs_equal:
condition = z3.And(
condition,
z3.Or(
z3.Not((a_nest.input_ == b_nest.input_).raw),
(a_nest.raw == b_nest.raw),
),
z3.Or(
z3.Not((a_nest.raw == b_nest.raw)),
(a_nest.input_ == b_nest.input_).raw,
),
)
else:
condition = z3.And(
condition,
z3.Or(
z3.Not((a_nest.input_ != b_nest.input_).raw),
(a_nest.raw == b_nest.raw),
),
z3.Or(
z3.Not((a_nest.raw == b_nest.raw)),
(a_nest.input_ != b_nest.input_).raw,
),
)
return And(
Bool(
cast(z3.BoolRef, _padded_operation(a.raw, b.raw, operation)),
annotations=union,
),
Bool(condition) if b.nested_functions else Bool(True),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
)
class BitVecFunc(BitVec):
"""A bit vector function symbol. Used in place of functions like sha3."""
def __init__(
self,
raw: z3.BitVecRef,
func_name: Optional[str],
input_: "BitVec" = None,
annotations: Optional[Annotations] = None,
nested_functions: Optional[List["BitVecFunc"]] = None,
):
"""
:param raw: The raw bit vector symbol
:param func_name: The function name. e.g. sha3
:param input: The input to the functions
:param annotations: The annotations the BitVecFunc should start with
"""
self.func_name = func_name
self.input_ = input_
self.nested_functions = nested_functions or []
self.nested_functions = list(dict.fromkeys(self.nested_functions))
if isinstance(input_, BitVecFunc):
self.nested_functions.extend(input_.nested_functions)
super().__init__(raw, annotations)
def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an addition expression.
:param other: The int or BitVec to add to this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.add)
def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a subtraction expression.
:param other: The int or BitVec to subtract from this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.sub)
def __mul__(self, other: "BitVec") -> "BitVecFunc":
"""Create a multiplication expression.
:param other: The int or BitVec to multiply to this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.mul)
def __truediv__(self, other: "BitVec") -> "BitVecFunc":
"""Create a signed division expression.
:param other: The int or BitVec to divide this BitVecFunc by
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.truediv)
def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an and expression.
:param other: The int or BitVec to and with this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.and_)
def __or__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an or expression.
:param other: The int or BitVec to or with this BitVecFunc
:return: The resulting BitVecFunc
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _arithmetic_helper(self, other, operator.or_)
def __xor__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a xor expression.
:param other: The int or BitVec to xor with this BitVecFunc
:return: The resulting BitVecFunc
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _arithmetic_helper(self, other, operator.xor)
def __lt__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed less than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.lt, default_value=False, inputs_equal=False
)
def __gt__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed greater than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.gt, default_value=False, inputs_equal=False
)
def __le__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed less than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return Or(self < other, self == other)
def __ge__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed greater than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return Or(self > other, self == other)
# MYPY: fix complains about overriding __eq__
def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an equality expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.eq, default_value=False, inputs_equal=True
)
# MYPY: fix complains about overriding __ne__
def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an inequality expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.ne, default_value=True, inputs_equal=False
)
def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec":
"""
Left shift operation
:param other: The int or BitVec to shift on
:return The resulting left shifted output
"""
return _arithmetic_helper(self, other, operator.lshift)
def __rshift__(self, other: Union[int, "BitVec"]) -> "BitVec":
"""
Right shift operation
:param other: The int or BitVec to shift on
:return The resulting right shifted output:
"""
return _arithmetic_helper(self, other, operator.rshift)
def __hash__(self) -> int:
return self.raw.__hash__()

@ -1,237 +0,0 @@
from mythril.laser.smt import Solver, symbol_factory, UGT, UGE, ULT, ULE
import z3
import pytest
import operator
@pytest.mark.parametrize(
"operation,expected",
[
(operator.add, z3.unsat),
(operator.sub, z3.unsat),
(operator.and_, z3.sat),
(operator.or_, z3.sat),
(operator.xor, z3.unsat),
],
)
def test_bitvecfunc_arithmetic(operation, expected):
# Arrange
s = Solver()
input_ = symbol_factory.BitVecVal(1, 8)
bvf = symbol_factory.BitVecFuncSym("bvf", "sha3", 256, input_=input_)
x = symbol_factory.BitVecSym("x", 256)
y = symbol_factory.BitVecSym("y", 256)
# Act
s.add(x != y)
s.add(operation(bvf, x) == operation(y, bvf))
# Assert
assert s.check() == expected
@pytest.mark.parametrize(
"operation,expected",
[
(operator.eq, z3.sat),
(operator.ne, z3.unsat),
(operator.lt, z3.unsat),
(operator.le, z3.sat),
(operator.gt, z3.unsat),
(operator.ge, z3.sat),
(UGT, z3.unsat),
(UGE, z3.sat),
(ULT, z3.unsat),
(ULE, z3.sat),
],
)
def test_bitvecfunc_bitvecfunc_comparison(operation, expected):
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=input2)
# Act
s.add(operation(bvf1, bvf2))
s.add(input1 == input2)
# Assert
assert s.check() == expected
def test_bitvecfunc_bitvecfuncval_comparison():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecVal(1337, 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncVal(12345678910, "sha3", 256, input_=input2)
# Act
s.add(bvf1 == bvf2)
# Assert
assert s.check() == z3.sat
assert s.model().eval(input2.raw) == 1337
def test_bitvecfunc_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 == input2)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
def test_bitvecfunc_unequal_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 != input2)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_ext_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 == input2)
s.add(input3 == input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
def test_bitvecfunc_ext_unequal_nested_comparison():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 == input2)
s.add(input3 != input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_ext_unequal_nested_comparison_f():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 != input2)
s.add(input3 == input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_find_input():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
# Act
s.add(input1 == symbol_factory.BitVecVal(1, 256))
s.add(bvf1 == bvf2)
# Assert
assert s.check() == z3.sat
assert s.model()[input2.raw] == 1
def test_bitvecfunc_nested_find_input():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 == symbol_factory.BitVecVal(123, 256))
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
assert s.model()[input2.raw] == 123
Loading…
Cancel
Save