diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index 1b8080f8..ff84cfb5 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -10,6 +10,16 @@ from mythril.laser.smt.expression import Expression Annotations = Set[Any] # fmt: off +import operator + + +def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator): + if a.size() == b.size(): + return operator(a, b) + if a.size() < b.size(): + a, b = b, a + b = z3.Concat(z3.BitVecVal(0, a.size() - b.size()), b) + return operator(a, b) class BitVec(Expression[z3.BitVecRef]): @@ -203,10 +213,9 @@ class BitVec(Expression[z3.BitVecRef]): union = self.annotations.union(other.annotations) # Some of the BitVecs can be 512 bit due to sha3() - if self.raw.size() != other.raw.size(): - return Bool(z3.BoolVal(False), annotations=union) + eq_check = _padded_operation(self.raw, other.raw, operator.eq) # MYPY: fix complaints due to z3 overriding __eq__ - return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union) + return Bool(cast(z3.BoolRef, eq_check), annotations=union) # MYPY: fix complains about overriding __ne__ def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore @@ -224,10 +233,9 @@ class BitVec(Expression[z3.BitVecRef]): union = self.annotations.union(other.annotations) # Some of the BitVecs can be 512 bit due to sha3() - if self.raw.size() != other.raw.size(): - return Bool(z3.BoolVal(True), annotations=union) + neq_check = _padded_operation(self.raw, other.raw, operator.ne) # MYPY: fix complaints due to z3 overriding __eq__ - return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) + return Bool(cast(z3.BoolRef, neq_check), annotations=union) def _handle_shift(self, other: Union[int, "BitVec"], operator: Callable) -> "BitVec": """