From bd27897533d619305099274d636c65e122d6ab96 Mon Sep 17 00:00:00 2001 From: Nikhil Parasaram Date: Sun, 10 Sep 2023 15:59:29 +0100 Subject: [PATCH] Misc fixes (#1799) * Misc fixes * Fix singleton --- .../analysis/module/modules/external_calls.py | 4 +-- mythril/analysis/module/modules/integer.py | 6 ++-- .../analysis/module/modules/multiple_sends.py | 2 +- .../modules/state_change_external_calls.py | 4 +-- mythril/analysis/module/util.py | 2 +- mythril/analysis/report.py | 6 ++-- mythril/analysis/security.py | 4 +-- mythril/analysis/solver.py | 6 ++-- mythril/analysis/symbolic.py | 4 +-- mythril/concolic/find_trace.py | 2 +- mythril/disassembler/asm.py | 4 +-- mythril/disassembler/disassembly.py | 12 ++++---- mythril/ethereum/util.py | 5 ---- mythril/interfaces/cli.py | 2 +- mythril/laser/ethereum/cfg.py | 2 +- .../keccak_function_manager.py | 2 +- mythril/laser/ethereum/instructions.py | 30 +++++++++---------- mythril/laser/ethereum/state/account.py | 4 +-- mythril/laser/ethereum/state/calldata.py | 6 ++-- mythril/laser/ethereum/state/constraints.py | 4 +-- mythril/laser/ethereum/state/memory.py | 4 +-- mythril/laser/ethereum/state/world_state.py | 4 +-- mythril/laser/ethereum/strategy/__init__.py | 4 +-- .../ethereum/strategy/constraint_strategy.py | 17 +++++------ .../strategy/extensions/bounded_loops.py | 4 +-- mythril/laser/ethereum/svm.py | 6 ++-- .../transaction/transaction_models.py | 8 +++-- mythril/laser/ethereum/util.py | 2 +- mythril/laser/plugin/loader.py | 6 ++-- .../plugins/coverage/coverage_plugin.py | 4 +-- .../laser/plugin/plugins/dependency_pruner.py | 8 ++--- .../plugin/plugins/plugin_annotations.py | 10 +++---- mythril/laser/smt/bitvec_helper.py | 2 +- mythril/laser/smt/bool.py | 4 +-- mythril/laser/smt/model.py | 2 +- .../laser/smt/solver/independence_solver.py | 22 +++++++------- mythril/laser/smt/solver/solver.py | 8 ++--- mythril/mythril/mythril_analyzer.py | 6 ++-- mythril/mythril/mythril_config.py | 4 +-- mythril/mythril/mythril_disassembler.py | 7 ++--- mythril/plugin/discovery.py | 2 +- mythril/plugin/loader.py | 2 +- mythril/solidity/features.py | 18 +++++------ mythril/support/loader.py | 4 +-- mythril/support/model.py | 10 +++++-- mythril/support/signatures.py | 6 ++-- mythril/support/source_support.py | 2 +- mythril/support/support_utils.py | 11 ++++--- tests/features_test.py | 1 - tests/laser/evm_testsuite/evm_test.py | 4 +-- tests/laser/tx_prioritisation_test.py | 1 - 51 files changed, 148 insertions(+), 156 deletions(-) diff --git a/mythril/analysis/module/modules/external_calls.py b/mythril/analysis/module/modules/external_calls.py index 82e8825c..b1ce8009 100644 --- a/mythril/analysis/module/modules/external_calls.py +++ b/mythril/analysis/module/modules/external_calls.py @@ -27,8 +27,8 @@ Search for external calls with unrestricted gas to a user-specified address. def _is_precompile_call(global_state: GlobalState): - to = global_state.mstate.stack[-2] # type: BitVec - constraints = copy(global_state.world_state.constraints) + to: BitVec = global_state.mstate.stack[-2] + constraints: Constraints = copy(global_state.world_state.constraints) constraints += [ Or( to < symbol_factory.BitVecVal(1, 256), diff --git a/mythril/analysis/module/modules/integer.py b/mythril/analysis/module/modules/integer.py index 75beda52..5afacda4 100644 --- a/mythril/analysis/module/modules/integer.py +++ b/mythril/analysis/module/modules/integer.py @@ -50,7 +50,7 @@ class OverUnderflowStateAnnotation(StateAnnotation): """State Annotation used if an overflow is both possible and used in the annotated path""" def __init__(self) -> None: - self.overflowing_state_annotations = set() # type: Set[OverUnderflowAnnotation] + self.overflowing_state_annotations: Set[OverUnderflowAnnotation] = set() def __copy__(self): new_annotation = OverUnderflowStateAnnotation() @@ -91,8 +91,8 @@ class IntegerArithmetics(DetectionModule): """ super().__init__() - self._ostates_satisfiable = set() # type: Set[GlobalState] - self._ostates_unsatisfiable = set() # type: Set[GlobalState] + self._ostates_satisfiable: Set[GlobalState] = set() + self._ostates_unsatisfiable: Set[GlobalState] = set() def reset_module(self): """ diff --git a/mythril/analysis/module/modules/multiple_sends.py b/mythril/analysis/module/modules/multiple_sends.py index 97e39ef2..c46bd077 100644 --- a/mythril/analysis/module/modules/multiple_sends.py +++ b/mythril/analysis/module/modules/multiple_sends.py @@ -17,7 +17,7 @@ log = logging.getLogger(__name__) class MultipleSendsAnnotation(StateAnnotation): def __init__(self) -> None: - self.call_offsets = [] # type: List[int] + self.call_offsets: List[int] = [] def __copy__(self): result = MultipleSendsAnnotation() diff --git a/mythril/analysis/module/modules/state_change_external_calls.py b/mythril/analysis/module/modules/state_change_external_calls.py index cea6b72d..a74bdae9 100644 --- a/mythril/analysis/module/modules/state_change_external_calls.py +++ b/mythril/analysis/module/modules/state_change_external_calls.py @@ -29,7 +29,7 @@ STATE_READ_WRITE_LIST = ["SSTORE", "SLOAD", "CREATE", "CREATE2"] class StateChangeCallsAnnotation(StateAnnotation): def __init__(self, call_state: GlobalState, user_defined_address: bool) -> None: self.call_state = call_state - self.state_change_states = [] # type: List[GlobalState] + self.state_change_states: List[GlobalState] = [] self.user_defined_address = user_defined_address def __copy__(self): @@ -165,7 +165,7 @@ class StateChangeAfterCall(DetectionModule): # Record state changes following from a transfer of ether if op_code in CALL_LIST: - value = global_state.mstate.stack[-3] # type: BitVec + value: BitVec = global_state.mstate.stack[-3] if StateChangeAfterCall._balance_change(value, global_state): for annotation in annotations: annotation.state_change_states.append(global_state) diff --git a/mythril/analysis/module/util.py b/mythril/analysis/module/util.py index 1c9b3b97..a8121f06 100644 --- a/mythril/analysis/module/util.py +++ b/mythril/analysis/module/util.py @@ -19,7 +19,7 @@ def get_detection_module_hooks( :param hook_type: The type of hooks to retrieve (default: "pre") :return: Dictionary with discovered hooks """ - hook_dict = defaultdict(list) # type: Mapping[str, List[Callable]] + hook_dict: Mapping[str, List[Callable]] = defaultdict(list) for module in modules: hooks = module.pre_hooks if hook_type == "pre" else module.post_hooks diff --git a/mythril/analysis/report.py b/mythril/analysis/report.py index 2e310000..fb7850d9 100644 --- a/mythril/analysis/report.py +++ b/mythril/analysis/report.py @@ -254,9 +254,9 @@ class Report: :param contracts: :param exceptions: """ - self.issues = {} # type: Dict[bytes, Issue] + self.issues: Dict[bytes, Issue] = {} self.solc_version = "" - self.meta = {} # type: Dict[str, Any] + self.meta: Dict[str, Any] = {} self.source = Source() self.source.get_source_from_contracts_list(contracts) self.exceptions = exceptions or [] @@ -306,7 +306,7 @@ class Report: def _get_exception_data(self) -> dict: if not self.exceptions: return {} - logs = [] # type: List[Dict] + logs: List[Dict] = [] for exception in self.exceptions: logs += [{"level": "error", "hidden": True, "msg": exception}] return {"logs": logs} diff --git a/mythril/analysis/security.py b/mythril/analysis/security.py index e53da210..cd81fe31 100644 --- a/mythril/analysis/security.py +++ b/mythril/analysis/security.py @@ -13,7 +13,7 @@ log = logging.getLogger(__name__) def retrieve_callback_issues(white_list: Optional[List[str]] = None) -> List[Issue]: """Get the issues discovered by callback type detection modules""" - issues = [] # type: List[Issue] + issues: List[Issue] = [] for module in ModuleLoader().get_detection_modules( entry_point=EntryPoint.CALLBACK, white_list=white_list ): @@ -34,7 +34,7 @@ def fire_lasers(statespace, white_list: Optional[List[str]] = None) -> List[Issu """ log.info("Starting analysis") - issues = [] # type: List[Issue] + issues: List[Issue] = [] for module in ModuleLoader().get_detection_modules( entry_point=EntryPoint.POST, white_list=white_list ): diff --git a/mythril/analysis/solver.py b/mythril/analysis/solver.py index 58077870..df4d1076 100644 --- a/mythril/analysis/solver.py +++ b/mythril/analysis/solver.py @@ -36,7 +36,7 @@ def pretty_print_model(model): ret = "" for d in model.decls(): - if type(model[d]) == FuncInterp: + if isinstance(model[d], FuncInterp): condition = model[d].as_list() ret += "%s: %s\n" % (d.name(), condition) continue @@ -119,7 +119,7 @@ def _add_calldata_placeholder( if not isinstance(transaction_sequence[0], ContractCreationTransaction): return - if type(transaction_sequence[0].code.bytecode) == tuple: + if isinstance(transaction_sequence[0].code.bytecode, tuple): code_len = len(transaction_sequence[0].code.bytecode) * 2 else: code_len = len(transaction_sequence[0].code.bytecode) @@ -206,7 +206,7 @@ def _get_concrete_transaction(model: z3.Model, transaction: BaseTransaction): ) # Create concrete transaction dict - concrete_transaction = dict() # type: Dict[str, str] + concrete_transaction: Dict[str, str] = dict() concrete_transaction["input"] = "0x" + input_ concrete_transaction["value"] = "0x%x" % value # Fixme: base origin assignment on origin symbol diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index 561a75dd..e7cb4bd5 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -85,7 +85,7 @@ class SymExecWrapper: address = symbol_factory.BitVecVal(address, 256) beam_width = None if strategy == "dfs": - s_strategy = DepthFirstSearchStrategy # type: Type[BasicSearchStrategy] + s_strategy: Type[BasicSearchStrategy] = DepthFirstSearchStrategy elif strategy == "bfs": s_strategy = BreadthFirstSearchStrategy elif strategy == "naive-random": @@ -240,7 +240,7 @@ class SymExecWrapper: # Parse calls to make them easily accessible - self.calls = [] # type: List[Call] + self.calls: List[Call] = [] for key in self.nodes: diff --git a/mythril/concolic/find_trace.py b/mythril/concolic/find_trace.py index f2e1f216..5700f4be 100644 --- a/mythril/concolic/find_trace.py +++ b/mythril/concolic/find_trace.py @@ -30,7 +30,7 @@ def setup_concrete_initial_state(concrete_data: ConcreteData) -> WorldState: account = Account(address, concrete_storage=True) account.code = Disassembly(details["code"][2:]) account.nonce = details["nonce"] - if type(details["storage"]) == str: + if isinstance(type(details["storage"]), str): details["storage"] = eval(details["storage"]) # type: ignore for key, value in details["storage"].items(): key_bitvec = symbol_factory.BitVecVal(int(key, 16), 256) diff --git a/mythril/disassembler/asm.py b/mythril/disassembler/asm.py index f0c388b8..c05197ff 100644 --- a/mythril/disassembler/asm.py +++ b/mythril/disassembler/asm.py @@ -106,7 +106,7 @@ def disassemble(bytecode) -> list: address = 0 length = len(bytecode) - if type(bytecode) == str: + if isinstance(bytecode, str): bytecode = util.safe_decode(bytecode) length = len(bytecode) part_code = bytecode[-43:] @@ -135,7 +135,7 @@ def disassemble(bytecode) -> list: match = re.search(regex_PUSH, op_code) if match: argument_bytes = bytecode[address + 1 : address + 1 + int(match.group(1))] - if type(argument_bytes) == bytes: + if isinstance(argument_bytes, bytes): current_instruction.argument = "0x" + argument_bytes.hex() else: current_instruction.argument = argument_bytes diff --git a/mythril/disassembler/disassembly.py b/mythril/disassembler/disassembly.py index 2376be85..99b2709d 100644 --- a/mythril/disassembler/disassembly.py +++ b/mythril/disassembler/disassembly.py @@ -23,13 +23,13 @@ class Disassembly(object): :param enable_online_lookup: """ self.bytecode = code - if type(code) == str: + if isinstance(code, str): self.instruction_list = asm.disassemble(util.safe_decode(code)) else: self.instruction_list = asm.disassemble(code) - self.func_hashes = [] # type: List[str] - self.function_name_to_address = {} # type: Dict[str, int] - self.address_to_function_name = {} # type: Dict[int, str] + self.func_hashes: List[str] = [] + self.function_name_to_address: Dict[str, int] = {} + self.address_to_function_name: Dict[int, str] = {} self.enable_online_lookup = enable_online_lookup self.assign_bytecode(bytecode=code) @@ -84,7 +84,7 @@ def get_function_info( """ # Append with missing 0s at the beginning - if type(instruction_list[index]["argument"]) == tuple: + if isinstance(instruction_list[index]["argument"], tuple): try: function_hash = "0x" + bytes( instruction_list[index]["argument"] @@ -105,7 +105,7 @@ def get_function_info( try: offset = instruction_list[index + 2]["argument"] - if type(offset) == tuple: + if isinstance(offset, tuple): offset = bytes(offset).hex() entry_point = int(offset, 16) except (KeyError, IndexError): diff --git a/mythril/ethereum/util.py b/mythril/ethereum/util.py index 4045dee6..a053f496 100644 --- a/mythril/ethereum/util.py +++ b/mythril/ethereum/util.py @@ -216,11 +216,6 @@ def extract_version(file: typing.Optional[str]): continue return str(version) - else: - return None - - return None - def extract_binary(file: str) -> str: file_data = None diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index 73d9403e..38afbc36 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -322,7 +322,7 @@ def main() -> None: formatter_class=RawTextHelpFormatter, ) - list_detectors_parser = subparsers.add_parser( + _ = subparsers.add_parser( LIST_DETECTORS_COMMAND, parents=[output_parser], help="Lists available detection modules", diff --git a/mythril/laser/ethereum/cfg.py b/mythril/laser/ethereum/cfg.py index af9a2152..e7456e11 100644 --- a/mythril/laser/ethereum/cfg.py +++ b/mythril/laser/ethereum/cfg.py @@ -48,7 +48,7 @@ class Node: constraints = constraints if constraints else Constraints() self.contract_name = contract_name self.start_addr = start_addr - self.states = [] # type: List[GlobalState] + self.states: List[GlobalState] = [] self.constraints = constraints self.function_name = function_name self.flags = NodeFlags() diff --git a/mythril/laser/ethereum/function_managers/keccak_function_manager.py b/mythril/laser/ethereum/function_managers/keccak_function_manager.py index 4a8d04da..a04fad10 100644 --- a/mythril/laser/ethereum/function_managers/keccak_function_manager.py +++ b/mythril/laser/ethereum/function_managers/keccak_function_manager.py @@ -135,7 +135,7 @@ class KeccakFunctionManager: :param model: The z3 model to query for concrete values :return: A dictionary with concrete hashes { : [, ]} """ - concrete_hashes = {} # type: Dict[int, List[Optional[int]]] + concrete_hashes: Dict[int, List[Optional[int]]] = {} for size in self.hash_result_store: concrete_hashes[size] = [] for val in self.hash_result_store[size]: diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index c50c08fe..7a8f3996 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -293,14 +293,14 @@ class Instruction: if length_of_value == 0: global_state.mstate.stack.append(symbol_factory.BitVecVal(0, 256)) - elif type(push_value) == tuple: - if type(push_value[0]) == int: + elif isinstance(push_value, tuple): + if isinstance(push_value[0], int): new_value = symbol_factory.BitVecVal(push_value[0], 8) else: new_value = push_value[0] if len(push_value) > 1: for val in push_value[1:]: - if type(val) == int: + if isinstance(val, int): new_value = Concat(new_value, symbol_factory.BitVecVal(val, 8)) else: new_value = Concat(new_value, val) @@ -442,12 +442,12 @@ class Instruction: index = util.get_concrete_int(op0) offset = (31 - index) * 8 if offset >= 0: - result = simplify( + result: Union[int, Expression] = simplify( Concat( symbol_factory.BitVecVal(0, 248), Extract(offset + 7, offset, op1), ) - ) # type: Union[int, Expression] + ) else: result = 0 except TypeError: @@ -818,13 +818,13 @@ class Instruction: return [global_state] try: - dstart = util.get_concrete_int(dstart) # type: Union[int, BitVec] + dstart: Union[int, BitVec] = util.get_concrete_int(dstart) except TypeError: log.debug("Unsupported symbolic calldata offset in CALLDATACOPY") dstart = simplify(dstart) try: - size = util.get_concrete_int(size) # type: Union[int, BitVec] + size: Union[int, BitVec] = util.get_concrete_int(size) except TypeError: log.debug("Unsupported symbolic size in CALLDATACOPY") size = SYMBOLIC_CALLDATA_SIZE # The excess size will get overwritten @@ -1087,7 +1087,7 @@ class Instruction: global_state.mstate.stack.pop(), ) code = global_state.environment.code.bytecode - if code[0:2] == "0x": + if code.startswith("0x"): code = code[2:] code_size = len(code) // 2 if isinstance(global_state.current_transaction, ContractCreationTransaction): @@ -1227,7 +1227,7 @@ class Instruction: ) return [global_state] - if code[0:2] == "0x": + if isinstance(code, str) and code.startswith("0x"): code = code[2:] for i in range(concrete_size): @@ -1486,9 +1486,7 @@ class Instruction: state.mem_extend(offset, 1) try: - value_to_write = ( - util.get_concrete_int(value) % 256 - ) # type: Union[int, BitVec] + value_to_write: Union[int, BitVec] = util.get_concrete_int(value) % 256 except TypeError: # BitVec value_to_write = Extract(7, 0, value) @@ -1591,10 +1589,10 @@ class Instruction: condi = simplify(condition) if isinstance(condition, Bool) else condition != 0 condi.simplify() - negated_cond = (type(negated) == bool and negated) or ( + negated_cond = (isinstance(negated, bool) and negated) or ( isinstance(negated, Bool) and not is_false(negated) ) - positive_cond = (type(condi) == bool and condi) or ( + positive_cond = (isinstance(condi, bool) and condi) or ( isinstance(condi, Bool) and not is_false(condi) ) @@ -1734,7 +1732,7 @@ class Instruction: world_state = global_state.world_state call_data = get_call_data(global_state, mem_offset, mem_offset + mem_size) - code_raw = [] + code_raw: List[int] = [] code_end = call_data.size size = call_data.size @@ -1776,7 +1774,7 @@ class Instruction: gas_price = environment.gasprice origin = environment.origin - contract_address = None # type: Union[BitVec, int] + contract_address: Union[BitVec, int] = None Instruction._sha3_gas_helper(global_state, len(code_str[2:]) // 2) if create2_salt: diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 1c738767..481d6c93 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -204,11 +204,11 @@ class Account: } def serialised_code(self): - if type(self.code.bytecode) == str: + if isinstance(self.code.bytecode, str): return self.code.bytecode new_code = "0x" for byte in self.code.bytecode: - if type(byte) == int: + if isinstance(byte, int): new_code += hex(byte) else: new_code += "" diff --git a/mythril/laser/ethereum/state/calldata.py b/mythril/laser/ethereum/state/calldata.py index 50d46166..0f774a34 100644 --- a/mythril/laser/ethereum/state/calldata.py +++ b/mythril/laser/ethereum/state/calldata.py @@ -278,14 +278,14 @@ class BasicSymbolicCalldata(BaseCalldata): :param tx_id: Id of the transaction that the calldata is for. """ - self._reads = [] # type: List[Tuple[Union[int, BitVec], BitVec]] + self._reads: List[Tuple[Union[int, BitVec], BitVec]] = [] self._size = symbol_factory.BitVecSym(str(tx_id) + "_calldatasize", 256) super().__init__(tx_id) def _load(self, item: Union[int, BitVec], clean=False) -> Any: - expr_item = ( + expr_item: BitVec = ( symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item - ) # type: BitVec + ) symbolic_base_value = If( expr_item >= self._size, diff --git a/mythril/laser/ethereum/state/constraints.py b/mythril/laser/ethereum/state/constraints.py index 30efa93e..9afa5c76 100644 --- a/mythril/laser/ethereum/state/constraints.py +++ b/mythril/laser/ethereum/state/constraints.py @@ -4,7 +4,7 @@ from mythril.exceptions import UnsatError, SolverTimeOutException from mythril.laser.smt import symbol_factory, simplify, Bool from mythril.support.model import get_model from mythril.laser.ethereum.function_managers import keccak_function_manager - +from mythril.laser.smt.model import Model from copy import copy from typing import Iterable, List, Optional, Union @@ -42,7 +42,7 @@ class Constraints(list): return False return True - def get_model(self, solver_timeout=None) -> bool: + def get_model(self, solver_timeout=None) -> Optional[Model]: """ :param solver_timeout: The default timeout uses analysis timeout from args.solver_timeout :return: True/False based on the existence of solution of constraints diff --git a/mythril/laser/ethereum/state/memory.py b/mythril/laser/ethereum/state/memory.py index 0e0fb8f5..f725e2ae 100644 --- a/mythril/laser/ethereum/state/memory.py +++ b/mythril/laser/ethereum/state/memory.py @@ -31,7 +31,7 @@ class Memory: def __init__(self): """""" self._msize = 0 - self._memory = {} # type: Dict[BitVec, Union[int, BitVec]] + self._memory: Dict[BitVec, Union[int, BitVec]] = {} def __len__(self): """ @@ -179,7 +179,7 @@ class Memory: step = 1 else: assert False, "Currently mentioning step size is not supported" - assert type(value) == list + assert isinstance(value, list) bvstart, bvstop, bvstep = ( convert_bv(start), convert_bv(stop), diff --git a/mythril/laser/ethereum/state/world_state.py b/mythril/laser/ethereum/state/world_state.py index 4efc3518..1cc8cda1 100644 --- a/mythril/laser/ethereum/state/world_state.py +++ b/mythril/laser/ethereum/state/world_state.py @@ -29,12 +29,12 @@ class WorldState: :param transaction_sequence: :param annotations: """ - self._accounts = {} # type: Dict[int, Account] + self._accounts: Dict[int, Account] = {} self.balances = Array("balance", 256, 256) self.starting_balances = deepcopy(self.balances) self.constraints = constraints or Constraints() - self.node = None # type: Optional['Node'] + self.node: Optional["Node"] = None self.transaction_sequence = transaction_sequence or [] self._annotations = annotations or [] diff --git a/mythril/laser/ethereum/strategy/__init__.py b/mythril/laser/ethereum/strategy/__init__.py index 9671de54..5e262fb5 100644 --- a/mythril/laser/ethereum/strategy/__init__.py +++ b/mythril/laser/ethereum/strategy/__init__.py @@ -9,8 +9,8 @@ class BasicSearchStrategy(ABC): """ def __init__(self, work_list, max_depth, **kwargs): - self.work_list = work_list # type: List[GlobalState] - self.max_depth = max_depth + self.work_list: List[GlobalState] = work_list + self.max_depth: int = max_depth def __iter__(self): return self diff --git a/mythril/laser/ethereum/strategy/constraint_strategy.py b/mythril/laser/ethereum/strategy/constraint_strategy.py index 981ee6bf..b350cb25 100644 --- a/mythril/laser/ethereum/strategy/constraint_strategy.py +++ b/mythril/laser/ethereum/strategy/constraint_strategy.py @@ -28,12 +28,11 @@ class DelayConstraintStrategy(BasicSearchStrategy): :return: Global state """ - while True: - while len(self.work_list) == 0: - state = self.pending_worklist.pop(0) - model = state.world_state.constraints.get_model() - if model is not None: - self.model_cache.put(model, 1) - self.work_list.append(state) - state = self.work_list.pop(0) - return state + while len(self.work_list) == 0: + state = self.pending_worklist.pop(0) + model = state.world_state.constraints.get_model() + if model is not None: + self.model_cache.put(model, 1) + self.work_list.append(state) + state = self.work_list.pop(0) + return state diff --git a/mythril/laser/ethereum/strategy/extensions/bounded_loops.py b/mythril/laser/ethereum/strategy/extensions/bounded_loops.py index 604f84d9..948591a6 100644 --- a/mythril/laser/ethereum/strategy/extensions/bounded_loops.py +++ b/mythril/laser/ethereum/strategy/extensions/bounded_loops.py @@ -14,8 +14,8 @@ class JumpdestCountAnnotation(StateAnnotation): """State annotation that counts the number of jumps per destination.""" def __init__(self) -> None: - self._reached_count = {} # type: Dict[int, int] - self.trace = [] # type: List[int] + self._reached_count: Dict[int, int] = {} + self.trace: List[int] = [] def __copy__(self): result = JumpdestCountAnnotation() diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 77a1acbe..d1b9fffe 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -238,7 +238,7 @@ class LaserEVM: for hook in self._stop_exec_trans_hooks: hook() - def _execute_transactions_ordered(self, address): + def _execute_transactions_non_ordered(self, address): """ This function executes multiple transactions non-incrementally, using some type priority ordering @@ -327,7 +327,7 @@ class LaserEVM: :param track_gas: :return: """ - final_states = [] # type: List[GlobalState] + final_states: List[GlobalState] = [] for hook in self._start_exec_hooks: hook() @@ -387,7 +387,7 @@ class LaserEVM: # exceptional halt all changes should be discarded, and this world state would not provide us with a # previously unseen world state log.debug("Encountered a VmException, ending path: `{}`".format(error_msg)) - new_global_states = [] # type: List[GlobalState] + new_global_states: List[GlobalState] = [] else: # First execute the post hook for the transaction ending instruction self._execute_post_hook(op_code, [global_state]) diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index eb43afe0..c75ad31c 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -105,7 +105,7 @@ class BaseTransaction: self.caller = caller self.callee_account = callee_account if call_data is None and init_call_data: - self.call_data = SymbolicCalldata(self.id) # type: BaseCalldata + self.call_data: BaseCalldata = SymbolicCalldata(self.id) else: self.call_data = ( call_data @@ -119,7 +119,7 @@ class BaseTransaction: else symbol_factory.BitVecSym(f"callvalue{identifier}", 256) ) self.static = static - self.return_data = None # type: str + self.return_data: Optional[ReturnData] = None def initial_global_state_from_environment(self, environment, active_function): """ @@ -278,7 +278,9 @@ class ContractCreationTransaction(BaseTransaction): tuple(return_data.return_data) ) return_data = str(hex(global_state.environment.active_account.address.value)) - self.return_data = ReturnData(return_data, len(return_data) // 2) + self.return_data: Optional[ReturnData] = ReturnData( + return_data, symbol_factory.BitVecVal(len(return_data) // 2, 256) + ) assert global_state.environment.active_account.code.instruction_list != [] raise TransactionEndSignal(global_state, revert=revert) diff --git a/mythril/laser/ethereum/util.py b/mythril/laser/ethereum/util.py index ede5443b..d0be0d94 100644 --- a/mythril/laser/ethereum/util.py +++ b/mythril/laser/ethereum/util.py @@ -143,7 +143,7 @@ def concrete_int_to_bytes(val): :return: """ # logging.debug("concrete_int_to_bytes " + str(val)) - if type(val) == int: + if isinstance(val, int): return val.to_bytes(32, byteorder="big") return simplify(val).value.to_bytes(32, byteorder="big") diff --git a/mythril/laser/plugin/loader.py b/mythril/laser/plugin/loader.py index f2cc1172..ad8b743a 100644 --- a/mythril/laser/plugin/loader.py +++ b/mythril/laser/plugin/loader.py @@ -17,9 +17,9 @@ class LaserPluginLoader(object, metaclass=Singleton): def __init__(self) -> None: """Initializes the plugin loader""" - self.laser_plugin_builders = {} # type: Dict[str, PluginBuilder] - self.plugin_args = {} # type: Dict[str, Dict] - self.plugin_list = {} # type: Dict[str, LaserPlugin] + self.laser_plugin_builders: Dict[str, PluginBuilder] = {} + self.plugin_args: Dict[str, Dict] = {} + self.plugin_list: Dict[str, LaserPlugin] = {} def add_args(self, plugin_name, **kwargs): self.plugin_args[plugin_name] = kwargs diff --git a/mythril/laser/plugin/plugins/coverage/coverage_plugin.py b/mythril/laser/plugin/plugins/coverage/coverage_plugin.py index f235a5e8..323fd918 100644 --- a/mythril/laser/plugin/plugins/coverage/coverage_plugin.py +++ b/mythril/laser/plugin/plugins/coverage/coverage_plugin.py @@ -30,7 +30,7 @@ class InstructionCoveragePlugin(LaserPlugin): """ def __init__(self): - self.coverage = {} # type: Dict[str, Tuple[int, List[bool]]] + self.coverage: Dict[str, Tuple[int, List[bool]]] = {} self.initial_coverage = 0 self.tx_id = 0 @@ -54,7 +54,7 @@ class InstructionCoveragePlugin(LaserPlugin): else: cov_percentage = sum(code_cov[1]) / float(code_cov[0]) * 100 string_code = code - if type(code) == tuple: + if isinstance(code, tuple): try: string_code = bytearray(code).hex() except TypeError: diff --git a/mythril/laser/plugin/plugins/dependency_pruner.py b/mythril/laser/plugin/plugins/dependency_pruner.py index 863ead99..c3c49878 100644 --- a/mythril/laser/plugin/plugins/dependency_pruner.py +++ b/mythril/laser/plugin/plugins/dependency_pruner.py @@ -95,10 +95,10 @@ class DependencyPruner(LaserPlugin): def _reset(self): self.iteration = 0 - self.calls_on_path = {} # type: Dict[int, bool] - self.sloads_on_path = {} # type: Dict[int, List[object]] - self.sstores_on_path = {} # type: Dict[int, List[object]] - self.storage_accessed_global = set() # type: Set + self.calls_on_path: Dict[int, bool] = {} + self.sloads_on_path: Dict[int, List[object]] = {} + self.sstores_on_path: Dict[int, List[object]] = {} + self.storage_accessed_global: Set = set() def update_sloads(self, path: List[int], target_location: object) -> None: """Update the dependency map for the block offsets on the given path. diff --git a/mythril/laser/plugin/plugins/plugin_annotations.py b/mythril/laser/plugin/plugins/plugin_annotations.py index 85fddc45..f1bf43a5 100644 --- a/mythril/laser/plugin/plugins/plugin_annotations.py +++ b/mythril/laser/plugin/plugins/plugin_annotations.py @@ -31,11 +31,11 @@ class DependencyAnnotation(MergeableStateAnnotation): """ def __init__(self): - self.storage_loaded = set() # type: Set - self.storage_written = {} # type: Dict[int, Set] - self.has_call = False # type: bool - self.path = [0] # type: List - self.blocks_seen = set() # type: Set[int] + self.storage_loaded: Set = set() + self.storage_written: Dict[int, Set] = {} + self.has_call: bool = False + self.path: List = [0] + self.blocks_seen: Set[int] = set() def __copy__(self): result = DependencyAnnotation() diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py index 3c3e1289..aadcefad 100644 --- a/mythril/laser/smt/bitvec_helper.py +++ b/mythril/laser/smt/bitvec_helper.py @@ -190,7 +190,7 @@ def Sum(*args: BitVec) -> BitVec: :return: """ raw = z3.Sum([a.raw for a in args]) - annotations = set() # type: Annotations + annotations: Annotations = set() for bv in args: annotations = annotations.union(bv.annotations) diff --git a/mythril/laser/smt/bool.py b/mythril/laser/smt/bool.py index b98e10c4..cc3f8d13 100644 --- a/mythril/laser/smt/bool.py +++ b/mythril/laser/smt/bool.py @@ -97,7 +97,7 @@ class Bool(Expression[z3.BoolRef]): def And(*args: Union[Bool, bool]) -> Bool: """Create an And expression.""" - annotations = set() # type: Set + annotations: Set = set() args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] for arg in args_list: annotations = annotations.union(arg.annotations) @@ -119,7 +119,7 @@ def Or(*args: Union[Bool, bool]) -> Bool: :return: """ args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] - annotations = set() # type: Set + annotations: Set = set() for arg in args_list: annotations = annotations.union(arg.annotations) return Bool(z3.Or([a.raw for a in args_list]), annotations=annotations) diff --git a/mythril/laser/smt/model.py b/mythril/laser/smt/model.py index b5e2646d..ffb10579 100644 --- a/mythril/laser/smt/model.py +++ b/mythril/laser/smt/model.py @@ -19,7 +19,7 @@ class Model: def decls(self) -> List[z3.ExprRef]: """Get the declarations for this model""" - result = [] # type: List[z3.ExprRef] + result: List[z3.ExprRef] = [] for internal_model in self.raw: result.extend(internal_model.decls()) return result diff --git a/mythril/laser/smt/solver/independence_solver.py b/mythril/laser/smt/solver/independence_solver.py index a8ce59dd..6c54d18a 100644 --- a/mythril/laser/smt/solver/independence_solver.py +++ b/mythril/laser/smt/solver/independence_solver.py @@ -31,8 +31,8 @@ class DependenceBucket: :param variables: Variables contained in the conditions :param conditions: The conditions that are dependent on each other """ - self.variables = variables or [] # type: List[z3.ExprRef] - self.conditions = conditions or [] # type: List[z3.ExprRef] + self.variables: List[z3.ExprRef] = variables or [] + self.conditions: List[z3.ExprRef] = conditions or [] class DependenceMap: @@ -40,8 +40,8 @@ class DependenceMap: def __init__(self): """Initializes a DependenceMap object""" - self.buckets = [] # type: List[DependenceBucket] - self.variable_map = {} # type: Dict[str, DependenceBucket] + self.buckets: List[DependenceBucket] = [] + self.variable_map: Dict[str, DependenceBucket] = {} def add_condition(self, condition: z3.BoolRef) -> None: """ @@ -70,8 +70,8 @@ class DependenceMap: def _merge_buckets(self, bucket_list: Set[DependenceBucket]) -> DependenceBucket: """Merges the buckets in bucket list""" - variables = [] # type: List[str] - conditions = [] # type: List[z3.BoolRef] + variables: List[str] = [] + conditions: List[z3.BoolRef] = [] for bucket in bucket_list: self.buckets.remove(bucket) variables += bucket.variables @@ -100,14 +100,14 @@ class IndependenceSolver: """ self.raw.set(timeout=timeout) - def add(self, *constraints: Tuple[Bool]) -> None: + def add(self, *constraints: Bool) -> None: """Adds the constraints to this solver. :param constraints: constraints to add """ - raw_constraints = [ + raw_constraints: List[z3.BoolRef] = [ c.raw for c in cast(Tuple[Bool], constraints) - ] # type: List[z3.BoolRef] + ] self.constraints.extend(raw_constraints) def append(self, *constraints: Tuple[Bool]) -> None: @@ -115,9 +115,9 @@ class IndependenceSolver: :param constraints: constraints to add """ - raw_constraints = [ + raw_constraints: List[z3.BoolRef] = [ c.raw for c in cast(Tuple[Bool], constraints) - ] # type: List[z3.BoolRef] + ] self.constraints.extend(raw_constraints) @stat_smt_query diff --git a/mythril/laser/smt/solver/solver.py b/mythril/laser/smt/solver/solver.py index 1bf00a81..bb729b90 100644 --- a/mythril/laser/smt/solver/solver.py +++ b/mythril/laser/smt/solver/solver.py @@ -28,18 +28,18 @@ class BaseSolver(Generic[T]): """ self.raw.set(timeout=timeout) - def add(self, *constraints: List[Bool]) -> None: + def add(self, *constraints: Bool) -> None: """Adds the constraints to this solver. :param constraints: :return: """ - z3_constraints = [ + z3_constraints: Sequence[z3.BoolRef] = [ c.raw for c in cast(List[Bool], constraints) - ] # type: Sequence[z3.BoolRef] + ] self.raw.add(z3_constraints) - def append(self, *constraints: List[Bool]) -> None: + def append(self, *constraints: Bool) -> None: """Adds the constraints to this solver. :param constraints: diff --git a/mythril/mythril/mythril_analyzer.py b/mythril/mythril/mythril_analyzer.py index 1f6556c2..8951027f 100644 --- a/mythril/mythril/mythril_analyzer.py +++ b/mythril/mythril/mythril_analyzer.py @@ -47,7 +47,7 @@ class MythrilAnalyzer: :param address: Address of the contract """ self.eth = disassembler.eth - self.contracts = disassembler.contracts or [] # type: List[EVMContract] + self.contracts: List[EVMContract] = disassembler.contracts or [] self.enable_online_lookup = disassembler.enable_online_lookup self.use_onchain_data = not cmd_args.no_onchain_data self.strategy = strategy @@ -141,10 +141,10 @@ class MythrilAnalyzer: :param transaction_count: The amount of transactions to be executed :return: The Report class which contains the all the issues/vulnerabilities """ - all_issues = [] # type: List[Issue] + all_issues: List[Issue] = [] SolverStatistics().enabled = True exceptions = [] - execution_info = None # type: Optional[List[ExecutionInfo]] + execution_info: Optional[List[ExecutionInfo]] = None for contract in self.contracts: StartTime() # Reinitialize start time for new contracts try: diff --git a/mythril/mythril/mythril_config.py b/mythril/mythril/mythril_config.py index 5fb39abd..41dd6937 100644 --- a/mythril/mythril/mythril_config.py +++ b/mythril/mythril/mythril_config.py @@ -22,11 +22,11 @@ class MythrilConfig: """ def __init__(self): - self.infura_id = os.getenv("INFURA_ID") # type: str + self.infura_id: str = os.getenv("INFURA_ID") self.mythril_dir = self.init_mythril_dir() self.config_path = os.path.join(self.mythril_dir, "config.ini") self._init_config() - self.eth = None # type: Optional[EthJsonRpc] + self.eth: Optional[EthJsonRpc] = None def set_api_infura_id(self, id): self.infura_id = id diff --git a/mythril/mythril/mythril_disassembler.py b/mythril/mythril/mythril_disassembler.py index bcee5550..d93294b9 100644 --- a/mythril/mythril/mythril_disassembler.py +++ b/mythril/mythril/mythril_disassembler.py @@ -30,11 +30,11 @@ from mythril.solidity.soliditycontract import ( from mythril.support.support_args import args -def format_Warning(message, category, filename, lineno, line=""): +def format_warning(message, category, filename, lineno, line=""): return "{}: {}\n\n".format(str(filename), str(message)) -warnings.formatwarning = format_Warning +warnings.formatwarning = format_warning log = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class MythrilDisassembler: self.eth = eth self.enable_online_lookup = enable_online_lookup self.sigs = signatures.SignatureDB(enable_online_lookup=enable_online_lookup) - self.contracts = [] # type: List[EVMContract] + self.contracts: List[EVMContract] = [] @staticmethod def _init_solc_binary(version: str) -> Optional[str]: @@ -190,7 +190,6 @@ class MythrilDisassembler: build_dir = os.path.join(project_root, "artifacts", "contracts", "build-info") - files = os.listdir(build_dir) address = util.get_indexed_address(0) files = sorted( diff --git a/mythril/plugin/discovery.py b/mythril/plugin/discovery.py index f6d66905..9f7ea02e 100644 --- a/mythril/plugin/discovery.py +++ b/mythril/plugin/discovery.py @@ -12,7 +12,7 @@ class PluginDiscovery(object, metaclass=Singleton): """ # Installed plugins structure. Retrieves all modules that have an entry point for mythril.plugins - _installed_plugins = None # type: Optional[Dict[str, Any]] + _installed_plugins: Optional[Dict[str, Any]] = None def init_installed_plugins(self): self._installed_plugins = { diff --git a/mythril/plugin/loader.py b/mythril/plugin/loader.py index 5879f8ea..b41bc16e 100644 --- a/mythril/plugin/loader.py +++ b/mythril/plugin/loader.py @@ -27,7 +27,7 @@ class MythrilPluginLoader(object, metaclass=Singleton): def __init__(self): log.info("Initializing mythril plugin loader") self.loaded_plugins = [] - self.plugin_args = dict() # type: Dict[str, Dict] + self.plugin_args: Dict[str, Dict] = dict() self._load_default_enabled() def set_args(self, plugin_name: str, **kwargs): diff --git a/mythril/solidity/features.py b/mythril/solidity/features.py index afe140f0..6fb03c6a 100644 --- a/mythril/solidity/features.py +++ b/mythril/solidity/features.py @@ -203,7 +203,7 @@ class SolidityFeatureExtractor: return variables def extract_address_variable(self, node): - if type(node) == int: + if isinstance(node, int): return set([]) transfer_vars = set([]) if ( @@ -211,15 +211,13 @@ class SolidityFeatureExtractor: and node.get("expression", {}).get("nodeType") == "FunctionCall" ): expression = node["expression"].get("expression", None) - if expression is not None: - if ( - expression["nodeType"] == "MemberAccess" - and expression["memberName"] in TRANSFER_METHODS - ): - print(expression) - address_variable = expression["expression"].get("name") - if address_variable: - transfer_vars.update(set([address_variable])) + if expression is not None and ( + expression["nodeType"] == "MemberAccess" + and expression["memberName"] in TRANSFER_METHODS + ): + address_variable = expression["expression"].get("name") + if address_variable: + transfer_vars.update(set([address_variable])) for key, value in node.items(): if isinstance(value, dict): diff --git a/mythril/support/loader.py b/mythril/support/loader.py index 8ea16ca9..d8a45736 100644 --- a/mythril/support/loader.py +++ b/mythril/support/loader.py @@ -40,7 +40,7 @@ class DynLoader: value = self.eth.eth_getStorageAt( contract_address, position=index, block="latest" ) - if value == "0x": + if value.startswith("0x"): value = "0x0000000000000000000000000000000000000000000000000000000000000000" return value @@ -96,7 +96,7 @@ class DynLoader: code = self.eth.eth_getCode(dependency_address) - if code == "0x": + if code.startswith("0x"): return None else: return Disassembly(code) diff --git a/mythril/support/model.py b/mythril/support/model.py index ab6cdbee..6dff0484 100644 --- a/mythril/support/model.py +++ b/mythril/support/model.py @@ -86,12 +86,16 @@ def get_model( if solver_timeout <= 0: raise SolverTimeOutException for constraint in constraints: - if type(constraint) == bool and not constraint: + if isinstance(constraint, bool) and not constraint: raise UnsatError - if type(constraints) != tuple: + if isinstance(constraints, tuple) is False: constraints = constraints.get_all_constraints() - constraints = [constraint for constraint in constraints if type(constraint) != bool] + constraints = [ + constraint + for constraint in constraints + if isinstance(constraint, bool) is False + ] if len(maximize) + len(minimize) == 0: ret_model = model_cache.check_quick_sat(simplify(And(*constraints)).raw) diff --git a/mythril/support/signatures.py b/mythril/support/signatures.py index 216a3d4c..ec3810d6 100644 --- a/mythril/support/signatures.py +++ b/mythril/support/signatures.py @@ -44,7 +44,7 @@ def synchronized(sync_lock): class Singleton(type): """A metaclass type implementing the singleton pattern.""" - _instances = dict() # type: Dict[Singleton, Singleton] + _instances: Dict["Singleton", "Singleton"] = dict() @synchronized(lock) def __call__(cls, *args, **kwargs): @@ -123,12 +123,12 @@ class SignatureDB(object, metaclass=Singleton): :param path: """ self.enable_online_lookup = enable_online_lookup - self.online_lookup_miss = set() # type: Set[str] + self.online_lookup_miss: Set[str] = set() self.online_lookup_timeout = 0 # if we're analysing a Solidity file, store its hashes # here to prevent unnecessary lookups - self.solidity_sigs = defaultdict(list) # type: DefaultDict[str, List[str]] + self.solidity_sigs: DefaultDict[str, List[str]] = defaultdict(list) if path is None: self.path = os.environ.get("MYTHRIL_DIR") or os.path.join( os.path.expanduser("~"), ".mythril" diff --git a/mythril/support/source_support.py b/mythril/support/source_support.py index db9ad2f1..aa3bec3a 100644 --- a/mythril/support/source_support.py +++ b/mythril/support/source_support.py @@ -38,7 +38,7 @@ class Source: self.source_format = "evm-byzantium-bytecode" self.source_type = ( "ethereum-address" - if len(contracts[0].name) == 42 and contracts[0].name[0:2] == "0x" + if len(contracts[0].name) == 42 and contracts[0].name.startswith("0x") else "raw-bytecode" ) for contract in contracts: diff --git a/mythril/support/support_utils.py b/mythril/support/support_utils.py index 484a3c9e..fa8d7f26 100644 --- a/mythril/support/support_utils.py +++ b/mythril/support/support_utils.py @@ -14,7 +14,7 @@ log = logging.getLogger(__name__) class Singleton(type): """A metaclass type implementing the singleton pattern.""" - _instances = {} # type: Dict + _instances: Dict = {} def __call__(cls, *args, **kwargs): """Delegate the call to an existing resource or a a new one. @@ -59,7 +59,6 @@ class ModelCache: @lru_cache(maxsize=2**10) def check_quick_sat(self, constraints) -> bool: - model_list = list(reversed(self.model_cache.lru_cache.keys())) for model in reversed(self.model_cache.lru_cache.keys()): model_copy = deepcopy(model) if is_true(model_copy.eval(constraints, model_completion=True)): @@ -77,11 +76,11 @@ def get_code_hash(code) -> str: :param code: bytecode :return: Returns hash of the given bytecode """ - if type(code) == tuple: + if isinstance(code, tuple): # Temporary hack, since we cannot find symbols of sha3 return str(hash(code)) - code = code[2:] if code[:2] == "0x" else code + code = code[2:] if code.startswith("0x") else code try: hash_ = keccak(bytes.fromhex(code)) return "0x" + hash_.hex() @@ -91,8 +90,8 @@ def get_code_hash(code) -> str: def sha3(value): - if type(value) == str: - if value[:2] == "0x": + if isinstance(value, str): + if value.startswith("0x"): new_hash = keccak(bytes.fromhex(value)) else: new_hash = keccak(value.encode()) diff --git a/tests/features_test.py b/tests/features_test.py index 1afd1f29..a89d1148 100644 --- a/tests/features_test.py +++ b/tests/features_test.py @@ -61,7 +61,6 @@ test_cases = [ def test_feature_selfdestruct(file_name, num_funcs, func_name, field, expected_value): input_file = TEST_FILES / file_name name = file_name.split(".")[0] - print(name, name.capitalize()) if name[0].islower(): name = name.capitalize() contract = SolidityContract(str(input_file), name=name, solc_binary=solc_binary) diff --git a/tests/laser/evm_testsuite/evm_test.py b/tests/laser/evm_testsuite/evm_test.py index 977af4b6..e38eb237 100644 --- a/tests/laser/evm_testsuite/evm_test.py +++ b/tests/laser/evm_testsuite/evm_test.py @@ -182,8 +182,8 @@ def test_vmtest( actual = actual.value actual = 1 if actual is True else 0 if actual is False else actual else: - if type(actual) == bytes: + if isinstance(actual, bytes): actual = int(binascii.b2a_hex(actual), 16) - elif type(actual) == str: + elif isinstance(actual, str): actual = int(actual, 16) assert actual == expected diff --git a/tests/laser/tx_prioritisation_test.py b/tests/laser/tx_prioritisation_test.py index db7bd297..491537d6 100644 --- a/tests/laser/tx_prioritisation_test.py +++ b/tests/laser/tx_prioritisation_test.py @@ -5,7 +5,6 @@ from unittest.mock import Mock, patch, mock_open def mock_predict_proba(X): - print(X) if X[0][-1] == 1: return np.array([[0.1, 0.7, 0.1, 0.1]]) elif X[0][-1] == 2: