Merge pull request #901 from nbanmp/sha3_symbols

Implement BitVecFunc for sha3
check_existing_annotations
JoranHonig 6 years ago committed by GitHub
commit c1ae1ee552
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 44
      mythril/laser/ethereum/instructions.py
  2. 63
      mythril/laser/smt/__init__.py
  3. 166
      mythril/laser/smt/bitvec.py
  4. 209
      mythril/laser/smt/bitvecfunc.py
  5. 82
      tests/laser/smt/bitvecfunc_test.py

@ -891,29 +891,35 @@ class Instruction:
state.max_gas_used += max_gas state.max_gas_used += max_gas
StateTransition.check_gas_usage_limit(global_state) StateTransition.check_gas_usage_limit(global_state)
try: state.mem_extend(index, length)
state.mem_extend(index, length) data_list = [
data = b"".join( b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8)
[ for b in state.memory[index : index + length]
util.get_concrete_int(i).to_bytes(1, byteorder="big") ]
for i in state.memory[index : index + length] if len(data_list) > 1:
] data = simplify(Concat(data_list))
) elif len(data_list) == 1:
data = data_list[0]
else:
# length is 0; this only matters for input of the BitVecFuncVal
data = symbol_factory.BitVecVal(0, 1)
except TypeError: if data.symbolic:
argument = str(state.memory[index]).replace(" ", "_") argument_str = str(state.memory[index]).replace(" ", "_")
result = symbol_factory.BitVecFuncSym(
"KECCAC[{}]".format(argument_str), "keccak256", 256, input_=data
)
log.debug("Created BitVecFunc hash.")
result = symbol_factory.BitVecSym("KECCAC[{}]".format(argument), 256)
keccak_function_manager.add_keccak(result, state.memory[index]) keccak_function_manager.add_keccak(result, state.memory[index])
state.stack.append(result) else:
return [global_state] keccak = utils.sha3(data.value.to_bytes(length, byteorder="big"))
result = symbol_factory.BitVecFuncVal(
keccak = utils.sha3(utils.bytearray_to_bytestr(data)) util.concrete_int_from_bytes(keccak, 0), "keccak256", 256, input_=data
log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) )
log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak)))
state.stack.append( state.stack.append(result)
symbol_factory.BitVecVal(util.concrete_int_from_bytes(keccak, 0), 256)
)
return [global_state] return [global_state]
@StateTransition() @StateTransition()

@ -14,6 +14,7 @@ from mythril.laser.smt.bitvec import (
BVMulNoOverflow, BVMulNoOverflow,
BVSubNoUnderflow, BVSubNoUnderflow,
) )
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.array import K, Array, BaseArray
@ -64,6 +65,44 @@ class SymbolFactory(Generic[T, U]):
""" """
raise NotImplementedError() raise NotImplementedError()
@staticmethod
def BitVecFuncVal(
value: int,
func_name: str,
size: int,
annotations: Annotations = None,
input_: Union[int, "BitVec"] = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value.
:param value: The concrete value to set the bit vector to
:param func_name: The name of the bit vector function
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:param input_: The input to the bit vector function
:return: The freshly created bit vector function
"""
raise NotImplementedError()
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input_: Union[int, "BitVec"] = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value.
:param name: The name of the symbolic bit vector
:param func_name: The name of the bit vector function
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:param input_: The input to the bit vector function
:return: The freshly created bit vector function
"""
raise NotImplementedError()
class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]):
""" """
@ -94,6 +133,30 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]):
raw = z3.BitVec(name, size) raw = z3.BitVec(name, size)
return BitVec(raw, annotations) return BitVec(raw, annotations)
@staticmethod
def BitVecFuncVal(
value: int,
func_name: str,
size: int,
annotations: Annotations = None,
input_: Union[int, "BitVec"] = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a concrete value."""
raw = z3.BitVecVal(value, size)
return BitVecFunc(raw, func_name, input_, annotations)
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input_: Union[int, "BitVec"] = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value."""
raw = z3.BitVec(name, size)
return BitVecFunc(raw, func_name, input_, annotations)
class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]):
""" """

@ -1,10 +1,10 @@
"""This module provides classes for an SMT abstraction of bit vectors.""" """This module provides classes for an SMT abstraction of bit vectors."""
from typing import Union, overload, List, cast, Any, Optional from typing import Union, overload, List, cast, Any, Optional, Callable
import z3 import z3
from mythril.laser.smt.bool import Bool from mythril.laser.smt.bool import Bool, And, Or
from mythril.laser.smt.expression import Expression from mythril.laser.smt.expression import Expression
Annotations = List[Any] Annotations = List[Any]
@ -15,7 +15,7 @@ Annotations = List[Any]
class BitVec(Expression[z3.BitVecRef]): class BitVec(Expression[z3.BitVecRef]):
"""A bit vector symbol.""" """A bit vector symbol."""
def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations]=None): def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None):
""" """
:param raw: :param raw:
@ -56,6 +56,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other + self
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw + other, annotations=self.annotations) return BitVec(self.raw + other, annotations=self.annotations)
@ -68,7 +70,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other - self
if isinstance(other, int): if isinstance(other, int):
return BitVec(self.raw - other, annotations=self.annotations) return BitVec(self.raw - other, annotations=self.annotations)
@ -81,6 +84,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other * self
union = self.annotations + other.annotations union = self.annotations + other.annotations
return BitVec(self.raw * other.raw, annotations=union) return BitVec(self.raw * other.raw, annotations=union)
@ -90,6 +95,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other / self
union = self.annotations + other.annotations union = self.annotations + other.annotations
return BitVec(self.raw / other.raw, annotations=union) return BitVec(self.raw / other.raw, annotations=union)
@ -99,8 +106,10 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other & self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, 256)) other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations + other.annotations union = self.annotations + other.annotations
return BitVec(self.raw & other.raw, annotations=union) return BitVec(self.raw & other.raw, annotations=union)
@ -110,6 +119,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other | self
union = self.annotations + other.annotations union = self.annotations + other.annotations
return BitVec(self.raw | other.raw, annotations=union) return BitVec(self.raw | other.raw, annotations=union)
@ -119,6 +130,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other ^ self
union = self.annotations + other.annotations union = self.annotations + other.annotations
return BitVec(self.raw ^ other.raw, annotations=union) return BitVec(self.raw ^ other.raw, annotations=union)
@ -128,6 +141,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other > self
union = self.annotations + other.annotations union = self.annotations + other.annotations
return Bool(self.raw < other.raw, annotations=union) return Bool(self.raw < other.raw, annotations=union)
@ -137,6 +152,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other < self
union = self.annotations + other.annotations union = self.annotations + other.annotations
return Bool(self.raw > other.raw, annotations=union) return Bool(self.raw > other.raw, annotations=union)
@ -165,8 +182,12 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other == self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return Bool(cast(z3.BoolRef, self.raw == other), annotations=self.annotations) return Bool(
cast(z3.BoolRef, self.raw == other), annotations=self.annotations
)
union = self.annotations + other.annotations union = self.annotations + other.annotations
# MYPY: fix complaints due to z3 overriding __eq__ # MYPY: fix complaints due to z3 overriding __eq__
@ -179,14 +200,60 @@ class BitVec(Expression[z3.BitVecRef]):
:param other: :param other:
:return: :return:
""" """
if isinstance(other, BitVecFunc):
return other != self
if not isinstance(other, BitVec): if not isinstance(other, BitVec):
return Bool(cast(z3.BoolRef, self.raw != other), annotations=self.annotations) return Bool(
cast(z3.BoolRef, self.raw != other), annotations=self.annotations
)
union = self.annotations + other.annotations union = self.annotations + other.annotations
# MYPY: fix complaints due to z3 overriding __eq__ # MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union)
def _comparison_helper(
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool:
annotations = a.annotations + b.annotations
if isinstance(a, BitVecFunc):
if not a.symbolic and not b.symbolic:
return Bool(operation(a.raw, b.raw), annotations=annotations)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
):
return Bool(z3.BoolVal(default_value), annotations=annotations)
return And(
Bool(operation(a.raw, b.raw), annotations=annotations),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
)
return Bool(operation(a.raw, b.raw), annotations)
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw)
union = a.annotations + b.annotations
if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
elif isinstance(a, BitVecFunc):
return BitVecFunc(
raw=raw, func_name=a.func_name, input_=a.input_, annotations=union
)
elif isinstance(b, BitVecFunc):
return BitVecFunc(
raw=raw, func_name=b.func_name, input_=b.input_, annotations=union
)
return BitVec(raw, annotations=union)
def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec:
"""Create an if-then-else expression. """Create an if-then-else expression.
@ -195,6 +262,8 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
:param c: :param c:
:return: :return:
""" """
# TODO: Handle BitVecFunc
if not isinstance(a, Bool): if not isinstance(a, Bool):
a = Bool(z3.BoolVal(a)) a = Bool(z3.BoolVal(a))
if not isinstance(b, BitVec): if not isinstance(b, BitVec):
@ -212,19 +281,17 @@ def UGT(a: BitVec, b: BitVec) -> Bool:
:param b: :param b:
:return: :return:
""" """
annotations = a.annotations + b.annotations return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False)
return Bool(z3.UGT(a.raw, b.raw), annotations)
def UGE(a: BitVec, b:BitVec) -> Bool: def UGE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned greater or equals expression. """Create an unsigned greater or equals expression.
:param a: :param a:
:param b: :param b:
:return: :return:
""" """
annotations = a.annotations + b.annotations return Or(UGT(a, b), a == b)
return Bool(z3.UGE(a.raw, b.raw), annotations)
def ULT(a: BitVec, b: BitVec) -> Bool: def ULT(a: BitVec, b: BitVec) -> Bool:
@ -234,8 +301,17 @@ def ULT(a: BitVec, b: BitVec) -> Bool:
:param b: :param b:
:return: :return:
""" """
annotations = a.annotations + b.annotations return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False)
return Bool(z3.ULT(a.raw, b.raw), annotations)
def ULE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned less than expression.
:param a:
:param b:
:return:
"""
return Or(ULT(a, b), a == b)
@overload @overload
@ -252,17 +328,26 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
:param args: :param args:
:return: :return:
""" """
# The following statement is used if a list is provided as an argument to concat # The following statement is used if a list is provided as an argument to concat
if len(args) == 1 and isinstance(args[0], list): if len(args) == 1 and isinstance(args[0], list):
bvs = args[0] # type: List[BitVec] bvs = args[0] # type: List[BitVec]
else: else:
bvs = cast(List[BitVec], args) bvs = cast(List[BitVec], args)
nraw = z3.Concat([a.raw for a in bvs]) nraw = z3.Concat([a.raw for a in bvs])
annotations = [] # type: Annotations annotations = [] # type: Annotations
bitvecfunc = False
for bv in bvs: for bv in bvs:
annotations += bv.annotations annotations += bv.annotations
if isinstance(bv, BitVecFunc):
bitvecfunc = True
if bitvecfunc:
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(
raw=nraw, func_name=None, input_=None, annotations=annotations
)
return BitVec(nraw, annotations) return BitVec(nraw, annotations)
@ -274,7 +359,14 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
:param bv: :param bv:
:return: :return:
""" """
return BitVec(z3.Extract(high, low, bv.raw), annotations=bv.annotations) raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(
raw=raw, func_name=None, input_=None, annotations=bv.annotations
)
return BitVec(raw, annotations=bv.annotations)
def URem(a: BitVec, b: BitVec) -> BitVec: def URem(a: BitVec, b: BitVec) -> BitVec:
@ -284,8 +376,7 @@ def URem(a: BitVec, b: BitVec) -> BitVec:
:param b: :param b:
:return: :return:
""" """
union = a.annotations + b.annotations return _arithmetic_helper(a, b, z3.URem)
return BitVec(z3.URem(a.raw, b.raw), annotations=union)
def SRem(a: BitVec, b: BitVec) -> BitVec: def SRem(a: BitVec, b: BitVec) -> BitVec:
@ -295,8 +386,7 @@ def SRem(a: BitVec, b: BitVec) -> BitVec:
:param b: :param b:
:return: :return:
""" """
union = a.annotations + b.annotations return _arithmetic_helper(a, b, z3.SRem)
return BitVec(z3.SRem(a.raw, b.raw), annotations=union)
def UDiv(a: BitVec, b: BitVec) -> BitVec: def UDiv(a: BitVec, b: BitVec) -> BitVec:
@ -306,8 +396,7 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec:
:param b: :param b:
:return: :return:
""" """
union = a.annotations + b.annotations return _arithmetic_helper(a, b, z3.UDiv)
return BitVec(z3.UDiv(a.raw, b.raw), annotations=union)
def Sum(*args: BitVec) -> BitVec: def Sum(*args: BitVec) -> BitVec:
@ -315,11 +404,26 @@ def Sum(*args: BitVec) -> BitVec:
:return: :return:
""" """
nraw = z3.Sum([a.raw for a in args]) raw = z3.Sum([a.raw for a in args])
annotations = [] # type: Annotations annotations = [] # type: Annotations
bitvecfuncs = []
for bv in args: for bv in args:
annotations += bv.annotations annotations += bv.annotations
return BitVec(nraw, annotations) if isinstance(bv, BitVecFunc):
bitvecfuncs.append(bv)
if len(bitvecfuncs) >= 2:
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=annotations)
elif len(bitvecfuncs) == 1:
return BitVecFunc(
raw=raw,
func_name=bitvecfuncs[0].func_name,
input_=bitvecfuncs[0].input_,
annotations=annotations,
)
return BitVec(raw, annotations)
def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool:
@ -353,7 +457,9 @@ def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool)
return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed)) return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed))
def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: def BVSubNoUnderflow(
a: Union[BitVec, int], b: Union[BitVec, int], signed: bool
) -> Bool:
"""Creates predicate that verifies that the subtraction doesn't overflow. """Creates predicate that verifies that the subtraction doesn't overflow.
:param a: :param a:
@ -367,3 +473,7 @@ def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool)
b = BitVec(z3.BitVecVal(b, 256)) b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed)) return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed))
# TODO: Fix circular import issues
from mythril.laser.smt.bitvecfunc import BitVecFunc

@ -0,0 +1,209 @@
from typing import Optional, Union, cast, Callable
import z3
from mythril.laser.smt.bitvec import BitVec, Bool, And, Annotations
from mythril.laser.smt.bool import Or
import operator
def _arithmetic_helper(
a: "BitVecFunc", b: Union[BitVec, int], operation: Callable
) -> "BitVecFunc":
"""
Helper function for arithmetic operations on BitVecFuncs.
:param a: The BitVecFunc to perform the operation on.
:param b: A BitVec or int to perform the operation on.
:param operation: The arithmetic operation to perform.
:return: The resulting BitVecFunc
"""
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
raw = operation(a.raw, b.raw)
union = a.annotations + b.annotations
if isinstance(b, BitVecFunc):
# TODO: Find better value to set input and name to in this case?
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
return BitVecFunc(
raw=raw, func_name=a.func_name, input_=a.input_, annotations=union
)
def _comparison_helper(
a: "BitVecFunc",
b: Union[BitVec, int],
operation: Callable,
default_value: bool,
inputs_equal: bool,
) -> Bool:
"""
Helper function for comparison operations with BitVecFuncs.
:param a: The BitVecFunc to compare.
:param b: A BitVec or int to compare to.
:param operation: The comparison operation to perform.
:return: The resulting Bool
"""
# Is there some hack for gt/lt comparisons?
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
union = a.annotations + b.annotations
if not a.symbolic and not b.symbolic:
return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
):
return Bool(z3.BoolVal(default_value), annotations=union)
return And(
Bool(cast(z3.BoolRef, operation(a.raw, b.raw)), annotations=union),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
)
class BitVecFunc(BitVec):
"""A bit vector function symbol. Used in place of functions like sha3."""
def __init__(
self,
raw: z3.BitVecRef,
func_name: Optional[str],
input_: Union[int, "BitVec"] = None,
annotations: Optional[Annotations] = None,
):
"""
:param raw: The raw bit vector symbol
:param func_name: The function name. e.g. sha3
:param input: The input to the functions
:param annotations: The annotations the BitVecFunc should start with
"""
self.func_name = func_name
self.input_ = input_
super().__init__(raw, annotations)
def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an addition expression.
:param other: The int or BitVec to add to this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.add)
def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a subtraction expression.
:param other: The int or BitVec to subtract from this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.sub)
def __mul__(self, other: "BitVec") -> "BitVecFunc":
"""Create a multiplication expression.
:param other: The int or BitVec to multiply to this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.mul)
def __truediv__(self, other: "BitVec") -> "BitVecFunc":
"""Create a signed division expression.
:param other: The int or BitVec to divide this BitVecFunc by
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.truediv)
def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an and expression.
:param other: The int or BitVec to and with this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.and_)
def __or__(self, other: "BitVec") -> "BitVecFunc":
"""Create an or expression.
:param other: The int or BitVec to or with this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.or_)
def __xor__(self, other: "BitVec") -> "BitVecFunc":
"""Create a xor expression.
:param other: The int or BitVec to xor with this BitVecFunc
:return: The resulting BitVecFunc
"""
return _arithmetic_helper(self, other, operator.xor)
def __lt__(self, other: "BitVec") -> Bool:
"""Create a signed less than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
return _comparison_helper(
self, other, operator.lt, default_value=False, inputs_equal=False
)
def __gt__(self, other: "BitVec") -> Bool:
"""Create a signed greater than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
return _comparison_helper(
self, other, operator.gt, default_value=False, inputs_equal=False
)
def __le__(self, other: "BitVec") -> Bool:
"""Create a signed less than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
return Or(self < other, self == other)
def __ge__(self, other: "BitVec") -> Bool:
"""Create a signed greater than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
return Or(self > other, self == other)
# MYPY: fix complains about overriding __eq__
def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an equality expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
return _comparison_helper(
self, other, operator.eq, default_value=False, inputs_equal=True
)
# MYPY: fix complains about overriding __ne__
def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an inequality expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
return _comparison_helper(
self, other, operator.eq, default_value=True, inputs_equal=False
)

@ -0,0 +1,82 @@
from mythril.laser.smt import Solver, symbol_factory, bitvec
import z3
import pytest
import operator
@pytest.mark.parametrize(
"operation,expected",
[
(operator.add, z3.unsat),
(operator.sub, z3.unsat),
(operator.and_, z3.sat),
(operator.or_, z3.sat),
(operator.xor, z3.unsat),
],
)
def test_bitvecfunc_arithmetic(operation, expected):
# Arrange
s = Solver()
input_ = symbol_factory.BitVecVal(1, 8)
bvf = symbol_factory.BitVecFuncSym("bvf", "sha3", 256, input_=input_)
x = symbol_factory.BitVecSym("x", 256)
y = symbol_factory.BitVecSym("y", 256)
# Act
s.add(x != y)
s.add(operation(bvf, x) == operation(y, bvf))
# Assert
assert s.check() == expected
@pytest.mark.parametrize(
"operation,expected",
[
(operator.eq, z3.sat),
(operator.ne, z3.unsat),
(operator.lt, z3.unsat),
(operator.le, z3.sat),
(operator.gt, z3.unsat),
(operator.ge, z3.sat),
(bitvec.UGT, z3.unsat),
(bitvec.UGE, z3.sat),
(bitvec.ULT, z3.unsat),
(bitvec.ULE, z3.sat),
],
)
def test_bitvecfunc_bitvecfunc_comparison(operation, expected):
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=input2)
# Act
s.add(operation(bvf1, bvf2))
s.add(input1 == input2)
# Assert
assert s.check() == expected
def test_bitvecfunc_bitvecfuncval_comparison():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecVal(1337, 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncVal(12345678910, "sha3", 256, input_=input2)
# Act
s.add(bvf1 == bvf2)
# Assert
assert s.check() == z3.sat
assert s.model().eval(input2.raw) == 1337
Loading…
Cancel
Save