diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index 407f6565..f441948e 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -21,13 +21,13 @@ from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics from mythril.laser.smt.model import Model - +from mythril.laser.smt.bool import Bool as SMTBool from typing import Union, Any, Optional, Set, TypeVar, Generic import z3 Annotations = Optional[Set[Any]] -T = TypeVar("T", bound=Union[bool.Bool, z3.BoolRef]) +T = TypeVar("T", bound=Union[SMTBool, z3.BoolRef]) U = TypeVar("U", bound=Union[BitVec, z3.BitVecRef]) @@ -105,14 +105,14 @@ class SymbolFactory(Generic[T, U]): raise NotImplementedError() -class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): +class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): """ An implementation of a SymbolFactory that creates symbols using the classes in: mythril.laser.smt """ @staticmethod - def Bool(value: "__builtins__.bool", annotations: Annotations = None) -> bool.Bool: + def Bool(value: "__builtins__.bool", annotations: Annotations = None) -> SMTBool: """ Creates a Bool with concrete value :param value: The boolean value @@ -120,7 +120,7 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): :return: The freshly created Bool() """ raw = z3.BoolVal(value) - return Bool(raw, annotations) + return SMTBool(raw, annotations) @staticmethod def BitVecVal(value: int, size: int, annotations: Annotations = None) -> BitVec: diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index 1b8080f8..df537582 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -1,7 +1,7 @@ """This module provides classes for an SMT abstraction of bit vectors.""" from typing import Union, overload, List, Set, cast, Any, Optional, Callable -from operator import lshift, rshift +from operator import lshift, rshift, ne, eq import z3 from mythril.laser.smt.bool import Bool, And, Or @@ -12,6 +12,15 @@ Annotations = Set[Any] # fmt: off +def _padded_operation(a: z3.BitVec, b: z3.BitVec, operator): + if a.size() == b.size(): + return operator(a, b) + if a.size() < b.size(): + a, b = b, a + b = z3.Concat(z3.BitVecVal(0, a.size() - b.size()), b) + return operator(a, b) + + class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" @@ -203,10 +212,9 @@ class BitVec(Expression[z3.BitVecRef]): union = self.annotations.union(other.annotations) # Some of the BitVecs can be 512 bit due to sha3() - if self.raw.size() != other.raw.size(): - return Bool(z3.BoolVal(False), annotations=union) + eq_check = _padded_operation(self.raw, other.raw, eq) # MYPY: fix complaints due to z3 overriding __eq__ - return Bool(cast(z3.BoolRef, self.raw == other.raw), annotations=union) + return Bool(cast(z3.BoolRef, eq_check), annotations=union) # MYPY: fix complains about overriding __ne__ def __ne__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore @@ -224,10 +232,9 @@ class BitVec(Expression[z3.BitVecRef]): union = self.annotations.union(other.annotations) # Some of the BitVecs can be 512 bit due to sha3() - if self.raw.size() != other.raw.size(): - return Bool(z3.BoolVal(True), annotations=union) + neq_check = _padded_operation(self.raw, other.raw, ne) # MYPY: fix complaints due to z3 overriding __eq__ - return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union) + return Bool(cast(z3.BoolRef, neq_check), annotations=union) def _handle_shift(self, other: Union[int, "BitVec"], operator: Callable) -> "BitVec": """ diff --git a/mythril/mythril/mythril_disassembler.py b/mythril/mythril/mythril_disassembler.py index 5e1e72c3..bfd7e23c 100644 --- a/mythril/mythril/mythril_disassembler.py +++ b/mythril/mythril/mythril_disassembler.py @@ -186,7 +186,27 @@ class MythrilDisassembler: except FileNotFoundError: raise CriticalError("Input file not found: " + file) except CompilerError as e: - raise CriticalError(e) + error_msg = str(e) + # Check if error is related to solidity version mismatch + if ( + "Error: Source file requires different compiler version" + in error_msg + ): + # Grab relevant line "pragma solidity ...", excluding any comments + solv_pragma_line = error_msg.split("\n")[-3].split("//")[0] + # Grab solidity version from relevant line + solv_match = re.findall(r"[0-9]+\.[0-9]+\.[0-9]+", solv_pragma_line) + error_suggestion = ( + "" if len(solv_match) != 1 else solv_match[0] + ) + error_msg = ( + error_msg + + '\nSolidityVersionMismatch: Try adding the option "--solv ' + + error_suggestion + + '"\n' + ) + + raise CriticalError(error_msg) except NoContractFoundError: log.error( "The file " + file + " does not contain a compilable contract."