Pad while comparing unequal functions

pull/1153/head
Nikhil 5 years ago
parent 120d0b2e2d
commit 7d9194a06c
  1. 20
      mythril/laser/smt/bitvec.py

@ -10,6 +10,16 @@ from mythril.laser.smt.expression import Expression
Annotations = Set[Any] Annotations = Set[Any]
# fmt: off # fmt: off
import operator
def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator):
if a.size() == b.size():
return operator(a, b)
if a.size() < b.size():
a, b = b, a
b = z3.Concat(z3.BitVecVal(0, a.size() - b.size()), b)
return operator(a, b)
class BitVec(Expression[z3.BitVecRef]): class BitVec(Expression[z3.BitVecRef]):
@ -203,10 +213,9 @@ class BitVec(Expression[z3.BitVecRef]):
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
# Some of the BitVecs can be 512 bit due to sha3() # Some of the BitVecs can be 512 bit due to sha3()
if self.raw.size() != other.raw.size(): eq_check = _padded_operation(self.raw, other.raw, operator.eq)
return Bool(z3.BoolVal(False), annotations=union)
# MYPY: fix complaints due to z3 overriding __eq__ # MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union) return Bool(cast(z3.BoolRef, eq_check), annotations=union)
# MYPY: fix complains about overriding __ne__ # MYPY: fix complains about overriding __ne__
def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
@ -224,10 +233,9 @@ class BitVec(Expression[z3.BitVecRef]):
union = self.annotations.union(other.annotations) union = self.annotations.union(other.annotations)
# Some of the BitVecs can be 512 bit due to sha3() # Some of the BitVecs can be 512 bit due to sha3()
if self.raw.size() != other.raw.size(): neq_check = _padded_operation(self.raw, other.raw, operator.ne)
return Bool(z3.BoolVal(True), annotations=union)
# MYPY: fix complaints due to z3 overriding __eq__ # MYPY: fix complaints due to z3 overriding __eq__
return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) return Bool(cast(z3.BoolRef, neq_check), annotations=union)
def _handle_shift(self, other: Union[int, "BitVec"], operator: Callable) -> "BitVec": def _handle_shift(self, other: Union[int, "BitVec"], operator: Callable) -> "BitVec":
""" """

Loading…
Cancel
Save