Add z3 mypy stubs; Add types to mythril/laser/smt/*.py

pull/885/head
Dimitar Bounov 6 years ago
parent 7779c0d5f5
commit 4c1ddb9fbe
  1. 283
      mypy-stubs/z3/__init__.pyi
  2. 3
      mypy-stubs/z3/z3core.pyi
  3. 12
      mypy-stubs/z3/z3types.pyi
  4. 26
      mythril/laser/smt/__init__.py
  5. 10
      mythril/laser/smt/array.py
  6. 92
      mythril/laser/smt/bitvec.py
  7. 26
      mythril/laser/smt/bool.py
  8. 22
      mythril/laser/smt/expression.py
  9. 25
      mythril/laser/smt/solver.py

@ -0,0 +1,283 @@
from typing import overload, Tuple, Any, List, Iterable, Iterator, Optional, TypeVar, Union
from .z3types import Ast, ContextObj
class Context:
...
class Z3PPObject:
...
class AstRef(Z3PPObject):
@overload
def __init__(self, ast: Ast, ctx: Context) -> None:
self.ast: Ast = ...
self.ctx: Context= ...
@overload
def __init__(self, ast: Ast) -> None:
self.ast: Ast = ...
self.ctx: Context= ...
def ctx_ref(self) -> ContextObj: ...
def as_ast(self) -> Ast: ...
def children(self) -> List[AstRef]: ...
def eq(self, other: AstRef) -> bool: ...
# TODO: Cannot add __eq__ currently: mypy complains conflict with
# object.__eq__ signature
#def __eq__(self, other: object) -> ArithRef: ...
class SortRef(AstRef):
...
class FuncDeclRef(AstRef):
def arity(self) -> int: ...
def name(self) -> str: ...
def __call__(self, *args: ExprRef) -> ExprRef: ...
class ExprRef(AstRef):
def sort(self) -> SortRef: ...
def decl(self) -> FuncDeclRef: ...
class BoolSortRef(SortRef):
...
class ArraySortRef(SortRef):
...
class BoolRef(ExprRef):
...
def is_true(a: BoolRef) -> bool: ...
def is_false(a: BoolRef) -> bool: ...
def is_int_value(a: AstRef) -> bool: ...
def substitute(a: AstRef, *m: Tuple[AstRef, AstRef]) -> AstRef: ...
def simplify(a: AstRef, *args: Any, **kwargs: Any) -> AstRef: ...
class ArithSortRef(SortRef):
...
class ArithRef(ExprRef):
def __neg__(self) -> ExprRef: ...
def __le__(self, other: ArithRef) -> BoolRef: ...
def __lt__(self, other: ArithRef) -> BoolRef: ...
def __ge__(self, other: ArithRef) -> BoolRef: ...
def __gt__(self, other: ArithRef) -> BoolRef: ...
def __add__(self, other: ArithRef) -> ArithRef: ...
def __sub__(self, other: ArithRef) -> ArithRef: ...
def __mul__(self, other: ArithRef) -> ArithRef: ...
def __div__(self, other: ArithRef) -> ArithRef: ...
def __truediv__(self, other: ArithRef) -> ArithRef: ...
def __mod__(self, other: ArithRef) -> ArithRef: ...
class BitVecSortRef(SortRef):
...
class BitVecRef(ExprRef):
def size(self) -> int: ...
def __add__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __radd__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __mul__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __rmul__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __sub__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __rsub__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __or__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __ror__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __and__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __rand__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __xor__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __rxor__(self, other: Union[BitVecRef, int]) -> BitVecRef: ...
def __pos__(self) -> BitVecRef: ...
def __neg__(self) -> BitVecRef: ...
def __invert__(self) -> BitVecRef: ...
def __div__(self, other: BitVecRef) -> BitVecRef: ...
def __rdiv__(self, other: BitVecRef) -> BitVecRef: ...
def __truediv__(self, other: BitVecRef) -> BitVecRef: ...
def __rtruediv__(self, other: BitVecRef) -> BitVecRef: ...
def __mod__(self, other: BitVecRef) -> BitVecRef: ...
def __rmod__(self, other: BitVecRef) -> BitVecRef: ...
def __le__(self, other: BitVecRef) -> BoolRef: ...
def __lt__(self, other: BitVecRef) -> BoolRef: ...
def __ge__(self, other: BitVecRef) -> BoolRef: ...
def __gt__(self, other: BitVecRef) -> BoolRef: ...
def __rshift__(self, other: BitVecRef) -> BitVecRef: ...
def __lshift__(self, other: BitVecRef) -> BitVecRef: ...
def __rrshift__(self, other: BitVecRef) -> BitVecRef: ...
def __rlshift__(self, other: BitVecRef) -> BitVecRef: ...
class BitVecNumRef(BitVecRef):
def as_long(self) -> int: ...
def as_signed_long(self) -> int: ...
def as_string(self) -> str: ...
class IntNumRef(ArithRef):
def as_long(self) -> int: ...
def as_string(self) -> str: ...
class SeqSortRef(ExprRef):
...
class SeqRef(ExprRef):
...
class ReSortRef(ExprRef):
...
class ReRef(ExprRef):
...
class ArrayRef(ExprRef):
...
class CheckSatResult: ...
class ModelRef(Z3PPObject):
def __getitem__(self, k: FuncDeclRef) -> IntNumRef: ...
def decls(self) -> Iterable[FuncDeclRef]: ...
def __iter__(self) -> Iterator[FuncDeclRef]: ...
class FuncEntry:
def num_args(self) -> int: ...
def arg_value(self, idx: int) -> ExprRef: ...
def value(self) -> ExprRef: ...
class FuncInterp(Z3PPObject):
def else_value(self) -> ExprRef: ...
def num_entries(self) -> int: ...
def arity(self) -> int: ...
def entry(self, idx: int) -> FuncEntry: ...
class Goal(Z3PPObject):
...
class Solver(Z3PPObject):
ctx: Context
def __init__(self, ctx:Optional[Context] = None) -> None: ...
def to_smt2(self) -> str: ...
def check(self) -> CheckSatResult: ...
def push(self) -> None: ...
def pop(self, num: Optional[int] = 1) -> None: ...
def model(self) -> ModelRef: ...
def set(self, *args: Any, **kwargs: Any) -> None: ...
@overload
def add(self, *args: Union[BoolRef, Goal]) -> None: ...
@overload
def add(self, args: List[Union[BoolRef, Goal]]) -> None: ...
def reset(self) -> None: ...
class Optimize(Z3PPObject):
ctx: Context
def __init__(self, ctx:Optional[Context] = None) -> None: ...
def check(self) -> CheckSatResult: ...
def push(self) -> None: ...
def pop(self) -> None: ...
def model(self) -> ModelRef: ...
def set(self, *args: Any, **kwargs: Any) -> None: ...
@overload
def add(self, *args: Union[BoolRef, Goal]) -> None: ...
@overload
def add(self, args: List[Union[BoolRef, Goal]]) -> None: ...
def minimize(self, element: ExprRef) -> None: ...
def maximize(self, element: ExprRef) -> None: ...
sat: CheckSatResult = ...
unsat: CheckSatResult = ...
@overload
def Int(name: str) -> ArithRef: ...
@overload
def Int(name: str, ctx: Context) -> ArithRef: ...
@overload
def Bool(name: str) -> BoolRef: ...
@overload
def Bool(name: str, ctx: Context) -> BoolRef: ...
@overload
def parse_smt2_string(s: str) -> ExprRef: ...
@overload
def parse_smt2_string(s: str, ctx: Context) -> ExprRef: ...
def Array(name: str, domain: SortRef, range: SortRef) -> ArrayRef: ...
def K(domain: SortRef, v: Union[ExprRef, int, bool, str]) -> ArrayRef: ...
# Can't give more precise types here since func signature is
# a vararg list of ExprRef optionally followed by a Context
def Or(*args: Any) -> BoolRef: ...
def And(*args: Any) -> BoolRef: ...
def Not(p: BoolRef, ctx: Optional[Context] = None) -> BoolRef: ...
def Implies(a: BoolRef, b: BoolRef, ctx:Context) -> BoolRef: ...
T=TypeVar("T", bound=ExprRef)
def If(a: BoolRef, b: T, c: T, ctx: Optional[Context] = None) -> T: ...
def ULE(a: T, b: T) -> BoolRef: ...
def ULT(a: T, b: T) -> BoolRef: ...
def UGE(a: T, b: T) -> BoolRef: ...
def UGT(a: T, b: T) -> BoolRef: ...
def UDiv(a: T, b: T) -> T: ...
def URem(a: T, b: T) -> T: ...
def SRem(a: T, b: T) -> T: ...
def LShR(a: T, b: T) -> T: ...
def RotateLeft(a: T, b: T) -> T: ...
def RotateRight(a: T, b: T) -> T: ...
def SignExt(n: int, a: BitVecRef) -> BitVecRef: ...
def ZeroExt(n: int, a: BitVecRef) -> BitVecRef: ...
@overload
def Concat(args: List[Union[SeqRef, str]]) -> SeqRef: ...
@overload
def Concat(*args: Union[SeqRef, str]) -> SeqRef: ...
@overload
def Concat(args: List[ReRef]) -> ReRef: ...
@overload
def Concat(*args: ReRef) -> ReRef: ...
@overload
def Concat(args: List[BitVecRef]) -> BitVecRef: ...
@overload
def Concat(*args: BitVecRef) -> BitVecRef: ...
@overload
def Extract(high: Union[SeqRef], lo: Union[int, ArithRef], a: Union[int, ArithRef]) -> SeqRef: ...
@overload
def Extract(high: Union[int, ArithRef], lo: Union[int, ArithRef], a: BitVecRef) -> BitVecRef: ...
@overload
def Sum(arg: BitVecRef, *args: Union[BitVecRef, int]) -> BitVecRef: ...
@overload
def Sum(arg: Union[List[BitVecRef], int]) -> BitVecRef: ...
@overload
def Sum(arg: ArithRef, *args: Union[ArithRef, int]) -> ArithRef: ...
# Can't include this overload as it overlaps with the second overload.
#@overload
#def Sum(arg: Union[List[ArithRef], int]) -> ArithRef: ...
def Function(name: str, *sig: SortRef) -> FuncDeclRef: ...
def IntVal(val: int, ctx: Optional[Context] = None) -> IntNumRef: ...
def BoolVal(val: bool, ctx: Optional[Context] = None) -> BoolRef: ...
def BitVecVal(val: int, bv: Union[int, BitVecSortRef], ctx: Optional[Context] = None) -> BitVecRef: ...
def BitVec(val: str, bv: Union[int, BitVecSortRef], ctx: Optional[Context] = None) -> BitVecRef: ...
def IntSort(ctx: Optional[Context] = None) -> ArithSortRef: ...
def BoolSort(ctx:Optional[Context] = None) -> BoolSortRef: ...
def ArraySort(domain: SortRef, range: SortRef) -> ArraySortRef: ...
def BitVecSort(domain: int, ctx:Optional[Context] = None) -> BoolSortRef: ...
def ForAll(vs: List[ExprRef], expr: ExprRef) -> ExprRef: ...
def Select(arr: ExprRef, ind: ExprRef) -> ExprRef: ...
def Update(arr: ArrayRef, ind: ExprRef, newVal: ExprRef) -> ArrayRef: ...
def Store(arr: ArrayRef, ind: ExprRef, newVal: ExprRef) -> ArrayRef: ...
def BVAddNoOverflow(a: BitVecRef, b: BitVecRef, signed: bool) -> BoolRef: ...
def BVAddNoUnderflow(a: BitVecRef, b: BitVecRef) -> BoolRef: ...
def BVSubNoOverflow(a: BitVecRef, b: BitVecRef) -> BoolRef: ...
def BVSubNoUnderflow(a: BitVecRef, b: BitVecRef, signed: bool) -> BoolRef: ...
def BVSDivNoOverflow(a: BitVecRef, b: BitVecRef) -> BoolRef: ...
def BVSNegNoOverflow(a: BitVecRef) -> BoolRef: ...
def BVMulNoOverflow(a: BitVecRef, b: BitVecRef, signed: bool) -> BoolRef: ...
def BVMulNoUnderflow(a: BitVecRef, b: BitVecRef) -> BoolRef: ...

@ -0,0 +1,3 @@
from .z3types import Ast, ContextObj
def Z3_mk_eq(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ...
def Z3_mk_div(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ...

@ -0,0 +1,12 @@
from typing import Any
class Z3Exception(Exception):
def __init__(self, a: Any) -> None:
self.value = a
...
class ContextObj:
...
class Ast:
...

@ -19,14 +19,18 @@ from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not
from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.solver import Solver, Optimize
from typing import Union, Any, Optional, List
import z3
Annotations = Optional[List[Any]]
class SymbolFactory:
"""A symbol factory provides a default interface for all the components of mythril to create symbols"""
@staticmethod
def Bool(value: bool, annotations=None):
def Bool(value: __builtins__.bool, annotations: Annotations=None) -> Union[bool.Bool, z3.BoolRef]:
"""
Creates a Bool with concrete value
:param value: The boolean value
@ -36,7 +40,7 @@ class SymbolFactory:
raise NotImplementedError
@staticmethod
def BitVecVal(value: int, size: int, annotations=None):
def BitVecVal(value: int, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]:
"""Creates a new bit vector with a concrete value.
:param value: The concrete value to set the bit vector to
@ -47,7 +51,7 @@ class SymbolFactory:
raise NotImplementedError()
@staticmethod
def BitVecSym(name: str, size: int, annotations=None):
def BitVecSym(name: str, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]:
"""Creates a new bit vector with a symbolic value.
:param name: The name of the symbolic bit vector
@ -65,24 +69,24 @@ class _SmtSymbolFactory(SymbolFactory):
"""
@staticmethod
def Bool(value: bool, annotations=None):
def Bool(value: __builtins__.bool, annotations: Annotations=None) -> Union[bool.Bool, z3.BoolRef]:
"""
Creates a Bool with concrete value
:param value: The boolean value
:param annotations: The annotations to initialize the bool with
:return: The freshly created Bool()
"""
raw = z3.Bool(value)
raw = z3.BoolVal(value)
return Bool(raw, annotations)
@staticmethod
def BitVecVal(value: int, size: int, annotations=None):
def BitVecVal(value: int, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]:
"""Creates a new bit vector with a concrete value."""
raw = z3.BitVecVal(value, size)
return BitVec(raw, annotations)
@staticmethod
def BitVecSym(name: str, size: int, annotations=None):
def BitVecSym(name: str, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]:
"""Creates a new bit vector with a symbolic value."""
raw = z3.BitVec(name, size)
return BitVec(raw, annotations)
@ -95,17 +99,17 @@ class _Z3SymbolFactory(SymbolFactory):
"""
@staticmethod
def Bool(value: bool, annotations=None):
def Bool(value: __builtins__.bool, annotations: Annotations=None) -> Union[bool.Bool, z3.BoolRef]:
""" Creates a new bit vector with a concrete value """
return z3.Bool(value)
return z3.BoolVal(value)
@staticmethod
def BitVecVal(value: int, size: int, annotations=None):
def BitVecVal(value: int, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]:
"""Creates a new bit vector with a concrete value."""
return z3.BitVecVal(value, size)
@staticmethod
def BitVecSym(name: str, size: int, annotations=None):
def BitVecSym(name: str, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]:
"""Creates a new bit vector with a symbolic value."""
return z3.BitVec(name, size)

@ -5,6 +5,7 @@ operations, as well as as a K-array, which can be initialized with
default values over a certain range.
"""
from typing import cast
import z3
from mythril.laser.smt.bitvec import BitVec
@ -12,16 +13,19 @@ from mythril.laser.smt.bitvec import BitVec
class BaseArray:
"""Base array type, which implements basic store and set operations."""
domain: z3.SortRef
range: z3.SortRef
raw: z3.ArrayRef
def __getitem__(self, item: BitVec):
def __getitem__(self, item: BitVec) -> BitVec:
"""Gets item from the array, item can be symbolic."""
if isinstance(item, slice):
raise ValueError(
"Instance of BaseArray, does not support getitem with slices"
)
return BitVec(z3.Select(self.raw, item.raw))
return BitVec(cast(z3.BitVecRef, z3.Select(self.raw, item.raw)))
def __setitem__(self, key: BitVec, value: BitVec):
def __setitem__(self, key: BitVec, value: BitVec) -> None:
"""Sets an item in the array, key can be symbolic."""
self.raw = z3.Store(self.raw, key.raw, value.raw)

@ -1,20 +1,20 @@
"""This module provides classes for an SMT abstraction of bit vectors."""
from typing import Union
from typing import Union, overload, List, cast, Any, Optional
import z3
from mythril.laser.smt.bool import Bool
from mythril.laser.smt.expression import Expression
Annotations = List[Any]
# fmt: off
class BitVec(Expression):
class BitVec(Expression[z3.BitVecRef]):
"""A bit vector symbol."""
def __init__(self, raw, annotations=None):
def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations]=None):
"""
:param raw:
@ -22,7 +22,7 @@ class BitVec(Expression):
"""
super().__init__(raw, annotations)
def size(self):
def size(self) -> int:
"""TODO: documentation
:return:
@ -30,7 +30,7 @@ class BitVec(Expression):
return self.raw.size()
@property
def symbolic(self):
def symbolic(self) -> bool:
"""Returns whether this symbol doesn't have a concrete value.
:return:
@ -39,7 +39,7 @@ class BitVec(Expression):
return not isinstance(self.raw, z3.BitVecNumRef)
@property
def value(self):
def value(self) -> Optional[int]:
"""Returns the value of this symbol if concrete, otherwise None.
:return:
@ -49,7 +49,7 @@ class BitVec(Expression):
assert isinstance(self.raw, z3.BitVecNumRef)
return self.raw.as_long()
def __add__(self, other) -> "BitVec":
def __add__(self, other: Union[int, BitVec]) -> "BitVec":
"""Create an addition expression.
:param other:
@ -61,7 +61,7 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return BitVec(self.raw + other.raw, annotations=union)
def __sub__(self, other) -> "BitVec":
def __sub__(self, other: Union[int, BitVec]) -> "BitVec":
"""Create a subtraction expression.
:param other:
@ -74,7 +74,7 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return BitVec(self.raw - other.raw, annotations=union)
def __mul__(self, other) -> "BitVec":
def __mul__(self, other: BitVec) -> "BitVec":
"""Create a multiplication expression.
:param other:
@ -83,7 +83,7 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return BitVec(self.raw * other.raw, annotations=union)
def __truediv__(self, other) -> "BitVec":
def __truediv__(self, other: BitVec) -> "BitVec":
"""Create a signed division expression.
:param other:
@ -92,7 +92,7 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return BitVec(self.raw / other.raw, annotations=union)
def __and__(self, other) -> "BitVec":
def __and__(self, other: Union[int, BitVec]) -> "BitVec":
"""Create an and expression.
:param other:
@ -103,7 +103,7 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return BitVec(self.raw & other.raw, annotations=union)
def __or__(self, other) -> "BitVec":
def __or__(self, other: BitVec) -> "BitVec":
"""Create an or expression.
:param other:
@ -112,7 +112,7 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return BitVec(self.raw | other.raw, annotations=union)
def __xor__(self, other) -> "BitVec":
def __xor__(self, other: BitVec) -> "BitVec":
"""Create a xor expression.
:param other:
@ -121,7 +121,7 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return BitVec(self.raw ^ other.raw, annotations=union)
def __lt__(self, other) -> Bool:
def __lt__(self, other: BitVec) -> Bool:
"""Create a signed less than expression.
:param other:
@ -130,7 +130,7 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return Bool(self.raw < other.raw, annotations=union)
def __gt__(self, other) -> Bool:
def __gt__(self, other: BitVec) -> Bool:
"""Create a signed greater than expression.
:param other:
@ -139,32 +139,36 @@ class BitVec(Expression):
union = self.annotations + other.annotations
return Bool(self.raw > other.raw, annotations=union)
def __eq__(self, other) -> Bool:
#MYPY: fix complains about overriding __eq__
def __eq__(self, other: Union[int, BitVec]) -> Bool: # type: ignore
"""Create an equality expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
return Bool(self.raw == other, annotations=self.annotations)
return Bool(cast(z3.BoolRef, self.raw == other), annotations=self.annotations)
union = self.annotations + other.annotations
return Bool(self.raw == other.raw, annotations=union)
# MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union)
def __ne__(self, other) -> Bool:
#MYPY: fix complains about overriding __ne__
def __ne__(self, other: Union[int, BitVec]) -> Bool: # type: ignore
"""Create an inequality expression.
:param other:
:return:
"""
if not isinstance(other, BitVec):
return Bool(self.raw != other, annotations=self.annotations)
return Bool(cast(z3.BoolRef, self.raw != other), annotations=self.annotations)
union = self.annotations + other.annotations
return Bool(self.raw != other.raw, annotations=union)
# MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union)
def If(a: Bool, b: BitVec, c: BitVec):
def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec:
"""Create an if-then-else expression.
:param a:
@ -172,11 +176,11 @@ def If(a: Bool, b: BitVec, c: BitVec):
:param c:
:return:
"""
if not isinstance(a, Expression):
if not isinstance(a, Bool):
a = Bool(z3.BoolVal(a))
if not isinstance(b, Expression):
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
if not isinstance(c, Expression):
if not isinstance(c, BitVec):
c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations + b.annotations + c.annotations
return BitVec(z3.If(a.raw, b.raw, c.raw), union)
@ -215,7 +219,15 @@ def ULT(a: BitVec, b: BitVec) -> Bool:
return Bool(z3.ULT(a.raw, b.raw), annotations)
def Concat(*args) -> BitVec:
@overload
def Concat(*args: List[BitVec]) -> BitVec: ...
@overload
def Concat(*args: BitVec) -> BitVec: ...
def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
"""Create a concatenation expression.
:param args:
@ -224,11 +236,13 @@ def Concat(*args) -> BitVec:
# The following statement is used if a list is provided as an argument to concat
if len(args) == 1 and isinstance(args[0], list):
args = args[0]
bvs = args[0] # type: List[BitVec]
else:
bvs = cast(List[BitVec], args)
nraw = z3.Concat([a.raw for a in args])
annotations = []
for bv in args:
nraw = z3.Concat([a.raw for a in bvs])
annotations = [] # type: Annotations
for bv in bvs:
annotations += bv.annotations
return BitVec(nraw, annotations)
@ -277,13 +291,13 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec:
return BitVec(z3.UDiv(a.raw, b.raw), annotations=union)
def Sum(*args) -> BitVec:
def Sum(*args: BitVec) -> BitVec:
"""Create sum expression.
:return:
"""
nraw = z3.Sum([a.raw for a in args])
annotations = []
annotations = [] # type: Annotations
for bv in args:
annotations += bv.annotations
return BitVec(nraw, annotations)
@ -297,9 +311,9 @@ def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool)
:param signed:
:return:
"""
if not isinstance(a, Expression):
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, Expression):
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVAddNoOverflow(a.raw, b.raw, signed))
@ -313,9 +327,9 @@ def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool)
:param signed:
:return:
"""
if not isinstance(a, Expression):
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, Expression):
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed))
@ -328,9 +342,9 @@ def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool)
:param signed:
:return:
"""
if not isinstance(a, Expression):
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, Expression):
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed))

@ -1,7 +1,7 @@
"""This module provides classes for an SMT abstraction of boolean
expressions."""
from typing import Union
from typing import Union, cast
import z3
@ -11,7 +11,7 @@ from mythril.laser.smt.expression import Expression
# fmt: off
class Bool(Expression):
class Bool(Expression[z3.BoolRef]):
"""This is a Bool expression."""
@property
@ -46,27 +46,31 @@ class Bool(Expression):
else:
return None
def __eq__(self, other):
#MYPY: complains about overloading __eq__ # noqa
def __eq__(self, other: object) -> Bool: # type: ignore
"""
:param other:
:return:
"""
if isinstance(other, Expression):
return Bool(self.raw == other.raw, self.annotations + other.annotations)
return Bool(self.raw == other, self.annotations)
return Bool(cast(z3.BoolRef, self.raw == other.raw),
self.annotations + other.annotations)
return Bool(cast(z3.BoolRef, self.raw == other), self.annotations)
def __ne__(self, other):
#MYPY: complains about overloading __ne__ # noqa
def __ne__(self, other: object) -> Bool: # type: ignore
"""
:param other:
:return:
"""
if isinstance(other, Expression):
return Bool(self.raw != other.raw, self.annotations + other.annotations)
return Bool(self.raw != other, self.annotations)
return Bool(cast(z3.BoolRef, self.raw != other.raw),
self.annotations + other.annotations)
return Bool(cast(z3.BoolRef, self.raw != other), self.annotations)
def __bool__(self):
def __bool__(self) -> bool:
"""
:return:
@ -77,7 +81,7 @@ class Bool(Expression):
return False
def Or(a: Bool, b: Bool):
def Or(a: Bool, b: Bool) -> Bool:
"""Create an or expression.
:param a:
@ -88,7 +92,7 @@ def Or(a: Bool, b: Bool):
return Bool(z3.Or(a.raw, b.raw), annotations=union)
def Not(a: Bool):
def Not(a: Bool) -> Bool:
"""Create a Not expression.
:param a:

@ -1,13 +1,17 @@
"""This module contains the SMT abstraction for a basic symbol expression."""
from typing import Optional, List, Any, TypeVar, Generic, cast
import z3
class Expression:
Annotations=List[Any]
T = TypeVar('T', bound=z3.ExprRef)
class Expression(Generic[T]):
"""This is the base symbol class and maintains functionality for
simplification and annotations."""
def __init__(self, raw, annotations=None):
def __init__(self, raw: T, annotations: Optional[Annotations]=None):
"""
:param raw:
@ -17,14 +21,14 @@ class Expression:
self._annotations = annotations or []
@property
def annotations(self):
def annotations(self) -> Annotations:
"""Gets the annotations for this expression.
:return:
"""
return self._annotations
def annotate(self, annotation):
def annotate(self, annotation: Any) -> None:
"""Annotates this expression with the given annotation.
:param annotation:
@ -34,15 +38,15 @@ class Expression:
else:
self._annotations.append(annotation)
def simplify(self):
def simplify(self) -> None:
"""Simplify this expression."""
self.raw = z3.simplify(self.raw)
self.raw = cast(T, z3.simplify(self.raw))
def __repr__(self):
def __repr__(self) -> str:
return repr(self.raw)
def simplify(expression: Expression):
def simplify(expression: Expression) -> Expression:
"""Simplify the expression .
:param expression:

@ -1,5 +1,6 @@
"""This module contains an abstract SMT representation of an SMT solver."""
import z3
from typing import Union, cast
from mythril.laser.smt.expression import Expression
@ -7,9 +8,9 @@ from mythril.laser.smt.expression import Expression
class Solver:
"""An SMT solver object."""
def __init__(self):
def __init__(self) -> None:
""""""
self.raw = z3.Solver()
self.raw = z3.Solver() # type: Union[z3.Solver, z3.Optimize]
def set_timeout(self, timeout: int) -> None:
"""Sets the timeout that will be used by this solver, timeout is in
@ -43,14 +44,14 @@ class Solver:
constraints = [c.raw for c in constraints]
self.raw.add(constraints)
def check(self):
def check(self) -> z3.CheckSatResult:
"""Returns z3 smt check result.
:return:
"""
return self.raw.check()
def model(self):
def model(self) -> z3.ModelRef:
"""Returns z3 model for a solution.
:return:
@ -59,34 +60,34 @@ class Solver:
def reset(self) -> None:
"""Reset this solver."""
self.raw.reset()
cast(z3.Solver, self.raw).reset()
def pop(self, num) -> None:
def pop(self, num: int) -> None:
"""Pop num constraints from this solver.
:param num:
"""
self.raw.pop(num)
cast(z3.Solver, self.raw).pop(num)
class Optimize(Solver):
"""An optimizing smt solver."""
def __init__(self):
def __init__(self) -> None:
"""Create a new optimizing solver instance."""
super().__init__()
self.raw = z3.Optimize()
def minimize(self, element: Expression):
def minimize(self, element: Expression) -> None:
"""In solving this solver will try to minimize the passed expression.
:param element:
"""
self.raw.minimize(element.raw)
cast(z3.Optimize, self.raw).minimize(element.raw)
def maximize(self, element: Expression):
def maximize(self, element: Expression) -> None:
"""In solving this solver will try to maximize the passed expression.
:param element:
"""
self.raw.maximize(element.raw)
cast(z3.Optimize, self.raw).maximize(element.raw)

Loading…
Cancel
Save