Implement BitVecFunc for sha3

pull/901/head
Nathan 6 years ago
parent 98d0239360
commit 5384602d00
  1. 38
      mythril/laser/ethereum/instructions.py
  2. 56
      mythril/laser/smt/__init__.py
  3. 365
      mythril/laser/smt/bitvec.py

@ -890,29 +890,35 @@ class Instruction:
state.max_gas_used += max_gas
StateTransition.check_gas_usage_limit(global_state)
try:
state.mem_extend(index, length)
data = b"".join(
[
util.get_concrete_int(i).to_bytes(1, byteorder="big")
for i in state.memory[index : index + length]
data = [
b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8)
for b in state.memory[index : index + length]
]
)
if len(data) > 1:
data = simplify(Concat(data))
elif len(data) == 1:
data = data[0]
else:
# length is 0; this only matters for input of the BitVecFuncVal
data = symbol_factory.BitVecVal(0, 1)
except TypeError:
argument = str(state.memory[index]).replace(" ", "_")
if data.symbolic:
argument_str = str(state.memory[index]).replace(" ", "_")
result = symbol_factory.BitVecFuncSym(
"KECCAC[{}]".format(argument_str), "keccak256", 256, input=data
)
log.debug("Created BitVecFunc hash.")
result = symbol_factory.BitVecSym("KECCAC[{}]".format(argument), 256)
keccak_function_manager.add_keccak(result, state.memory[index])
state.stack.append(result)
return [global_state]
keccak = utils.sha3(utils.bytearray_to_bytestr(data))
else:
keccak = utils.sha3(data.value.to_bytes(length, byteorder="big"))
result = symbol_factory.BitVecFuncVal(
"keccak256", util.concrete_int_from_bytes(keccak, 0), 256, input=data
)
log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak)))
state.stack.append(
symbol_factory.BitVecVal(util.concrete_int_from_bytes(keccak, 0), 256)
)
state.stack.append(result)
return [global_state]
@StateTransition()

@ -1,5 +1,6 @@
from mythril.laser.smt.bitvec import (
BitVec,
BitVecFunc,
If,
UGT,
ULT,
@ -63,6 +64,37 @@ class SymbolFactory(Generic[T, U]):
"""
raise NotImplementedError()
@staticmethod
def BitVecFuncVal(
func_name: str,
value: int,
size: int,
annotations: Annotations = None,
input: Union[int, "BitVec"] = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a concrete value.
:param func_name: The name of the function
:param value: The concrete value to set the bit vector to
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:return: The freshly created bit vector
"""
raise NotImplementedError()
@staticmethod
def BitVecFuncSym(
name: str, func_name: str, size: int, annotations: Annotations = None
) -> U:
"""Creates a new bit vector with a symbolic value.
:param name: The name of the symbolic bit vector
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:return: The freshly created bit vector
"""
raise NotImplementedError()
class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]):
"""
@ -93,6 +125,30 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]):
raw = z3.BitVec(name, size)
return BitVec(raw, annotations)
@staticmethod
def BitVecFuncVal(
func_name: str,
value: int,
size: int,
annotations: Annotations = None,
input: Union[int, "BitVec"] = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a concrete value."""
raw = z3.BitVecVal(value, size)
return BitVecFunc(raw, func_name, input, annotations)
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input: Union[int, "BitVec"] = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value."""
raw = z3.BitVec(name, size)
return BitVecFunc(raw, func_name, input, annotations)
class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]):
"""

@ -4,7 +4,7 @@ from typing import Union, overload, List, cast, Any, Optional
import z3
from mythril.laser.smt.bool import Bool
from mythril.laser.smt.bool import Bool, And, Or
from mythril.laser.smt.expression import Expression
Annotations = List[Any]
@ -14,7 +14,7 @@ Annotations = List[Any]
class BitVec(Expression[z3.BitVecRef]):
"""A bit vector symbol."""
def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations]=None):
def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations] = None):
"""
:param raw:
@ -55,6 +55,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other + self
if isinstance(other, int):
return BitVec(self.raw + other, annotations=self.annotations)
@ -67,7 +69,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other - self
if isinstance(other, int):
return BitVec(self.raw - other, annotations=self.annotations)
@ -80,6 +83,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other * self
union = self.annotations + other.annotations
return BitVec(self.raw * other.raw, annotations=union)
@ -89,6 +94,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other / self
union = self.annotations + other.annotations
return BitVec(self.raw / other.raw, annotations=union)
@ -98,8 +105,10 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other & self
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, 256))
other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations + other.annotations
return BitVec(self.raw & other.raw, annotations=union)
@ -109,6 +118,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other | self
union = self.annotations + other.annotations
return BitVec(self.raw | other.raw, annotations=union)
@ -118,6 +129,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other ^ self
union = self.annotations + other.annotations
return BitVec(self.raw ^ other.raw, annotations=union)
@ -127,6 +140,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other > self
union = self.annotations + other.annotations
return Bool(self.raw < other.raw, annotations=union)
@ -136,6 +151,8 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other < self
union = self.annotations + other.annotations
return Bool(self.raw > other.raw, annotations=union)
@ -146,8 +163,12 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other == self
if not isinstance(other, BitVec):
return Bool(cast(z3.BoolRef, self.raw == other), annotations=self.annotations)
return Bool(
cast(z3.BoolRef, self.raw == other), annotations=self.annotations
)
union = self.annotations + other.annotations
# MYPY: fix complaints due to z3 overriding __eq__
@ -160,14 +181,321 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
if isinstance(other, BitVecFunc):
return other != self
if not isinstance(other, BitVec):
return Bool(cast(z3.BoolRef, self.raw != other), annotations=self.annotations)
return Bool(
cast(z3.BoolRef, self.raw != other), annotations=self.annotations
)
union = self.annotations + other.annotations
# MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union)
class BitVecFunc(BitVec):
"""A bit vector symbol."""
def __init__(
self,
raw: z3.BitVecRef,
name: str,
input: Union[int, "BitVec"] = None,
annotations: Optional[Annotations] = None,
):
"""
:param raw:
:param annotations:
:param input:
"""
from mythril.laser.smt import symbol_factory
self.symbol_factory = symbol_factory
self.name = name
self.input = input
super().__init__(raw, annotations)
def __add__(self, other: Union[int, "BitVec"]) -> "BitVec":
"""Create an addition expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
raw = (self.raw + other.raw,)
union = self.annotations + other.annotations
if isinstance(other, BitVecFunc):
# TODO: Find better value to set input and name to in this case
return BitVecFunc(
raw=raw,
name=self.name if self.name and self.name == other.name else None,
input=None,
annotations=union,
)
return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union)
def __sub__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a subtraction expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
raw = (self.raw - other.raw,)
union = self.annotations + other.annotations
if isinstance(other, BitVecFunc):
# TODO: Find better value to set input and name to in this case
return BitVecFunc(
raw=raw,
name=self.name if self.name and self.name == other.name else None,
input=None,
annotations=union,
)
return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union)
def __mul__(self, other: "BitVec") -> "BitVecFunc":
"""Create a multiplication expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
raw = (self.raw * other.raw,)
union = self.annotations + other.annotations
if isinstance(other, BitVecFunc):
# TODO: Find better value to set input and name to in this case
return BitVecFunc(
raw=raw,
name=self.name if self.name and self.name == other.name else None,
input=None,
annotations=union,
)
return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union)
def __truediv__(self, other: "BitVec") -> "BitVecFunc":
"""Create a signed division expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
raw = (self.raw / other.raw,)
union = self.annotations + other.annotations
if isinstance(other, BitVecFunc):
# TODO: Find better value to set input and name to in this case
return BitVecFunc(
raw=raw,
name=self.name if self.name and self.name == other.name else None,
input=None,
annotations=union,
)
return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union)
def __and__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an and expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
raw = (self.raw & other.raw,)
union = self.annotations + other.annotations
if isinstance(other, BitVecFunc):
# TODO: Find better value to set input and name to in this case
return BitVecFunc(
raw=raw,
name=self.name if self.name and self.name == other.name else None,
input=None,
annotations=union,
)
return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union)
def __or__(self, other: "BitVec") -> "BitVecFunc":
"""Create an or expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
raw = (self.raw | other.raw,)
union = self.annotations + other.annotations
if isinstance(other, BitVecFunc):
# TODO: Find better value to set input and name to in this case
return BitVecFunc(
raw=raw,
name=self.name if self.name and self.name == other.name else None,
input=None,
annotations=union,
)
return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union)
def __xor__(self, other: "BitVec") -> "BitVec":
"""Create a xor expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
raw = (self.raw ^ other.raw,)
union = self.annotations + other.annotations
if isinstance(other, BitVecFunc):
# TODO: Find better value to set input and name to in this case
return BitVecFunc(
raw=raw,
name=self.name if self.name and self.name == other.name else None,
input=None,
annotations=union,
)
return BitVecFunc(raw=raw, name=self.name, input=self.input, annotations=union)
def __lt__(self, other: "BitVec") -> Bool:
"""Create a signed less than expression.
:param other:
:return:
"""
# Is there some hack for these comparisons?
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations + other.annotations
if not self.symbolic and not other.symbolic:
return Bool(cast(z3.BoolRef, self.value < other.value), annotations=union)
if (
not isinstance(other, BitVecFunc)
or not self.name
or not self.input
or not self.name == other.name
):
return Bool(False, annotations=union)
return And(
Bool(cast(z3.BoolRef, self.raw < other.raw), annotations=union),
self.input != other.input,
)
def __gt__(self, other: "BitVec") -> Bool:
"""Create a signed greater than expression.
:param other:
:return:
"""
# Is there some hack for these comparisons?
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations + other.annotations
if not self.symbolic and not other.symbolic:
return Bool(cast(z3.BoolRef, self.value > other.value), annotations=union)
if (
not isinstance(other, BitVecFunc)
or not self.name
or not self.input
or not self.name == other.name
):
return Bool(False, annotations=union)
return And(
Bool(cast(z3.BoolRef, self.raw > other.raw), annotations=union),
self.input != other.input,
)
# MYPY: fix complains about overriding __eq__
def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an equality expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations + other.annotations
if not self.symbolic and not other.symbolic:
return Bool(cast(z3.BoolRef, self.value == other.value), annotations=union)
if (
not isinstance(other, BitVecFunc)
or not self.name
or not self.input
or not self.name == other.name
):
return Bool(cast(z3.BoolRef, False), annotations=union)
# MYPY: fix complaints due to z3 overriding __eq__
return And(
Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union),
self.input == other.input,
)
# MYPY: fix complains about overriding __ne__
def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an inequality expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations + other.annotations
if not self.symbolic and not other.symbolic:
return Bool(cast(z3.BoolRef, self.value != other.value), annotations=union)
if (
not isinstance(other, BitVecFunc)
or not self.name
or not self.input
or not self.name == other.name
):
return Bool(cast(z3.BoolRef, True), annotations=union)
# MYPY: fix complaints due to z3 overriding __eq__
return Or(
Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union),
self.input != other.input,
)
def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec:
"""Create an if-then-else expression.
@ -176,6 +504,8 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
:param c:
:return:
"""
# TODO: Handle BitVecFunc
if not isinstance(a, Bool):
a = Bool(z3.BoolVal(a))
if not isinstance(b, BitVec):
@ -193,17 +523,21 @@ def UGT(a: BitVec, b: BitVec) -> Bool:
:param b:
:return:
"""
# TODO: Handle BitVecFunc
annotations = a.annotations + b.annotations
return Bool(z3.UGT(a.raw, b.raw), annotations)
def UGE(a: BitVec, b:BitVec) -> Bool:
def UGE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned greater or equals expression.
:param a:
:param b:
:return:
"""
# TODO: Handle BitVecFunc
annotations = a.annotations + b.annotations
return Bool(z3.UGE(a.raw, b.raw), annotations)
@ -215,6 +549,8 @@ def ULT(a: BitVec, b: BitVec) -> Bool:
:param b:
:return:
"""
# TODO: Handle BitVecFunc
annotations = a.annotations + b.annotations
return Bool(z3.ULT(a.raw, b.raw), annotations)
@ -233,6 +569,7 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
:param args:
:return:
"""
# TODO: Handle BitVecFunc
# The following statement is used if a list is provided as an argument to concat
if len(args) == 1 and isinstance(args[0], list):
@ -255,6 +592,8 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
:param bv:
:return:
"""
# TODO: Handle BitVecFunc
return BitVec(z3.Extract(high, low, bv.raw), annotations=bv.annotations)
@ -265,6 +604,8 @@ def URem(a: BitVec, b: BitVec) -> BitVec:
:param b:
:return:
"""
# TODO: Handle BitVecFunc
union = a.annotations + b.annotations
return BitVec(z3.URem(a.raw, b.raw), annotations=union)
@ -276,6 +617,8 @@ def SRem(a: BitVec, b: BitVec) -> BitVec:
:param b:
:return:
"""
# TODO: Handle BitVecFunc
union = a.annotations + b.annotations
return BitVec(z3.SRem(a.raw, b.raw), annotations=union)
@ -287,6 +630,8 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec:
:param b:
:return:
"""
# TODO: Handle BitVecFunc
union = a.annotations + b.annotations
return BitVec(z3.UDiv(a.raw, b.raw), annotations=union)
@ -296,6 +641,8 @@ def Sum(*args: BitVec) -> BitVec:
:return:
"""
# TODO: Handle BitVecFunc
nraw = z3.Sum([a.raw for a in args])
annotations = [] # type: Annotations
for bv in args:
@ -334,7 +681,9 @@ def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool)
return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed))
def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool:
def BVSubNoUnderflow(
a: Union[BitVec, int], b: Union[BitVec, int], signed: bool
) -> Bool:
"""Creates predicate that verifies that the subtraction doesn't overflow.
:param a:

Loading…
Cancel
Save