diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 428d70da..72d17e5d 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -891,29 +891,35 @@ class Instruction: state.max_gas_used += max_gas StateTransition.check_gas_usage_limit(global_state) - try: - state.mem_extend(index, length) - data = b"".join( - [ - util.get_concrete_int(i).to_bytes(1, byteorder="big") - for i in state.memory[index : index + length] - ] - ) + state.mem_extend(index, length) + data_list = [ + b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8) + for b in state.memory[index : index + length] + ] + if len(data_list) > 1: + data = simplify(Concat(data_list)) + elif len(data_list) == 1: + data = data_list[0] + else: + # length is 0; this only matters for input of the BitVecFuncVal + data = symbol_factory.BitVecVal(0, 1) - except TypeError: - argument = str(state.memory[index]).replace(" ", "_") + if data.symbolic: + argument_str = str(state.memory[index]).replace(" ", "_") + result = symbol_factory.BitVecFuncSym( + "KECCAC[{}]".format(argument_str), "keccak256", 256, input_=data + ) + log.debug("Created BitVecFunc hash.") - result = symbol_factory.BitVecSym("KECCAC[{}]".format(argument), 256) keccak_function_manager.add_keccak(result, state.memory[index]) - state.stack.append(result) - return [global_state] - - keccak = utils.sha3(utils.bytearray_to_bytestr(data)) - log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) + else: + keccak = utils.sha3(data.value.to_bytes(length, byteorder="big")) + result = symbol_factory.BitVecFuncVal( + util.concrete_int_from_bytes(keccak, 0), "keccak256", 256, input_=data + ) + log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) - state.stack.append( - symbol_factory.BitVecVal(util.concrete_int_from_bytes(keccak, 0), 256) - ) + state.stack.append(result) return [global_state] @StateTransition() diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 28f0c42b..f46a8968 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -14,6 +14,7 @@ from mythril.laser.smt.bitvec import ( BVMulNoOverflow, BVSubNoUnderflow, ) +from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray @@ -64,6 +65,44 @@ class SymbolFactory(Generic[T, U]): """ raise NotImplementedError() + @staticmethod + def BitVecFuncVal( + value: int, + func_name: str, + size: int, + annotations: Annotations = None, + input_: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a symbolic value. + + :param value: The concrete value to set the bit vector to + :param func_name: The name of the bit vector function + :param size: The size of the bit vector + :param annotations: The annotations to initialize the bit vector with + :param input_: The input to the bit vector function + :return: The freshly created bit vector function + """ + raise NotImplementedError() + + @staticmethod + def BitVecFuncSym( + name: str, + func_name: str, + size: int, + annotations: Annotations = None, + input_: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a symbolic value. + + :param name: The name of the symbolic bit vector + :param func_name: The name of the bit vector function + :param size: The size of the bit vector + :param annotations: The annotations to initialize the bit vector with + :param input_: The input to the bit vector function + :return: The freshly created bit vector function + """ + raise NotImplementedError() + class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): """ @@ -94,6 +133,30 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): raw = z3.BitVec(name, size) return BitVec(raw, annotations) + @staticmethod + def BitVecFuncVal( + value: int, + func_name: str, + size: int, + annotations: Annotations = None, + input_: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a concrete value.""" + raw = z3.BitVecVal(value, size) + return BitVecFunc(raw, func_name, input_, annotations) + + @staticmethod + def BitVecFuncSym( + name: str, + func_name: str, + size: int, + annotations: Annotations = None, + input_: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a symbolic value.""" + raw = z3.BitVec(name, size) + return BitVecFunc(raw, func_name, input_, annotations) + class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): """ diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index 8347352c..00b519a7 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -1,10 +1,10 @@ """This module provides classes for an SMT abstraction of bit vectors.""" -from typing import Union, overload, List, cast, Any, Optional +from typing import Union, overload, List, cast, Any, Optional, Callable import z3 -from mythril.laser.smt.bool import Bool +from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.expression import Expression Annotations = List[Any] @@ -15,7 +15,7 @@ Annotations = List[Any] class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" - def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations]=None): + def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None): """ :param raw: @@ -56,6 +56,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other + self if isinstance(other, int): return BitVec(self.raw + other, annotations=self.annotations) @@ -68,7 +70,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - + if isinstance(other, BitVecFunc): + return other - self if isinstance(other, int): return BitVec(self.raw - other, annotations=self.annotations) @@ -81,6 +84,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other * self union = self.annotations + other.annotations return BitVec(self.raw * other.raw, annotations=union) @@ -90,6 +95,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other / self union = self.annotations + other.annotations return BitVec(self.raw / other.raw, annotations=union) @@ -99,8 +106,10 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other & self if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, 256)) + other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations + other.annotations return BitVec(self.raw & other.raw, annotations=union) @@ -110,6 +119,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other | self union = self.annotations + other.annotations return BitVec(self.raw | other.raw, annotations=union) @@ -119,6 +130,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other ^ self union = self.annotations + other.annotations return BitVec(self.raw ^ other.raw, annotations=union) @@ -128,6 +141,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other > self union = self.annotations + other.annotations return Bool(self.raw < other.raw, annotations=union) @@ -137,6 +152,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other < self union = self.annotations + other.annotations return Bool(self.raw > other.raw, annotations=union) @@ -165,8 +182,12 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other == self if not isinstance(other, BitVec): - return Bool(cast(z3.BoolRef, self.raw == other), annotations=self.annotations) + return Bool( + cast(z3.BoolRef, self.raw == other), annotations=self.annotations + ) union = self.annotations + other.annotations # MYPY: fix complaints due to z3 overriding __eq__ @@ -179,14 +200,60 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other != self if not isinstance(other, BitVec): - return Bool(cast(z3.BoolRef, self.raw != other), annotations=self.annotations) + return Bool( + cast(z3.BoolRef, self.raw != other), annotations=self.annotations + ) union = self.annotations + other.annotations # MYPY: fix complaints due to z3 overriding __eq__ return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) +def _comparison_helper( + a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool +) -> Bool: + annotations = a.annotations + b.annotations + if isinstance(a, BitVecFunc): + if not a.symbolic and not b.symbolic: + return Bool(operation(a.raw, b.raw), annotations=annotations) + + if ( + not isinstance(b, BitVecFunc) + or not a.func_name + or not a.input_ + or not a.func_name == b.func_name + ): + return Bool(z3.BoolVal(default_value), annotations=annotations) + + return And( + Bool(operation(a.raw, b.raw), annotations=annotations), + a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, + ) + + return Bool(operation(a.raw, b.raw), annotations) + + +def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: + raw = operation(a.raw, b.raw) + union = a.annotations + b.annotations + + if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) + elif isinstance(a, BitVecFunc): + return BitVecFunc( + raw=raw, func_name=a.func_name, input_=a.input_, annotations=union + ) + elif isinstance(b, BitVecFunc): + return BitVecFunc( + raw=raw, func_name=b.func_name, input_=b.input_, annotations=union + ) + + return BitVec(raw, annotations=union) + + def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: """Create an if-then-else expression. @@ -195,6 +262,8 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi :param c: :return: """ + # TODO: Handle BitVecFunc + if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) if not isinstance(b, BitVec): @@ -212,19 +281,17 @@ def UGT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - annotations = a.annotations + b.annotations - return Bool(z3.UGT(a.raw, b.raw), annotations) + return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) -def UGE(a: BitVec, b:BitVec) -> Bool: +def UGE(a: BitVec, b: BitVec) -> Bool: """Create an unsigned greater or equals expression. :param a: :param b: :return: """ - annotations = a.annotations + b.annotations - return Bool(z3.UGE(a.raw, b.raw), annotations) + return Or(UGT(a, b), a == b) def ULT(a: BitVec, b: BitVec) -> Bool: @@ -234,8 +301,17 @@ def ULT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - annotations = a.annotations + b.annotations - return Bool(z3.ULT(a.raw, b.raw), annotations) + return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) + + +def ULE(a: BitVec, b: BitVec) -> Bool: + """Create an unsigned less than expression. + + :param a: + :param b: + :return: + """ + return Or(ULT(a, b), a == b) @overload @@ -252,17 +328,26 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: :param args: :return: """ - # The following statement is used if a list is provided as an argument to concat if len(args) == 1 and isinstance(args[0], list): - bvs = args[0] # type: List[BitVec] + bvs = args[0] # type: List[BitVec] else: bvs = cast(List[BitVec], args) nraw = z3.Concat([a.raw for a in bvs]) - annotations = [] # type: Annotations + annotations = [] # type: Annotations + bitvecfunc = False for bv in bvs: annotations += bv.annotations + if isinstance(bv, BitVecFunc): + bitvecfunc = True + + if bitvecfunc: + # Is there a better value to set func_name and input to in this case? + return BitVecFunc( + raw=nraw, func_name=None, input_=None, annotations=annotations + ) + return BitVec(nraw, annotations) @@ -274,7 +359,14 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :param bv: :return: """ - return BitVec(z3.Extract(high, low, bv.raw), annotations=bv.annotations) + raw = z3.Extract(high, low, bv.raw) + if isinstance(bv, BitVecFunc): + # Is there a better value to set func_name and input to in this case? + return BitVecFunc( + raw=raw, func_name=None, input_=None, annotations=bv.annotations + ) + + return BitVec(raw, annotations=bv.annotations) def URem(a: BitVec, b: BitVec) -> BitVec: @@ -284,8 +376,7 @@ def URem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - union = a.annotations + b.annotations - return BitVec(z3.URem(a.raw, b.raw), annotations=union) + return _arithmetic_helper(a, b, z3.URem) def SRem(a: BitVec, b: BitVec) -> BitVec: @@ -295,8 +386,7 @@ def SRem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - union = a.annotations + b.annotations - return BitVec(z3.SRem(a.raw, b.raw), annotations=union) + return _arithmetic_helper(a, b, z3.SRem) def UDiv(a: BitVec, b: BitVec) -> BitVec: @@ -306,8 +396,7 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - union = a.annotations + b.annotations - return BitVec(z3.UDiv(a.raw, b.raw), annotations=union) + return _arithmetic_helper(a, b, z3.UDiv) def Sum(*args: BitVec) -> BitVec: @@ -315,11 +404,26 @@ def Sum(*args: BitVec) -> BitVec: :return: """ - nraw = z3.Sum([a.raw for a in args]) - annotations = [] # type: Annotations + raw = z3.Sum([a.raw for a in args]) + annotations = [] # type: Annotations + bitvecfuncs = [] + for bv in args: annotations += bv.annotations - return BitVec(nraw, annotations) + if isinstance(bv, BitVecFunc): + bitvecfuncs.append(bv) + + if len(bitvecfuncs) >= 2: + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=annotations) + elif len(bitvecfuncs) == 1: + return BitVecFunc( + raw=raw, + func_name=bitvecfuncs[0].func_name, + input_=bitvecfuncs[0].input_, + annotations=annotations, + ) + + return BitVec(raw, annotations) def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: @@ -353,7 +457,9 @@ def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed)) -def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: +def BVSubNoUnderflow( + a: Union[BitVec, int], b: Union[BitVec, int], signed: bool +) -> Bool: """Creates predicate that verifies that the subtraction doesn't overflow. :param a: @@ -367,3 +473,7 @@ def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed)) + + +# TODO: Fix circular import issues +from mythril.laser.smt.bitvecfunc import BitVecFunc diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py new file mode 100644 index 00000000..d3e77601 --- /dev/null +++ b/mythril/laser/smt/bitvecfunc.py @@ -0,0 +1,209 @@ +from typing import Optional, Union, cast, Callable + +import z3 + +from mythril.laser.smt.bitvec import BitVec, Bool, And, Annotations +from mythril.laser.smt.bool import Or + +import operator + + +def _arithmetic_helper( + a: "BitVecFunc", b: Union[BitVec, int], operation: Callable +) -> "BitVecFunc": + """ + Helper function for arithmetic operations on BitVecFuncs. + + :param a: The BitVecFunc to perform the operation on. + :param b: A BitVec or int to perform the operation on. + :param operation: The arithmetic operation to perform. + :return: The resulting BitVecFunc + """ + if isinstance(b, int): + b = BitVec(z3.BitVecVal(b, a.size())) + + raw = operation(a.raw, b.raw) + union = a.annotations + b.annotations + + if isinstance(b, BitVecFunc): + # TODO: Find better value to set input and name to in this case? + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) + + return BitVecFunc( + raw=raw, func_name=a.func_name, input_=a.input_, annotations=union + ) + + +def _comparison_helper( + a: "BitVecFunc", + b: Union[BitVec, int], + operation: Callable, + default_value: bool, + inputs_equal: bool, +) -> Bool: + """ + Helper function for comparison operations with BitVecFuncs. + + :param a: The BitVecFunc to compare. + :param b: A BitVec or int to compare to. + :param operation: The comparison operation to perform. + :return: The resulting Bool + """ + # Is there some hack for gt/lt comparisons? + if isinstance(b, int): + b = BitVec(z3.BitVecVal(b, a.size())) + + union = a.annotations + b.annotations + + if not a.symbolic and not b.symbolic: + return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) + + if ( + not isinstance(b, BitVecFunc) + or not a.func_name + or not a.input_ + or not a.func_name == b.func_name + ): + return Bool(z3.BoolVal(default_value), annotations=union) + + return And( + Bool(cast(z3.BoolRef, operation(a.raw, b.raw)), annotations=union), + a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, + ) + + +class BitVecFunc(BitVec): + """A bit vector function symbol. Used in place of functions like sha3.""" + + def __init__( + self, + raw: z3.BitVecRef, + func_name: Optional[str], + input_: Union[int, "BitVec"] = None, + annotations: Optional[Annotations] = None, + ): + """ + + :param raw: The raw bit vector symbol + :param func_name: The function name. e.g. sha3 + :param input: The input to the functions + :param annotations: The annotations the BitVecFunc should start with + """ + + self.func_name = func_name + self.input_ = input_ + super().__init__(raw, annotations) + + def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create an addition expression. + + :param other: The int or BitVec to add to this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.add) + + def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create a subtraction expression. + + :param other: The int or BitVec to subtract from this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.sub) + + def __mul__(self, other: "BitVec") -> "BitVecFunc": + """Create a multiplication expression. + + :param other: The int or BitVec to multiply to this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.mul) + + def __truediv__(self, other: "BitVec") -> "BitVecFunc": + """Create a signed division expression. + + :param other: The int or BitVec to divide this BitVecFunc by + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.truediv) + + def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create an and expression. + + :param other: The int or BitVec to and with this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.and_) + + def __or__(self, other: "BitVec") -> "BitVecFunc": + """Create an or expression. + + :param other: The int or BitVec to or with this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.or_) + + def __xor__(self, other: "BitVec") -> "BitVecFunc": + """Create a xor expression. + + :param other: The int or BitVec to xor with this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.xor) + + def __lt__(self, other: "BitVec") -> Bool: + """Create a signed less than expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return _comparison_helper( + self, other, operator.lt, default_value=False, inputs_equal=False + ) + + def __gt__(self, other: "BitVec") -> Bool: + """Create a signed greater than expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return _comparison_helper( + self, other, operator.gt, default_value=False, inputs_equal=False + ) + + def __le__(self, other: "BitVec") -> Bool: + """Create a signed less than or equal to expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return Or(self < other, self == other) + + def __ge__(self, other: "BitVec") -> Bool: + """Create a signed greater than or equal to expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return Or(self > other, self == other) + + # MYPY: fix complains about overriding __eq__ + def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore + """Create an equality expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return _comparison_helper( + self, other, operator.eq, default_value=False, inputs_equal=True + ) + + # MYPY: fix complains about overriding __ne__ + def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore + """Create an inequality expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return _comparison_helper( + self, other, operator.eq, default_value=True, inputs_equal=False + ) diff --git a/tests/laser/smt/bitvecfunc_test.py b/tests/laser/smt/bitvecfunc_test.py new file mode 100644 index 00000000..ea19dad1 --- /dev/null +++ b/tests/laser/smt/bitvecfunc_test.py @@ -0,0 +1,82 @@ +from mythril.laser.smt import Solver, symbol_factory, bitvec +import z3 +import pytest + +import operator + + +@pytest.mark.parametrize( + "operation,expected", + [ + (operator.add, z3.unsat), + (operator.sub, z3.unsat), + (operator.and_, z3.sat), + (operator.or_, z3.sat), + (operator.xor, z3.unsat), + ], +) +def test_bitvecfunc_arithmetic(operation, expected): + # Arrange + s = Solver() + + input_ = symbol_factory.BitVecVal(1, 8) + bvf = symbol_factory.BitVecFuncSym("bvf", "sha3", 256, input_=input_) + + x = symbol_factory.BitVecSym("x", 256) + y = symbol_factory.BitVecSym("y", 256) + + # Act + s.add(x != y) + s.add(operation(bvf, x) == operation(y, bvf)) + + # Assert + assert s.check() == expected + + +@pytest.mark.parametrize( + "operation,expected", + [ + (operator.eq, z3.sat), + (operator.ne, z3.unsat), + (operator.lt, z3.unsat), + (operator.le, z3.sat), + (operator.gt, z3.unsat), + (operator.ge, z3.sat), + (bitvec.UGT, z3.unsat), + (bitvec.UGE, z3.sat), + (bitvec.ULT, z3.unsat), + (bitvec.ULE, z3.sat), + ], +) +def test_bitvecfunc_bitvecfunc_comparison(operation, expected): + # Arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=input2) + + # Act + s.add(operation(bvf1, bvf2)) + s.add(input1 == input2) + + # Assert + assert s.check() == expected + + +def test_bitvecfunc_bitvecfuncval_comparison(): + # Arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecVal(1337, 256) + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncVal(12345678910, "sha3", 256, input_=input2) + + # Act + s.add(bvf1 == bvf2) + + # Assert + assert s.check() == z3.sat + assert s.model().eval(input2.raw) == 1337