From 5384602d00ff4cae5b9310aa4cbba889752fe55a Mon Sep 17 00:00:00 2001 From: Nathan Date: Sun, 27 Jan 2019 16:34:09 -0500 Subject: [PATCH 01/10] Implement BitVecFunc for sha3 --- mythril/laser/ethereum/instructions.py | 44 +-- mythril/laser/smt/__init__.py | 56 ++++ mythril/laser/smt/bitvec.py | 365 ++++++++++++++++++++++++- 3 files changed, 438 insertions(+), 27 deletions(-) diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 677ec626..03f0f9ee 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -890,29 +890,35 @@ class Instruction: state.max_gas_used += max_gas StateTransition.check_gas_usage_limit(global_state) - try: - state.mem_extend(index, length) - data = b"".join( - [ - util.get_concrete_int(i).to_bytes(1, byteorder="big") - for i in state.memory[index : index + length] - ] - ) + state.mem_extend(index, length) + data = [ + b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8) + for b in state.memory[index : index + length] + ] + if len(data) > 1: + data = simplify(Concat(data)) + elif len(data) == 1: + data = data[0] + else: + # length is 0; this only matters for input of the BitVecFuncVal + data = symbol_factory.BitVecVal(0, 1) - except TypeError: - argument = str(state.memory[index]).replace(" ", "_") + if data.symbolic: + argument_str = str(state.memory[index]).replace(" ", "_") + result = symbol_factory.BitVecFuncSym( + "KECCAC[{}]".format(argument_str), "keccak256", 256, input=data + ) + log.debug("Created BitVecFunc hash.") - result = symbol_factory.BitVecSym("KECCAC[{}]".format(argument), 256) keccak_function_manager.add_keccak(result, state.memory[index]) - state.stack.append(result) - return [global_state] - - keccak = utils.sha3(utils.bytearray_to_bytestr(data)) - log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) + else: + keccak = utils.sha3(data.value.to_bytes(length, byteorder="big")) + result = symbol_factory.BitVecFuncVal( + "keccak256", util.concrete_int_from_bytes(keccak, 0), 256, input=data + ) + log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) - state.stack.append( - symbol_factory.BitVecVal(util.concrete_int_from_bytes(keccak, 0), 256) - ) + state.stack.append(result) return [global_state] @StateTransition() diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 9cb81244..a16399c2 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -1,5 +1,6 @@ from mythril.laser.smt.bitvec import ( BitVec, + BitVecFunc, If, UGT, ULT, @@ -63,6 +64,37 @@ class SymbolFactory(Generic[T, U]): """ raise NotImplementedError() + @staticmethod + def BitVecFuncVal( + func_name: str, + value: int, + size: int, + annotations: Annotations = None, + input: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a concrete value. + + :param func_name: The name of the function + :param value: The concrete value to set the bit vector to + :param size: The size of the bit vector + :param annotations: The annotations to initialize the bit vector with + :return: The freshly created bit vector + """ + raise NotImplementedError() + + @staticmethod + def BitVecFuncSym( + name: str, func_name: str, size: int, annotations: Annotations = None + ) -> U: + """Creates a new bit vector with a symbolic value. + + :param name: The name of the symbolic bit vector + :param size: The size of the bit vector + :param annotations: The annotations to initialize the bit vector with + :return: The freshly created bit vector + """ + raise NotImplementedError() + class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): """ @@ -93,6 +125,30 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): raw = z3.BitVec(name, size) return BitVec(raw, annotations) + @staticmethod + def BitVecFuncVal( + func_name: str, + value: int, + size: int, + annotations: Annotations = None, + input: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a concrete value.""" + raw = z3.BitVecVal(value, size) + return BitVecFunc(raw, func_name, input, annotations) + + @staticmethod + def BitVecFuncSym( + name: str, + func_name: str, + size: int, + annotations: Annotations = None, + input: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a symbolic value.""" + raw = z3.BitVec(name, size) + return BitVecFunc(raw, func_name, input, annotations) + class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): """ diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index feeac40b..ad13ed4c 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -4,7 +4,7 @@ from typing import Union, overload, List, cast, Any, Optional import z3 -from mythril.laser.smt.bool import Bool +from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.expression import Expression Annotations = List[Any] @@ -14,7 +14,7 @@ Annotations = List[Any] class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" - def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations]=None): + def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None): """ :param raw: @@ -55,6 +55,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other + self if isinstance(other, int): return BitVec(self.raw + other, annotations=self.annotations) @@ -67,7 +69,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - + if isinstance(other, BitVecFunc): + return other - self if isinstance(other, int): return BitVec(self.raw - other, annotations=self.annotations) @@ -80,6 +83,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other * self union = self.annotations + other.annotations return BitVec(self.raw * other.raw, annotations=union) @@ -89,6 +94,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other / self union = self.annotations + other.annotations return BitVec(self.raw / other.raw, annotations=union) @@ -98,8 +105,10 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other & self if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, 256)) + other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations + other.annotations return BitVec(self.raw & other.raw, annotations=union) @@ -109,6 +118,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other | self union = self.annotations + other.annotations return BitVec(self.raw | other.raw, annotations=union) @@ -118,6 +129,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other ^ self union = self.annotations + other.annotations return BitVec(self.raw ^ other.raw, annotations=union) @@ -127,6 +140,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other > self union = self.annotations + other.annotations return Bool(self.raw < other.raw, annotations=union) @@ -136,6 +151,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other < self union = self.annotations + other.annotations return Bool(self.raw > other.raw, annotations=union) @@ -146,8 +163,12 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other == self if not isinstance(other, BitVec): - return Bool(cast(z3.BoolRef, self.raw == other), annotations=self.annotations) + return Bool( + cast(z3.BoolRef, self.raw == other), annotations=self.annotations + ) union = self.annotations + other.annotations # MYPY: fix complaints due to z3 overriding __eq__ @@ -160,14 +181,321 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other != self if not isinstance(other, BitVec): - return Bool(cast(z3.BoolRef, self.raw != other), annotations=self.annotations) + return Bool( + cast(z3.BoolRef, self.raw != other), annotations=self.annotations + ) union = self.annotations + other.annotations # MYPY: fix complaints due to z3 overriding __eq__ return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) +class BitVecFunc(BitVec): + """A bit vector symbol.""" + + def __init__( + self, + raw: z3.BitVecRef, + name: str, + input: Union[int, "BitVec"] = None, + annotations: Optional[Annotations] = None, + ): + """ + + :param raw: + :param annotations: + :param input: + """ + + from mythril.laser.smt import symbol_factory + + self.symbol_factory = symbol_factory + + self.name = name + self.input = input + super().__init__(raw, annotations) + + def __add__(self, other: Union[int, "BitVec"]) -> "BitVec": + """Create an addition expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw + other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create a subtraction expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw - other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __mul__(self, other: "BitVec") -> "BitVecFunc": + """Create a multiplication expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw * other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __truediv__(self, other: "BitVec") -> "BitVecFunc": + """Create a signed division expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw / other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create an and expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw & other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __or__(self, other: "BitVec") -> "BitVecFunc": + """Create an or expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw | other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __xor__(self, other: "BitVec") -> "BitVec": + """Create a xor expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw ^ other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __lt__(self, other: "BitVec") -> Bool: + """Create a signed less than expression. + + :param other: + :return: + """ + # Is there some hack for these comparisons? + + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(cast(z3.BoolRef, self.value < other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(False, annotations=union) + + return And( + Bool(cast(z3.BoolRef, self.raw < other.raw), annotations=union), + self.input != other.input, + ) + + def __gt__(self, other: "BitVec") -> Bool: + """Create a signed greater than expression. + + :param other: + :return: + """ + # Is there some hack for these comparisons? + + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(cast(z3.BoolRef, self.value > other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(False, annotations=union) + + return And( + Bool(cast(z3.BoolRef, self.raw > other.raw), annotations=union), + self.input != other.input, + ) + + # MYPY: fix complains about overriding __eq__ + def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore + """Create an equality expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(cast(z3.BoolRef, self.value == other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(cast(z3.BoolRef, False), annotations=union) + + # MYPY: fix complaints due to z3 overriding __eq__ + return And( + Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union), + self.input == other.input, + ) + + # MYPY: fix complains about overriding __ne__ + def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore + """Create an inequality expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(cast(z3.BoolRef, self.value != other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(cast(z3.BoolRef, True), annotations=union) + + # MYPY: fix complaints due to z3 overriding __eq__ + return Or( + Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union), + self.input != other.input, + ) + + def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: """Create an if-then-else expression. @@ -176,6 +504,8 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi :param c: :return: """ + # TODO: Handle BitVecFunc + if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) if not isinstance(b, BitVec): @@ -193,17 +523,21 @@ def UGT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ + # TODO: Handle BitVecFunc + annotations = a.annotations + b.annotations return Bool(z3.UGT(a.raw, b.raw), annotations) -def UGE(a: BitVec, b:BitVec) -> Bool: +def UGE(a: BitVec, b: BitVec) -> Bool: """Create an unsigned greater or equals expression. :param a: :param b: :return: """ + # TODO: Handle BitVecFunc + annotations = a.annotations + b.annotations return Bool(z3.UGE(a.raw, b.raw), annotations) @@ -215,6 +549,8 @@ def ULT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ + # TODO: Handle BitVecFunc + annotations = a.annotations + b.annotations return Bool(z3.ULT(a.raw, b.raw), annotations) @@ -233,6 +569,7 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: :param args: :return: """ + # TODO: Handle BitVecFunc # The following statement is used if a list is provided as an argument to concat if len(args) == 1 and isinstance(args[0], list): @@ -255,6 +592,8 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :param bv: :return: """ + # TODO: Handle BitVecFunc + return BitVec(z3.Extract(high, low, bv.raw), annotations=bv.annotations) @@ -265,6 +604,8 @@ def URem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ + # TODO: Handle BitVecFunc + union = a.annotations + b.annotations return BitVec(z3.URem(a.raw, b.raw), annotations=union) @@ -276,6 +617,8 @@ def SRem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ + # TODO: Handle BitVecFunc + union = a.annotations + b.annotations return BitVec(z3.SRem(a.raw, b.raw), annotations=union) @@ -287,6 +630,8 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ + # TODO: Handle BitVecFunc + union = a.annotations + b.annotations return BitVec(z3.UDiv(a.raw, b.raw), annotations=union) @@ -296,6 +641,8 @@ def Sum(*args: BitVec) -> BitVec: :return: """ + # TODO: Handle BitVecFunc + nraw = z3.Sum([a.raw for a in args]) annotations = [] # type: Annotations for bv in args: @@ -334,7 +681,9 @@ def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed)) -def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: +def BVSubNoUnderflow( + a: Union[BitVec, int], b: Union[BitVec, int], signed: bool +) -> Bool: """Creates predicate that verifies that the subtraction doesn't overflow. :param a: From 59a851f09a4fb772cae9518519f5c8033013f401 Mon Sep 17 00:00:00 2001 From: Nathan Date: Sun, 27 Jan 2019 17:00:01 -0500 Subject: [PATCH 02/10] Remove unused code --- mythril/laser/smt/bitvec.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index ad13ed4c..9797a4db 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -210,10 +210,6 @@ class BitVecFunc(BitVec): :param input: """ - from mythril.laser.smt import symbol_factory - - self.symbol_factory = symbol_factory - self.name = name self.input = input super().__init__(raw, annotations) From 7da8cc29743485d3fb3f3d420981477ff18b5cf0 Mon Sep 17 00:00:00 2001 From: Nathan Date: Tue, 29 Jan 2019 14:44:34 -0500 Subject: [PATCH 03/10] Add cases for BitVecFuncs --- mythril/laser/smt/bitvec.py | 210 ++++++++++++++++++++++-------------- 1 file changed, 130 insertions(+), 80 deletions(-) diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index 9797a4db..1f588ee9 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -223,17 +223,12 @@ class BitVecFunc(BitVec): if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) - raw = (self.raw + other.raw,) + raw = self.raw + other.raw union = self.annotations + other.annotations if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc( - raw=raw, - name=self.name if self.name and self.name == other.name else None, - input=None, - annotations=union, - ) + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) @@ -246,17 +241,12 @@ class BitVecFunc(BitVec): if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) - raw = (self.raw - other.raw,) + raw = self.raw - other.raw union = self.annotations + other.annotations if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc( - raw=raw, - name=self.name if self.name and self.name == other.name else None, - input=None, - annotations=union, - ) + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) @@ -269,17 +259,12 @@ class BitVecFunc(BitVec): if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) - raw = (self.raw * other.raw,) + raw = self.raw * other.raw union = self.annotations + other.annotations if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc( - raw=raw, - name=self.name if self.name and self.name == other.name else None, - input=None, - annotations=union, - ) + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) @@ -292,17 +277,12 @@ class BitVecFunc(BitVec): if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) - raw = (self.raw / other.raw,) + raw = self.raw / other.raw union = self.annotations + other.annotations if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc( - raw=raw, - name=self.name if self.name and self.name == other.name else None, - input=None, - annotations=union, - ) + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) @@ -315,17 +295,12 @@ class BitVecFunc(BitVec): if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) - raw = (self.raw & other.raw,) + raw = self.raw & other.raw union = self.annotations + other.annotations if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc( - raw=raw, - name=self.name if self.name and self.name == other.name else None, - input=None, - annotations=union, - ) + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) @@ -338,17 +313,12 @@ class BitVecFunc(BitVec): if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) - raw = (self.raw | other.raw,) + raw = self.raw | other.raw union = self.annotations + other.annotations if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc( - raw=raw, - name=self.name if self.name and self.name == other.name else None, - input=None, - annotations=union, - ) + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) @@ -361,17 +331,12 @@ class BitVecFunc(BitVec): if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) - raw = (self.raw ^ other.raw,) + raw = self.raw ^ other.raw union = self.annotations + other.annotations if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc( - raw=raw, - name=self.name if self.name and self.name == other.name else None, - input=None, - annotations=union, - ) + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) @@ -389,7 +354,7 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if not self.symbolic and not other.symbolic: - return Bool(cast(z3.BoolRef, self.value < other.value), annotations=union) + return Bool(z3.BoolVal(self.value < other.value), annotations=union) if ( not isinstance(other, BitVecFunc) @@ -397,7 +362,7 @@ class BitVecFunc(BitVec): or not self.input or not self.name == other.name ): - return Bool(False, annotations=union) + return Bool(z3.BoolVal(False), annotations=union) return And( Bool(cast(z3.BoolRef, self.raw < other.raw), annotations=union), @@ -418,7 +383,7 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if not self.symbolic and not other.symbolic: - return Bool(cast(z3.BoolRef, self.value > other.value), annotations=union) + return Bool(z3.BoolVal(self.value > other.value), annotations=union) if ( not isinstance(other, BitVecFunc) @@ -426,7 +391,7 @@ class BitVecFunc(BitVec): or not self.input or not self.name == other.name ): - return Bool(False, annotations=union) + return Bool(z3.BoolVal(False), annotations=union) return And( Bool(cast(z3.BoolRef, self.raw > other.raw), annotations=union), @@ -446,7 +411,7 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if not self.symbolic and not other.symbolic: - return Bool(cast(z3.BoolRef, self.value == other.value), annotations=union) + return Bool(z3.BoolVal(self.value == other.value), annotations=union) if ( not isinstance(other, BitVecFunc) @@ -454,7 +419,7 @@ class BitVecFunc(BitVec): or not self.input or not self.name == other.name ): - return Bool(cast(z3.BoolRef, False), annotations=union) + return Bool(z3.BoolVal(True), annotations=union) # MYPY: fix complaints due to z3 overriding __eq__ return And( @@ -475,7 +440,7 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if not self.symbolic and not other.symbolic: - return Bool(cast(z3.BoolRef, self.value != other.value), annotations=union) + return Bool(z3.BoolVal(self.value != other.value), annotations=union) if ( not isinstance(other, BitVecFunc) @@ -483,7 +448,7 @@ class BitVecFunc(BitVec): or not self.input or not self.name == other.name ): - return Bool(cast(z3.BoolRef, True), annotations=union) + return Bool(z3.BoolVal(True), annotations=union) # MYPY: fix complaints due to z3 overriding __eq__ return Or( @@ -519,9 +484,23 @@ def UGT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - # TODO: Handle BitVecFunc - annotations = a.annotations + b.annotations + if isinstance(a, BitVecFunc): + if not a.symbolic and not b.symbolic: + return Bool(z3.UGT(a.raw, b.raw), annotations=annotations) + + if ( + not isinstance(b, BitVecFunc) + or not a.name + or not a.input + or not a.name == b.name + ): + return Bool(z3.BoolVal(False), annotations=annotations) + + return And( + Bool(z3.UGT(a.raw, b.raw), annotations=annotations), a.input != b.input + ) + return Bool(z3.UGT(a.raw, b.raw), annotations) @@ -532,9 +511,23 @@ def UGE(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - # TODO: Handle BitVecFunc - annotations = a.annotations + b.annotations + if isinstance(a, BitVecFunc): + if not a.symbolic and not b.symbolic: + return Bool(z3.UGE(a.raw, b.raw), annotations=annotations) + + if ( + not isinstance(b, BitVecFunc) + or not a.name + or not a.input + or not a.name == b.name + ): + return Bool(z3.BoolVal(False), annotations=annotations) + + return And( + Bool(z3.UGE(a.raw, b.raw), annotations=annotations), a.input != b.input + ) + return Bool(z3.UGE(a.raw, b.raw), annotations) @@ -545,9 +538,23 @@ def ULT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - # TODO: Handle BitVecFunc - annotations = a.annotations + b.annotations + if isinstance(a, BitVecFunc): + if not a.symbolic and not b.symbolic: + return Bool(z3.ULT(a.raw, b.raw), annotations=annotations) + + if ( + not isinstance(b, BitVecFunc) + or not a.name + or not a.input + or not a.name == b.name + ): + return Bool(z3.BoolVal(False), annotations=annotations) + + return And( + Bool(z3.ULT(a.raw, b.raw), annotations=annotations), a.input != b.input + ) + return Bool(z3.ULT(a.raw, b.raw), annotations) @@ -565,8 +572,6 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: :param args: :return: """ - # TODO: Handle BitVecFunc - # The following statement is used if a list is provided as an argument to concat if len(args) == 1 and isinstance(args[0], list): bvs = args[0] # type: List[BitVec] @@ -575,8 +580,16 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: nraw = z3.Concat([a.raw for a in bvs]) annotations = [] # type: Annotations + bitvecfunc = False for bv in bvs: annotations += bv.annotations + if isinstance(bv, BitVecFunc): + bitvecfunc = True + + if bitvecfunc: + # Is there a better value to set name and input to in this case? + return BitVecFunc(raw=nraw, name=None, input=None, annotations=bv.annotations) + return BitVec(nraw, annotations) @@ -588,9 +601,12 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :param bv: :return: """ - # TODO: Handle BitVecFunc + raw = z3.Extract(high, low, bv.raw) + if isinstance(bv, BitVecFunc): + # Is there a better value to set name and input to in this case? + return BitVecFunc(raw=raw, name=None, input=None, annotations=bv.annotations) - return BitVec(z3.Extract(high, low, bv.raw), annotations=bv.annotations) + return BitVec(raw, annotations=bv.annotations) def URem(a: BitVec, b: BitVec) -> BitVec: @@ -600,10 +616,17 @@ def URem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - # TODO: Handle BitVecFunc - + raw = z3.URem(a.raw, b.raw) union = a.annotations + b.annotations - return BitVec(z3.URem(a.raw, b.raw), annotations=union) + + if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + elif isinstance(a, BitVecFunc): + return BitVecFunc(raw=raw, name=a.name, input=a.input, annotations=union) + elif isinstance(b, BitVecFunc): + return BitVecFunc(raw=raw, name=b.name, input=b.input, annotations=union) + + return BitVec(raw, annotations=union) def SRem(a: BitVec, b: BitVec) -> BitVec: @@ -613,10 +636,17 @@ def SRem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - # TODO: Handle BitVecFunc - + raw = z3.SRem(a.raw, b.raw) union = a.annotations + b.annotations - return BitVec(z3.SRem(a.raw, b.raw), annotations=union) + + if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + elif isinstance(a, BitVecFunc): + return BitVecFunc(raw=raw, name=a.name, input=a.input, annotations=union) + elif isinstance(b, BitVecFunc): + return BitVecFunc(raw=raw, name=b.name, input=b.input, annotations=union) + + return BitVec(raw, annotations=union) def UDiv(a: BitVec, b: BitVec) -> BitVec: @@ -626,10 +656,17 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - # TODO: Handle BitVecFunc - + raw = z3.UDiv(a.raw, b.raw) union = a.annotations + b.annotations - return BitVec(z3.UDiv(a.raw, b.raw), annotations=union) + + if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): + return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + elif isinstance(a, BitVecFunc): + return BitVecFunc(raw=raw, name=a.name, input=a.input, annotations=union) + elif isinstance(b, BitVecFunc): + return BitVecFunc(raw=raw, name=b.name, input=b.input, annotations=union) + + return BitVec(raw, annotations=union) def Sum(*args: BitVec) -> BitVec: @@ -637,13 +674,26 @@ def Sum(*args: BitVec) -> BitVec: :return: """ - # TODO: Handle BitVecFunc - - nraw = z3.Sum([a.raw for a in args]) + raw = z3.Sum([a.raw for a in args]) annotations = [] # type: Annotations + bitvecfuncs = [] + for bv in args: annotations += bv.annotations - return BitVec(nraw, annotations) + if isinstance(bv, BitVecFunc): + bitvecfuncs.append(bv) + + if len(bitvecfuncs) >= 2: + return BitVecFunc(raw=raw, name=None, input=None, annotations=annotations) + elif len(bitvecfuncs) == 1: + return BitVecFunc( + raw=raw, + name=bitvecfuncs[0].name, + input=bitvecfuncs[0].input, + annotations=annotations, + ) + + return BitVec(raw, annotations) def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: From 1dc622f90a592ae02ea9c060a55027973a007bc9 Mon Sep 17 00:00:00 2001 From: Nathan Date: Thu, 31 Jan 2019 15:20:57 -0500 Subject: [PATCH 04/10] Improve typehints, documentation, and implement <= and >= for BitVecFunc --- mythril/laser/smt/bitvec.py | 69 ++++++++++++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index 1f588ee9..a53aec02 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -194,7 +194,7 @@ class BitVec(Expression[z3.BitVecRef]): class BitVecFunc(BitVec): - """A bit vector symbol.""" + """A bit vector function symbol. Used in place of functions like sha3.""" def __init__( self, @@ -214,7 +214,7 @@ class BitVecFunc(BitVec): self.input = input super().__init__(raw, annotations) - def __add__(self, other: Union[int, "BitVec"]) -> "BitVec": + def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": """Create an addition expression. :param other: @@ -322,7 +322,7 @@ class BitVecFunc(BitVec): return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) - def __xor__(self, other: "BitVec") -> "BitVec": + def __xor__(self, other: "BitVec") -> "BitVecFunc": """Create a xor expression. :param other: @@ -398,6 +398,63 @@ class BitVecFunc(BitVec): self.input != other.input, ) + def __le__(self, other: "BitVec") -> Bool: + """Create a signed less than expression. + + :param other: + :return: + """ + # Is there some hack for these comparisons? + + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(z3.BoolVal(self.value <= other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(z3.BoolVal(False), annotations=union) + + return And( + Bool(cast(z3.BoolRef, self.raw <= other.raw), annotations=union), + self.input != other.input, + ) + + def __ge__(self, other: "BitVec") -> Bool: + """Create a signed greater than expression. + + :param other: + :return: + """ + # Is there some hack for these comparisons? + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(z3.BoolVal(self.value >= other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(z3.BoolVal(False), annotations=union) + + return And( + Bool(cast(z3.BoolRef, self.raw >= other.raw), annotations=union), + self.input != other.input, + ) + # MYPY: fix complains about overriding __eq__ def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore """Create an equality expression. @@ -466,6 +523,10 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi :return: """ # TODO: Handle BitVecFunc + if isinstance(b, BitVecFunc): + die() + elif isinstance(c, BitVecFunc): + die() if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) @@ -588,7 +649,7 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: if bitvecfunc: # Is there a better value to set name and input to in this case? - return BitVecFunc(raw=nraw, name=None, input=None, annotations=bv.annotations) + return BitVecFunc(raw=nraw, name=None, input=None, annotations=annotations) return BitVec(nraw, annotations) From c3aad236fa1a06b654ae1e91afe4edcc383b2410 Mon Sep 17 00:00:00 2001 From: Nathan Date: Fri, 1 Feb 2019 10:21:02 -0500 Subject: [PATCH 05/10] Fix bug and documentation for BitVecFunc --- mythril/laser/smt/bitvec.py | 119 ++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 61 deletions(-) diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index a53aec02..733e6be1 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -199,18 +199,19 @@ class BitVecFunc(BitVec): def __init__( self, raw: z3.BitVecRef, - name: str, + func_name: Optional[str], input: Union[int, "BitVec"] = None, annotations: Optional[Annotations] = None, ): """ - :param raw: - :param annotations: - :param input: + :param raw: The raw bit vector symbol + :param func_name: The function name. e.g. sha3 + :param input: The input to the functions + :param annotations: The annotations the BitVecFunc should start with """ - self.name = name + self.func_name = func_name self.input = input super().__init__(raw, annotations) @@ -228,9 +229,9 @@ class BitVecFunc(BitVec): if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": """Create a subtraction expression. @@ -246,9 +247,9 @@ class BitVecFunc(BitVec): if isinstance(other, BitVecFunc): # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) def __mul__(self, other: "BitVec") -> "BitVecFunc": """Create a multiplication expression. @@ -263,10 +264,10 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + # TODO: Find better value to set input and func_name to in this case + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) def __truediv__(self, other: "BitVec") -> "BitVecFunc": """Create a signed division expression. @@ -281,10 +282,10 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + # TODO: Find better value to set input and func_name to in this case + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": """Create an and expression. @@ -299,10 +300,10 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + # TODO: Find better value to set input and func_name to in this case + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) def __or__(self, other: "BitVec") -> "BitVecFunc": """Create an or expression. @@ -317,10 +318,10 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + # TODO: Find better value to set input and func_name to in this case + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) def __xor__(self, other: "BitVec") -> "BitVecFunc": """Create a xor expression. @@ -335,10 +336,10 @@ class BitVecFunc(BitVec): union = self.annotations + other.annotations if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + # TODO: Find better value to set input and func_name to in this case + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) def __lt__(self, other: "BitVec") -> Bool: """Create a signed less than expression. @@ -358,9 +359,9 @@ class BitVecFunc(BitVec): if ( not isinstance(other, BitVecFunc) - or not self.name + or not self.func_name or not self.input - or not self.name == other.name + or not self.func_name == other.func_name ): return Bool(z3.BoolVal(False), annotations=union) @@ -387,9 +388,9 @@ class BitVecFunc(BitVec): if ( not isinstance(other, BitVecFunc) - or not self.name + or not self.func_name or not self.input - or not self.name == other.name + or not self.func_name == other.func_name ): return Bool(z3.BoolVal(False), annotations=union) @@ -416,9 +417,9 @@ class BitVecFunc(BitVec): if ( not isinstance(other, BitVecFunc) - or not self.name + or not self.func_name or not self.input - or not self.name == other.name + or not self.func_name == other.func_name ): return Bool(z3.BoolVal(False), annotations=union) @@ -444,9 +445,9 @@ class BitVecFunc(BitVec): if ( not isinstance(other, BitVecFunc) - or not self.name + or not self.func_name or not self.input - or not self.name == other.name + or not self.func_name == other.func_name ): return Bool(z3.BoolVal(False), annotations=union) @@ -472,9 +473,9 @@ class BitVecFunc(BitVec): if ( not isinstance(other, BitVecFunc) - or not self.name + or not self.func_name or not self.input - or not self.name == other.name + or not self.func_name == other.func_name ): return Bool(z3.BoolVal(True), annotations=union) @@ -501,9 +502,9 @@ class BitVecFunc(BitVec): if ( not isinstance(other, BitVecFunc) - or not self.name + or not self.func_name or not self.input - or not self.name == other.name + or not self.func_name == other.func_name ): return Bool(z3.BoolVal(True), annotations=union) @@ -523,10 +524,6 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi :return: """ # TODO: Handle BitVecFunc - if isinstance(b, BitVecFunc): - die() - elif isinstance(c, BitVecFunc): - die() if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) @@ -552,9 +549,9 @@ def UGT(a: BitVec, b: BitVec) -> Bool: if ( not isinstance(b, BitVecFunc) - or not a.name + or not a.func_name or not a.input - or not a.name == b.name + or not a.func_name == b.func_name ): return Bool(z3.BoolVal(False), annotations=annotations) @@ -579,9 +576,9 @@ def UGE(a: BitVec, b: BitVec) -> Bool: if ( not isinstance(b, BitVecFunc) - or not a.name + or not a.func_name or not a.input - or not a.name == b.name + or not a.func_name == b.func_name ): return Bool(z3.BoolVal(False), annotations=annotations) @@ -606,9 +603,9 @@ def ULT(a: BitVec, b: BitVec) -> Bool: if ( not isinstance(b, BitVecFunc) - or not a.name + or not a.func_name or not a.input - or not a.name == b.name + or not a.func_name == b.func_name ): return Bool(z3.BoolVal(False), annotations=annotations) @@ -648,8 +645,8 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: bitvecfunc = True if bitvecfunc: - # Is there a better value to set name and input to in this case? - return BitVecFunc(raw=nraw, name=None, input=None, annotations=annotations) + # Is there a better value to set func_name and input to in this case? + return BitVecFunc(raw=nraw, func_name=None, input=None, annotations=annotations) return BitVec(nraw, annotations) @@ -664,8 +661,8 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: """ raw = z3.Extract(high, low, bv.raw) if isinstance(bv, BitVecFunc): - # Is there a better value to set name and input to in this case? - return BitVecFunc(raw=raw, name=None, input=None, annotations=bv.annotations) + # Is there a better value to set func_name and input to in this case? + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=bv.annotations) return BitVec(raw, annotations=bv.annotations) @@ -681,11 +678,11 @@ def URem(a: BitVec, b: BitVec) -> BitVec: union = a.annotations + b.annotations if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, name=a.name, input=a.input, annotations=union) + return BitVecFunc(raw=raw, func_name=a.func_name, input=a.input, annotations=union) elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, name=b.name, input=b.input, annotations=union) + return BitVecFunc(raw=raw, func_name=b.func_name, input=b.input, annotations=union) return BitVec(raw, annotations=union) @@ -701,11 +698,11 @@ def SRem(a: BitVec, b: BitVec) -> BitVec: union = a.annotations + b.annotations if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, name=a.name, input=a.input, annotations=union) + return BitVecFunc(raw=raw, func_name=a.func_name, input=a.input, annotations=union) elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, name=b.name, input=b.input, annotations=union) + return BitVecFunc(raw=raw, func_name=b.func_name, input=b.input, annotations=union) return BitVec(raw, annotations=union) @@ -721,11 +718,11 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec: union = a.annotations + b.annotations if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, name=a.name, input=a.input, annotations=union) + return BitVecFunc(raw=raw, func_name=a.func_name, input=a.input, annotations=union) elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, name=b.name, input=b.input, annotations=union) + return BitVecFunc(raw=raw, func_name=b.func_name, input=b.input, annotations=union) return BitVec(raw, annotations=union) @@ -745,11 +742,11 @@ def Sum(*args: BitVec) -> BitVec: bitvecfuncs.append(bv) if len(bitvecfuncs) >= 2: - return BitVecFunc(raw=raw, name=None, input=None, annotations=annotations) + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=annotations) elif len(bitvecfuncs) == 1: return BitVecFunc( raw=raw, - name=bitvecfuncs[0].name, + func_name=bitvecfuncs[0].func_name, input=bitvecfuncs[0].input, annotations=annotations, ) From 25a7774b1280c9e7d7aebf6aae6d94e93ee416b1 Mon Sep 17 00:00:00 2001 From: Nathan Date: Mon, 4 Feb 2019 15:15:35 -0500 Subject: [PATCH 06/10] Move BitVecFunc to bitvecfunc.py and add helper functions --- mythril/laser/smt/__init__.py | 2 +- mythril/laser/smt/bitvec.py | 328 +------------------------------- mythril/laser/smt/bitvecfunc.py | 207 ++++++++++++++++++++ 3 files changed, 213 insertions(+), 324 deletions(-) create mode 100644 mythril/laser/smt/bitvecfunc.py diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index a16399c2..0f7ee8e2 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -1,6 +1,5 @@ from mythril.laser.smt.bitvec import ( BitVec, - BitVecFunc, If, UGT, ULT, @@ -15,6 +14,7 @@ from mythril.laser.smt.bitvec import ( BVMulNoOverflow, BVSubNoUnderflow, ) +from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index 733e6be1..c111fae8 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -4,7 +4,7 @@ from typing import Union, overload, List, cast, Any, Optional import z3 -from mythril.laser.smt.bool import Bool, And, Or +from mythril.laser.smt.bool import Bool, And from mythril.laser.smt.expression import Expression Annotations = List[Any] @@ -193,328 +193,6 @@ class BitVec(Expression[z3.BitVecRef]): return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) -class BitVecFunc(BitVec): - """A bit vector function symbol. Used in place of functions like sha3.""" - - def __init__( - self, - raw: z3.BitVecRef, - func_name: Optional[str], - input: Union[int, "BitVec"] = None, - annotations: Optional[Annotations] = None, - ): - """ - - :param raw: The raw bit vector symbol - :param func_name: The function name. e.g. sha3 - :param input: The input to the functions - :param annotations: The annotations the BitVecFunc should start with - """ - - self.func_name = func_name - self.input = input - super().__init__(raw, annotations) - - def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an addition expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - raw = self.raw + other.raw - union = self.annotations + other.annotations - - if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - - return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) - - def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create a subtraction expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - raw = self.raw - other.raw - union = self.annotations + other.annotations - - if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and name to in this case - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - - return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) - - def __mul__(self, other: "BitVec") -> "BitVecFunc": - """Create a multiplication expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - raw = self.raw * other.raw - union = self.annotations + other.annotations - - if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and func_name to in this case - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - - return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) - - def __truediv__(self, other: "BitVec") -> "BitVecFunc": - """Create a signed division expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - raw = self.raw / other.raw - union = self.annotations + other.annotations - - if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and func_name to in this case - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - - return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) - - def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": - """Create an and expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - raw = self.raw & other.raw - union = self.annotations + other.annotations - - if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and func_name to in this case - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - - return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) - - def __or__(self, other: "BitVec") -> "BitVecFunc": - """Create an or expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - raw = self.raw | other.raw - union = self.annotations + other.annotations - - if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and func_name to in this case - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - - return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) - - def __xor__(self, other: "BitVec") -> "BitVecFunc": - """Create a xor expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - raw = self.raw ^ other.raw - union = self.annotations + other.annotations - - if isinstance(other, BitVecFunc): - # TODO: Find better value to set input and func_name to in this case - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) - - return BitVecFunc(raw=raw, func_name=self.func_name, input=self.input, annotations=union) - - def __lt__(self, other: "BitVec") -> Bool: - """Create a signed less than expression. - - :param other: - :return: - """ - # Is there some hack for these comparisons? - - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - union = self.annotations + other.annotations - - if not self.symbolic and not other.symbolic: - return Bool(z3.BoolVal(self.value < other.value), annotations=union) - - if ( - not isinstance(other, BitVecFunc) - or not self.func_name - or not self.input - or not self.func_name == other.func_name - ): - return Bool(z3.BoolVal(False), annotations=union) - - return And( - Bool(cast(z3.BoolRef, self.raw < other.raw), annotations=union), - self.input != other.input, - ) - - def __gt__(self, other: "BitVec") -> Bool: - """Create a signed greater than expression. - - :param other: - :return: - """ - # Is there some hack for these comparisons? - - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - union = self.annotations + other.annotations - - if not self.symbolic and not other.symbolic: - return Bool(z3.BoolVal(self.value > other.value), annotations=union) - - if ( - not isinstance(other, BitVecFunc) - or not self.func_name - or not self.input - or not self.func_name == other.func_name - ): - return Bool(z3.BoolVal(False), annotations=union) - - return And( - Bool(cast(z3.BoolRef, self.raw > other.raw), annotations=union), - self.input != other.input, - ) - - def __le__(self, other: "BitVec") -> Bool: - """Create a signed less than expression. - - :param other: - :return: - """ - # Is there some hack for these comparisons? - - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - union = self.annotations + other.annotations - - if not self.symbolic and not other.symbolic: - return Bool(z3.BoolVal(self.value <= other.value), annotations=union) - - if ( - not isinstance(other, BitVecFunc) - or not self.func_name - or not self.input - or not self.func_name == other.func_name - ): - return Bool(z3.BoolVal(False), annotations=union) - - return And( - Bool(cast(z3.BoolRef, self.raw <= other.raw), annotations=union), - self.input != other.input, - ) - - def __ge__(self, other: "BitVec") -> Bool: - """Create a signed greater than expression. - - :param other: - :return: - """ - # Is there some hack for these comparisons? - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - union = self.annotations + other.annotations - - if not self.symbolic and not other.symbolic: - return Bool(z3.BoolVal(self.value >= other.value), annotations=union) - - if ( - not isinstance(other, BitVecFunc) - or not self.func_name - or not self.input - or not self.func_name == other.func_name - ): - return Bool(z3.BoolVal(False), annotations=union) - - return And( - Bool(cast(z3.BoolRef, self.raw >= other.raw), annotations=union), - self.input != other.input, - ) - - # MYPY: fix complains about overriding __eq__ - def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore - """Create an equality expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - union = self.annotations + other.annotations - - if not self.symbolic and not other.symbolic: - return Bool(z3.BoolVal(self.value == other.value), annotations=union) - - if ( - not isinstance(other, BitVecFunc) - or not self.func_name - or not self.input - or not self.func_name == other.func_name - ): - return Bool(z3.BoolVal(True), annotations=union) - - # MYPY: fix complaints due to z3 overriding __eq__ - return And( - Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union), - self.input == other.input, - ) - - # MYPY: fix complains about overriding __ne__ - def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore - """Create an inequality expression. - - :param other: - :return: - """ - if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, self.size())) - - union = self.annotations + other.annotations - - if not self.symbolic and not other.symbolic: - return Bool(z3.BoolVal(self.value != other.value), annotations=union) - - if ( - not isinstance(other, BitVecFunc) - or not self.func_name - or not self.input - or not self.func_name == other.func_name - ): - return Bool(z3.BoolVal(True), annotations=union) - - # MYPY: fix complaints due to z3 overriding __eq__ - return Or( - Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union), - self.input != other.input, - ) - - def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: """Create an if-then-else expression. @@ -801,3 +479,7 @@ def BVSubNoUnderflow( b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed)) + + +# TODO: Fix circular import issues +from mythril.laser.smt.bitvecfunc import BitVecFunc diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py new file mode 100644 index 00000000..dabf9d56 --- /dev/null +++ b/mythril/laser/smt/bitvecfunc.py @@ -0,0 +1,207 @@ +from typing import Optional, Union, cast + +import z3 + +from mythril.laser.smt.bitvec import BitVec, Bool, And, Annotations +from mythril.laser.smt.bool import Or + +import operator + + +def _arithmetic_helper( + a: "BitVecFunc", b: Union[BitVec, int], operation: callable +) -> "BitVecFunc": + """ + Helper function for arithmetic operations on BitVecFuncs. + + :param a: The BitVecFunc to perform the operation on. + :param b: A BitVec or int to perform the operation on. + :param operation: The arithmetic operation to perform. + :return: The resulting BitVecFunc + """ + if isinstance(b, int): + b = BitVec(z3.BitVecVal(b, a.size())) + + raw = operation(a.raw, b.raw) + union = a.annotations + b.annotations + + if isinstance(b, BitVecFunc): + # TODO: Find better value to set input and name to in this case? + return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) + + return BitVecFunc(raw=raw, func_name=a.func_name, input=a.input, annotations=union) + + +def _comparison_helper( + a: "BitVecFunc", + b: Union[BitVec, int], + operation: callable, + default_value: bool, + inputs_equal: bool, +) -> Bool: + """ + Helper function for comparison operations with BitVecFuncs. + + :param a: The BitVecFunc to compare. + :param b: A BitVec or int to compare to. + :param operation: The comparison operation to perform. + :return: The resulting Bool + """ + # Is there some hack for gt/lt comparisons? + if isinstance(b, int): + b = BitVec(z3.BitVecVal(b, a.size())) + + union = a.annotations + b.annotations + + if not a.symbolic and not b.symbolic: + return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) + + if ( + not isinstance(b, BitVecFunc) + or not a.func_name + or not a.input + or not a.func_name == b.func_name + ): + return Bool(z3.BoolVal(default_value), annotations=union) + + return And( + Bool(cast(z3.BoolRef, operation(a.raw, b.raw)), annotations=union), + a.input == b.input if inputs_equal else a.input != b.input, + ) + + +class BitVecFunc(BitVec): + """A bit vector function symbol. Used in place of functions like sha3.""" + + def __init__( + self, + raw: z3.BitVecRef, + func_name: Optional[str], + input: Union[int, "BitVec"] = None, + annotations: Optional[Annotations] = None, + ): + """ + + :param raw: The raw bit vector symbol + :param func_name: The function name. e.g. sha3 + :param input: The input to the functions + :param annotations: The annotations the BitVecFunc should start with + """ + + self.func_name = func_name + self.input = input + super().__init__(raw, annotations) + + def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create an addition expression. + + :param other: The int or BitVec to add to this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.add) + + def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create a subtraction expression. + + :param other: The int or BitVec to subtract from this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.sub) + + def __mul__(self, other: "BitVec") -> "BitVecFunc": + """Create a multiplication expression. + + :param other: The int or BitVec to multiply to this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.mul) + + def __truediv__(self, other: "BitVec") -> "BitVecFunc": + """Create a signed division expression. + + :param other: The int or BitVec to divide this BitVecFunc by + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.truediv) + + def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create an and expression. + + :param other: The int or BitVec to and with this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.and_) + + def __or__(self, other: "BitVec") -> "BitVecFunc": + """Create an or expression. + + :param other: The int or BitVec to or with this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.or_) + + def __xor__(self, other: "BitVec") -> "BitVecFunc": + """Create a xor expression. + + :param other: The int or BitVec to xor with this BitVecFunc + :return: The resulting BitVecFunc + """ + return _arithmetic_helper(self, other, operator.xor) + + def __lt__(self, other: "BitVec") -> Bool: + """Create a signed less than expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return _comparison_helper( + self, other, operator.lt, default_value=False, inputs_equal=False + ) + + def __gt__(self, other: "BitVec") -> Bool: + """Create a signed greater than expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return _comparison_helper( + self, other, operator.gt, default_value=False, inputs_equal=False + ) + + def __le__(self, other: "BitVec") -> Bool: + """Create a signed less than or equal to expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return Or(self < other, self == other) + + def __ge__(self, other: "BitVec") -> Bool: + """Create a signed greater than or equal to expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return Or(self > other, self == other) + + # MYPY: fix complains about overriding __eq__ + def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore + """Create an equality expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return _comparison_helper( + self, other, operator.eq, default_value=False, inputs_equal=True + ) + + # MYPY: fix complains about overriding __ne__ + def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore + """Create an inequality expression. + + :param other: The int or BitVec to compare to this BitVecFunc + :return: The resulting Bool + """ + return _comparison_helper( + self, other, operator.eq, default_value=True, inputs_equal=False + ) From 5bfd7fe39215a1175b4b18516e9d57dfb2b85764 Mon Sep 17 00:00:00 2001 From: Nathan Date: Tue, 5 Feb 2019 13:43:16 -0500 Subject: [PATCH 07/10] Rename BitVecFunc.input to BitVecFunc.input_ --- mythril/laser/ethereum/instructions.py | 4 +-- mythril/laser/smt/__init__.py | 35 ++++++++++++++---------- mythril/laser/smt/bitvec.py | 38 +++++++++++++------------- mythril/laser/smt/bitvecfunc.py | 14 ++++++---- 4 files changed, 50 insertions(+), 41 deletions(-) diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 03f0f9ee..cc718e4d 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -906,7 +906,7 @@ class Instruction: if data.symbolic: argument_str = str(state.memory[index]).replace(" ", "_") result = symbol_factory.BitVecFuncSym( - "KECCAC[{}]".format(argument_str), "keccak256", 256, input=data + "KECCAC[{}]".format(argument_str), "keccak256", 256, input_=data ) log.debug("Created BitVecFunc hash.") @@ -914,7 +914,7 @@ class Instruction: else: keccak = utils.sha3(data.value.to_bytes(length, byteorder="big")) result = symbol_factory.BitVecFuncVal( - "keccak256", util.concrete_int_from_bytes(keccak, 0), 256, input=data + util.concrete_int_from_bytes(keccak, 0), "keccak256", 256, input_=data ) log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 0f7ee8e2..e2f055e2 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -66,32 +66,39 @@ class SymbolFactory(Generic[T, U]): @staticmethod def BitVecFuncVal( - func_name: str, value: int, + func_name: str, size: int, annotations: Annotations = None, - input: Union[int, "BitVec"] = None, + input_: Union[int, "BitVec"] = None, ) -> BitVecFunc: - """Creates a new bit vector function with a concrete value. + """Creates a new bit vector function with a symbolic value. - :param func_name: The name of the function :param value: The concrete value to set the bit vector to + :param func_name: The name of the bit vector function :param size: The size of the bit vector :param annotations: The annotations to initialize the bit vector with - :return: The freshly created bit vector + :param input_: The input to the bit vector function + :return: The freshly created bit vector function """ raise NotImplementedError() @staticmethod def BitVecFuncSym( - name: str, func_name: str, size: int, annotations: Annotations = None - ) -> U: - """Creates a new bit vector with a symbolic value. + name: str, + func_name: str, + size: int, + annotations: Annotations = None, + input_: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a symbolic value. :param name: The name of the symbolic bit vector + :param func_name: The name of the bit vector function :param size: The size of the bit vector :param annotations: The annotations to initialize the bit vector with - :return: The freshly created bit vector + :param input_: The input to the bit vector function + :return: The freshly created bit vector function """ raise NotImplementedError() @@ -127,15 +134,15 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): @staticmethod def BitVecFuncVal( - func_name: str, value: int, + func_name: str, size: int, annotations: Annotations = None, - input: Union[int, "BitVec"] = None, + input_: Union[int, "BitVec"] = None, ) -> BitVecFunc: """Creates a new bit vector function with a concrete value.""" raw = z3.BitVecVal(value, size) - return BitVecFunc(raw, func_name, input, annotations) + return BitVecFunc(raw, func_name, input_, annotations) @staticmethod def BitVecFuncSym( @@ -143,11 +150,11 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): func_name: str, size: int, annotations: Annotations = None, - input: Union[int, "BitVec"] = None, + input_: Union[int, "BitVec"] = None, ) -> BitVecFunc: """Creates a new bit vector function with a symbolic value.""" raw = z3.BitVec(name, size) - return BitVecFunc(raw, func_name, input, annotations) + return BitVecFunc(raw, func_name, input_, annotations) class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index c111fae8..c936bad4 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -228,13 +228,13 @@ def UGT(a: BitVec, b: BitVec) -> Bool: if ( not isinstance(b, BitVecFunc) or not a.func_name - or not a.input + or not a.input_ or not a.func_name == b.func_name ): return Bool(z3.BoolVal(False), annotations=annotations) return And( - Bool(z3.UGT(a.raw, b.raw), annotations=annotations), a.input != b.input + Bool(z3.UGT(a.raw, b.raw), annotations=annotations), a.input_ != b.input_ ) return Bool(z3.UGT(a.raw, b.raw), annotations) @@ -255,13 +255,13 @@ def UGE(a: BitVec, b: BitVec) -> Bool: if ( not isinstance(b, BitVecFunc) or not a.func_name - or not a.input + or not a.input_ or not a.func_name == b.func_name ): return Bool(z3.BoolVal(False), annotations=annotations) return And( - Bool(z3.UGE(a.raw, b.raw), annotations=annotations), a.input != b.input + Bool(z3.UGE(a.raw, b.raw), annotations=annotations), a.input_ != b.input_ ) return Bool(z3.UGE(a.raw, b.raw), annotations) @@ -282,13 +282,13 @@ def ULT(a: BitVec, b: BitVec) -> Bool: if ( not isinstance(b, BitVecFunc) or not a.func_name - or not a.input + or not a.input_ or not a.func_name == b.func_name ): return Bool(z3.BoolVal(False), annotations=annotations) return And( - Bool(z3.ULT(a.raw, b.raw), annotations=annotations), a.input != b.input + Bool(z3.ULT(a.raw, b.raw), annotations=annotations), a.input_ != b.input_ ) return Bool(z3.ULT(a.raw, b.raw), annotations) @@ -324,7 +324,7 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: if bitvecfunc: # Is there a better value to set func_name and input to in this case? - return BitVecFunc(raw=nraw, func_name=None, input=None, annotations=annotations) + return BitVecFunc(raw=nraw, func_name=None, input_=None, annotations=annotations) return BitVec(nraw, annotations) @@ -340,7 +340,7 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: raw = z3.Extract(high, low, bv.raw) if isinstance(bv, BitVecFunc): # Is there a better value to set func_name and input to in this case? - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=bv.annotations) + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=bv.annotations) return BitVec(raw, annotations=bv.annotations) @@ -356,11 +356,11 @@ def URem(a: BitVec, b: BitVec) -> BitVec: union = a.annotations + b.annotations if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, func_name=a.func_name, input=a.input, annotations=union) + return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union) elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=b.func_name, input=b.input, annotations=union) + return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union) return BitVec(raw, annotations=union) @@ -376,11 +376,11 @@ def SRem(a: BitVec, b: BitVec) -> BitVec: union = a.annotations + b.annotations if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, func_name=a.func_name, input=a.input, annotations=union) + return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union) elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=b.func_name, input=b.input, annotations=union) + return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union) return BitVec(raw, annotations=union) @@ -396,11 +396,11 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec: union = a.annotations + b.annotations if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, func_name=a.func_name, input=a.input, annotations=union) + return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union) elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=b.func_name, input=b.input, annotations=union) + return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union) return BitVec(raw, annotations=union) @@ -420,12 +420,12 @@ def Sum(*args: BitVec) -> BitVec: bitvecfuncs.append(bv) if len(bitvecfuncs) >= 2: - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=annotations) + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=annotations) elif len(bitvecfuncs) == 1: return BitVecFunc( raw=raw, func_name=bitvecfuncs[0].func_name, - input=bitvecfuncs[0].input, + input_=bitvecfuncs[0].input_, annotations=annotations, ) diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index dabf9d56..e1ff3525 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -27,9 +27,11 @@ def _arithmetic_helper( if isinstance(b, BitVecFunc): # TODO: Find better value to set input and name to in this case? - return BitVecFunc(raw=raw, func_name=None, input=None, annotations=union) + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) - return BitVecFunc(raw=raw, func_name=a.func_name, input=a.input, annotations=union) + return BitVecFunc( + raw=raw, func_name=a.func_name, input_=a.input_, annotations=union + ) def _comparison_helper( @@ -59,14 +61,14 @@ def _comparison_helper( if ( not isinstance(b, BitVecFunc) or not a.func_name - or not a.input + or not a.input_ or not a.func_name == b.func_name ): return Bool(z3.BoolVal(default_value), annotations=union) return And( Bool(cast(z3.BoolRef, operation(a.raw, b.raw)), annotations=union), - a.input == b.input if inputs_equal else a.input != b.input, + a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, ) @@ -77,7 +79,7 @@ class BitVecFunc(BitVec): self, raw: z3.BitVecRef, func_name: Optional[str], - input: Union[int, "BitVec"] = None, + input_: Union[int, "BitVec"] = None, annotations: Optional[Annotations] = None, ): """ @@ -89,7 +91,7 @@ class BitVecFunc(BitVec): """ self.func_name = func_name - self.input = input + self.input_ = input_ super().__init__(raw, annotations) def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": From bfed409e4534114613873ca3ed91474e4d72af38 Mon Sep 17 00:00:00 2001 From: Nathan Date: Tue, 5 Feb 2019 13:43:20 -0500 Subject: [PATCH 08/10] Add bitvecfunc tests --- tests/laser/smt/bitvecfunc_test.py | 78 ++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/laser/smt/bitvecfunc_test.py diff --git a/tests/laser/smt/bitvecfunc_test.py b/tests/laser/smt/bitvecfunc_test.py new file mode 100644 index 00000000..0b3f5367 --- /dev/null +++ b/tests/laser/smt/bitvecfunc_test.py @@ -0,0 +1,78 @@ +from mythril.laser.smt import Solver, symbol_factory +import z3 +import pytest + +import operator + + +@pytest.mark.parametrize( + "operation,expected", + [ + (operator.add, z3.unsat), + (operator.sub, z3.unsat), + (operator.and_, z3.sat), + (operator.or_, z3.sat), + (operator.xor, z3.unsat), + ], +) +def test_bitvecfunc_arithmetic(operation, expected): + # Arrange + s = Solver() + + input_ = symbol_factory.BitVecVal(1, 8) + bvf = symbol_factory.BitVecFuncSym("bvf", "sha3", 256, input_=input_) + + x = symbol_factory.BitVecSym("x", 256) + y = symbol_factory.BitVecSym("y", 256) + + # Act + s.add(x != y) + s.add(operation(bvf, x) == operation(y, bvf)) + + # Assert + assert s.check() == expected + + +@pytest.mark.parametrize( + "operation,expected", + [ + (operator.eq, z3.sat), + (operator.ne, z3.unsat), + (operator.lt, z3.unsat), + (operator.le, z3.sat), + (operator.gt, z3.unsat), + (operator.ge, z3.sat), + ], +) +def test_bitvecfunc_bitvecfunc_comparison(operation, expected): + # Arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=input2) + + # Act + s.add(operation(bvf1, bvf2)) + s.add(input1 == input2) + + # Assert + assert s.check() == expected + + +def test_bitvecfunc_bitvecfuncval_comparison(): + # Arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecVal(1337, 256) + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncVal(12345678910, "sha3", 256, input_=input2) + + # Act + s.add(bvf1 == bvf2) + + # Assert + assert s.check() == z3.sat + assert s.model().eval(input2.raw) == 1337 From b9214bcb4184ee2d63f0aa20515bde275274139a Mon Sep 17 00:00:00 2001 From: Nathan Date: Thu, 7 Feb 2019 14:17:09 -0500 Subject: [PATCH 09/10] Fix bitvecfunc typing --- mythril/laser/ethereum/instructions.py | 10 +++++----- mythril/laser/smt/bitvecfunc.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index a3a07d2b..72d17e5d 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -892,14 +892,14 @@ class Instruction: StateTransition.check_gas_usage_limit(global_state) state.mem_extend(index, length) - data = [ + data_list = [ b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8) for b in state.memory[index : index + length] ] - if len(data) > 1: - data = simplify(Concat(data)) - elif len(data) == 1: - data = data[0] + if len(data_list) > 1: + data = simplify(Concat(data_list)) + elif len(data_list) == 1: + data = data_list[0] else: # length is 0; this only matters for input of the BitVecFuncVal data = symbol_factory.BitVecVal(0, 1) diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index e1ff3525..d3e77601 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, cast +from typing import Optional, Union, cast, Callable import z3 @@ -9,7 +9,7 @@ import operator def _arithmetic_helper( - a: "BitVecFunc", b: Union[BitVec, int], operation: callable + a: "BitVecFunc", b: Union[BitVec, int], operation: Callable ) -> "BitVecFunc": """ Helper function for arithmetic operations on BitVecFuncs. @@ -37,7 +37,7 @@ def _arithmetic_helper( def _comparison_helper( a: "BitVecFunc", b: Union[BitVec, int], - operation: callable, + operation: Callable, default_value: bool, inputs_equal: bool, ) -> Bool: From 0a992bb6d939727958228936ffe3b057ee7e1e06 Mon Sep 17 00:00:00 2001 From: Nathan Date: Mon, 11 Feb 2019 08:27:56 -0500 Subject: [PATCH 10/10] Create helper functions for operations in bitvec.py --- mythril/laser/smt/bitvec.py | 157 ++++++++++++----------------- tests/laser/smt/bitvecfunc_test.py | 6 +- 2 files changed, 71 insertions(+), 92 deletions(-) diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index 8873d18a..00b519a7 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -1,10 +1,10 @@ """This module provides classes for an SMT abstraction of bit vectors.""" -from typing import Union, overload, List, cast, Any, Optional +from typing import Union, overload, List, cast, Any, Optional, Callable import z3 -from mythril.laser.smt.bool import Bool, And +from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.expression import Expression Annotations = List[Any] @@ -212,6 +212,48 @@ class BitVec(Expression[z3.BitVecRef]): return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) +def _comparison_helper( + a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool +) -> Bool: + annotations = a.annotations + b.annotations + if isinstance(a, BitVecFunc): + if not a.symbolic and not b.symbolic: + return Bool(operation(a.raw, b.raw), annotations=annotations) + + if ( + not isinstance(b, BitVecFunc) + or not a.func_name + or not a.input_ + or not a.func_name == b.func_name + ): + return Bool(z3.BoolVal(default_value), annotations=annotations) + + return And( + Bool(operation(a.raw, b.raw), annotations=annotations), + a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, + ) + + return Bool(operation(a.raw, b.raw), annotations) + + +def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: + raw = operation(a.raw, b.raw) + union = a.annotations + b.annotations + + if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): + return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) + elif isinstance(a, BitVecFunc): + return BitVecFunc( + raw=raw, func_name=a.func_name, input_=a.input_, annotations=union + ) + elif isinstance(b, BitVecFunc): + return BitVecFunc( + raw=raw, func_name=b.func_name, input_=b.input_, annotations=union + ) + + return BitVec(raw, annotations=union) + + def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: """Create an if-then-else expression. @@ -239,24 +281,7 @@ def UGT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - annotations = a.annotations + b.annotations - if isinstance(a, BitVecFunc): - if not a.symbolic and not b.symbolic: - return Bool(z3.UGT(a.raw, b.raw), annotations=annotations) - - if ( - not isinstance(b, BitVecFunc) - or not a.func_name - or not a.input_ - or not a.func_name == b.func_name - ): - return Bool(z3.BoolVal(False), annotations=annotations) - - return And( - Bool(z3.UGT(a.raw, b.raw), annotations=annotations), a.input_ != b.input_ - ) - - return Bool(z3.UGT(a.raw, b.raw), annotations) + return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) def UGE(a: BitVec, b: BitVec) -> Bool: @@ -266,24 +291,7 @@ def UGE(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - annotations = a.annotations + b.annotations - if isinstance(a, BitVecFunc): - if not a.symbolic and not b.symbolic: - return Bool(z3.UGE(a.raw, b.raw), annotations=annotations) - - if ( - not isinstance(b, BitVecFunc) - or not a.func_name - or not a.input_ - or not a.func_name == b.func_name - ): - return Bool(z3.BoolVal(False), annotations=annotations) - - return And( - Bool(z3.UGE(a.raw, b.raw), annotations=annotations), a.input_ != b.input_ - ) - - return Bool(z3.UGE(a.raw, b.raw), annotations) + return Or(UGT(a, b), a == b) def ULT(a: BitVec, b: BitVec) -> Bool: @@ -293,24 +301,17 @@ def ULT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ - annotations = a.annotations + b.annotations - if isinstance(a, BitVecFunc): - if not a.symbolic and not b.symbolic: - return Bool(z3.ULT(a.raw, b.raw), annotations=annotations) + return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) - if ( - not isinstance(b, BitVecFunc) - or not a.func_name - or not a.input_ - or not a.func_name == b.func_name - ): - return Bool(z3.BoolVal(False), annotations=annotations) - return And( - Bool(z3.ULT(a.raw, b.raw), annotations=annotations), a.input_ != b.input_ - ) +def ULE(a: BitVec, b: BitVec) -> Bool: + """Create an unsigned less than expression. - return Bool(z3.ULT(a.raw, b.raw), annotations) + :param a: + :param b: + :return: + """ + return Or(ULT(a, b), a == b) @overload @@ -329,12 +330,12 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: """ # The following statement is used if a list is provided as an argument to concat if len(args) == 1 and isinstance(args[0], list): - bvs = args[0] # type: List[BitVec] + bvs = args[0] # type: List[BitVec] else: bvs = cast(List[BitVec], args) nraw = z3.Concat([a.raw for a in bvs]) - annotations = [] # type: Annotations + annotations = [] # type: Annotations bitvecfunc = False for bv in bvs: annotations += bv.annotations @@ -343,7 +344,9 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: if bitvecfunc: # Is there a better value to set func_name and input to in this case? - return BitVecFunc(raw=nraw, func_name=None, input_=None, annotations=annotations) + return BitVecFunc( + raw=nraw, func_name=None, input_=None, annotations=annotations + ) return BitVec(nraw, annotations) @@ -359,7 +362,9 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: raw = z3.Extract(high, low, bv.raw) if isinstance(bv, BitVecFunc): # Is there a better value to set func_name and input to in this case? - return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=bv.annotations) + return BitVecFunc( + raw=raw, func_name=None, input_=None, annotations=bv.annotations + ) return BitVec(raw, annotations=bv.annotations) @@ -371,17 +376,7 @@ def URem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - raw = z3.URem(a.raw, b.raw) - union = a.annotations + b.annotations - - if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) - elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union) - elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union) - - return BitVec(raw, annotations=union) + return _arithmetic_helper(a, b, z3.URem) def SRem(a: BitVec, b: BitVec) -> BitVec: @@ -391,17 +386,7 @@ def SRem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - raw = z3.SRem(a.raw, b.raw) - union = a.annotations + b.annotations - - if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) - elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union) - elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union) - - return BitVec(raw, annotations=union) + return _arithmetic_helper(a, b, z3.SRem) def UDiv(a: BitVec, b: BitVec) -> BitVec: @@ -411,17 +396,7 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ - raw = z3.UDiv(a.raw, b.raw) - union = a.annotations + b.annotations - - if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) - elif isinstance(a, BitVecFunc): - return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union) - elif isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union) - - return BitVec(raw, annotations=union) + return _arithmetic_helper(a, b, z3.UDiv) def Sum(*args: BitVec) -> BitVec: @@ -430,7 +405,7 @@ def Sum(*args: BitVec) -> BitVec: :return: """ raw = z3.Sum([a.raw for a in args]) - annotations = [] # type: Annotations + annotations = [] # type: Annotations bitvecfuncs = [] for bv in args: diff --git a/tests/laser/smt/bitvecfunc_test.py b/tests/laser/smt/bitvecfunc_test.py index 0b3f5367..ea19dad1 100644 --- a/tests/laser/smt/bitvecfunc_test.py +++ b/tests/laser/smt/bitvecfunc_test.py @@ -1,4 +1,4 @@ -from mythril.laser.smt import Solver, symbol_factory +from mythril.laser.smt import Solver, symbol_factory, bitvec import z3 import pytest @@ -42,6 +42,10 @@ def test_bitvecfunc_arithmetic(operation, expected): (operator.le, z3.sat), (operator.gt, z3.unsat), (operator.ge, z3.sat), + (bitvec.UGT, z3.unsat), + (bitvec.UGE, z3.sat), + (bitvec.ULT, z3.unsat), + (bitvec.ULE, z3.sat), ], ) def test_bitvecfunc_bitvecfunc_comparison(operation, expected):