diff --git a/mythril/analysis/modules/base.py b/mythril/analysis/modules/base.py index 072f74e9..d7075109 100644 --- a/mythril/analysis/modules/base.py +++ b/mythril/analysis/modules/base.py @@ -3,6 +3,7 @@ modules.""" import logging from typing import List +from mythril.analysis.report import Issue log = logging.getLogger(__name__) @@ -21,7 +22,7 @@ class DetectionModule: entrypoint: str = "post", pre_hooks: List[str] = None, post_hooks: List[str] = None, - ): + ) -> None: self.name = name self.swc_id = swc_id self.pre_hooks = pre_hooks if pre_hooks else [] @@ -33,7 +34,7 @@ class DetectionModule: self.name, ) self.entrypoint = entrypoint - self._issues = [] + self._issues = [] # type: List[Issue] @property def issues(self): diff --git a/mythril/analysis/modules/delegatecall.py b/mythril/analysis/modules/delegatecall.py index 596f8ce5..9ebfcd17 100644 --- a/mythril/analysis/modules/delegatecall.py +++ b/mythril/analysis/modules/delegatecall.py @@ -17,7 +17,7 @@ log = logging.getLogger(__name__) class DelegateCallModule(DetectionModule): """This module detects calldata being forwarded using DELEGATECALL.""" - def __init__(self): + def __init__(self) -> None: """""" super().__init__( name="DELEGATECALL Usage in Fallback Function", @@ -46,7 +46,7 @@ def _analyze_states(state: GlobalState) -> List[Issue]: call = get_call_from_state(state) if call is None: return [] - issues = [] + issues = [] # type: List[Issue] if call.type is not "DELEGATECALL": return [] diff --git a/mythril/analysis/modules/dependence_on_predictable_vars.py b/mythril/analysis/modules/dependence_on_predictable_vars.py index 632672b8..1450f93a 100644 --- a/mythril/analysis/modules/dependence_on_predictable_vars.py +++ b/mythril/analysis/modules/dependence_on_predictable_vars.py @@ -19,7 +19,7 @@ class PredictableDependenceModule(DetectionModule): """This module detects whether Ether is sent using predictable parameters.""" - def __init__(self): + def __init__(self) -> None: """""" super().__init__( name="Dependence of Predictable Variables", @@ -118,9 +118,9 @@ def _analyze_states(state: GlobalState) -> list: m = re.search(r"blockhash\w+(\s-\s(\d+))*", str(constraint)) if m and solve(call): - found = m.group(1) + found_item = m.group(1) - if found: # block.blockhash(block.number - N) + if found_item: # block.blockhash(block.number - N) description = ( "The predictable expression 'block.blockhash(block.number - " + m.group(2) diff --git a/mythril/analysis/modules/integer.py b/mythril/analysis/modules/integer.py index c11b40ed..6cbfd5ac 100644 --- a/mythril/analysis/modules/integer.py +++ b/mythril/analysis/modules/integer.py @@ -2,7 +2,7 @@ underflows.""" import json - +from typing import Dict from mythril.analysis import solver from mythril.analysis.report import Issue from mythril.analysis.swc_data import INTEGER_OVERFLOW_AND_UNDERFLOW @@ -27,7 +27,9 @@ log = logging.getLogger(__name__) class OverUnderflowAnnotation: - def __init__(self, overflowing_state: GlobalState, operator: str, constraint): + def __init__( + self, overflowing_state: GlobalState, operator: str, constraint + ) -> None: self.overflowing_state = overflowing_state self.operator = operator self.constraint = constraint @@ -36,7 +38,7 @@ class OverUnderflowAnnotation: class IntegerOverflowUnderflowModule(DetectionModule): """This module searches for integer over- and underflows.""" - def __init__(self): + def __init__(self) -> None: """""" super().__init__( name="Integer Overflow and Underflow", @@ -49,8 +51,8 @@ class IntegerOverflowUnderflowModule(DetectionModule): entrypoint="callback", pre_hooks=["ADD", "MUL", "SUB", "SSTORE", "JUMPI"], ) - self._overflow_cache = {} - self._underflow_cache = {} + self._overflow_cache = {} # type: Dict[int, bool] + self._underflow_cache = {} # type: Dict[int, bool] def reset_module(self): """ diff --git a/mythril/analysis/modules/multiple_sends.py b/mythril/analysis/modules/multiple_sends.py index 04a74bce..fbe0c03b 100644 --- a/mythril/analysis/modules/multiple_sends.py +++ b/mythril/analysis/modules/multiple_sends.py @@ -1,7 +1,9 @@ """This module contains the detection code to find multiple sends occurring in a single transaction.""" from copy import copy +from typing import cast, List, Optional +from mythril.analysis.ops import Call from mythril.analysis.report import Issue from mythril.analysis.swc_data import MULTIPLE_SENDS from mythril.analysis.modules.base import DetectionModule @@ -14,8 +16,8 @@ log = logging.getLogger(__name__) class MultipleSendsAnnotation(StateAnnotation): - def __init__(self): - self.calls = [] + def __init__(self) -> None: + self.calls = [] # type: List[Optional[Call]] def __copy__(self): result = MultipleSendsAnnotation() @@ -56,11 +58,17 @@ def _analyze_state(state: GlobalState): node = state.node instruction = state.get_current_instruction() - annotations = [a for a in state.get_annotations(MultipleSendsAnnotation)] + annotations = cast( + List[MultipleSendsAnnotation], + [a for a in state.get_annotations(MultipleSendsAnnotation)], + ) if len(annotations) == 0: log.debug("Creating annotation for state") state.annotate(MultipleSendsAnnotation()) - annotations = [a for a in state.get_annotations(MultipleSendsAnnotation)] + annotations = cast( + List[MultipleSendsAnnotation], + [a for a in state.get_annotations(MultipleSendsAnnotation)], + ) calls = annotations[0].calls diff --git a/mythril/analysis/modules/unchecked_retval.py b/mythril/analysis/modules/unchecked_retval.py index 70d7f0fd..5beb0fda 100644 --- a/mythril/analysis/modules/unchecked_retval.py +++ b/mythril/analysis/modules/unchecked_retval.py @@ -1,12 +1,15 @@ """This module contains detection code to find occurrences of calls whose return value remains unchecked.""" from copy import copy +from typing import cast, List, Union, Mapping from mythril.analysis import solver from mythril.analysis.report import Issue from mythril.analysis.swc_data import UNCHECKED_RET_VAL from mythril.analysis.modules.base import DetectionModule from mythril.exceptions import UnsatError +from mythril.laser.smt.bitvec import BitVec + from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.global_state import GlobalState @@ -16,8 +19,8 @@ log = logging.getLogger(__name__) class UncheckedRetvalAnnotation(StateAnnotation): - def __init__(self): - self.retvals = [] + def __init__(self) -> None: + self.retvals = [] # type: List[Mapping[str, Union[int, BitVec]]] def __copy__(self): result = UncheckedRetvalAnnotation() @@ -60,10 +63,16 @@ def _analyze_state(state: GlobalState) -> list: instruction = state.get_current_instruction() node = state.node - annotations = [a for a in state.get_annotations(UncheckedRetvalAnnotation)] + annotations = cast( + List[UncheckedRetvalAnnotation], + [a for a in state.get_annotations(UncheckedRetvalAnnotation)], + ) if len(annotations) == 0: state.annotate(UncheckedRetvalAnnotation()) - annotations = [a for a in state.get_annotations(UncheckedRetvalAnnotation)] + annotations = cast( + List[UncheckedRetvalAnnotation], + [a for a in state.get_annotations(UncheckedRetvalAnnotation)], + ) retvals = annotations[0].retvals @@ -103,7 +112,13 @@ def _analyze_state(state: GlobalState) -> list: "opcode" ] in ["CALL", "DELEGATECALL", "STATICCALL", "CALLCODE"] retval = state.mstate.stack[-1] - retvals.append({"address": state.instruction["address"] - 1, "retval": retval}) + # Use Typed Dict after release of mypy 0.670 and remove type ignore + retvals.append( + { # type: ignore + "address": state.instruction["address"] - 1, + "retval": retval, + } + ) return [] diff --git a/mythril/disassembler/asm.py b/mythril/disassembler/asm.py index 2c9c7907..cbbe1881 100644 --- a/mythril/disassembler/asm.py +++ b/mythril/disassembler/asm.py @@ -90,7 +90,7 @@ def is_sequence_match(pattern: list, instruction_list: list, index: int) -> bool return True -def disassemble(bytecode: str) -> list: +def disassemble(bytecode: bytes) -> list: """Disassembles evm bytecode and returns a list of instructions. :param bytecode: diff --git a/mythril/disassembler/disassembly.py b/mythril/disassembler/disassembly.py index 0d4828fc..0cc1c8aa 100644 --- a/mythril/disassembler/disassembly.py +++ b/mythril/disassembler/disassembly.py @@ -3,6 +3,8 @@ from mythril.ethereum import util from mythril.disassembler import asm from mythril.support.signatures import SignatureDB +from typing import Dict, List, Tuple + class Disassembly(object): """Disassembly class. @@ -14,7 +16,7 @@ class Disassembly(object): - function entry point to function name mapping """ - def __init__(self, code: str, enable_online_lookup: bool = False): + def __init__(self, code: str, enable_online_lookup: bool = False) -> None: """ :param code: @@ -23,9 +25,9 @@ class Disassembly(object): self.bytecode = code self.instruction_list = asm.disassemble(util.safe_decode(code)) - self.func_hashes = [] - self.function_name_to_address = {} - self.address_to_function_name = {} + self.func_hashes = [] # type: List[str] + self.function_name_to_address = {} # type: Dict[str, int] + self.address_to_function_name = {} # type: Dict[int, str] # open from default locations # control if you want to have online signature hash lookups @@ -41,7 +43,6 @@ class Disassembly(object): index, self.instruction_list, signatures ) self.func_hashes.append(function_hash) - if jump_target is not None and function_name is not None: self.function_name_to_address[function_name] = jump_target self.address_to_function_name[jump_target] = function_name @@ -56,7 +57,7 @@ class Disassembly(object): def get_function_info( index: int, instruction_list: list, signature_database: SignatureDB -) -> (str, int, str): +) -> Tuple[str, int, str]: """Finds the function information for a call table entry Solidity uses the first 4 bytes of the calldata to indicate which function the message call should execute The generated code that directs execution to the correct diff --git a/mythril/laser/ethereum/call.py b/mythril/laser/ethereum/call.py index 7ee0af34..27df993e 100644 --- a/mythril/laser/ethereum/call.py +++ b/mythril/laser/ethereum/call.py @@ -3,9 +3,9 @@ instructions.py to get the necessary elements from the stack and determine the parameters for the new global state.""" import logging -from typing import Union, List +from typing import Union, List, cast, Callable from z3 import Z3Exception - +from mythril.laser.smt import BitVec from mythril.laser.ethereum import natives from mythril.laser.ethereum.gas import OPCODE_GAS from mythril.laser.smt import simplify, Expression, symbol_factory @@ -155,8 +155,8 @@ def get_callee_account( def get_call_data( global_state: GlobalState, - memory_start: Union[int, Expression], - memory_size: Union[int, Expression], + memory_start: Union[int, BitVec], + memory_size: Union[int, BitVec], ): """Gets call_data from the global_state. @@ -168,22 +168,28 @@ def get_call_data( state = global_state.mstate transaction_id = "{}_internalcall".format(global_state.current_transaction.id) - memory_start = ( - symbol_factory.BitVecVal(memory_start, 256) - if isinstance(memory_start, int) - else memory_start + memory_start = cast( + BitVec, + ( + symbol_factory.BitVecVal(memory_start, 256) + if isinstance(memory_start, int) + else memory_start + ), ) - memory_size = ( - symbol_factory.BitVecVal(memory_size, 256) - if isinstance(memory_size, int) - else memory_size + memory_size = cast( + BitVec, + ( + symbol_factory.BitVecVal(memory_size, 256) + if isinstance(memory_size, int) + else memory_size + ), ) uses_entire_calldata = simplify( memory_size - global_state.environment.calldata.calldatasize == 0 ) - if uses_entire_calldata == True: + if uses_entire_calldata is True: return global_state.environment.calldata try: @@ -218,7 +224,7 @@ def native_call( contract_list = ["ecrecover", "sha256", "ripemd160", "identity"] call_address_int = int(callee_address, 16) - native_gas_min, native_gas_max = OPCODE_GAS["NATIVE_COST"]( + native_gas_min, native_gas_max = cast(Callable, OPCODE_GAS["NATIVE_COST"])( global_state.mstate.calculate_extension_size(mem_out_start, mem_out_sz), contract_list[call_address_int - 1], ) diff --git a/mythril/laser/ethereum/cfg.py b/mythril/laser/ethereum/cfg.py index 056692e5..0578c602 100644 --- a/mythril/laser/ethereum/cfg.py +++ b/mythril/laser/ethereum/cfg.py @@ -1,9 +1,12 @@ """This module.""" from enum import Enum -from typing import Dict +from typing import Dict, List, TYPE_CHECKING from flags import Flags +if TYPE_CHECKING: + from mythril.laser.ethereum.state.global_state import GlobalState + gbl_next_uid = 0 # node counter @@ -20,6 +23,9 @@ class JumpType(Enum): class NodeFlags(Flags): """A collection of flags to denote the type a call graph node can have.""" + def __or__(self, other) -> "NodeFlags": + return super().__or__(other) + FUNC_ENTRY = 1 CALL_RETURN = 2 @@ -33,7 +39,7 @@ class Node: start_addr=0, constraints=None, function_name="unknown", - ): + ) -> None: """ :param contract_name: @@ -43,7 +49,7 @@ class Node: constraints = constraints if constraints else [] self.contract_name = contract_name self.start_addr = start_addr - self.states = [] + self.states = [] # type: List[GlobalState] self.constraints = constraints self.function_name = function_name self.flags = NodeFlags() @@ -86,7 +92,7 @@ class Edge: node_to: int, edge_type=JumpType.UNCONDITIONAL, condition=None, - ): + ) -> None: """ :param node_from: diff --git a/mythril/laser/ethereum/gas.py b/mythril/laser/ethereum/gas.py index 43c5d48b..8fae5f53 100644 --- a/mythril/laser/ethereum/gas.py +++ b/mythril/laser/ethereum/gas.py @@ -2,6 +2,7 @@ table.""" from ethereum import opcodes from ethereum.utils import ceil32 +from typing import Callable, Dict, Tuple, Union def calculate_native_gas(size: int, contract: str): @@ -185,4 +186,4 @@ OPCODE_GAS = { "SUICIDE": (5000, 30000), "ASSERT_FAIL": (0, 0), "INVALID": (0, 0), -} +} # type: Dict[str, Union[Tuple[int, int], Callable]] diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index 8ea097fe..ab0df0b3 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -4,7 +4,7 @@ import binascii import logging from copy import copy, deepcopy -from typing import Callable, List, Union +from typing import cast, Callable, List, Union, Tuple from datetime import datetime from ethereum import utils @@ -127,7 +127,7 @@ class StateTransition(object): if not self.enable_gas: return global_state opcode = global_state.instruction["opcode"] - min_gas, max_gas = OPCODE_GAS[opcode] + min_gas, max_gas = cast(Tuple[int, int], OPCODE_GAS[opcode]) global_state.mstate.min_gas_used += min_gas global_state.mstate.max_gas_used += max_gas return global_state @@ -155,7 +155,7 @@ class Instruction: """Instruction class is used to mutate a state according to the current instruction.""" - def __init__(self, op_code: str, dynamic_loader: DynLoader, iprof=None): + def __init__(self, op_code: str, dynamic_loader: DynLoader, iprof=None) -> None: """ :param op_code: @@ -358,7 +358,7 @@ class Instruction: symbol_factory.BitVecVal(0, 248), Extract(offset + 7, offset, op1), ) - ) + ) # type: Union[int, Expression] else: result = 0 except TypeError: @@ -717,17 +717,15 @@ class Instruction: log.debug("Unsupported symbolic memory offset in CALLDATACOPY") return [global_state] - dstart_sym = False try: - dstart = util.get_concrete_int(op1) + dstart = util.get_concrete_int(op1) # type: Union[int, BitVec] except TypeError: log.debug("Unsupported symbolic calldata offset in CALLDATACOPY") dstart = simplify(op1) - dstart_sym = True size_sym = False try: - size = util.get_concrete_int(op2) + size = util.get_concrete_int(op2) # type: Union[int, BitVec] except TypeError: log.debug("Unsupported symbolic size in CALLDATACOPY") size = simplify(op2) @@ -746,7 +744,7 @@ class Instruction: 8, ) return [global_state] - + size = cast(int, size) if size > 0: try: state.mem_extend(mstart, size) @@ -778,7 +776,9 @@ class Instruction: new_memory.append(value) i_data = ( - i_data + 1 if isinstance(i_data, int) else simplify(i_data + 1) + i_data + 1 + if isinstance(i_data, int) + else simplify(cast(BitVec, i_data) + 1) ) for i in range(len(new_memory)): state.memory[i + mstart] = new_memory[i] @@ -881,11 +881,12 @@ class Instruction: state.stack.append( symbol_factory.BitVecSym("KECCAC_mem[" + str(op0) + "]", 256) ) - state.min_gas_used += OPCODE_GAS["SHA3"][0] - state.max_gas_used += OPCODE_GAS["SHA3"][1] + gas_tuple = cast(Tuple, OPCODE_GAS["SHA3"]) + state.min_gas_used += gas_tuple[0] + state.max_gas_used += gas_tuple[1] return [global_state] - min_gas, max_gas = OPCODE_GAS["SHA3_FUNC"](length) + min_gas, max_gas = cast(Callable, OPCODE_GAS["SHA3_FUNC"])(length) state.min_gas_used += min_gas state.max_gas_used += max_gas StateTransition.check_gas_usage_limit(global_state) @@ -1268,7 +1269,9 @@ class Instruction: state.mem_extend(offset, 1) try: - value_to_write = util.get_concrete_int(value) ^ 0xFF + value_to_write = ( + util.get_concrete_int(value) ^ 0xFF + ) # type: Union[int, BitVec] except TypeError: # BitVec value_to_write = Extract(7, 0, value) log.debug("MSTORE8 to mem[" + str(offset) + "]: " + str(value_to_write)) @@ -1301,7 +1304,7 @@ class Instruction: storage_keys = global_state.environment.active_account.storage.keys() keccak_keys = list(filter(keccak_function_manager.is_keccak, storage_keys)) - results = [] + results = [] # type: List[GlobalState] constraints = [] for keccak_key in keccak_keys: @@ -1328,7 +1331,7 @@ class Instruction: @staticmethod def _sload_helper( - global_state: GlobalState, index: Union[int, Expression], constraints=None + global_state: GlobalState, index: Union[str, int], constraints=None ): """ @@ -1387,17 +1390,21 @@ class Instruction: storage_keys = global_state.environment.active_account.storage.keys() keccak_keys = filter(keccak_function_manager.is_keccak, storage_keys) - results = [] + results = [] # type: List[GlobalState] new = symbol_factory.Bool(False) for keccak_key in keccak_keys: - key_argument = keccak_function_manager.get_argument(keccak_key) - index_argument = keccak_function_manager.get_argument(index) + key_argument = keccak_function_manager.get_argument( + keccak_key + ) # type: Expression + index_argument = keccak_function_manager.get_argument( + index + ) # type: Expression condition = key_argument == index_argument condition = ( condition if type(condition) == bool - else is_true(simplify(condition)) + else is_true(simplify(cast(Bool, condition))) ) if condition: return self._sstore_helper( @@ -1414,7 +1421,7 @@ class Instruction: key_argument == index_argument, ) - new = Or(new, key_argument != index_argument) + new = Or(new, cast(Bool, key_argument != index_argument)) if len(results) > 0: results += self._sstore_helper( @@ -1482,7 +1489,7 @@ class Instruction: new_state = copy(global_state) # add JUMP gas cost - min_gas, max_gas = OPCODE_GAS["JUMP"] + min_gas, max_gas = cast(Tuple[int, int], OPCODE_GAS["JUMP"]) new_state.mstate.min_gas_used += min_gas new_state.mstate.max_gas_used += max_gas @@ -1501,7 +1508,7 @@ class Instruction: """ state = global_state.mstate disassembly = global_state.environment.code - min_gas, max_gas = OPCODE_GAS["JUMPI"] + min_gas, max_gas = cast(Tuple[int, int], OPCODE_GAS["JUMPI"]) states = [] op0, condition = state.stack.pop(), state.stack.pop() @@ -1910,12 +1917,12 @@ class Instruction: try: memory_out_offset = ( util.get_concrete_int(memory_out_offset) - if isinstance(memory_out_offset, ExprRef) + if isinstance(memory_out_offset, Expression) else memory_out_offset ) memory_out_size = ( util.get_concrete_int(memory_out_size) - if isinstance(memory_out_size, ExprRef) + if isinstance(memory_out_size, Expression) else memory_out_size ) except TypeError: diff --git a/mythril/laser/ethereum/keccak.py b/mythril/laser/ethereum/keccak.py index 9d2aefdc..47a39ed1 100644 --- a/mythril/laser/ethereum/keccak.py +++ b/mythril/laser/ethereum/keccak.py @@ -18,7 +18,7 @@ class KeccakFunctionManager: """ return str(expression) in self.keccak_expression_mapping.keys() - def get_argument(self, expression: str) -> Expression: + def get_argument(self, expression: Expression) -> Expression: """ :param expression: diff --git a/mythril/laser/ethereum/natives.py b/mythril/laser/ethereum/natives.py index 205e325a..ef38d8c9 100644 --- a/mythril/laser/ethereum/natives.py +++ b/mythril/laser/ethereum/natives.py @@ -9,7 +9,8 @@ from py_ecc.secp256k1 import N as secp256k1n from rlp.utils import ALL_BYTES from mythril.laser.ethereum.state.calldata import BaseCalldata, ConcreteCalldata -from mythril.laser.ethereum.util import bytearray_to_int, sha3 +from mythril.laser.ethereum.util import bytearray_to_int +from ethereum.utils import sha3 from mythril.laser.smt import Concat, simplify log = logging.getLogger(__name__) @@ -50,7 +51,7 @@ def extract32(data: bytearray, i: int) -> int: return bytearray_to_int(o) -def ecrecover(data: Union[bytes, str, List[int]]) -> bytes: +def ecrecover(data: List[int]) -> List[int]: """ :param data: @@ -58,54 +59,54 @@ def ecrecover(data: Union[bytes, str, List[int]]) -> bytes: """ # TODO: Add type hints try: - data = bytearray(data) - v = extract32(data, 32) - r = extract32(data, 64) - s = extract32(data, 96) + byte_data = bytearray(data) + v = extract32(byte_data, 32) + r = extract32(byte_data, 64) + s = extract32(byte_data, 96) except TypeError: raise NativeContractException - message = b"".join([ALL_BYTES[x] for x in data[0:32]]) + message = b"".join([ALL_BYTES[x] for x in byte_data[0:32]]) if r >= secp256k1n or s >= secp256k1n or v < 27 or v > 28: return [] try: pub = ecrecover_to_pub(message, v, r, s) except Exception as e: - log.debug("An error has occured while extracting public key: " + e) + log.debug("An error has occured while extracting public key: " + str(e)) return [] o = [0] * 12 + [x for x in sha3(pub)[-20:]] - return o + return list(bytearray(o)) -def sha256(data: Union[bytes, str, List[int]]) -> bytes: +def sha256(data: List[int]) -> List[int]: """ :param data: :return: """ try: - data = bytes(data) + byte_data = bytes(data) except TypeError: raise NativeContractException - return hashlib.sha256(data).digest() + return list(bytearray(hashlib.sha256(byte_data).digest())) -def ripemd160(data: Union[bytes, str, List[int]]) -> bytes: +def ripemd160(data: List[int]) -> List[int]: """ :param data: :return: """ try: - data = bytes(data) + bytes_data = bytes(data) except TypeError: raise NativeContractException - digest = hashlib.new("ripemd160", data).digest() + digest = hashlib.new("ripemd160", bytes_data).digest() padded = 12 * [0] + list(digest) - return bytes(padded) + return list(bytearray(bytes(padded))) -def identity(data: Union[bytes, str, List[int]]) -> bytes: +def identity(data: List[int]) -> List[int]: """ :param data: @@ -117,13 +118,9 @@ def identity(data: Union[bytes, str, List[int]]) -> bytes: # implementation would be byte indexed for the most # part. return data - result = [] - for i in range(0, len(data), 32): - result.append(simplify(Concat(data[i : i + 32]))) - return result -def native_contracts(address: int, data: BaseCalldata): +def native_contracts(address: int, data: BaseCalldata) -> List[int]: """Takes integer address 1, 2, 3, 4. :param address: @@ -133,8 +130,8 @@ def native_contracts(address: int, data: BaseCalldata): functions = (ecrecover, sha256, ripemd160, identity) if isinstance(data, ConcreteCalldata): - data = data.concrete(None) + concrete_data = data.concrete(None) else: raise NativeContractException() - return functions[address - 1](data) + return functions[address - 1](concrete_data) diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 2c69dbc5..c13c3f2f 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -14,17 +14,17 @@ from mythril.laser.smt import symbol_factory class Storage: """Storage class represents the storage of an Account.""" - def __init__(self, concrete=False, address=None, dynamic_loader=None): + def __init__(self, concrete=False, address=None, dynamic_loader=None) -> None: """Constructor for Storage. :param concrete: bool indicating whether to interpret uninitialized storage as concrete versus symbolic """ - self._storage = {} + self._storage = {} # type: Dict[Union[int, str], Any] self.concrete = concrete self.dynld = dynamic_loader self.address = address - def __getitem__(self, item: Union[int, slice]) -> Any: + def __getitem__(self, item: Union[str, int]) -> Any: try: return self._storage[item] except KeyError: @@ -51,7 +51,7 @@ class Storage: self._storage[item] = symbol_factory.BitVecVal(0, 256) return self._storage[item] - def __setitem__(self, key: str, value: ExprRef) -> None: + def __setitem__(self, key: Union[int, str], value: Any) -> None: self._storage[key] = value def keys(self) -> KeysView: @@ -73,7 +73,7 @@ class Account: balance=None, concrete_storage=False, dynamic_loader=None, - ): + ) -> None: """Constructor for account. :param address: Address of the account diff --git a/mythril/laser/ethereum/state/calldata.py b/mythril/laser/ethereum/state/calldata.py index b2ea15a0..a8ebfa2b 100644 --- a/mythril/laser/ethereum/state/calldata.py +++ b/mythril/laser/ethereum/state/calldata.py @@ -1,7 +1,6 @@ """This module declares classes to represent call data.""" -from typing import Union, Any +from typing import cast, Union, Tuple, List -from mythril.laser.smt import K, Array, If, simplify, Concat, Expression, BitVec from enum import Enum from typing import Any, Union @@ -13,6 +12,7 @@ from mythril.laser.ethereum.util import get_concrete_int from mythril.laser.smt import ( Array, BitVec, + Bool, Concat, Expression, If, @@ -26,7 +26,7 @@ class BaseCalldata: """Base calldata class This represents the calldata provided when sending a transaction to a contract.""" - def __init__(self, tx_id): + def __init__(self, tx_id: str) -> None: """ :param tx_id: @@ -34,7 +34,7 @@ class BaseCalldata: self.tx_id = tx_id @property - def calldatasize(self) -> Expression: + def calldatasize(self) -> BitVec: """ :return: Calldata size for this calldata object @@ -53,7 +53,7 @@ class BaseCalldata: parts = self[offset : offset + 32] return simplify(Concat(parts)) - def __getitem__(self, item: Union[int, slice]) -> Any: + def __getitem__(self, item: Union[int, slice, BitVec]) -> Any: """ :param item: @@ -88,7 +88,7 @@ class BaseCalldata: raise ValueError - def _load(self, item: Union[int, Expression]) -> Any: + def _load(self, item: Union[int, BitVec]) -> Any: """ :param item: @@ -96,7 +96,7 @@ class BaseCalldata: raise NotImplementedError() @property - def size(self) -> Union[Expression, int]: + def size(self) -> Union[BitVec, int]: """Returns the exact size of this calldata, this is not normalized. :return: unnormalized call data size @@ -114,7 +114,7 @@ class BaseCalldata: class ConcreteCalldata(BaseCalldata): """A concrete call data representation.""" - def __init__(self, tx_id: int, calldata: list): + def __init__(self, tx_id: str, calldata: list) -> None: """Initializes the ConcreteCalldata object. :param tx_id: Id of the transaction that the calldata is for. @@ -132,7 +132,7 @@ class ConcreteCalldata(BaseCalldata): super().__init__(tx_id) - def _load(self, item: Union[int, Expression]) -> BitVec: + def _load(self, item: Union[int, BitVec]) -> BitVec: """ :param item: @@ -161,7 +161,7 @@ class ConcreteCalldata(BaseCalldata): class BasicConcreteCalldata(BaseCalldata): """A base class to represent concrete call data.""" - def __init__(self, tx_id: int, calldata: list): + def __init__(self, tx_id: str, calldata: list) -> None: """Initializes the ConcreteCalldata object, that doesn't use z3 arrays. :param tx_id: Id of the transaction that the calldata is for. @@ -184,7 +184,7 @@ class BasicConcreteCalldata(BaseCalldata): value = symbol_factory.BitVecVal(0x0, 8) for i in range(self.size): - value = If(item == i, self._calldata[i], value) + value = If(cast(Union[BitVec, Bool], item) == i, self._calldata[i], value) return value def concrete(self, model: Model) -> list: @@ -207,7 +207,7 @@ class BasicConcreteCalldata(BaseCalldata): class SymbolicCalldata(BaseCalldata): """A class for representing symbolic call data.""" - def __init__(self, tx_id: int): + def __init__(self, tx_id: str) -> None: """Initializes the SymbolicCalldata object. :param tx_id: Id of the transaction that the calldata is for. @@ -216,7 +216,7 @@ class SymbolicCalldata(BaseCalldata): self._calldata = Array("{}_calldata".format(tx_id), 256, 8) super().__init__(tx_id) - def _load(self, item: Union[int, Expression]) -> Any: + def _load(self, item: Union[int, BitVec]) -> Any: """ :param item: @@ -226,7 +226,7 @@ class SymbolicCalldata(BaseCalldata): return simplify( If( item < self._size, - simplify(self._calldata[item]), + simplify(self._calldata[cast(BitVec, item)]), symbol_factory.BitVecVal(0, 8), ) ) @@ -247,7 +247,7 @@ class SymbolicCalldata(BaseCalldata): return result @property - def size(self) -> Expression: + def size(self) -> BitVec: """ :return: @@ -258,29 +258,34 @@ class SymbolicCalldata(BaseCalldata): class BasicSymbolicCalldata(BaseCalldata): """A basic class representing symbolic call data.""" - def __init__(self, tx_id: int): + def __init__(self, tx_id: str) -> None: """Initializes the SymbolicCalldata object. :param tx_id: Id of the transaction that the calldata is for. """ - self._reads = [] - self._size = BitVec(str(tx_id) + "_calldatasize", 256) + self._reads = [] # type: List[Tuple[Union[int, BitVec], BitVec]] + self._size = symbol_factory.BitVecSym(str(tx_id) + "_calldatasize", 256) super().__init__(tx_id) - def _load(self, item: Union[int, Expression], clean=False) -> Any: - x = symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item + def _load(self, item: Union[int, BitVec], clean=False) -> Any: + expr_item = ( + symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item + ) # type: BitVec symbolic_base_value = If( - x >= self._size, + expr_item >= self._size, symbol_factory.BitVecVal(0, 8), - BitVec("{}_calldata_{}".format(self.tx_id, str(item)), 8), + BitVec( + symbol_factory.BitVecSym( + "{}_calldata_{}".format(self.tx_id, str(item)), 8 + ) + ), ) return_value = symbolic_base_value for r_index, r_value in self._reads: - return_value = If(r_index == item, r_value, return_value) - + return_value = If(r_index == expr_item, r_value, return_value) if not clean: - self._reads.append((item, symbolic_base_value)) + self._reads.append((expr_item, symbolic_base_value)) return simplify(return_value) def concrete(self, model: Model) -> list: @@ -299,7 +304,7 @@ class BasicSymbolicCalldata(BaseCalldata): return result @property - def size(self) -> Expression: + def size(self) -> BitVec: """ :return: diff --git a/mythril/laser/ethereum/state/environment.py b/mythril/laser/ethereum/state/environment.py index 6fb018ab..69830007 100644 --- a/mythril/laser/ethereum/state/environment.py +++ b/mythril/laser/ethereum/state/environment.py @@ -22,7 +22,7 @@ class Environment: callvalue: ExprRef, origin: ExprRef, code=None, - ): + ) -> None: """ :param active_account: diff --git a/mythril/laser/ethereum/state/global_state.py b/mythril/laser/ethereum/state/global_state.py index 8facc663..aa931606 100644 --- a/mythril/laser/ethereum/state/global_state.py +++ b/mythril/laser/ethereum/state/global_state.py @@ -1,5 +1,5 @@ """This module contains a representation of the global execution state.""" -from typing import Dict, Union, List, Iterable +from typing import Dict, Union, List, Iterable, TYPE_CHECKING from copy import copy, deepcopy from z3 import BitVec @@ -10,6 +10,13 @@ from mythril.laser.ethereum.state.environment import Environment from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.ethereum.state.annotation import StateAnnotation +if TYPE_CHECKING: + from mythril.laser.ethereum.state.world_state import WorldState + from mythril.laser.ethereum.transaction.transaction_models import ( + MessageCallTransaction, + ContractCreationTransaction, + ) + class GlobalState: """GlobalState represents the current globalstate.""" @@ -23,7 +30,7 @@ class GlobalState: transaction_stack=None, last_return_data=None, annotations=None, - ): + ) -> None: """Constructor for GlobalState. :param world_state: diff --git a/mythril/laser/ethereum/state/machine_state.py b/mythril/laser/ethereum/state/machine_state.py index 45e824db..0ee781ae 100644 --- a/mythril/laser/ethereum/state/machine_state.py +++ b/mythril/laser/ethereum/state/machine_state.py @@ -1,9 +1,9 @@ """This module contains a representation of the EVM's machine state and its stack.""" from copy import copy -from typing import Union, Any, List, Dict +from typing import cast, Sized, Union, Any, List, Dict, Optional -from z3 import BitVec +from mythril.laser.smt import BitVec, Expression from ethereum import opcodes, utils from mythril.laser.ethereum.evm_exceptions import ( @@ -20,16 +20,14 @@ class MachineStack(list): STACK_LIMIT = 1024 - def __init__(self, default_list=None): + def __init__(self, default_list=None) -> None: """ :param default_list: """ - if default_list is None: - default_list = [] - super(MachineStack, self).__init__(default_list) + super(MachineStack, self).__init__(default_list or []) - def append(self, element: BitVec) -> None: + def append(self, element: Union[int, Expression]) -> None: """ :param element: element to be appended to the list :function: appends the element to list if the size is less than STACK_LIMIT, else throws an error @@ -41,7 +39,7 @@ class MachineStack(list): ) super(MachineStack, self).append(element) - def pop(self, index=-1) -> BitVec: + def pop(self, index=-1) -> Union[int, Expression]: """ :param index:index to be popped, same as the list() class. :returns popped value @@ -90,12 +88,12 @@ class MachineState: gas_limit: int, pc=0, stack=None, - memory=None, + memory: Optional[Memory] = None, constraints=None, depth=0, max_gas_used=0, min_gas_used=0, - ): + ) -> None: """Constructor for machineState. :param gas_limit: @@ -164,7 +162,7 @@ class MachineState: self.check_gas() self.memory.extend(m_extend) - def memory_write(self, offset: int, data: List[int]) -> None: + def memory_write(self, offset: int, data: List[Union[int, BitVec]]) -> None: """Writes data to memory starting at offset. :param offset: @@ -217,7 +215,7 @@ class MachineState: :return: """ - return len(self.memory) + return len(cast(Sized, self.memory)) @property def as_dict(self) -> Dict: diff --git a/mythril/laser/ethereum/state/memory.py b/mythril/laser/ethereum/state/memory.py index 64be50f3..4fac4120 100644 --- a/mythril/laser/ethereum/state/memory.py +++ b/mythril/laser/ethereum/state/memory.py @@ -1,5 +1,5 @@ """This module contains a representation of a smart contract's memory.""" -from typing import Union +from typing import cast, List, Union, overload from z3 import Z3Exception @@ -20,7 +20,7 @@ class Memory: def __init__(self): """""" - self._memory = [] + self._memory = [] # type: List[Union[int, BitVec]] def __len__(self): """ @@ -50,12 +50,14 @@ class Memory: ), 256, ) - except: + except TypeError: result = simplify( Concat( [ b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8) - for b in self[index : index + 32] + for b in cast( + List[Union[int, BitVec]], self[index : index + 32] + ) ] ) ) @@ -79,8 +81,9 @@ class Memory: else: _bytes = util.concrete_int_to_bytes(value) assert len(_bytes) == 32 - self[index : index + 32] = _bytes + self[index : index + 32] = list(bytearray(_bytes)) except (Z3Exception, AttributeError): # BitVector or BoolRef + value = cast(Union[BitVec, Bool], value) if isinstance(value, Bool): value_to_write = If( value, @@ -94,7 +97,17 @@ class Memory: for i in range(0, value_to_write.size(), 8): self[index + 31 - (i // 8)] = Extract(i + 7, i, value_to_write) - def __getitem__(self, item: Union[int, slice]) -> Union[BitVec, int, list]: + @overload + def __getitem__(self, item: int) -> Union[int, BitVec]: + ... + + @overload + def __getitem__(self, item: slice) -> List[Union[int, BitVec]]: + ... + + def __getitem__( + self, item: Union[int, slice] + ) -> Union[BitVec, int, List[Union[int, BitVec]]]: """ :param item: @@ -108,14 +121,18 @@ class Memory: raise IndexError("Invalid Memory Slice") if step is None: step = 1 - return [self[i] for i in range(start, stop, step)] + return [cast(Union[int, BitVec], self[i]) for i in range(start, stop, step)] try: return self._memory[item] except IndexError: return 0 - def __setitem__(self, key: Union[int, slice], value: Union[BitVec, int, list]): + def __setitem__( + self, + key: Union[int, slice], + value: Union[BitVec, int, List[Union[int, BitVec]]], + ): """ :param key: @@ -130,13 +147,13 @@ class Memory: raise IndexError("Invalid Memory Slice") if step is None: step = 1 - + assert type(value) == list for i in range(0, stop - start, step): - self[start + i] = value[i] + self[start + i] = cast(List[Union[int, BitVec]], value)[i] else: if isinstance(value, int): assert 0 <= value <= 0xFF if isinstance(value, BitVec): assert value.size() == 8 - self._memory[key] = value + self._memory[key] = cast(Union[int, BitVec], value) diff --git a/mythril/laser/ethereum/state/world_state.py b/mythril/laser/ethereum/state/world_state.py index 7bcd06b4..9b5bdcaa 100644 --- a/mythril/laser/ethereum/state/world_state.py +++ b/mythril/laser/ethereum/state/world_state.py @@ -1,11 +1,14 @@ """This module contains a representation of the EVM's world state.""" from copy import copy from random import randint -from typing import List, Iterator +from typing import Dict, List, Iterator, Optional, TYPE_CHECKING from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.annotation import StateAnnotation +if TYPE_CHECKING: + from mythril.laser.ethereum.cfg import Node + class WorldState: """The WorldState class represents the world state as described in the @@ -19,8 +22,8 @@ class WorldState: :param transaction_sequence: :param annotations: """ - self.accounts = {} - self.node = None + self.accounts = {} # type: Dict[str, Account] + self.node = None # type: Optional['Node'] self.transaction_sequence = transaction_sequence or [] self._annotations = annotations or [] diff --git a/mythril/laser/ethereum/strategy/basic.py b/mythril/laser/ethereum/strategy/basic.py index 6627241c..5930f4e6 100644 --- a/mythril/laser/ethereum/strategy/basic.py +++ b/mythril/laser/ethereum/strategy/basic.py @@ -1,5 +1,6 @@ """This module implements basic symbolic execution search strategies.""" from random import randrange +from typing import List from mythril.laser.ethereum.state.global_state import GlobalState from . import BasicSearchStrategy @@ -13,7 +14,10 @@ except ImportError: from random import random from bisect import bisect - def choices(population, weights=None): + # TODO: Remove ignore after this has been fixed: https://github.com/python/mypy/issues/1297 + def choices( # type: ignore + population: List, weights: List[int] = None + ) -> List[int]: """Returns a random element out of the population based on weight. If the relative weights or cumulative weights are not specified, @@ -21,7 +25,7 @@ except ImportError: """ if weights is None: return [population[int(random() * len(population))]] - cum_weights = accumulate(weights) + cum_weights = list(accumulate(weights)) return [ population[ bisect(cum_weights, random() * cum_weights[-1], 0, len(population) - 1) diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 7b2a39d8..491719ea 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -4,7 +4,7 @@ from collections import defaultdict from copy import copy from datetime import datetime, timedelta from functools import reduce -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, DefaultDict, List, Tuple, Union from mythril.laser.ethereum.cfg import NodeFlags, Node, Edge, JumpType from mythril.laser.ethereum.evm_exceptions import StackUnderflowException @@ -56,7 +56,7 @@ class LaserEVM: transaction_count=2, requires_statespace=True, enable_iprof=False, - ): + ) -> None: """ :param accounts: @@ -73,12 +73,12 @@ class LaserEVM: self.world_state = world_state self.open_states = [world_state] - self.coverage = {} + self.coverage = {} # type: Dict[str, Tuple[int, List[bool]]] self.total_states = 0 self.dynamic_loader = dynamic_loader - self.work_list = [] + self.work_list = [] # type: List[GlobalState] self.strategy = strategy(self.work_list, max_depth) self.max_depth = max_depth self.transaction_count = transaction_count @@ -88,14 +88,15 @@ class LaserEVM: self.requires_statespace = requires_statespace if self.requires_statespace: - self.nodes = {} - self.edges = [] + self.nodes = {} # type: Dict[int, Node] + self.edges = [] # type: List[Edge] - self.time = None + self.time = None # type: datetime - self.pre_hooks = defaultdict(list) - self.post_hooks = defaultdict(list) - self._add_world_state_hooks = [] + self.pre_hooks = defaultdict(list) # type: DefaultDict[str, List[Callable]] + self.post_hooks = defaultdict(list) # type: DefaultDict[str, List[Callable]] + + self._add_world_state_hooks = [] # type: List[Callable] self.iprof = InstructionProfiler() if enable_iprof else None log.info("LASER EVM initialized with dynamic loader: " + str(dynamic_loader)) @@ -153,11 +154,8 @@ class LaserEVM: self.total_states, ) for code, coverage in self.coverage.items(): - cov = ( - reduce(lambda sum_, val: sum_ + 1 if val else sum_, coverage[1]) - / float(coverage[0]) - * 100 - ) + cov = sum(coverage[1]) / float(coverage[0]) * 100 + log.info("Achieved {:.2f}% coverage for code: {}".format(cov, code)) if self.iprof is not None: @@ -198,9 +196,7 @@ class LaserEVM: """ total_covered_instructions = 0 for _, cv in self.coverage.items(): - total_covered_instructions += reduce( - lambda sum_, val: sum_ + 1 if val else sum_, cv[1] - ) + total_covered_instructions += sum(cv[1]) return total_covered_instructions def exec(self, create=False, track_gas=False) -> Union[List[GlobalState], None]: @@ -210,7 +206,7 @@ class LaserEVM: :param track_gas: :return: """ - final_states = [] + final_states = [] # type: List[GlobalState] for global_state in self.strategy: if ( self.create_timeout @@ -385,10 +381,10 @@ class LaserEVM: instruction_index = global_state.mstate.pc if code not in self.coverage.keys(): - self.coverage[code] = [ + self.coverage[code] = ( number_of_instructions, [False] * number_of_instructions, - ] + ) self.coverage[code][1][instruction_index] = True diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index a915c0e3..def9a494 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -3,7 +3,7 @@ execution.""" import array from z3 import ExprRef -from typing import Union +from typing import Union, Optional, cast from mythril.laser.ethereum.state.calldata import ConcreteCalldata from mythril.laser.ethereum.state.account import Account @@ -17,20 +17,20 @@ from mythril.laser.smt import symbol_factory _next_transaction_id = 0 -def get_next_transaction_id() -> int: +def get_next_transaction_id() -> str: """ :return: """ global _next_transaction_id _next_transaction_id += 1 - return _next_transaction_id + return str(_next_transaction_id) class TransactionEndSignal(Exception): """Exception raised when a transaction is finalized.""" - def __init__(self, global_state: GlobalState, revert=False): + def __init__(self, global_state: GlobalState, revert=False) -> None: self.global_state = global_state self.revert = revert @@ -42,7 +42,7 @@ class TransactionStartSignal(Exception): self, transaction: Union["MessageCallTransaction", "ContractCreationTransaction"], op_code: str, - ): + ) -> None: self.transaction = transaction self.op_code = op_code @@ -56,14 +56,14 @@ class BaseTransaction: callee_account: Account = None, caller: ExprRef = None, call_data=None, - identifier=None, + identifier: Optional[str] = None, gas_price=None, gas_limit=None, origin=None, code=None, call_value=None, init_call_data=True, - ): + ) -> None: assert isinstance(world_state, WorldState) self.world_state = world_state self.id = identifier or get_next_transaction_id() @@ -85,7 +85,7 @@ class BaseTransaction: self.caller = caller self.callee_account = callee_account if call_data is None and init_call_data: - self.call_data = SymbolicCalldata(self.id) + self.call_data = SymbolicCalldata(self.id) # type: BaseCalldata else: self.call_data = ( call_data @@ -99,7 +99,7 @@ class BaseTransaction: else symbol_factory.BitVecSym("callvalue{}".format(identifier), 256) ) - self.return_data = None + self.return_data = None # type: str def initial_global_state_from_environment(self, environment, active_function): """ @@ -117,7 +117,7 @@ class BaseTransaction: class MessageCallTransaction(BaseTransaction): """Transaction object models an transaction.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def initial_global_state(self) -> GlobalState: @@ -149,8 +149,9 @@ class MessageCallTransaction(BaseTransaction): class ContractCreationTransaction(BaseTransaction): """Transaction object models an transaction.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, init_call_data=False) + def __init__(self, *args, **kwargs) -> None: + # Remove ignore after https://github.com/python/mypy/issues/4335 is fixed + super().__init__(*args, **kwargs, init_call_data=False) # type: ignore # TODO: set correct balance for new account self.callee_account = self.callee_account or self.world_state.create_account( 0, concrete_storage=True diff --git a/mythril/laser/ethereum/util.py b/mythril/laser/ethereum/util.py index 880e04c4..9cb5d950 100644 --- a/mythril/laser/ethereum/util.py +++ b/mythril/laser/ethereum/util.py @@ -1,9 +1,10 @@ """This module contains various utility conversion functions and constants for LASER.""" import re -from typing import Dict, List, Union +from typing import Dict, List, Union, TYPE_CHECKING, cast -import sha3 as _sha3 +if TYPE_CHECKING: + from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.smt import BitVec, Bool, Expression, If, simplify, symbol_factory @@ -12,15 +13,6 @@ TT256M1 = 2 ** 256 - 1 TT255 = 2 ** 255 -def sha3(seed: str) -> bytes: - """ - - :param seed: - :return: - """ - return _sha3.keccak_256(bytes(seed)).digest() - - def safe_decode(hex_encoded_string: str) -> bytes: """ @@ -83,18 +75,16 @@ def pop_bitvec(state: "MachineState") -> BitVec: item = state.stack.pop() - if type(item) == Bool: + if isinstance(item, Bool): return If( - item, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256) + cast(Bool, item), + symbol_factory.BitVecVal(1, 256), + symbol_factory.BitVecVal(0, 256), ) - elif type(item) == bool: - if item: - return symbol_factory.BitVecVal(1, 256) - else: - return symbol_factory.BitVecVal(0, 256) - elif type(item) == int: + elif isinstance(item, int): return symbol_factory.BitVecVal(item, 256) else: + item = cast(BitVec, item) return simplify(item) @@ -116,8 +106,12 @@ def get_concrete_int(item: Union[int, Expression]) -> int: raise TypeError("Symbolic boolref encountered") return value + assert False, "Unhandled type {} encountered".format(str(type(item))) + -def concrete_int_from_bytes(concrete_bytes: bytes, start_index: int) -> int: +def concrete_int_from_bytes( + concrete_bytes: Union[List[Union[BitVec, int]], bytes], start_index: int +) -> int: """ :param concrete_bytes: @@ -130,7 +124,8 @@ def concrete_int_from_bytes(concrete_bytes: bytes, start_index: int) -> int: ] integer_bytes = concrete_bytes[start_index : start_index + 32] - return int.from_bytes(integer_bytes, byteorder="big") + # The below statement is expected to fail in some circumstances whose error is caught + return int.from_bytes(integer_bytes, byteorder="big") # type: ignore def concrete_int_to_bytes(val): diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index feeac40b..8347352c 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -11,6 +11,7 @@ Annotations = List[Any] # fmt: off + class BitVec(Expression[z3.BitVecRef]): """A bit vector symbol.""" @@ -139,6 +140,24 @@ class BitVec(Expression[z3.BitVecRef]): union = self.annotations + other.annotations return Bool(self.raw > other.raw, annotations=union) + def __le__(self, other: "BitVec") -> Bool: + """Create a signed less than expression. + + :param other: + :return: + """ + union = self.annotations + other.annotations + return Bool(self.raw <= other.raw, annotations=union) + + def __ge__(self, other: "BitVec") -> Bool: + """Create a signed greater than expression. + + :param other: + :return: + """ + union = self.annotations + other.annotations + return Bool(self.raw >= other.raw, annotations=union) + # MYPY: fix complains about overriding __eq__ def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore """Create an equality expression. diff --git a/mythril/laser/smt/bool.py b/mythril/laser/smt/bool.py index c82e4f1f..9fa097e4 100644 --- a/mythril/laser/smt/bool.py +++ b/mythril/laser/smt/bool.py @@ -1,7 +1,7 @@ """This module provides classes for an SMT abstraction of boolean expressions.""" -from typing import Union, cast +from typing import Union, cast, List import z3 @@ -81,13 +81,13 @@ class Bool(Expression[z3.BoolRef]): return False -def And(*args: Bool) -> Bool: +def And(*args: Union[Bool, bool]) -> Bool: """Create an And expression.""" union = [] - args = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] - for arg in args: + args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] + for arg in args_list: union.append(arg.annotations) - return Bool(z3.And([a.raw for a in args]), union) + return Bool(z3.And([a.raw for a in args_list]), union) def Or(a: Bool, b: Bool) -> Bool: diff --git a/mythril/laser/smt/expression.py b/mythril/laser/smt/expression.py index 2ed166c7..8e9e697e 100644 --- a/mythril/laser/smt/expression.py +++ b/mythril/laser/smt/expression.py @@ -46,7 +46,10 @@ class Expression(Generic[T]): return repr(self.raw) -def simplify(expression: Expression) -> Expression: +G = TypeVar("G", bound=Expression) + + +def simplify(expression: G) -> G: """Simplify the expression . :param expression: diff --git a/mythril/support/loader.py b/mythril/support/loader.py index d55bc185..6e434439 100644 --- a/mythril/support/loader.py +++ b/mythril/support/loader.py @@ -22,7 +22,7 @@ class DynLoader: self.contract_loading = contract_loading self.storage_loading = storage_loading - def read_storage(self, contract_address, index): + def read_storage(self, contract_address: str, index: int): """ :param contract_address: diff --git a/mythril/support/signatures.py b/mythril/support/signatures.py index 530f0137..e0deb9ea 100644 --- a/mythril/support/signatures.py +++ b/mythril/support/signatures.py @@ -7,7 +7,7 @@ import sqlite3 import time from collections import defaultdict from subprocess import PIPE, Popen -from typing import List +from typing import List, Set, DefaultDict, Dict from mythril.exceptions import CompilerError @@ -45,7 +45,7 @@ def synchronized(sync_lock): class Singleton(type): """A metaclass type implementing the singleton pattern.""" - _instances = {} + _instances = dict() # type: Dict[Singleton, Singleton] @synchronized(lock) def __call__(cls, *args, **kwargs): @@ -60,6 +60,7 @@ class Singleton(type): """ if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] @@ -120,12 +121,12 @@ class SignatureDB(object, metaclass=Singleton): :param path: """ self.enable_online_lookup = enable_online_lookup - self.online_lookup_miss = set() + self.online_lookup_miss = set() # type: Set[str] self.online_lookup_timeout = 0 # if we're analysing a Solidity file, store its hashes # here to prevent unnecessary lookups - self.solidity_sigs = defaultdict(list) + self.solidity_sigs = defaultdict(list) # type: DefaultDict[str, List[str]] if path is None: self.path = os.environ.get("MYTHRIL_DIR") or os.path.join( os.path.expanduser("~"), ".mythril" @@ -225,7 +226,7 @@ class SignatureDB(object, metaclass=Singleton): return text_sigs except FourByteDirectoryOnlineLookupError as fbdole: # wait at least 2 mins to try again - self.online_lookup_timeout = time.time() + 2 * 60 + self.online_lookup_timeout = int(time.time()) + 2 * 60 log.warning("Online lookup failed, not retrying for 2min: %s", fbdole) return [] diff --git a/mythril/support/support_utils.py b/mythril/support/support_utils.py index 9fe90a8f..b437d795 100644 --- a/mythril/support/support_utils.py +++ b/mythril/support/support_utils.py @@ -1,10 +1,11 @@ """This module contains utility functions for the Mythril support package.""" +from typing import Dict class Singleton(type): """A metaclass type implementing the singleton pattern.""" - _instances = {} + _instances = {} # type: Dict def __call__(cls, *args, **kwargs): """Delegate the call to an existing resource or a a new one. diff --git a/setup.py b/setup.py index 82a6fbed..1f0a0a02 100755 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ setup( "persistent>=4.2.0", "ethereum-input-decoder>=0.2.2", ], - tests_require=["pytest>=3.6.0", "pytest_mock", "pytest-cov"], + tests_require=["mypy", "pytest>=3.6.0", "pytest_mock", "pytest-cov"], python_requires=">=3.5", extras_require={}, package_data={"mythril.analysis.templates": ["*"], "mythril.support.assets": ["*"]}, diff --git a/tox.ini b/tox.ini index aeab98ca..93ee48a2 100644 --- a/tox.ini +++ b/tox.ini @@ -19,12 +19,14 @@ basepython = python3.6 setenv = COVERAGE_FILE = .coverage.{envname} deps = + mypy pytest pytest-mock pytest-cov passenv = MYTHRIL_DIR = {homedir} whitelist_externals = mkdir commands = + mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional mythril mkdir -p {toxinidir}/tests/testdata/outputs_current/ mkdir -p {toxinidir}/tests/testdata/outputs_current_laser_result/ py.test -v \ @@ -35,6 +37,8 @@ commands = --junitxml={toxworkdir}/output/{envname}/junit.xml \ {posargs} + + [coverage:report] omit = *__init__.py