Merge pull request #1153 from ConsenSys/fix/padding

Fix the padding while comparing
pull/1156/head
Bernhard Mueller 5 years ago committed by GitHub
commit 8f04aacf86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      mythril/laser/smt/__init__.py
  2. 21
      mythril/laser/smt/bitvec.py

@ -21,13 +21,13 @@ from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics
from mythril.laser.smt.model import Model
from mythril.laser.smt.bool import Bool as SMTBool
from typing import Union, Any, Optional, Set, TypeVar, Generic
import z3
Annotations = Optional[Set[Any]]
T = TypeVar("T", bound=Union[bool.Bool, z3.BoolRef])
T = TypeVar("T", bound=Union[SMTBool, z3.BoolRef])
U = TypeVar("U", bound=Union[BitVec, z3.BitVecRef])
@ -105,14 +105,14 @@ class SymbolFactory(Generic[T, U]):
raise NotImplementedError()
class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]):
class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
"""
An implementation of a SymbolFactory that creates symbols using
the classes in: mythril.laser.smt
"""
@staticmethod
def Bool(value: "__builtins__.bool", annotations: Annotations = None) -> bool.Bool:
def Bool(value: "__builtins__.bool", annotations: Annotations = None) -> SMTBool:
"""
Creates a Bool with concrete value
:param value: The boolean value
@ -120,7 +120,7 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]):
:return: The freshly created Bool()
"""
raw = z3.BoolVal(value)
return Bool(raw, annotations)
return SMTBool(raw, annotations)
@staticmethod
def BitVecVal(value: int, size: int, annotations: Annotations = None) -> BitVec:

@ -1,7 +1,7 @@
"""This module provides classes for an SMT abstraction of bit vectors."""
from typing import Union, overload, List, Set, cast, Any, Optional, Callable
from operator import lshift, rshift
from operator import lshift, rshift, ne, eq
import z3
from mythril.laser.smt.bool import Bool, And, Or
@ -12,6 +12,15 @@ Annotations = Set[Any]
# fmt: off
def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator):
if a.size() == b.size():
return operator(a, b)
if a.size() < b.size():
a, b = b, a
b = z3.Concat(z3.BitVecVal(0, a.size() - b.size()), b)
return operator(a, b)
class BitVec(Expression[z3.BitVecRef]):
"""A bit vector symbol."""
@ -203,10 +212,9 @@ class BitVec(Expression[z3.BitVecRef]):
union = self.annotations.union(other.annotations)
# Some of the BitVecs can be 512 bit due to sha3()
if self.raw.size() != other.raw.size():
return Bool(z3.BoolVal(False), annotations=union)
eq_check = _padded_operation(self.raw, other.raw, eq)
# MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union)
return Bool(cast(z3.BoolRef, eq_check), annotations=union)
# MYPY: fix complains about overriding __ne__
def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
@ -224,10 +232,9 @@ class BitVec(Expression[z3.BitVecRef]):
union = self.annotations.union(other.annotations)
# Some of the BitVecs can be 512 bit due to sha3()
if self.raw.size() != other.raw.size():
return Bool(z3.BoolVal(True), annotations=union)
neq_check = _padded_operation(self.raw, other.raw, ne)
# MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union)
return Bool(cast(z3.BoolRef, neq_check), annotations=union)
def _handle_shift(self, other: Union[int, "BitVec"], operator: Callable) -> "BitVec":
"""

Loading…
Cancel
Save