diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 677ec626..03f0f9ee 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -890,29 +890,35 @@ class Instruction: state.max_gas_used += max_gas StateTransition.check_gas_usage_limit(global_state) - try: - state.mem_extend(index, length) - data = b"".join( - [ - util.get_concrete_int(i).to_bytes(1, byteorder="big") - for i in state.memory[index : index + length] - ] - ) + state.mem_extend(index, length) + data = [ + b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8) + for b in state.memory[index : index + length] + ] + if len(data) > 1: + data = simplify(Concat(data)) + elif len(data) == 1: + data = data[0] + else: + # length is 0; this only matters for input of the BitVecFuncVal + data = symbol_factory.BitVecVal(0, 1) - except TypeError: - argument = str(state.memory[index]).replace(" ", "_") + if data.symbolic: + argument_str = str(state.memory[index]).replace(" ", "_") + result = symbol_factory.BitVecFuncSym( + "KECCAC[{}]".format(argument_str), "keccak256", 256, input=data + ) + log.debug("Created BitVecFunc hash.") - result = symbol_factory.BitVecSym("KECCAC[{}]".format(argument), 256) keccak_function_manager.add_keccak(result, state.memory[index]) - state.stack.append(result) - return [global_state] - - keccak = utils.sha3(utils.bytearray_to_bytestr(data)) - log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) + else: + keccak = utils.sha3(data.value.to_bytes(length, byteorder="big")) + result = symbol_factory.BitVecFuncVal( + "keccak256", util.concrete_int_from_bytes(keccak, 0), 256, input=data + ) + log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak))) - state.stack.append( - symbol_factory.BitVecVal(util.concrete_int_from_bytes(keccak, 0), 256) - ) + state.stack.append(result) return [global_state] @StateTransition() diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 9cb81244..a16399c2 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -1,5 +1,6 @@ from mythril.laser.smt.bitvec import ( BitVec, + BitVecFunc, If, UGT, ULT, @@ -63,6 +64,37 @@ class SymbolFactory(Generic[T, U]): """ raise NotImplementedError() + @staticmethod + def BitVecFuncVal( + func_name: str, + value: int, + size: int, + annotations: Annotations = None, + input: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a concrete value. + + :param func_name: The name of the function + :param value: The concrete value to set the bit vector to + :param size: The size of the bit vector + :param annotations: The annotations to initialize the bit vector with + :return: The freshly created bit vector + """ + raise NotImplementedError() + + @staticmethod + def BitVecFuncSym( + name: str, func_name: str, size: int, annotations: Annotations = None + ) -> U: + """Creates a new bit vector with a symbolic value. + + :param name: The name of the symbolic bit vector + :param size: The size of the bit vector + :param annotations: The annotations to initialize the bit vector with + :return: The freshly created bit vector + """ + raise NotImplementedError() + class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): """ @@ -93,6 +125,30 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): raw = z3.BitVec(name, size) return BitVec(raw, annotations) + @staticmethod + def BitVecFuncVal( + func_name: str, + value: int, + size: int, + annotations: Annotations = None, + input: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a concrete value.""" + raw = z3.BitVecVal(value, size) + return BitVecFunc(raw, func_name, input, annotations) + + @staticmethod + def BitVecFuncSym( + name: str, + func_name: str, + size: int, + annotations: Annotations = None, + input: Union[int, "BitVec"] = None, + ) -> BitVecFunc: + """Creates a new bit vector function with a symbolic value.""" + raw = z3.BitVec(name, size) + return BitVecFunc(raw, func_name, input, annotations) + class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): """ diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index feeac40b..ad13ed4c 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -4,7 +4,7 @@ from typing import Union, overload, List, cast, Any, Optional import z3 -from mythril.laser.smt.bool import Bool +from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.expression import Expression Annotations = List[Any] @@ -14,7 +14,7 @@ Annotations = List[Any] class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" - def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations]=None): + def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None): """ :param raw: @@ -55,6 +55,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other + self if isinstance(other, int): return BitVec(self.raw + other, annotations=self.annotations) @@ -67,7 +69,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - + if isinstance(other, BitVecFunc): + return other - self if isinstance(other, int): return BitVec(self.raw - other, annotations=self.annotations) @@ -80,6 +83,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other * self union = self.annotations + other.annotations return BitVec(self.raw * other.raw, annotations=union) @@ -89,6 +94,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other / self union = self.annotations + other.annotations return BitVec(self.raw / other.raw, annotations=union) @@ -98,8 +105,10 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other & self if not isinstance(other, BitVec): - other = BitVec(z3.BitVecVal(other, 256)) + other = BitVec(z3.BitVecVal(other, self.size())) union = self.annotations + other.annotations return BitVec(self.raw & other.raw, annotations=union) @@ -109,6 +118,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other | self union = self.annotations + other.annotations return BitVec(self.raw | other.raw, annotations=union) @@ -118,6 +129,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other ^ self union = self.annotations + other.annotations return BitVec(self.raw ^ other.raw, annotations=union) @@ -127,6 +140,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other > self union = self.annotations + other.annotations return Bool(self.raw < other.raw, annotations=union) @@ -136,6 +151,8 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other < self union = self.annotations + other.annotations return Bool(self.raw > other.raw, annotations=union) @@ -146,8 +163,12 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other == self if not isinstance(other, BitVec): - return Bool(cast(z3.BoolRef, self.raw == other), annotations=self.annotations) + return Bool( + cast(z3.BoolRef, self.raw == other), annotations=self.annotations + ) union = self.annotations + other.annotations # MYPY: fix complaints due to z3 overriding __eq__ @@ -160,14 +181,321 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ + if isinstance(other, BitVecFunc): + return other != self if not isinstance(other, BitVec): - return Bool(cast(z3.BoolRef, self.raw != other), annotations=self.annotations) + return Bool( + cast(z3.BoolRef, self.raw != other), annotations=self.annotations + ) union = self.annotations + other.annotations # MYPY: fix complaints due to z3 overriding __eq__ return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) +class BitVecFunc(BitVec): + """A bit vector symbol.""" + + def __init__( + self, + raw: z3.BitVecRef, + name: str, + input: Union[int, "BitVec"] = None, + annotations: Optional[Annotations] = None, + ): + """ + + :param raw: + :param annotations: + :param input: + """ + + from mythril.laser.smt import symbol_factory + + self.symbol_factory = symbol_factory + + self.name = name + self.input = input + super().__init__(raw, annotations) + + def __add__(self, other: Union[int, "BitVec"]) -> "BitVec": + """Create an addition expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw + other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create a subtraction expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw - other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __mul__(self, other: "BitVec") -> "BitVecFunc": + """Create a multiplication expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw * other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __truediv__(self, other: "BitVec") -> "BitVecFunc": + """Create a signed division expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw / other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": + """Create an and expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw & other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __or__(self, other: "BitVec") -> "BitVecFunc": + """Create an or expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw | other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __xor__(self, other: "BitVec") -> "BitVec": + """Create a xor expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + raw = (self.raw ^ other.raw,) + union = self.annotations + other.annotations + + if isinstance(other, BitVecFunc): + # TODO: Find better value to set input and name to in this case + return BitVecFunc( + raw=raw, + name=self.name if self.name and self.name == other.name else None, + input=None, + annotations=union, + ) + + return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union) + + def __lt__(self, other: "BitVec") -> Bool: + """Create a signed less than expression. + + :param other: + :return: + """ + # Is there some hack for these comparisons? + + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(cast(z3.BoolRef, self.value < other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(False, annotations=union) + + return And( + Bool(cast(z3.BoolRef, self.raw < other.raw), annotations=union), + self.input != other.input, + ) + + def __gt__(self, other: "BitVec") -> Bool: + """Create a signed greater than expression. + + :param other: + :return: + """ + # Is there some hack for these comparisons? + + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(cast(z3.BoolRef, self.value > other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(False, annotations=union) + + return And( + Bool(cast(z3.BoolRef, self.raw > other.raw), annotations=union), + self.input != other.input, + ) + + # MYPY: fix complains about overriding __eq__ + def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore + """Create an equality expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(cast(z3.BoolRef, self.value == other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(cast(z3.BoolRef, False), annotations=union) + + # MYPY: fix complaints due to z3 overriding __eq__ + return And( + Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union), + self.input == other.input, + ) + + # MYPY: fix complains about overriding __ne__ + def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore + """Create an inequality expression. + + :param other: + :return: + """ + if not isinstance(other, BitVec): + other = BitVec(z3.BitVecVal(other, self.size())) + + union = self.annotations + other.annotations + + if not self.symbolic and not other.symbolic: + return Bool(cast(z3.BoolRef, self.value != other.value), annotations=union) + + if ( + not isinstance(other, BitVecFunc) + or not self.name + or not self.input + or not self.name == other.name + ): + return Bool(cast(z3.BoolRef, True), annotations=union) + + # MYPY: fix complaints due to z3 overriding __eq__ + return Or( + Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union), + self.input != other.input, + ) + + def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: """Create an if-then-else expression. @@ -176,6 +504,8 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi :param c: :return: """ + # TODO: Handle BitVecFunc + if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) if not isinstance(b, BitVec): @@ -193,17 +523,21 @@ def UGT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ + # TODO: Handle BitVecFunc + annotations = a.annotations + b.annotations return Bool(z3.UGT(a.raw, b.raw), annotations) -def UGE(a: BitVec, b:BitVec) -> Bool: +def UGE(a: BitVec, b: BitVec) -> Bool: """Create an unsigned greater or equals expression. :param a: :param b: :return: """ + # TODO: Handle BitVecFunc + annotations = a.annotations + b.annotations return Bool(z3.UGE(a.raw, b.raw), annotations) @@ -215,6 +549,8 @@ def ULT(a: BitVec, b: BitVec) -> Bool: :param b: :return: """ + # TODO: Handle BitVecFunc + annotations = a.annotations + b.annotations return Bool(z3.ULT(a.raw, b.raw), annotations) @@ -233,6 +569,7 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: :param args: :return: """ + # TODO: Handle BitVecFunc # The following statement is used if a list is provided as an argument to concat if len(args) == 1 and isinstance(args[0], list): @@ -255,6 +592,8 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec: :param bv: :return: """ + # TODO: Handle BitVecFunc + return BitVec(z3.Extract(high, low, bv.raw), annotations=bv.annotations) @@ -265,6 +604,8 @@ def URem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ + # TODO: Handle BitVecFunc + union = a.annotations + b.annotations return BitVec(z3.URem(a.raw, b.raw), annotations=union) @@ -276,6 +617,8 @@ def SRem(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ + # TODO: Handle BitVecFunc + union = a.annotations + b.annotations return BitVec(z3.SRem(a.raw, b.raw), annotations=union) @@ -287,6 +630,8 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec: :param b: :return: """ + # TODO: Handle BitVecFunc + union = a.annotations + b.annotations return BitVec(z3.UDiv(a.raw, b.raw), annotations=union) @@ -296,6 +641,8 @@ def Sum(*args: BitVec) -> BitVec: :return: """ + # TODO: Handle BitVecFunc + nraw = z3.Sum([a.raw for a in args]) annotations = [] # type: Annotations for bv in args: @@ -334,7 +681,9 @@ def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed)) -def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: +def BVSubNoUnderflow( + a: Union[BitVec, int], b: Union[BitVec, int], signed: bool +) -> Bool: """Creates predicate that verifies that the subtraction doesn't overflow. :param a: