diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 1fd74673..82a3d160 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -4,7 +4,7 @@ import binascii import logging from copy import copy, deepcopy -from typing import cast, Callable, List, Union, Tuple +from typing import cast, Callable, List, Set, Union, Tuple, Any from datetime import datetime from math import ceil from ethereum import utils @@ -553,7 +553,7 @@ class Instruction: + str(hash(simplify(exponent))) + ")", 256, - base.annotations + exponent.annotations, + base.annotations.union(exponent.annotations), ) ) # Hash is used because str(symbol) takes a long time to be converted to a string else: @@ -562,7 +562,7 @@ class Instruction: symbol_factory.BitVecVal( pow(base.value, exponent.value, 2 ** 256), 256, - annotations=base.annotations + exponent.annotations, + annotations=base.annotations.union(exponent.annotations), ) ) @@ -925,11 +925,11 @@ class Instruction: if data.symbolic: - annotations = [] + annotations = set() # type: Set[Any] for b in state.memory[index : index + length]: if isinstance(b, BitVec): - annotations += b.annotations + annotations = annotations.union(b.annotations) argument_str = str(state.memory[index]).replace(" ", "_") result = symbol_factory.BitVecFuncSym( diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 5a3f1ca1..407f6565 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -22,11 +22,11 @@ from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics from mythril.laser.smt.model import Model -from typing import Union, Any, Optional, List, TypeVar, Generic +from typing import Union, Any, Optional, Set, TypeVar, Generic import z3 -Annotations = Optional[List[Any]] +Annotations = Optional[Set[Any]] T = TypeVar("T", bound=Union[bool.Bool, z3.BoolRef]) U = TypeVar("U", bound=Union[BitVec, z3.BitVecRef]) diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index 1dd15191..11060d04 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -1,13 +1,13 @@ """This module provides classes for an SMT abstraction of bit vectors.""" -from typing import Union, overload, List, cast, Any, Optional, Callable +from typing import Union, overload, List, Set, cast, Any, Optional, Callable from operator import lshift, rshift import z3 from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.expression import Expression -Annotations = List[Any] +Annotations = Set[Any] # fmt: off @@ -61,7 +61,7 @@ class BitVec(Expression[z3.BitVecRef]): if isinstance(other, int): return BitVec(self.raw + other, annotations=self.annotations) - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return BitVec(self.raw + other.raw, annotations=union) def __sub__(self, other: Union[int, "BitVec"]) -> "BitVec": @@ -75,7 +75,7 @@ class BitVec(Expression[z3.BitVecRef]): if isinstance(other, int): return BitVec(self.raw - other, annotations=self.annotations) - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return BitVec(self.raw - other.raw, annotations=union) def __mul__(self, other: "BitVec") -> "BitVec": @@ -86,7 +86,7 @@ class BitVec(Expression[z3.BitVecRef]): """ if isinstance(other, BitVecFunc): return other * self - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return BitVec(self.raw * other.raw, annotations=union) def __truediv__(self, other: "BitVec") -> "BitVec": @@ -97,7 +97,7 @@ class BitVec(Expression[z3.BitVecRef]): """ if isinstance(other, BitVecFunc): return other / self - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return BitVec(self.raw / other.raw, annotations=union) def __and__(self, other: Union[int, "BitVec"]) -> "BitVec": @@ -110,7 +110,7 @@ class BitVec(Expression[z3.BitVecRef]): return other & self if not isinstance(other, BitVec): other = BitVec(z3.BitVecVal(other, self.size())) - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return BitVec(self.raw & other.raw, annotations=union) def __or__(self, other: "BitVec") -> "BitVec": @@ -121,7 +121,7 @@ class BitVec(Expression[z3.BitVecRef]): """ if isinstance(other, BitVecFunc): return other | self - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return BitVec(self.raw | other.raw, annotations=union) def __xor__(self, other: "BitVec") -> "BitVec": @@ -132,7 +132,7 @@ class BitVec(Expression[z3.BitVecRef]): """ if isinstance(other, BitVecFunc): return other ^ self - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return BitVec(self.raw ^ other.raw, annotations=union) def __lt__(self, other: "BitVec") -> Bool: @@ -143,7 +143,7 @@ class BitVec(Expression[z3.BitVecRef]): """ if isinstance(other, BitVecFunc): return other > self - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return Bool(self.raw < other.raw, annotations=union) def __gt__(self, other: "BitVec") -> Bool: @@ -154,7 +154,7 @@ class BitVec(Expression[z3.BitVecRef]): """ if isinstance(other, BitVecFunc): return other < self - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return Bool(self.raw > other.raw, annotations=union) def __le__(self, other: "BitVec") -> Bool: @@ -163,7 +163,7 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return Bool(self.raw <= other.raw, annotations=union) def __ge__(self, other: "BitVec") -> Bool: @@ -172,7 +172,7 @@ class BitVec(Expression[z3.BitVecRef]): :param other: :return: """ - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return Bool(self.raw >= other.raw, annotations=union) # MYPY: fix complains about overriding __eq__ @@ -189,7 +189,7 @@ class BitVec(Expression[z3.BitVecRef]): cast(z3.BoolRef, self.raw == other), annotations=self.annotations ) - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) # MYPY: fix complaints due to z3 overriding __eq__ return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union) @@ -207,7 +207,7 @@ class BitVec(Expression[z3.BitVecRef]): cast(z3.BoolRef, self.raw != other), annotations=self.annotations ) - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) # MYPY: fix complaints due to z3 overriding __eq__ return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) @@ -224,7 +224,7 @@ class BitVec(Expression[z3.BitVecRef]): return BitVec( operator(self.raw, other), annotations=self.annotations ) - union = self.annotations + other.annotations + union = self.annotations.union(other.annotations) return BitVec(operator(self.raw, other.raw), annotations=union) def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec": @@ -254,7 +254,7 @@ class BitVec(Expression[z3.BitVecRef]): def _comparison_helper( a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool ) -> Bool: - annotations = a.annotations + b.annotations + annotations = a.annotations.union(b.annotations) if isinstance(a, BitVecFunc): if not a.symbolic and not b.symbolic: return Bool(operation(a.raw, b.raw), annotations=annotations) @@ -277,7 +277,7 @@ def _comparison_helper( def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: raw = operation(a.raw, b.raw) - union = a.annotations + b.annotations + union = a.annotations.union(b.annotations) if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) @@ -313,7 +313,7 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi b = BitVec(z3.BitVecVal(b, 256)) if not isinstance(c, BitVec): c = BitVec(z3.BitVecVal(c, 256)) - union = a.annotations + b.annotations + c.annotations + union = a.annotations.union(b.annotations).union(c.annotations) return BitVec(z3.If(a.raw, b.raw, c.raw), union) @@ -378,10 +378,10 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: bvs = cast(List[BitVec], args) nraw = z3.Concat([a.raw for a in bvs]) - annotations = [] # type: Annotations + annotations = set() # type: Annotations bitvecfunc = False for bv in bvs: - annotations += bv.annotations + annotations = annotations.union(bv.annotations) if isinstance(bv, BitVecFunc): bitvecfunc = True @@ -448,11 +448,11 @@ def Sum(*args: BitVec) -> BitVec: :return: """ raw = z3.Sum([a.raw for a in args]) - annotations = [] # type: Annotations + annotations = set() # type: Annotations bitvecfuncs = [] for bv in args: - annotations += bv.annotations + annotations = annotations.union(bv.annotations) if isinstance(bv, BitVecFunc): bitvecfuncs.append(bv) diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index 257d2b46..8d680648 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -23,7 +23,7 @@ def _arithmetic_helper( b = BitVec(z3.BitVecVal(b, a.size())) raw = operation(a.raw, b.raw) - union = a.annotations + b.annotations + union = a.annotations.union(b.annotations) if isinstance(b, BitVecFunc): # TODO: Find better value to set input and name to in this case? @@ -53,7 +53,7 @@ def _comparison_helper( if isinstance(b, int): b = BitVec(z3.BitVecVal(b, a.size())) - union = a.annotations + b.annotations + union = a.annotations.union(b.annotations) if not a.symbolic and not b.symbolic: return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) diff --git a/mythril/laser/smt/bool.py b/mythril/laser/smt/bool.py index 40f098fc..c09358d0 100644 --- a/mythril/laser/smt/bool.py +++ b/mythril/laser/smt/bool.py @@ -1,7 +1,7 @@ """This module provides classes for an SMT abstraction of boolean expressions.""" -from typing import Union, cast, List +from typing import Union, cast, List, Set import z3 @@ -55,7 +55,7 @@ class Bool(Expression[z3.BoolRef]): """ if isinstance(other, Expression): return Bool(cast(z3.BoolRef, self.raw == other.raw), - self.annotations + other.annotations) + self.annotations.union(other.annotations)) return Bool(cast(z3.BoolRef, self.raw == other), self.annotations) # MYPY: complains about overloading __ne__ # noqa @@ -67,7 +67,7 @@ class Bool(Expression[z3.BoolRef]): """ if isinstance(other, Expression): return Bool(cast(z3.BoolRef, self.raw != other.raw), - self.annotations + other.annotations) + self.annotations.union(other.annotations)) return Bool(cast(z3.BoolRef, self.raw != other), self.annotations) def __bool__(self) -> bool: @@ -86,17 +86,17 @@ class Bool(Expression[z3.BoolRef]): def And(*args: Union[Bool, bool]) -> Bool: """Create an And expression.""" - union = [] # type: List + annotations = set() # type: Set args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] for arg in args_list: - union += arg.annotations - return Bool(z3.And([a.raw for a in args_list]), union) + annotations = annotations.union(arg.annotations) + return Bool(z3.And([a.raw for a in args_list]), annotations) def Xor(a: Bool, b: Bool) -> Bool: """Create an And expression.""" - union = a.annotations + b.annotations + union = a.annotations.union(b.annotations) return Bool(z3.Xor(a.raw, b.raw), union) @@ -108,10 +108,10 @@ def Or(*args: Union[Bool, bool]) -> Bool: :return: """ args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] - union = [] # type: List + annotations = set() # type: Set for arg in args_list: - union += arg.annotations - return Bool(z3.Or([a.raw for a in args_list]), annotations=union) + annotations = annotations.union(arg.annotations) + return Bool(z3.Or([a.raw for a in args_list]), annotations=annotations) def Not(a: Bool) -> Bool: diff --git a/mythril/laser/smt/expression.py b/mythril/laser/smt/expression.py index ae98cd31..9fa7cef1 100644 --- a/mythril/laser/smt/expression.py +++ b/mythril/laser/smt/expression.py @@ -1,9 +1,9 @@ """This module contains the SMT abstraction for a basic symbol expression.""" -from typing import Optional, List, Any, TypeVar, Generic, cast +from typing import Optional, Set, Any, TypeVar, Generic, cast import z3 -Annotations = List[Any] +Annotations = Set[Any] T = TypeVar("T", bound=z3.ExprRef) @@ -18,7 +18,11 @@ class Expression(Generic[T]): :param annotations: """ self.raw = raw - self._annotations = annotations or [] + + if annotations: + assert isinstance(annotations, set) + + self._annotations = annotations or set() @property def annotations(self) -> Annotations: @@ -26,6 +30,7 @@ class Expression(Generic[T]): :return: """ + return self._annotations def annotate(self, annotation: Any) -> None: @@ -33,10 +38,8 @@ class Expression(Generic[T]): :param annotation: """ - if isinstance(annotation, list): - self._annotations += annotation - else: - self._annotations.append(annotation) + + self._annotations.add(annotation) def simplify(self) -> None: """Simplify this expression."""