From 4c1ddb9fbe6638fb1dd9c751ad87905b358ede37 Mon Sep 17 00:00:00 2001 From: Dimitar Bounov Date: Mon, 14 Jan 2019 17:54:07 +0200 Subject: [PATCH] Add z3 mypy stubs; Add types to mythril/laser/smt/*.py --- mypy-stubs/z3/__init__.pyi | 283 ++++++++++++++++++++++++++++++++ mypy-stubs/z3/z3core.pyi | 3 + mypy-stubs/z3/z3types.pyi | 12 ++ mythril/laser/smt/__init__.py | 26 +-- mythril/laser/smt/array.py | 10 +- mythril/laser/smt/bitvec.py | 92 ++++++----- mythril/laser/smt/bool.py | 26 +-- mythril/laser/smt/expression.py | 22 ++- mythril/laser/smt/solver.py | 25 +-- 9 files changed, 414 insertions(+), 85 deletions(-) create mode 100644 mypy-stubs/z3/__init__.pyi create mode 100644 mypy-stubs/z3/z3core.pyi create mode 100644 mypy-stubs/z3/z3types.pyi diff --git a/mypy-stubs/z3/__init__.pyi b/mypy-stubs/z3/__init__.pyi new file mode 100644 index 00000000..202a6835 --- /dev/null +++ b/mypy-stubs/z3/__init__.pyi @@ -0,0 +1,283 @@ +from typing import overload, Tuple, Any, List, Iterable, Iterator, Optional, TypeVar, Union +from .z3types import Ast, ContextObj + +class Context: + ... + +class Z3PPObject: + ... + +class AstRef(Z3PPObject): + @overload + def __init__(self, ast: Ast, ctx: Context) -> None: + self.ast: Ast = ... + self.ctx: Context= ... + + @overload + def __init__(self, ast: Ast) -> None: + self.ast: Ast = ... + self.ctx: Context= ... + def ctx_ref(self) -> ContextObj: ... + def as_ast(self) -> Ast: ... + def children(self) -> List[AstRef]: ... + def eq(self, other: AstRef) -> bool: ... + # TODO: Cannot add __eq__ currently: mypy complains conflict with + # object.__eq__ signature + #def __eq__(self, other: object) -> ArithRef: ... + +class SortRef(AstRef): + ... + +class FuncDeclRef(AstRef): + def arity(self) -> int: ... + def name(self) -> str: ... + def __call__(self, *args: ExprRef) -> ExprRef: ... + +class ExprRef(AstRef): + def sort(self) -> SortRef: ... + def decl(self) -> FuncDeclRef: ... + +class BoolSortRef(SortRef): + ... + +class ArraySortRef(SortRef): + ... + +class BoolRef(ExprRef): + ... + + +def is_true(a: BoolRef) -> bool: ... +def is_false(a: BoolRef) -> bool: ... +def is_int_value(a: AstRef) -> bool: ... +def substitute(a: AstRef, *m: Tuple[AstRef, AstRef]) -> AstRef: ... +def simplify(a: AstRef, *args: Any, **kwargs: Any) -> AstRef: ... + + +class ArithSortRef(SortRef): + ... + +class ArithRef(ExprRef): + def __neg__(self) -> ExprRef: ... + def __le__(self, other: ArithRef) -> BoolRef: ... + def __lt__(self, other: ArithRef) -> BoolRef: ... + def __ge__(self, other: ArithRef) -> BoolRef: ... + def __gt__(self, other: ArithRef) -> BoolRef: ... + def __add__(self, other: ArithRef) -> ArithRef: ... + def __sub__(self, other: ArithRef) -> ArithRef: ... + def __mul__(self, other: ArithRef) -> ArithRef: ... + def __div__(self, other: ArithRef) -> ArithRef: ... + def __truediv__(self, other: ArithRef) -> ArithRef: ... + def __mod__(self, other: ArithRef) -> ArithRef: ... + +class BitVecSortRef(SortRef): + ... + +class BitVecRef(ExprRef): + def size(self) -> int: ... + def __add__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __radd__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __mul__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __rmul__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __sub__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __rsub__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + + def __or__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __ror__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __and__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __rand__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __xor__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __rxor__(self, other: Union[BitVecRef, int]) -> BitVecRef: ... + def __pos__(self) -> BitVecRef: ... + def __neg__(self) -> BitVecRef: ... + def __invert__(self) -> BitVecRef: ... + def __div__(self, other: BitVecRef) -> BitVecRef: ... + def __rdiv__(self, other: BitVecRef) -> BitVecRef: ... + def __truediv__(self, other: BitVecRef) -> BitVecRef: ... + def __rtruediv__(self, other: BitVecRef) -> BitVecRef: ... + def __mod__(self, other: BitVecRef) -> BitVecRef: ... + def __rmod__(self, other: BitVecRef) -> BitVecRef: ... + + def __le__(self, other: BitVecRef) -> BoolRef: ... + def __lt__(self, other: BitVecRef) -> BoolRef: ... + def __ge__(self, other: BitVecRef) -> BoolRef: ... + def __gt__(self, other: BitVecRef) -> BoolRef: ... + + + def __rshift__(self, other: BitVecRef) -> BitVecRef: ... + def __lshift__(self, other: BitVecRef) -> BitVecRef: ... + def __rrshift__(self, other: BitVecRef) -> BitVecRef: ... + def __rlshift__(self, other: BitVecRef) -> BitVecRef: ... + +class BitVecNumRef(BitVecRef): + def as_long(self) -> int: ... + def as_signed_long(self) -> int: ... + def as_string(self) -> str: ... + +class IntNumRef(ArithRef): + def as_long(self) -> int: ... + def as_string(self) -> str: ... + +class SeqSortRef(ExprRef): + ... + +class SeqRef(ExprRef): + ... + +class ReSortRef(ExprRef): + ... + +class ReRef(ExprRef): + ... + +class ArrayRef(ExprRef): + ... + +class CheckSatResult: ... + +class ModelRef(Z3PPObject): + def __getitem__(self, k: FuncDeclRef) -> IntNumRef: ... + def decls(self) -> Iterable[FuncDeclRef]: ... + def __iter__(self) -> Iterator[FuncDeclRef]: ... + +class FuncEntry: + def num_args(self) -> int: ... + def arg_value(self, idx: int) -> ExprRef: ... + def value(self) -> ExprRef: ... + +class FuncInterp(Z3PPObject): + def else_value(self) -> ExprRef: ... + def num_entries(self) -> int: ... + def arity(self) -> int: ... + def entry(self, idx: int) -> FuncEntry: ... + +class Goal(Z3PPObject): + ... + +class Solver(Z3PPObject): + ctx: Context + def __init__(self, ctx:Optional[Context] = None) -> None: ... + + def to_smt2(self) -> str: ... + def check(self) -> CheckSatResult: ... + def push(self) -> None: ... + def pop(self, num: Optional[int] = 1) -> None: ... + def model(self) -> ModelRef: ... + def set(self, *args: Any, **kwargs: Any) -> None: ... + @overload + def add(self, *args: Union[BoolRef, Goal]) -> None: ... + @overload + def add(self, args: List[Union[BoolRef, Goal]]) -> None: ... + def reset(self) -> None: ... + +class Optimize(Z3PPObject): + ctx: Context + def __init__(self, ctx:Optional[Context] = None) -> None: ... + + def check(self) -> CheckSatResult: ... + def push(self) -> None: ... + def pop(self) -> None: ... + def model(self) -> ModelRef: ... + def set(self, *args: Any, **kwargs: Any) -> None: ... + @overload + def add(self, *args: Union[BoolRef, Goal]) -> None: ... + @overload + def add(self, args: List[Union[BoolRef, Goal]]) -> None: ... + def minimize(self, element: ExprRef) -> None: ... + def maximize(self, element: ExprRef) -> None: ... + +sat: CheckSatResult = ... +unsat: CheckSatResult = ... + +@overload +def Int(name: str) -> ArithRef: ... +@overload +def Int(name: str, ctx: Context) -> ArithRef: ... + +@overload +def Bool(name: str) -> BoolRef: ... +@overload +def Bool(name: str, ctx: Context) -> BoolRef: ... + +@overload +def parse_smt2_string(s: str) -> ExprRef: ... +@overload +def parse_smt2_string(s: str, ctx: Context) -> ExprRef: ... + +def Array(name: str, domain: SortRef, range: SortRef) -> ArrayRef: ... +def K(domain: SortRef, v: Union[ExprRef, int, bool, str]) -> ArrayRef: ... + +# Can't give more precise types here since func signature is +# a vararg list of ExprRef optionally followed by a Context +def Or(*args: Any) -> BoolRef: ... +def And(*args: Any) -> BoolRef: ... +def Not(p: BoolRef, ctx: Optional[Context] = None) -> BoolRef: ... +def Implies(a: BoolRef, b: BoolRef, ctx:Context) -> BoolRef: ... + +T=TypeVar("T", bound=ExprRef) +def If(a: BoolRef, b: T, c: T, ctx: Optional[Context] = None) -> T: ... +def ULE(a: T, b: T) -> BoolRef: ... +def ULT(a: T, b: T) -> BoolRef: ... +def UGE(a: T, b: T) -> BoolRef: ... +def UGT(a: T, b: T) -> BoolRef: ... +def UDiv(a: T, b: T) -> T: ... +def URem(a: T, b: T) -> T: ... +def SRem(a: T, b: T) -> T: ... +def LShR(a: T, b: T) -> T: ... +def RotateLeft(a: T, b: T) -> T: ... +def RotateRight(a: T, b: T) -> T: ... +def SignExt(n: int, a: BitVecRef) -> BitVecRef: ... +def ZeroExt(n: int, a: BitVecRef) -> BitVecRef: ... + +@overload +def Concat(args: List[Union[SeqRef, str]]) -> SeqRef: ... +@overload +def Concat(*args: Union[SeqRef, str]) -> SeqRef: ... +@overload +def Concat(args: List[ReRef]) -> ReRef: ... +@overload +def Concat(*args: ReRef) -> ReRef: ... +@overload +def Concat(args: List[BitVecRef]) -> BitVecRef: ... +@overload +def Concat(*args: BitVecRef) -> BitVecRef: ... + +@overload +def Extract(high: Union[SeqRef], lo: Union[int, ArithRef], a: Union[int, ArithRef]) -> SeqRef: ... +@overload +def Extract(high: Union[int, ArithRef], lo: Union[int, ArithRef], a: BitVecRef) -> BitVecRef: ... + +@overload +def Sum(arg: BitVecRef, *args: Union[BitVecRef, int]) -> BitVecRef: ... +@overload +def Sum(arg: Union[List[BitVecRef], int]) -> BitVecRef: ... +@overload +def Sum(arg: ArithRef, *args: Union[ArithRef, int]) -> ArithRef: ... +# Can't include this overload as it overlaps with the second overload. +#@overload +#def Sum(arg: Union[List[ArithRef], int]) -> ArithRef: ... + +def Function(name: str, *sig: SortRef) -> FuncDeclRef: ... +def IntVal(val: int, ctx: Optional[Context] = None) -> IntNumRef: ... +def BoolVal(val: bool, ctx: Optional[Context] = None) -> BoolRef: ... +def BitVecVal(val: int, bv: Union[int, BitVecSortRef], ctx: Optional[Context] = None) -> BitVecRef: ... +def BitVec(val: str, bv: Union[int, BitVecSortRef], ctx: Optional[Context] = None) -> BitVecRef: ... + +def IntSort(ctx: Optional[Context] = None) -> ArithSortRef: ... +def BoolSort(ctx:Optional[Context] = None) -> BoolSortRef: ... +def ArraySort(domain: SortRef, range: SortRef) -> ArraySortRef: ... +def BitVecSort(domain: int, ctx:Optional[Context] = None) -> BoolSortRef: ... + +def ForAll(vs: List[ExprRef], expr: ExprRef) -> ExprRef: ... +def Select(arr: ExprRef, ind: ExprRef) -> ExprRef: ... +def Update(arr: ArrayRef, ind: ExprRef, newVal: ExprRef) -> ArrayRef: ... +def Store(arr: ArrayRef, ind: ExprRef, newVal: ExprRef) -> ArrayRef: ... + +def BVAddNoOverflow(a: BitVecRef, b: BitVecRef, signed: bool) -> BoolRef: ... +def BVAddNoUnderflow(a: BitVecRef, b: BitVecRef) -> BoolRef: ... +def BVSubNoOverflow(a: BitVecRef, b: BitVecRef) -> BoolRef: ... +def BVSubNoUnderflow(a: BitVecRef, b: BitVecRef, signed: bool) -> BoolRef: ... +def BVSDivNoOverflow(a: BitVecRef, b: BitVecRef) -> BoolRef: ... +def BVSNegNoOverflow(a: BitVecRef) -> BoolRef: ... +def BVMulNoOverflow(a: BitVecRef, b: BitVecRef, signed: bool) -> BoolRef: ... +def BVMulNoUnderflow(a: BitVecRef, b: BitVecRef) -> BoolRef: ... diff --git a/mypy-stubs/z3/z3core.pyi b/mypy-stubs/z3/z3core.pyi new file mode 100644 index 00000000..36f1f887 --- /dev/null +++ b/mypy-stubs/z3/z3core.pyi @@ -0,0 +1,3 @@ +from .z3types import Ast, ContextObj +def Z3_mk_eq(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ... +def Z3_mk_div(ctx: ContextObj, a: Ast, b: Ast) -> Ast: ... diff --git a/mypy-stubs/z3/z3types.pyi b/mypy-stubs/z3/z3types.pyi new file mode 100644 index 00000000..fa8fc446 --- /dev/null +++ b/mypy-stubs/z3/z3types.pyi @@ -0,0 +1,12 @@ +from typing import Any + +class Z3Exception(Exception): + def __init__(self, a: Any) -> None: + self.value = a + ... + +class ContextObj: + ... + +class Ast: + ... diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 3b41f7d2..80cc7308 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -19,14 +19,18 @@ from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.solver import Solver, Optimize +from typing import Union, Any, Optional, List import z3 +Annotations = Optional[List[Any]] + + class SymbolFactory: """A symbol factory provides a default interface for all the components of mythril to create symbols""" @staticmethod - def Bool(value: bool, annotations=None): + def Bool(value: __builtins__.bool, annotations: Annotations=None) -> Union[bool.Bool, z3.BoolRef]: """ Creates a Bool with concrete value :param value: The boolean value @@ -36,7 +40,7 @@ class SymbolFactory: raise NotImplementedError @staticmethod - def BitVecVal(value: int, size: int, annotations=None): + def BitVecVal(value: int, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]: """Creates a new bit vector with a concrete value. :param value: The concrete value to set the bit vector to @@ -47,7 +51,7 @@ class SymbolFactory: raise NotImplementedError() @staticmethod - def BitVecSym(name: str, size: int, annotations=None): + def BitVecSym(name: str, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]: """Creates a new bit vector with a symbolic value. :param name: The name of the symbolic bit vector @@ -65,24 +69,24 @@ class _SmtSymbolFactory(SymbolFactory): """ @staticmethod - def Bool(value: bool, annotations=None): + def Bool(value: __builtins__.bool, annotations: Annotations=None) -> Union[bool.Bool, z3.BoolRef]: """ Creates a Bool with concrete value :param value: The boolean value :param annotations: The annotations to initialize the bool with :return: The freshly created Bool() """ - raw = z3.Bool(value) + raw = z3.BoolVal(value) return Bool(raw, annotations) @staticmethod - def BitVecVal(value: int, size: int, annotations=None): + def BitVecVal(value: int, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]: """Creates a new bit vector with a concrete value.""" raw = z3.BitVecVal(value, size) return BitVec(raw, annotations) @staticmethod - def BitVecSym(name: str, size: int, annotations=None): + def BitVecSym(name: str, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]: """Creates a new bit vector with a symbolic value.""" raw = z3.BitVec(name, size) return BitVec(raw, annotations) @@ -95,17 +99,17 @@ class _Z3SymbolFactory(SymbolFactory): """ @staticmethod - def Bool(value: bool, annotations=None): + def Bool(value: __builtins__.bool, annotations: Annotations=None) -> Union[bool.Bool, z3.BoolRef]: """ Creates a new bit vector with a concrete value """ - return z3.Bool(value) + return z3.BoolVal(value) @staticmethod - def BitVecVal(value: int, size: int, annotations=None): + def BitVecVal(value: int, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]: """Creates a new bit vector with a concrete value.""" return z3.BitVecVal(value, size) @staticmethod - def BitVecSym(name: str, size: int, annotations=None): + def BitVecSym(name: str, size: int, annotations: Annotations=None) -> Union[BitVec, z3.BitVecRef]: """Creates a new bit vector with a symbolic value.""" return z3.BitVec(name, size) diff --git a/mythril/laser/smt/array.py b/mythril/laser/smt/array.py index a568ea4c..3df06301 100644 --- a/mythril/laser/smt/array.py +++ b/mythril/laser/smt/array.py @@ -5,6 +5,7 @@ operations, as well as as a K-array, which can be initialized with default values over a certain range. """ +from typing import cast import z3 from mythril.laser.smt.bitvec import BitVec @@ -12,16 +13,19 @@ from mythril.laser.smt.bitvec import BitVec class BaseArray: """Base array type, which implements basic store and set operations.""" + domain: z3.SortRef + range: z3.SortRef + raw: z3.ArrayRef - def __getitem__(self, item: BitVec): + def __getitem__(self, item: BitVec) -> BitVec: """Gets item from the array, item can be symbolic.""" if isinstance(item, slice): raise ValueError( "Instance of BaseArray, does not support getitem with slices" ) - return BitVec(z3.Select(self.raw, item.raw)) + return BitVec(cast(z3.BitVecRef, z3.Select(self.raw, item.raw))) - def __setitem__(self, key: BitVec, value: BitVec): + def __setitem__(self, key: BitVec, value: BitVec) -> None: """Sets an item in the array, key can be symbolic.""" self.raw = z3.Store(self.raw, key.raw, value.raw) diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index bfbe1c9f..98e5d4d1 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -1,20 +1,20 @@ """This module provides classes for an SMT abstraction of bit vectors.""" -from typing import Union +from typing import Union, overload, List, cast, Any, Optional import z3 from mythril.laser.smt.bool import Bool from mythril.laser.smt.expression import Expression +Annotations = List[Any] # fmt: off - -class BitVec(Expression): +class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" - def __init__(self, raw, annotations=None): + def __init__(self, raw: z3.BitVecRef, annotations: Optional[Annotations]=None): """ :param raw: @@ -22,7 +22,7 @@ class BitVec(Expression): """ super().__init__(raw, annotations) - def size(self): + def size(self) -> int: """TODO: documentation :return: @@ -30,7 +30,7 @@ class BitVec(Expression): return self.raw.size() @property - def symbolic(self): + def symbolic(self) -> bool: """Returns whether this symbol doesn't have a concrete value. :return: @@ -39,7 +39,7 @@ class BitVec(Expression): return not isinstance(self.raw, z3.BitVecNumRef) @property - def value(self): + def value(self) -> Optional[int]: """Returns the value of this symbol if concrete, otherwise None. :return: @@ -49,7 +49,7 @@ class BitVec(Expression): assert isinstance(self.raw, z3.BitVecNumRef) return self.raw.as_long() - def __add__(self, other) -> "BitVec": + def __add__(self, other: Union[int, BitVec]) -> "BitVec": """Create an addition expression. :param other: @@ -61,7 +61,7 @@ class BitVec(Expression): union = self.annotations + other.annotations return BitVec(self.raw + other.raw, annotations=union) - def __sub__(self, other) -> "BitVec": + def __sub__(self, other: Union[int, BitVec]) -> "BitVec": """Create a subtraction expression. :param other: @@ -74,7 +74,7 @@ class BitVec(Expression): union = self.annotations + other.annotations return BitVec(self.raw - other.raw, annotations=union) - def __mul__(self, other) -> "BitVec": + def __mul__(self, other: BitVec) -> "BitVec": """Create a multiplication expression. :param other: @@ -83,7 +83,7 @@ class BitVec(Expression): union = self.annotations + other.annotations return BitVec(self.raw * other.raw, annotations=union) - def __truediv__(self, other) -> "BitVec": + def __truediv__(self, other: BitVec) -> "BitVec": """Create a signed division expression. :param other: @@ -92,7 +92,7 @@ class BitVec(Expression): union = self.annotations + other.annotations return BitVec(self.raw / other.raw, annotations=union) - def __and__(self, other) -> "BitVec": + def __and__(self, other: Union[int, BitVec]) -> "BitVec": """Create an and expression. :param other: @@ -103,7 +103,7 @@ class BitVec(Expression): union = self.annotations + other.annotations return BitVec(self.raw & other.raw, annotations=union) - def __or__(self, other) -> "BitVec": + def __or__(self, other: BitVec) -> "BitVec": """Create an or expression. :param other: @@ -112,7 +112,7 @@ class BitVec(Expression): union = self.annotations + other.annotations return BitVec(self.raw | other.raw, annotations=union) - def __xor__(self, other) -> "BitVec": + def __xor__(self, other: BitVec) -> "BitVec": """Create a xor expression. :param other: @@ -121,7 +121,7 @@ class BitVec(Expression): union = self.annotations + other.annotations return BitVec(self.raw ^ other.raw, annotations=union) - def __lt__(self, other) -> Bool: + def __lt__(self, other: BitVec) -> Bool: """Create a signed less than expression. :param other: @@ -130,7 +130,7 @@ class BitVec(Expression): union = self.annotations + other.annotations return Bool(self.raw < other.raw, annotations=union) - def __gt__(self, other) -> Bool: + def __gt__(self, other: BitVec) -> Bool: """Create a signed greater than expression. :param other: @@ -139,32 +139,36 @@ class BitVec(Expression): union = self.annotations + other.annotations return Bool(self.raw > other.raw, annotations=union) - def __eq__(self, other) -> Bool: + #MYPY: fix complains about overriding __eq__ + def __eq__(self, other: Union[int, BitVec]) -> Bool: # type: ignore """Create an equality expression. :param other: :return: """ if not isinstance(other, BitVec): - return Bool(self.raw == other, annotations=self.annotations) + return Bool(cast(z3.BoolRef, self.raw == other), annotations=self.annotations) union = self.annotations + other.annotations - return Bool(self.raw == other.raw, annotations=union) + # MYPY: fix complaints due to z3 overriding __eq__ + return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union) - def __ne__(self, other) -> Bool: + #MYPY: fix complains about overriding __ne__ + def __ne__(self, other: Union[int, BitVec]) -> Bool: # type: ignore """Create an inequality expression. :param other: :return: """ if not isinstance(other, BitVec): - return Bool(self.raw != other, annotations=self.annotations) + return Bool(cast(z3.BoolRef, self.raw != other), annotations=self.annotations) union = self.annotations + other.annotations - return Bool(self.raw != other.raw, annotations=union) + # MYPY: fix complaints due to z3 overriding __eq__ + return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) -def If(a: Bool, b: BitVec, c: BitVec): +def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: """Create an if-then-else expression. :param a: @@ -172,11 +176,11 @@ def If(a: Bool, b: BitVec, c: BitVec): :param c: :return: """ - if not isinstance(a, Expression): + if not isinstance(a, Bool): a = Bool(z3.BoolVal(a)) - if not isinstance(b, Expression): + if not isinstance(b, BitVec): b = BitVec(z3.BitVecVal(b, 256)) - if not isinstance(c, Expression): + if not isinstance(c, BitVec): c = BitVec(z3.BitVecVal(c, 256)) union = a.annotations + b.annotations + c.annotations return BitVec(z3.If(a.raw, b.raw, c.raw), union) @@ -215,7 +219,15 @@ def ULT(a: BitVec, b: BitVec) -> Bool: return Bool(z3.ULT(a.raw, b.raw), annotations) -def Concat(*args) -> BitVec: +@overload +def Concat(*args: List[BitVec]) -> BitVec: ... + + +@overload +def Concat(*args: BitVec) -> BitVec: ... + + +def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: """Create a concatenation expression. :param args: @@ -224,11 +236,13 @@ def Concat(*args) -> BitVec: # The following statement is used if a list is provided as an argument to concat if len(args) == 1 and isinstance(args[0], list): - args = args[0] + bvs = args[0] # type: List[BitVec] + else: + bvs = cast(List[BitVec], args) - nraw = z3.Concat([a.raw for a in args]) - annotations = [] - for bv in args: + nraw = z3.Concat([a.raw for a in bvs]) + annotations = [] # type: Annotations + for bv in bvs: annotations += bv.annotations return BitVec(nraw, annotations) @@ -277,13 +291,13 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec: return BitVec(z3.UDiv(a.raw, b.raw), annotations=union) -def Sum(*args) -> BitVec: +def Sum(*args: BitVec) -> BitVec: """Create sum expression. :return: """ nraw = z3.Sum([a.raw for a in args]) - annotations = [] + annotations = [] # type: Annotations for bv in args: annotations += bv.annotations return BitVec(nraw, annotations) @@ -297,9 +311,9 @@ def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) :param signed: :return: """ - if not isinstance(a, Expression): + if not isinstance(a, BitVec): a = BitVec(z3.BitVecVal(a, 256)) - if not isinstance(b, Expression): + if not isinstance(b, BitVec): b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVAddNoOverflow(a.raw, b.raw, signed)) @@ -313,9 +327,9 @@ def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) :param signed: :return: """ - if not isinstance(a, Expression): + if not isinstance(a, BitVec): a = BitVec(z3.BitVecVal(a, 256)) - if not isinstance(b, Expression): + if not isinstance(b, BitVec): b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed)) @@ -328,9 +342,9 @@ def BVSubNoUnderflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) :param signed: :return: """ - if not isinstance(a, Expression): + if not isinstance(a, BitVec): a = BitVec(z3.BitVecVal(a, 256)) - if not isinstance(b, Expression): + if not isinstance(b, BitVec): b = BitVec(z3.BitVecVal(b, 256)) return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed)) diff --git a/mythril/laser/smt/bool.py b/mythril/laser/smt/bool.py index ac17caa8..a6442e01 100644 --- a/mythril/laser/smt/bool.py +++ b/mythril/laser/smt/bool.py @@ -1,7 +1,7 @@ """This module provides classes for an SMT abstraction of boolean expressions.""" -from typing import Union +from typing import Union, cast import z3 @@ -11,7 +11,7 @@ from mythril.laser.smt.expression import Expression # fmt: off -class Bool(Expression): +class Bool(Expression[z3.BoolRef]): """This is a Bool expression.""" @property @@ -46,27 +46,31 @@ class Bool(Expression): else: return None - def __eq__(self, other): + #MYPY: complains about overloading __eq__ # noqa + def __eq__(self, other: object) -> Bool: # type: ignore """ :param other: :return: """ if isinstance(other, Expression): - return Bool(self.raw == other.raw, self.annotations + other.annotations) - return Bool(self.raw == other, self.annotations) + return Bool(cast(z3.BoolRef, self.raw == other.raw), + self.annotations + other.annotations) + return Bool(cast(z3.BoolRef, self.raw == other), self.annotations) - def __ne__(self, other): + #MYPY: complains about overloading __ne__ # noqa + def __ne__(self, other: object) -> Bool: # type: ignore """ :param other: :return: """ if isinstance(other, Expression): - return Bool(self.raw != other.raw, self.annotations + other.annotations) - return Bool(self.raw != other, self.annotations) + return Bool(cast(z3.BoolRef, self.raw != other.raw), + self.annotations + other.annotations) + return Bool(cast(z3.BoolRef, self.raw != other), self.annotations) - def __bool__(self): + def __bool__(self) -> bool: """ :return: @@ -77,7 +81,7 @@ class Bool(Expression): return False -def Or(a: Bool, b: Bool): +def Or(a: Bool, b: Bool) -> Bool: """Create an or expression. :param a: @@ -88,7 +92,7 @@ def Or(a: Bool, b: Bool): return Bool(z3.Or(a.raw, b.raw), annotations=union) -def Not(a: Bool): +def Not(a: Bool) -> Bool: """Create a Not expression. :param a: diff --git a/mythril/laser/smt/expression.py b/mythril/laser/smt/expression.py index fa8ca579..0a054af0 100644 --- a/mythril/laser/smt/expression.py +++ b/mythril/laser/smt/expression.py @@ -1,13 +1,17 @@ """This module contains the SMT abstraction for a basic symbol expression.""" - +from typing import Optional, List, Any, TypeVar, Generic, cast import z3 -class Expression: +Annotations=List[Any] +T = TypeVar('T', bound=z3.ExprRef) + + +class Expression(Generic[T]): """This is the base symbol class and maintains functionality for simplification and annotations.""" - def __init__(self, raw, annotations=None): + def __init__(self, raw: T, annotations: Optional[Annotations]=None): """ :param raw: @@ -17,14 +21,14 @@ class Expression: self._annotations = annotations or [] @property - def annotations(self): + def annotations(self) -> Annotations: """Gets the annotations for this expression. :return: """ return self._annotations - def annotate(self, annotation): + def annotate(self, annotation: Any) -> None: """Annotates this expression with the given annotation. :param annotation: @@ -34,15 +38,15 @@ class Expression: else: self._annotations.append(annotation) - def simplify(self): + def simplify(self) -> None: """Simplify this expression.""" - self.raw = z3.simplify(self.raw) + self.raw = cast(T, z3.simplify(self.raw)) - def __repr__(self): + def __repr__(self) -> str: return repr(self.raw) -def simplify(expression: Expression): +def simplify(expression: Expression) -> Expression: """Simplify the expression . :param expression: diff --git a/mythril/laser/smt/solver.py b/mythril/laser/smt/solver.py index 4b973470..1523456d 100644 --- a/mythril/laser/smt/solver.py +++ b/mythril/laser/smt/solver.py @@ -1,5 +1,6 @@ """This module contains an abstract SMT representation of an SMT solver.""" import z3 +from typing import Union, cast from mythril.laser.smt.expression import Expression @@ -7,9 +8,9 @@ from mythril.laser.smt.expression import Expression class Solver: """An SMT solver object.""" - def __init__(self): + def __init__(self) -> None: """""" - self.raw = z3.Solver() + self.raw = z3.Solver() # type: Union[z3.Solver, z3.Optimize] def set_timeout(self, timeout: int) -> None: """Sets the timeout that will be used by this solver, timeout is in @@ -43,14 +44,14 @@ class Solver: constraints = [c.raw for c in constraints] self.raw.add(constraints) - def check(self): + def check(self) -> z3.CheckSatResult: """Returns z3 smt check result. :return: """ return self.raw.check() - def model(self): + def model(self) -> z3.ModelRef: """Returns z3 model for a solution. :return: @@ -59,34 +60,34 @@ class Solver: def reset(self) -> None: """Reset this solver.""" - self.raw.reset() + cast(z3.Solver, self.raw).reset() - def pop(self, num) -> None: + def pop(self, num: int) -> None: """Pop num constraints from this solver. :param num: """ - self.raw.pop(num) + cast(z3.Solver, self.raw).pop(num) class Optimize(Solver): """An optimizing smt solver.""" - def __init__(self): + def __init__(self) -> None: """Create a new optimizing solver instance.""" super().__init__() self.raw = z3.Optimize() - def minimize(self, element: Expression): + def minimize(self, element: Expression) -> None: """In solving this solver will try to minimize the passed expression. :param element: """ - self.raw.minimize(element.raw) + cast(z3.Optimize, self.raw).minimize(element.raw) - def maximize(self, element: Expression): + def maximize(self, element: Expression) -> None: """In solving this solver will try to maximize the passed expression. :param element: """ - self.raw.maximize(element.raw) + cast(z3.Optimize, self.raw).maximize(element.raw)