Merge pull request #1121 from ConsenSys/annotations_to_union

Refactor Bitvec annotations to sets
fix_integers
Bernhard Mueller 5 years ago committed by GitHub
commit beea142fb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      mythril/laser/ethereum/instructions.py
  2. 4
      mythril/laser/smt/__init__.py
  3. 46
      mythril/laser/smt/bitvec.py
  4. 4
      mythril/laser/smt/bitvecfunc.py
  5. 20
      mythril/laser/smt/bool.py
  6. 17
      mythril/laser/smt/expression.py

@ -4,7 +4,7 @@ import binascii
import logging
from copy import copy, deepcopy
from typing import cast, Callable, List, Union, Tuple
from typing import cast, Callable, List, Set, Union, Tuple, Any
from datetime import datetime
from math import ceil
from ethereum import utils
@ -553,7 +553,7 @@ class Instruction:
+ str(hash(simplify(exponent)))
+ ")",
256,
base.annotations + exponent.annotations,
base.annotations.union(exponent.annotations),
)
) # Hash is used because str(symbol) takes a long time to be converted to a string
else:
@ -562,7 +562,7 @@ class Instruction:
symbol_factory.BitVecVal(
pow(base.value, exponent.value, 2 ** 256),
256,
annotations=base.annotations + exponent.annotations,
annotations=base.annotations.union(exponent.annotations),
)
)
@ -925,11 +925,11 @@ class Instruction:
if data.symbolic:
annotations = []
annotations = set() # type: Set[Any]
for b in state.memory[index : index + length]:
if isinstance(b, BitVec):
annotations += b.annotations
annotations = annotations.union(b.annotations)
argument_str = str(state.memory[index]).replace(" ", "_")
result = symbol_factory.BitVecFuncSym(

@ -22,11 +22,11 @@ from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics
from mythril.laser.smt.model import Model
from typing import Union, Any, Optional, List, TypeVar, Generic
from typing import Union, Any, Optional, Set, TypeVar, Generic
import z3
Annotations = Optional[List[Any]]
Annotations = Optional[Set[Any]]
T = TypeVar("T", bound=Union[bool.Bool, z3.BoolRef])
U = TypeVar("U", bound=Union[BitVec, z3.BitVecRef])

@ -1,13 +1,13 @@
"""This module provides classes for an SMT abstraction of bit vectors."""
from typing import Union, overload, List, cast, Any, Optional, Callable
from typing import Union, overload, List, Set, cast, Any, Optional, Callable
from operator import lshift, rshift
import z3
from mythril.laser.smt.bool import Bool, And, Or
from mythril.laser.smt.expression import Expression
Annotations = List[Any]
Annotations = Set[Any]
# fmt: off
@ -61,7 +61,7 @@ class BitVec(Expression[z3.BitVecRef]):
if isinstance(other, int):
return BitVec(self.raw + other, annotations=self.annotations)
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return BitVec(self.raw + other.raw, annotations=union)
def __sub__(self, other: Union[int, "BitVec"]) -> "BitVec":
@ -75,7 +75,7 @@ class BitVec(Expression[z3.BitVecRef]):
if isinstance(other, int):
return BitVec(self.raw - other, annotations=self.annotations)
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return BitVec(self.raw - other.raw, annotations=union)
def __mul__(self, other: "BitVec") -> "BitVec":
@ -86,7 +86,7 @@ class BitVec(Expression[z3.BitVecRef]):
"""
if isinstance(other, BitVecFunc):
return other * self
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return BitVec(self.raw * other.raw, annotations=union)
def __truediv__(self, other: "BitVec") -> "BitVec":
@ -97,7 +97,7 @@ class BitVec(Expression[z3.BitVecRef]):
"""
if isinstance(other, BitVecFunc):
return other / self
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return BitVec(self.raw / other.raw, annotations=union)
def __and__(self, other: Union[int, "BitVec"]) -> "BitVec":
@ -110,7 +110,7 @@ class BitVec(Expression[z3.BitVecRef]):
return other & self
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return BitVec(self.raw & other.raw, annotations=union)
def __or__(self, other: "BitVec") -> "BitVec":
@ -121,7 +121,7 @@ class BitVec(Expression[z3.BitVecRef]):
"""
if isinstance(other, BitVecFunc):
return other | self
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return BitVec(self.raw | other.raw, annotations=union)
def __xor__(self, other: "BitVec") -> "BitVec":
@ -132,7 +132,7 @@ class BitVec(Expression[z3.BitVecRef]):
"""
if isinstance(other, BitVecFunc):
return other ^ self
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return BitVec(self.raw ^ other.raw, annotations=union)
def __lt__(self, other: "BitVec") -> Bool:
@ -143,7 +143,7 @@ class BitVec(Expression[z3.BitVecRef]):
"""
if isinstance(other, BitVecFunc):
return other > self
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return Bool(self.raw < other.raw, annotations=union)
def __gt__(self, other: "BitVec") -> Bool:
@ -154,7 +154,7 @@ class BitVec(Expression[z3.BitVecRef]):
"""
if isinstance(other, BitVecFunc):
return other < self
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return Bool(self.raw > other.raw, annotations=union)
def __le__(self, other: "BitVec") -> Bool:
@ -163,7 +163,7 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return Bool(self.raw <= other.raw, annotations=union)
def __ge__(self, other: "BitVec") -> Bool:
@ -172,7 +172,7 @@ class BitVec(Expression[z3.BitVecRef]):
:param other:
:return:
"""
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return Bool(self.raw >= other.raw, annotations=union)
# MYPY: fix complains about overriding __eq__
@ -189,7 +189,7 @@ class BitVec(Expression[z3.BitVecRef]):
cast(z3.BoolRef, self.raw == other), annotations=self.annotations
)
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
# MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union)
@ -207,7 +207,7 @@ class BitVec(Expression[z3.BitVecRef]):
cast(z3.BoolRef, self.raw != other), annotations=self.annotations
)
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
# MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union)
@ -224,7 +224,7 @@ class BitVec(Expression[z3.BitVecRef]):
return BitVec(
operator(self.raw, other), annotations=self.annotations
)
union = self.annotations + other.annotations
union = self.annotations.union(other.annotations)
return BitVec(operator(self.raw, other.raw), annotations=union)
def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec":
@ -254,7 +254,7 @@ class BitVec(Expression[z3.BitVecRef]):
def _comparison_helper(
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool:
annotations = a.annotations + b.annotations
annotations = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
if not a.symbolic and not b.symbolic:
return Bool(operation(a.raw, b.raw), annotations=annotations)
@ -277,7 +277,7 @@ def _comparison_helper(
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw)
union = a.annotations + b.annotations
union = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
@ -313,7 +313,7 @@ def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> Bi
b = BitVec(z3.BitVecVal(b, 256))
if not isinstance(c, BitVec):
c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations + b.annotations + c.annotations
union = a.annotations.union(b.annotations).union(c.annotations)
return BitVec(z3.If(a.raw, b.raw, c.raw), union)
@ -378,10 +378,10 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
bvs = cast(List[BitVec], args)
nraw = z3.Concat([a.raw for a in bvs])
annotations = [] # type: Annotations
annotations = set() # type: Annotations
bitvecfunc = False
for bv in bvs:
annotations += bv.annotations
annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
bitvecfunc = True
@ -448,11 +448,11 @@ def Sum(*args: BitVec) -> BitVec:
:return:
"""
raw = z3.Sum([a.raw for a in args])
annotations = [] # type: Annotations
annotations = set() # type: Annotations
bitvecfuncs = []
for bv in args:
annotations += bv.annotations
annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
bitvecfuncs.append(bv)

@ -23,7 +23,7 @@ def _arithmetic_helper(
b = BitVec(z3.BitVecVal(b, a.size()))
raw = operation(a.raw, b.raw)
union = a.annotations + b.annotations
union = a.annotations.union(b.annotations)
if isinstance(b, BitVecFunc):
# TODO: Find better value to set input and name to in this case?
@ -53,7 +53,7 @@ def _comparison_helper(
if isinstance(b, int):
b = BitVec(z3.BitVecVal(b, a.size()))
union = a.annotations + b.annotations
union = a.annotations.union(b.annotations)
if not a.symbolic and not b.symbolic:
return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union)

@ -1,7 +1,7 @@
"""This module provides classes for an SMT abstraction of boolean
expressions."""
from typing import Union, cast, List
from typing import Union, cast, List, Set
import z3
@ -55,7 +55,7 @@ class Bool(Expression[z3.BoolRef]):
"""
if isinstance(other, Expression):
return Bool(cast(z3.BoolRef, self.raw == other.raw),
self.annotations + other.annotations)
self.annotations.union(other.annotations))
return Bool(cast(z3.BoolRef, self.raw == other), self.annotations)
# MYPY: complains about overloading __ne__ # noqa
@ -67,7 +67,7 @@ class Bool(Expression[z3.BoolRef]):
"""
if isinstance(other, Expression):
return Bool(cast(z3.BoolRef, self.raw != other.raw),
self.annotations + other.annotations)
self.annotations.union(other.annotations))
return Bool(cast(z3.BoolRef, self.raw != other), self.annotations)
def __bool__(self) -> bool:
@ -86,17 +86,17 @@ class Bool(Expression[z3.BoolRef]):
def And(*args: Union[Bool, bool]) -> Bool:
"""Create an And expression."""
union = [] # type: List
annotations = set() # type: Set
args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
for arg in args_list:
union += arg.annotations
return Bool(z3.And([a.raw for a in args_list]), union)
annotations = annotations.union(arg.annotations)
return Bool(z3.And([a.raw for a in args_list]), annotations)
def Xor(a: Bool, b: Bool) -> Bool:
"""Create an And expression."""
union = a.annotations + b.annotations
union = a.annotations.union(b.annotations)
return Bool(z3.Xor(a.raw, b.raw), union)
@ -108,10 +108,10 @@ def Or(*args: Union[Bool, bool]) -> Bool:
:return:
"""
args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
union = [] # type: List
annotations = set() # type: Set
for arg in args_list:
union += arg.annotations
return Bool(z3.Or([a.raw for a in args_list]), annotations=union)
annotations = annotations.union(arg.annotations)
return Bool(z3.Or([a.raw for a in args_list]), annotations=annotations)
def Not(a: Bool) -> Bool:

@ -1,9 +1,9 @@
"""This module contains the SMT abstraction for a basic symbol expression."""
from typing import Optional, List, Any, TypeVar, Generic, cast
from typing import Optional, Set, Any, TypeVar, Generic, cast
import z3
Annotations = List[Any]
Annotations = Set[Any]
T = TypeVar("T", bound=z3.ExprRef)
@ -18,7 +18,11 @@ class Expression(Generic[T]):
:param annotations:
"""
self.raw = raw
self._annotations = annotations or []
if annotations:
assert isinstance(annotations, set)
self._annotations = annotations or set()
@property
def annotations(self) -> Annotations:
@ -26,6 +30,7 @@ class Expression(Generic[T]):
:return:
"""
return self._annotations
def annotate(self, annotation: Any) -> None:
@ -33,10 +38,8 @@ class Expression(Generic[T]):
:param annotation:
"""
if isinstance(annotation, list):
self._annotations += annotation
else:
self._annotations.append(annotation)
self._annotations.add(annotation)
def simplify(self) -> None:
"""Simplify this expression."""

Loading…
Cancel
Save