From aa4a57d05015a8e5e103e7d5c8d25a76a58216c0 Mon Sep 17 00:00:00 2001 From: Josselin Feist Date: Wed, 31 Aug 2022 11:48:13 +0200 Subject: [PATCH] Add more types --- plugin_example/slither_my_plugin/__init__.py | 9 +- slither/__main__.py | 96 ++++++++++++------- slither/core/declarations/structure.py | 14 +-- slither/core/expressions/tuple_expression.py | 4 +- slither/core/source_mapping/source_mapping.py | 8 +- slither/slither.py | 12 ++- slither/slithir/operations/call.py | 8 +- slither/slithir/tmp_operations/tmp_call.py | 8 +- .../tmp_operations/tmp_new_elementary_type.py | 12 ++- slither/tools/demo/__main__.py | 4 +- slither/tools/erc_conformance/__main__.py | 9 +- slither/tools/erc_conformance/erc/ercs.py | 15 ++- slither/tools/flattening/__main__.py | 4 +- slither/tools/kspec_coverage/__main__.py | 4 +- .../tools/kspec_coverage/kspec_coverage.py | 4 +- slither/tools/mutator/__main__.py | 13 ++- slither/tools/possible_paths/__main__.py | 7 +- .../tools/possible_paths/possible_paths.py | 45 ++++++--- slither/tools/properties/__main__.py | 19 ++-- slither/tools/properties/platforms/truffle.py | 8 +- slither/tools/properties/utils.py | 2 +- slither/tools/similarity/__main__.py | 4 +- slither/tools/similarity/encode.py | 2 +- slither/tools/similarity/info.py | 3 +- slither/tools/similarity/plot.py | 3 +- slither/tools/similarity/train.py | 3 +- slither/tools/slither_format/__main__.py | 4 +- .../tools/slither_format/slither_format.py | 18 ++-- slither/tools/upgradeability/__main__.py | 75 ++++++++++----- .../upgradeability/utils/command_line.py | 20 ++-- slither/utils/command_line.py | 57 +++++++---- slither/utils/output.py | 4 +- slither/utils/output_capture.py | 8 +- 33 files changed, 323 insertions(+), 183 deletions(-) diff --git a/plugin_example/slither_my_plugin/__init__.py b/plugin_example/slither_my_plugin/__init__.py index eabdb147e..9379ac833 100644 --- a/plugin_example/slither_my_plugin/__init__.py +++ b/plugin_example/slither_my_plugin/__init__.py @@ -1,8 +1,13 @@ +from typing import Tuple, List, Type + from slither_my_plugin.detectors.example import Example +from slither.detectors.abstract_detector import AbstractDetector +from slither.printers.abstract_printer import AbstractPrinter + -def make_plugin(): +def make_plugin() -> Tuple[List[Type[AbstractDetector]], List[Type[AbstractPrinter]]]: plugin_detectors = [Example] - plugin_printers = [] + plugin_printers: List[Type[AbstractPrinter]] = [] return plugin_detectors, plugin_printers diff --git a/slither/__main__.py b/slither/__main__.py index 4cf148f13..98320c76a 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -10,11 +10,11 @@ import os import pstats import sys import traceback -from typing import Tuple, Optional, List, Dict +from typing import Tuple, Optional, List, Dict, Type, Union, Any, Sequence from pkg_resources import iter_entry_points, require -from crytic_compile import cryticparser +from crytic_compile import cryticparser, CryticCompile from crytic_compile.platform.standard import generate_standard_export from crytic_compile.platform.etherscan import SUPPORTED_NETWORK from crytic_compile import compile_all, is_supported @@ -55,10 +55,10 @@ logger = logging.getLogger("Slither") def process_single( - target: str, + target: Union[str, CryticCompile], args: argparse.Namespace, - detector_classes: List[AbstractDetector], - printer_classes: List[AbstractPrinter], + detector_classes: List[Type[AbstractDetector]], + printer_classes: List[Type[AbstractPrinter]], ) -> Tuple[Slither, List[Dict], List[Dict], int]: """ The core high-level code for running Slither static analysis. @@ -80,8 +80,8 @@ def process_single( def process_all( target: str, args: argparse.Namespace, - detector_classes: List[AbstractDetector], - printer_classes: List[AbstractPrinter], + detector_classes: List[Type[AbstractDetector]], + printer_classes: List[Type[AbstractPrinter]], ) -> Tuple[List[Slither], List[Dict], List[Dict], int]: compilations = compile_all(target, **vars(args)) slither_instances = [] @@ -109,8 +109,8 @@ def process_all( def _process( slither: Slither, - detector_classes: List[AbstractDetector], - printer_classes: List[AbstractPrinter], + detector_classes: List[Type[AbstractDetector]], + printer_classes: List[Type[AbstractPrinter]], ) -> Tuple[Slither, List[Dict], List[Dict], int]: for detector_cls in detector_classes: slither.register_detector(detector_cls) @@ -137,13 +137,14 @@ def _process( return slither, results_detectors, results_printers, analyzed_contracts_count +# TODO: delete me? def process_from_asts( filenames: List[str], args: argparse.Namespace, - detector_classes: List[AbstractDetector], - printer_classes: List[AbstractPrinter], -): - all_contracts = [] + detector_classes: List[Type[AbstractDetector]], + printer_classes: List[Type[AbstractPrinter]], +) -> Tuple[Slither, List[Dict], List[Dict], int]: + all_contracts: List[str] = [] for filename in filenames: with open(filename, encoding="utf8") as file_open: @@ -162,13 +163,15 @@ def process_from_asts( ################################################################################### -def get_detectors_and_printers(): +def get_detectors_and_printers() -> Tuple[ + List[Type[AbstractDetector]], List[Type[AbstractPrinter]] +]: - detectors = [getattr(all_detectors, name) for name in dir(all_detectors)] - detectors = [d for d in detectors if inspect.isclass(d) and issubclass(d, AbstractDetector)] + detectors_ = [getattr(all_detectors, name) for name in dir(all_detectors)] + detectors = [d for d in detectors_ if inspect.isclass(d) and issubclass(d, AbstractDetector)] - printers = [getattr(all_printers, name) for name in dir(all_printers)] - printers = [p for p in printers if inspect.isclass(p) and issubclass(p, AbstractPrinter)] + printers_ = [getattr(all_printers, name) for name in dir(all_printers)] + printers = [p for p in printers_ if inspect.isclass(p) and issubclass(p, AbstractPrinter)] # Handle plugins! for entry_point in iter_entry_points(group="slither_analyzer.plugin", name=None): @@ -194,8 +197,8 @@ def get_detectors_and_printers(): # pylint: disable=too-many-branches def choose_detectors( - args: argparse.Namespace, all_detector_classes: List[AbstractDetector] -) -> List[AbstractDetector]: + args: argparse.Namespace, all_detector_classes: List[Type[AbstractDetector]] +) -> List[Type[AbstractDetector]]: # If detectors are specified, run only these ones detectors_to_run = [] @@ -245,8 +248,8 @@ def choose_detectors( def choose_printers( - args: argparse.Namespace, all_printer_classes: List[AbstractPrinter] -) -> List[AbstractPrinter]: + args: argparse.Namespace, all_printer_classes: List[Type[AbstractPrinter]] +) -> List[Type[AbstractPrinter]]: printers_to_run = [] # disable default printer @@ -273,13 +276,16 @@ def choose_printers( ################################################################################### -def parse_filter_paths(args): +def parse_filter_paths(args: argparse.Namespace) -> List[str]: if args.filter_paths: return args.filter_paths.split(",") return [] -def parse_args(detector_classes, printer_classes): # pylint: disable=too-many-statements +# pylint: disable=too-many-statements +def parse_args( + detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]] +) -> argparse.Namespace: usage = "slither target [flag]\n" usage += "\ntarget can be:\n" @@ -622,7 +628,9 @@ class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs detectors, _ = get_detectors_and_printers() detector_types_json = output_detectors_json(detectors) print(json.dumps(detector_types_json)) @@ -630,22 +638,38 @@ class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-meth class ListPrinters(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs _, printers = get_detectors_and_printers() output_printers(printers) parser.exit() class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, args, values, option_string=None): + def __call__( + self, + parser: Any, + args: Any, + values: Optional[Union[str, Sequence[Any]]], + option_string: Any = None, + ) -> None: detectors, printers = get_detectors_and_printers() + assert isinstance(values, str) output_to_markdown(detectors, printers, values) parser.exit() class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, args, values, option_string=None): + def __call__( + self, + parser: Any, + args: Any, + values: Optional[Union[str, Sequence[Any]]], + option_string: Any = None, + ) -> None: detectors, _ = get_detectors_and_printers() + assert isinstance(values, str) output_wiki(detectors, values) parser.exit() @@ -678,7 +702,7 @@ class FormatterCryticCompile(logging.Formatter): ################################################################################### -def main(): +def main() -> None: # Codebase with complex domninators can lead to a lot of SSA recursive call sys.setrecursionlimit(1500) @@ -689,8 +713,9 @@ def main(): # pylint: disable=too-many-statements,too-many-branches,too-many-locals def main_impl( - all_detector_classes: List[AbstractDetector], all_printer_classes: List[AbstractPrinter] -): + all_detector_classes: List[Type[AbstractDetector]], + all_printer_classes: List[Type[AbstractPrinter]], +) -> None: """ :param all_detector_classes: A list of all detectors that can be included/excluded. :param all_printer_classes: A list of all printers that can be included. @@ -756,8 +781,8 @@ def main_impl( crytic_compile_error.propagate = False crytic_compile_error.setLevel(logging.INFO) - results_detectors = [] - results_printers = [] + results_detectors: List[Dict] = [] + results_printers: List[Dict] = [] try: filename = args.filename @@ -806,6 +831,7 @@ def main_impl( if "compilations" in args.json_types: compilation_results = [] for slither_instance in slither_instances: + assert slither_instance.crytic_compile compilation_results.append( generate_standard_export(slither_instance.crytic_compile) ) @@ -856,7 +882,7 @@ def main_impl( except Exception: # pylint: disable=broad-except output_error = traceback.format_exc() - logging.error(traceback.print_exc()) + traceback.print_exc() logging.error(f"Error in {args.filename}") # pylint: disable=logging-fstring-interpolation logging.error(output_error) @@ -879,7 +905,7 @@ def main_impl( if outputting_zip: output_to_zip(args.zip, output_error, json_results, args.zip_type) - if args.perf: + if args.perf and cp: cp.disable() stats = pstats.Stats(cp).sort_stats("cumtime") stats.print_stats() diff --git a/slither/core/declarations/structure.py b/slither/core/declarations/structure.py index 39b1948ee..8f6d8c50a 100644 --- a/slither/core/declarations/structure.py +++ b/slither/core/declarations/structure.py @@ -8,10 +8,10 @@ if TYPE_CHECKING: class Structure(SourceMapping): - def __init__(self, compilation_unit: "SlitherCompilationUnit"): + def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: super().__init__() self._name: Optional[str] = None - self._canonical_name = None + self._canonical_name: Optional[str] = None self._elems: Dict[str, "StructureVariable"] = {} # Name of the elements in the order of declaration self._elems_ordered: List[str] = [] @@ -19,25 +19,27 @@ class Structure(SourceMapping): @property def canonical_name(self) -> str: + assert self._canonical_name return self._canonical_name @canonical_name.setter - def canonical_name(self, name: str): + def canonical_name(self, name: str) -> None: self._canonical_name = name @property def name(self) -> str: + assert self._name return self._name @name.setter - def name(self, new_name: str): + def name(self, new_name: str) -> None: self._name = new_name @property def elems(self) -> Dict[str, "StructureVariable"]: return self._elems - def add_elem_in_order(self, s: str): + def add_elem_in_order(self, s: str) -> None: self._elems_ordered.append(s) @property @@ -47,5 +49,5 @@ class Structure(SourceMapping): ret.append(self._elems[e]) return ret - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/core/expressions/tuple_expression.py b/slither/core/expressions/tuple_expression.py index 7f14601f4..1fd8fc795 100644 --- a/slither/core/expressions/tuple_expression.py +++ b/slither/core/expressions/tuple_expression.py @@ -4,7 +4,7 @@ from slither.core.expressions.expression import Expression class TupleExpression(Expression): - def __init__(self, expressions): + def __init__(self, expressions: List[Expression]) -> None: assert all(isinstance(x, Expression) for x in expressions if x) super().__init__() self._expressions = expressions @@ -13,6 +13,6 @@ class TupleExpression(Expression): def expressions(self) -> List[Expression]: return self._expressions - def __str__(self): + def __str__(self) -> str: expressions_str = [str(e) for e in self.expressions] return "(" + ",".join(expressions_str) + ")" diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index 7ceabd568..ee5211c7c 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -162,13 +162,15 @@ def _convert_source_mapping( class SourceMapping(Context, metaclass=ABCMeta): - def __init__(self): + def __init__(self) -> None: super().__init__() # self._source_mapping: Optional[Dict] = None self.source_mapping: Source = Source() self.references: List[Source] = [] - def set_offset(self, offset: Union["Source", str], compilation_unit: "SlitherCompilationUnit"): + def set_offset( + self, offset: Union["Source", str], compilation_unit: "SlitherCompilationUnit" + ) -> None: if isinstance(offset, Source): self.source_mapping.start = offset.start self.source_mapping.length = offset.length @@ -184,6 +186,6 @@ class SourceMapping(Context, metaclass=ABCMeta): def add_reference_from_raw_source( self, offset: str, compilation_unit: "SlitherCompilationUnit" - ): + ) -> None: s = _convert_source_mapping(offset, compilation_unit) self.references.append(s) diff --git a/slither/slither.py b/slither/slither.py index 59bbf8a5f..dcfc0ad7e 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -1,5 +1,5 @@ import logging -from typing import Union, List, ValuesView +from typing import Union, List, ValuesView, Type, Dict from crytic_compile import CryticCompile, InvalidCompilation @@ -19,7 +19,9 @@ logger_detector = logging.getLogger("Detectors") logger_printer = logging.getLogger("Printers") -def _check_common_things(thing_name, cls, base_cls, instances_list): +def _check_common_things( + thing_name: str, cls: Type, base_cls: Type, instances_list: List[Type[AbstractDetector]] +) -> None: if not issubclass(cls, base_cls) or cls is base_cls: raise Exception( @@ -178,7 +180,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes def detectors_optimization(self): return [d for d in self.detectors if d.IMPACT == DetectorClassification.OPTIMIZATION] - def register_detector(self, detector_class): + def register_detector(self, detector_class: Type[AbstractDetector]) -> None: """ :param detector_class: Class inheriting from `AbstractDetector`. """ @@ -188,7 +190,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes instance = detector_class(compilation_unit, self, logger_detector) self._detectors.append(instance) - def register_printer(self, printer_class): + def register_printer(self, printer_class: Type[AbstractPrinter]) -> None: """ :param printer_class: Class inheriting from `AbstractPrinter`. """ @@ -197,7 +199,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes instance = printer_class(self, logger_printer) self._printers.append(instance) - def run_detectors(self): + def run_detectors(self) -> List[Dict]: """ :return: List of registered detectors results. """ diff --git a/slither/slithir/operations/call.py b/slither/slithir/operations/call.py index cff2767cd..07304fa99 100644 --- a/slither/slithir/operations/call.py +++ b/slither/slithir/operations/call.py @@ -1,8 +1,10 @@ +from typing import Optional, List + from slither.slithir.operations.operation import Operation class Call(Operation): - def __init__(self): + def __init__(self) -> None: super().__init__() self._arguments = [] @@ -14,14 +16,14 @@ class Call(Operation): def arguments(self, v): self._arguments = v - def can_reenter(self, _callstack=None): # pylint: disable=no-self-use + def can_reenter(self, _callstack: Optional[List] = None) -> bool: # pylint: disable=no-self-use """ Must be called after slithIR analysis pass :return: bool """ return False - def can_send_eth(self): # pylint: disable=no-self-use + def can_send_eth(self) -> bool: # pylint: disable=no-self-use """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/tmp_operations/tmp_call.py b/slither/slithir/tmp_operations/tmp_call.py index e9562b1c1..fb6641139 100644 --- a/slither/slithir/tmp_operations/tmp_call.py +++ b/slither/slithir/tmp_operations/tmp_call.py @@ -63,14 +63,14 @@ class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attribu def call_id(self): return self._callid - @property - def read(self): - return [self.called] - @call_id.setter def call_id(self, c): self._callid = c + @property + def read(self): + return [self.called] + @property def called(self): return self._called diff --git a/slither/slithir/tmp_operations/tmp_new_elementary_type.py b/slither/slithir/tmp_operations/tmp_new_elementary_type.py index 357c063a7..d7a4f5e1b 100644 --- a/slither/slithir/tmp_operations/tmp_new_elementary_type.py +++ b/slither/slithir/tmp_operations/tmp_new_elementary_type.py @@ -1,21 +1,23 @@ +from typing import List + from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.solidity_types.elementary_type import ElementaryType class TmpNewElementaryType(OperationWithLValue): - def __init__(self, new_type, lvalue): + def __init__(self, new_type: ElementaryType, lvalue): assert isinstance(new_type, ElementaryType) super().__init__() - self._type = new_type + self._type: ElementaryType = new_type self._lvalue = lvalue @property - def read(self): + def read(self) -> List: return [] @property - def type(self): + def type(self) -> ElementaryType: return self._type - def __str__(self): + def __str__(self) -> str: return f"{self.lvalue} = new {self._type}" diff --git a/slither/tools/demo/__main__.py b/slither/tools/demo/__main__.py index 37d265bb1..5bc2c7c8e 100644 --- a/slither/tools/demo/__main__.py +++ b/slither/tools/demo/__main__.py @@ -9,7 +9,7 @@ logging.getLogger("Slither").setLevel(logging.INFO) logger = logging.getLogger("Slither-demo") -def parse_args(): +def parse_args() -> argparse.Namespace: """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. @@ -26,7 +26,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: args = parse_args() # Perform slither analysis on the given filename diff --git a/slither/tools/erc_conformance/__main__.py b/slither/tools/erc_conformance/__main__.py index 45e57b55c..ef594a7c6 100644 --- a/slither/tools/erc_conformance/__main__.py +++ b/slither/tools/erc_conformance/__main__.py @@ -1,6 +1,7 @@ import argparse import logging from collections import defaultdict +from typing import Any, Dict, List from crytic_compile import cryticparser from slither import Slither @@ -26,7 +27,7 @@ logger.propagate = False ADDITIONAL_CHECKS = {"ERC20": check_erc20, "ERC1155": check_erc1155} -def parse_args(): +def parse_args() -> argparse.Namespace: """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. @@ -63,20 +64,20 @@ def parse_args(): return parser.parse_args() -def _log_error(err, args): +def _log_error(err: Any, args: argparse.Namespace) -> None: if args.json: output_to_json(args.json, str(err), {"upgradeability-check": []}) logger.error(err) -def main(): +def main() -> None: args = parse_args() # Perform slither analysis on the given filename slither = Slither(args.project, **vars(args)) - ret = defaultdict(list) + ret: Dict[str, List] = defaultdict(list) if args.erc.upper() in ERCS: diff --git a/slither/tools/erc_conformance/erc/ercs.py b/slither/tools/erc_conformance/erc/ercs.py index ef459eef9..a6b9050ae 100644 --- a/slither/tools/erc_conformance/erc/ercs.py +++ b/slither/tools/erc_conformance/erc/ercs.py @@ -1,7 +1,10 @@ import logging +from typing import Dict, List, Optional, Set +from slither.core.declarations import Contract from slither.slithir.operations import EventCall from slither.utils import output +from slither.utils.erc import ERC, ERC_EVENT from slither.utils.type import ( export_nested_types_from_variable, export_return_type_from_variable, @@ -11,7 +14,7 @@ logger = logging.getLogger("Slither-conformance") # pylint: disable=too-many-locals,too-many-branches,too-many-statements -def _check_signature(erc_function, contract, ret): +def _check_signature(erc_function: ERC, contract: Contract, ret: Dict) -> None: name = erc_function.name parameters = erc_function.parameters return_type = erc_function.return_type @@ -146,7 +149,7 @@ def _check_signature(erc_function, contract, ret): ret["missing_event_emmited"].append(missing_event_emmited.data) -def _check_events(erc_event, contract, ret): +def _check_events(erc_event: ERC_EVENT, contract: Contract, ret: Dict[str, List]) -> None: name = erc_event.name parameters = erc_event.parameters indexes = erc_event.indexes @@ -180,7 +183,13 @@ def _check_events(erc_event, contract, ret): ret["missing_event_index"].append(missing_event_index.data) -def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None): +def generic_erc_checks( + contract: Contract, + erc_functions: List[ERC], + erc_events: List[ERC_EVENT], + ret: Dict[str, List], + explored: Optional[Set[Contract]] = None, +) -> None: if explored is None: explored = set() diff --git a/slither/tools/flattening/__main__.py b/slither/tools/flattening/__main__.py index 977b84896..bf9856fe8 100644 --- a/slither/tools/flattening/__main__.py +++ b/slither/tools/flattening/__main__.py @@ -18,7 +18,7 @@ logger = logging.getLogger("Slither") logger.setLevel(logging.INFO) -def parse_args(): +def parse_args() -> argparse.Namespace: """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. @@ -106,7 +106,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: args = parse_args() slither = Slither(args.filename, **vars(args)) diff --git a/slither/tools/kspec_coverage/__main__.py b/slither/tools/kspec_coverage/__main__.py index b6ce0f81b..19933e0fe 100644 --- a/slither/tools/kspec_coverage/__main__.py +++ b/slither/tools/kspec_coverage/__main__.py @@ -16,7 +16,7 @@ logger.handlers[0].setFormatter(formatter) logger.propagate = False -def parse_args(): +def parse_args() -> argparse.Namespace: """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. @@ -56,7 +56,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: # ------------------------------ # Usage: slither-kspec-coverage contract kspec # Example: slither-kspec-coverage contract.sol kspec.md diff --git a/slither/tools/kspec_coverage/kspec_coverage.py b/slither/tools/kspec_coverage/kspec_coverage.py index 569a35cf1..f8c2d8cf2 100755 --- a/slither/tools/kspec_coverage/kspec_coverage.py +++ b/slither/tools/kspec_coverage/kspec_coverage.py @@ -1,8 +1,10 @@ +import argparse + from slither.tools.kspec_coverage.analysis import run_analysis from slither import Slither -def kspec_coverage(args): +def kspec_coverage(args: argparse.Namespace) -> None: contract = args.contract kspec = args.kspec diff --git a/slither/tools/mutator/__main__.py b/slither/tools/mutator/__main__.py index 442b4849d..78b86d681 100644 --- a/slither/tools/mutator/__main__.py +++ b/slither/tools/mutator/__main__.py @@ -2,6 +2,7 @@ import argparse import inspect import logging import sys +from typing import Type, List, Any from crytic_compile import cryticparser @@ -22,7 +23,7 @@ logger.setLevel(logging.INFO) ################################################################################### -def parse_args(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Experimental smart contract mutator. Based on https://arxiv.org/abs/2006.11597", usage="slither-mutate target", @@ -48,14 +49,16 @@ def parse_args(): return parser.parse_args() -def _get_mutators(): - detectors = [getattr(all_mutators, name) for name in dir(all_mutators)] - detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractMutator)] +def _get_mutators() -> List[Type[AbstractMutator]]: + detectors_ = [getattr(all_mutators, name) for name in dir(all_mutators)] + detectors = [c for c in detectors_ if inspect.isclass(c) and issubclass(c, AbstractMutator)] return detectors class ListMutators(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs checks = _get_mutators() output_mutators(checks) parser.exit() diff --git a/slither/tools/possible_paths/__main__.py b/slither/tools/possible_paths/__main__.py index 29cd05c46..b993d266a 100644 --- a/slither/tools/possible_paths/__main__.py +++ b/slither/tools/possible_paths/__main__.py @@ -5,6 +5,7 @@ from argparse import ArgumentParser, Namespace from crytic_compile import cryticparser from slither import Slither +from slither.core.declarations import FunctionContract from slither.utils.colors import red from slither.tools.possible_paths.possible_paths import ( find_target_paths, @@ -58,7 +59,11 @@ def main() -> None: # Print out all target functions. print("Target functions:") for target in targets: - print(f"- {target.contract_declarer.name}.{target.full_name}") + if isinstance(target, FunctionContract): + print(f"- {target.contract_declarer.name}.{target.full_name}") + else: + pass + # TODO implement me print("\n") # Obtain all paths which reach the target functions. diff --git a/slither/tools/possible_paths/possible_paths.py b/slither/tools/possible_paths/possible_paths.py index 9e52ed6b7..6e836e76a 100644 --- a/slither/tools/possible_paths/possible_paths.py +++ b/slither/tools/possible_paths/possible_paths.py @@ -1,8 +1,15 @@ +from typing import List, Tuple, Union, Optional, Set + +from slither import Slither +from slither.core.declarations import Function, FunctionContract +from slither.core.slither_core import SlitherCore + + class ResolveFunctionException(Exception): pass -def resolve_function(slither, contract_name, function_name): +def resolve_function(slither: SlitherCore, contract_name: str, function_name: str) -> Function: """ Resolves a function instance, given a contract name and function. :param contract_name: The name of the contract the function is declared in. @@ -32,7 +39,9 @@ def resolve_function(slither, contract_name, function_name): return target_function -def resolve_functions(slither, functions): +def resolve_functions( + slither: Slither, functions: List[Union[str, Tuple[str, str]]] +) -> List[Function]: """ Resolves the provided function descriptors. :param functions: A list of tuples (contract_name, function_name) or str (of form "ContractName.FunctionName") @@ -40,7 +49,7 @@ def resolve_functions(slither, functions): :return: Returns a list of resolved functions. """ # Create the resolved list. - resolved = [] + resolved: List[Function] = [] # Verify that the provided argument is a list. if not isinstance(functions, list): @@ -72,24 +81,31 @@ def resolve_functions(slither, functions): return resolved -def all_function_definitions(function): +def all_function_definitions(function: Function) -> List[Function]: """ Obtains a list of representing this function and any base definitions :param function: The function to obtain all definitions at and beneath. :return: Returns a list composed of the provided function definition and any base definitions. """ - return [function] + [ + # TODO implement me + if not isinstance(function, FunctionContract): + return [] + ret: List[Function] = [function] + ret += [ f for c in function.contract.inheritance for f in c.functions_and_modifiers_declared if f.full_name == function.full_name ] + return ret -def __find_target_paths(slither, target_function, current_path=None): +def __find_target_paths( + slither: SlitherCore, target_function: Function, current_path: Optional[List[Function]] = None +) -> Set[Tuple[Function, ...]]: current_path = current_path if current_path else [] # Create our results list - results = set() + results: Set[Tuple[Function, ...]] = set() # Add our current function to the path. current_path = [target_function] + current_path @@ -106,9 +122,12 @@ def __find_target_paths(slither, target_function, current_path=None): continue # Find all function calls in this function (except for low level) - called_functions = [f for (_, f) in function.high_level_calls + function.library_calls] - called_functions += function.internal_calls - called_functions = set(called_functions) + called_functions_list = [ + f for (_, f) in function.high_level_calls if isinstance(f, Function) + ] + called_functions_list += [f for (_, f) in function.library_calls] + called_functions_list += [f for f in function.internal_calls if isinstance(f, Function)] + called_functions = set(called_functions_list) # If any of our target functions are reachable from this function, it's a result. if all_target_functions.intersection(called_functions): @@ -123,14 +142,16 @@ def __find_target_paths(slither, target_function, current_path=None): return results -def find_target_paths(slither, target_functions): +def find_target_paths( + slither: SlitherCore, target_functions: List[Function] +) -> Set[Tuple[Function, ...]]: """ Obtains all functions which can lead to any of the target functions being called. :param target_functions: The functions we are interested in reaching. :return: Returns a list of all functions which can reach any of the target_functions. """ # Create our results list - results = set() + results: Set[Tuple[Function, ...]] = set() # Loop for each target function for target_function in target_functions: diff --git a/slither/tools/properties/__main__.py b/slither/tools/properties/__main__.py index 91c990669..10837bb4b 100644 --- a/slither/tools/properties/__main__.py +++ b/slither/tools/properties/__main__.py @@ -1,6 +1,7 @@ import argparse import logging import sys +from typing import Any from crytic_compile import cryticparser @@ -26,7 +27,7 @@ logger.handlers[0].setFormatter(formatter) logger.propagate = False -def _all_scenarios(): +def _all_scenarios() -> str: txt = "\n" txt += "#################### ERC20 ####################\n" for k, value in ERC20_PROPERTIES.items(): @@ -35,29 +36,33 @@ def _all_scenarios(): return txt -def _all_properties(): +def _all_properties() -> MyPrettyTable: table = MyPrettyTable(["Num", "Description", "Scenario"]) idx = 0 for scenario, value in ERC20_PROPERTIES.items(): for prop in value.properties: - table.add_row([idx, prop.description, scenario]) + table.add_row([str(idx), prop.description, scenario]) idx = idx + 1 return table class ListScenarios(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs logger.info(_all_scenarios()) parser.exit() class ListProperties(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs logger.info(_all_properties()) parser.exit() -def parse_args(): +def parse_args() -> argparse.Namespace: """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. @@ -120,7 +125,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: args = parse_args() # Perform slither analysis on the given filename diff --git a/slither/tools/properties/platforms/truffle.py b/slither/tools/properties/platforms/truffle.py index dc83eb273..7d2f3d9b6 100644 --- a/slither/tools/properties/platforms/truffle.py +++ b/slither/tools/properties/platforms/truffle.py @@ -15,7 +15,7 @@ PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_") logger = logging.getLogger("Slither") -def _extract_caller(p: PropertyCaller): +def _extract_caller(p: PropertyCaller) -> List[str]: if p == PropertyCaller.OWNER: return ["owner"] if p == PropertyCaller.SENDER: @@ -28,7 +28,7 @@ def _extract_caller(p: PropertyCaller): return ["user"] -def _helpers(): +def _helpers() -> str: """ Generate two functions: - catchRevertThrowReturnFalse: check if the call revert/throw or return false @@ -75,7 +75,7 @@ def generate_unit_test( # pylint: disable=too-many-arguments,too-many-branches output_dir: Path, addresses: Addresses, assert_message: str = "", -): +) -> Path: """ Generate unit tests files :param test_contract: @@ -134,7 +134,7 @@ def generate_unit_test( # pylint: disable=too-many-arguments,too-many-branches return output_dir -def generate_migration(test_contract: str, output_dir: Path, owner_address: str): +def generate_migration(test_contract: str, output_dir: Path, owner_address: str) -> None: """ Generate migration file :param test_contract: diff --git a/slither/tools/properties/utils.py b/slither/tools/properties/utils.py index bdc5bdd68..5a0153211 100644 --- a/slither/tools/properties/utils.py +++ b/slither/tools/properties/utils.py @@ -12,7 +12,7 @@ def write_file( content: str, allow_overwrite: bool = True, discard_if_exist: bool = False, -): +) -> None: """ Write the content into output_dir/filename :param output_dir: diff --git a/slither/tools/similarity/__main__.py b/slither/tools/similarity/__main__.py index 21ba88681..86673fccd 100755 --- a/slither/tools/similarity/__main__.py +++ b/slither/tools/similarity/__main__.py @@ -17,7 +17,7 @@ logger = logging.getLogger("Slither-simil") modes = ["info", "test", "train", "plot"] -def parse_args(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector" ) @@ -78,7 +78,7 @@ def parse_args(): ################################################################################### -def main(): +def main() -> None: args = parse_args() default_log = logging.INFO diff --git a/slither/tools/similarity/encode.py b/slither/tools/similarity/encode.py index c69de91b8..9889644fb 100644 --- a/slither/tools/similarity/encode.py +++ b/slither/tools/similarity/encode.py @@ -74,7 +74,7 @@ def parse_target(target): return None -def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): +def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs): r = {} if infile.endswith(".npz"): r = load_cache(infile, nsamples=nsamples) diff --git a/slither/tools/similarity/info.py b/slither/tools/similarity/info.py index 95aadea6a..c9f9753d1 100644 --- a/slither/tools/similarity/info.py +++ b/slither/tools/similarity/info.py @@ -1,3 +1,4 @@ +import argparse import logging import sys import os.path @@ -10,7 +11,7 @@ logging.basicConfig() logger = logging.getLogger("Slither-simil") -def info(args): +def info(args: argparse.Namespace) -> None: try: diff --git a/slither/tools/similarity/plot.py b/slither/tools/similarity/plot.py index bbdeec2cf..f11e92129 100644 --- a/slither/tools/similarity/plot.py +++ b/slither/tools/similarity/plot.py @@ -1,3 +1,4 @@ +import argparse import logging import random import sys @@ -23,7 +24,7 @@ except ImportError: logger = logging.getLogger("Slither-simil") -def plot(args): # pylint: disable=too-many-locals +def plot(args: argparse.Namespace) -> None: # pylint: disable=too-many-locals if decomposition is None or plt is None: logger.error( diff --git a/slither/tools/similarity/train.py b/slither/tools/similarity/train.py index a0d06c944..ccadf4926 100755 --- a/slither/tools/similarity/train.py +++ b/slither/tools/similarity/train.py @@ -1,3 +1,4 @@ +import argparse import logging import os import sys @@ -10,7 +11,7 @@ from slither.tools.similarity.model import train_unsupervised logger = logging.getLogger("Slither-simil") -def train(args): # pylint: disable=too-many-locals +def train(args: argparse.Namespace) -> None: # pylint: disable=too-many-locals try: last_data_train_filename = "last_data_train.txt" diff --git a/slither/tools/slither_format/__main__.py b/slither/tools/slither_format/__main__.py index a3d63d922..85c0a3917 100644 --- a/slither/tools/slither_format/__main__.py +++ b/slither/tools/slither_format/__main__.py @@ -23,7 +23,7 @@ available_detectors = [ ] -def parse_args(): +def parse_args() -> argparse.Namespace: """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. @@ -90,7 +90,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: # ------------------------------ # Usage: python3 -m slither_format filename # Example: python3 -m slither_format contract.sol diff --git a/slither/tools/slither_format/slither_format.py b/slither/tools/slither_format/slither_format.py index 794307100..c165b3fbb 100644 --- a/slither/tools/slither_format/slither_format.py +++ b/slither/tools/slither_format/slither_format.py @@ -1,5 +1,9 @@ import logging from pathlib import Path +from typing import Type, List, Dict + +from slither import Slither +from slither.detectors.abstract_detector import AbstractDetector from slither.detectors.variables.unused_state_variables import UnusedStateVars from slither.detectors.attributes.incorrect_solc import IncorrectSolc from slither.detectors.attributes.constant_pragma import ConstantPragma @@ -13,7 +17,7 @@ from slither.utils.colors import yellow logging.basicConfig(level=logging.INFO) logger = logging.getLogger("Slither.Format") -all_detectors = { +all_detectors: Dict[str, Type[AbstractDetector]] = { "unused-state": UnusedStateVars, "solc-version": IncorrectSolc, "pragma": ConstantPragma, @@ -25,7 +29,7 @@ all_detectors = { } -def slither_format(slither, **kwargs): # pylint: disable=too-many-locals +def slither_format(slither: Slither, **kwargs: Dict) -> None: # pylint: disable=too-many-locals """' Keyword Args: detectors_to_run (str): Comma-separated list of detectors, defaults to all @@ -85,9 +89,11 @@ def slither_format(slither, **kwargs): # pylint: disable=too-many-locals ################################################################################### -def choose_detectors(detectors_to_run, detectors_to_exclude): +def choose_detectors( + detectors_to_run: str, detectors_to_exclude: str +) -> List[Type[AbstractDetector]]: # If detectors are specified, run only these ones - cls_detectors_to_run = [] + cls_detectors_to_run: List[Type[AbstractDetector]] = [] exclude = detectors_to_exclude.split(",") if detectors_to_run == "all": for key, detector in all_detectors.items(): @@ -114,7 +120,7 @@ def choose_detectors(detectors_to_run, detectors_to_exclude): ################################################################################### -def print_patches(number_of_slither_results, patches): +def print_patches(number_of_slither_results: int, patches: Dict) -> None: logger.info("Number of Slither results: " + str(number_of_slither_results)) number_of_patches = 0 for file in patches: @@ -130,7 +136,7 @@ def print_patches(number_of_slither_results, patches): logger.info("Location end: " + str(patch["end"])) -def print_patches_json(number_of_slither_results, patches): +def print_patches_json(number_of_slither_results: int, patches: Dict) -> None: print("{", end="") print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",') print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",") diff --git a/slither/tools/upgradeability/__main__.py b/slither/tools/upgradeability/__main__.py index 6cc953015..d772029d0 100644 --- a/slither/tools/upgradeability/__main__.py +++ b/slither/tools/upgradeability/__main__.py @@ -3,10 +3,13 @@ import inspect import json import logging import sys +from typing import List, Any, Type, Dict, Tuple, Union, Sequence, Optional from crytic_compile import cryticparser + from slither import Slither +from slither.core.declarations import Contract from slither.exceptions import SlitherException from slither.utils.colors import red from slither.utils.output import output_to_json @@ -24,7 +27,7 @@ logger: logging.Logger = logging.getLogger("Slither") logger.setLevel(logging.INFO) -def parse_args(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.", usage="slither-check-upgradeability contract.sol ContractName", @@ -93,21 +96,27 @@ def parse_args(): ################################################################################### -def _get_checks(): - detectors = [getattr(all_checks, name) for name in dir(all_checks)] - detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)] +def _get_checks() -> List[Type[AbstractCheck]]: + detectors_ = [getattr(all_checks, name) for name in dir(all_checks)] + detectors: List[Type[AbstractCheck]] = [ + c for c in detectors_ if inspect.isclass(c) and issubclass(c, AbstractCheck) + ] return detectors class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs checks = _get_checks() output_detectors(checks) parser.exit() class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs checks = _get_checks() detector_types_json = output_detectors_json(checks) print(json.dumps(detector_types_json)) @@ -116,48 +125,64 @@ class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-meth class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods def __call__( - self, parser, args, values, option_string=None - ): # pylint: disable=signature-differs + self, + parser: Any, + args: Any, + values: Optional[Union[str, Sequence[Any]]], + option_string: Any = None, + ) -> None: # pylint: disable=signature-differs checks = _get_checks() + assert isinstance(values, str) output_to_markdown(checks, values) parser.exit() class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods def __call__( - self, parser, args, values, option_string=None - ): # pylint: disable=signature-differs + self, + parser: Any, + args: Any, + values: Optional[Union[str, Sequence[Any]]], + option_string: Any = None, + ) -> Any: # pylint: disable=signature-differs checks = _get_checks() + assert isinstance(values, str) output_wiki(checks, values) parser.exit() -def _run_checks(detectors): - results = [d.check() for d in detectors] - results = [r for r in results if r] - results = [item for sublist in results for item in sublist] # flatten +def _run_checks(detectors: List[AbstractCheck]) -> List[Dict]: + results_ = [d.check() for d in detectors] + results_ = [r for r in results_ if r] + results = [item for sublist in results_ for item in sublist] # flatten return results -def _checks_on_contract(detectors, contract): - detectors = [ +def _checks_on_contract( + detectors: List[Type[AbstractCheck]], contract: Contract +) -> Tuple[List[Dict], int]: + detectors_ = [ d(logger, contract) for d in detectors if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2) ] - return _run_checks(detectors), len(detectors) + return _run_checks(detectors_), len(detectors) -def _checks_on_contract_update(detectors, contract_v1, contract_v2): - detectors = [ +def _checks_on_contract_update( + detectors: List[Type[AbstractCheck]], contract_v1: Contract, contract_v2: Contract +) -> Tuple[List[Dict], int]: + detectors_ = [ d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2 ] - return _run_checks(detectors), len(detectors) + return _run_checks(detectors_), len(detectors) -def _checks_on_contract_and_proxy(detectors, contract, proxy): - detectors = [d(logger, contract, proxy=proxy) for d in detectors if d.REQUIRE_PROXY] - return _run_checks(detectors), len(detectors) +def _checks_on_contract_and_proxy( + detectors: List[Type[AbstractCheck]], contract: Contract, proxy: Contract +) -> Tuple[List[Dict], int]: + detectors_ = [d(logger, contract, proxy=proxy) for d in detectors if d.REQUIRE_PROXY] + return _run_checks(detectors_), len(detectors) # endregion @@ -168,8 +193,8 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy): ################################################################################### # pylint: disable=too-many-statements,too-many-branches,too-many-locals -def main(): - json_results = { +def main() -> None: + json_results: Dict = { "proxy-present": False, "contract_v2-present": False, "detectors": [], diff --git a/slither/tools/upgradeability/utils/command_line.py b/slither/tools/upgradeability/utils/command_line.py index 6ab3f82f9..88b61ceed 100644 --- a/slither/tools/upgradeability/utils/command_line.py +++ b/slither/tools/upgradeability/utils/command_line.py @@ -1,8 +1,10 @@ -from slither.tools.upgradeability.checks.abstract_checks import classification_txt +from typing import List, Union, Dict, Type + +from slither.tools.upgradeability.checks.abstract_checks import classification_txt, AbstractCheck from slither.utils.myprettytable import MyPrettyTable -def output_wiki(detector_classes, filter_wiki): +def output_wiki(detector_classes: List[Type[AbstractCheck]], filter_wiki: str) -> None: # Sort by impact, confidence, and name detectors_list = sorted( detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT) @@ -31,7 +33,7 @@ def output_wiki(detector_classes, filter_wiki): print(recommendation) -def output_detectors(detector_classes): +def output_detectors(detector_classes: List[Type[AbstractCheck]]) -> None: detectors_list = [] for detector in detector_classes: argument = detector.ARGUMENT @@ -48,7 +50,7 @@ def output_detectors(detector_classes): for (argument, help_info, impact, proxy, v2) in detectors_list: table.add_row( [ - idx, + str(idx), argument, help_info, classification_txt[impact], @@ -60,8 +62,8 @@ def output_detectors(detector_classes): print(table) -def output_to_markdown(detector_classes, _filter_wiki): - def extract_help(cls): +def output_to_markdown(detector_classes: List[Type[AbstractCheck]], _filter_wiki: str) -> None: + def extract_help(cls: AbstractCheck) -> str: if cls.WIKI == "": return cls.HELP return f"[{cls.HELP}]({cls.WIKI})" @@ -85,7 +87,9 @@ def output_to_markdown(detector_classes, _filter_wiki): idx = idx + 1 -def output_detectors_json(detector_classes): +def output_detectors_json( + detector_classes: List[Type[AbstractCheck]], +) -> List[Dict[str, Union[str, int]]]: detectors_list = [] for detector in detector_classes: argument = detector.ARGUMENT @@ -110,7 +114,7 @@ def output_detectors_json(detector_classes): # Sort by impact, confidence, and name detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) idx = 1 - table = [] + table: List[Dict[str, Union[str, int]]] = [] for ( argument, help_info, diff --git a/slither/utils/command_line.py b/slither/utils/command_line.py index d264e65ea..c2fef5eca 100644 --- a/slither/utils/command_line.py +++ b/slither/utils/command_line.py @@ -1,13 +1,17 @@ +import argparse import json import os import re import logging from collections import defaultdict +from typing import Dict, List, Type, Union + from crytic_compile.cryticparser.defaults import ( DEFAULTS_FLAG_IN_CONFIG as DEFAULTS_FLAG_IN_CONFIG_CRYTIC_COMPILE, ) -from slither.detectors.abstract_detector import classification_txt +from slither.detectors.abstract_detector import classification_txt, AbstractDetector +from slither.printers.abstract_printer import AbstractPrinter from slither.utils.colors import yellow, red from slither.utils.myprettytable import MyPrettyTable @@ -54,7 +58,7 @@ defaults_flag_in_config = { } -def read_config_file(args): +def read_config_file(args: argparse.Namespace) -> None: # No config file was provided as an argument if args.config_file is None: # Check wether the default config file is present @@ -83,8 +87,12 @@ def read_config_file(args): logger.error(yellow("Falling back to the default settings...")) -def output_to_markdown(detector_classes, printer_classes, filter_wiki): - def extract_help(cls): +def output_to_markdown( + detector_classes: List[Type[AbstractDetector]], + printer_classes: List[Type[AbstractPrinter]], + filter_wiki: str, +) -> None: + def extract_help(cls: Union[Type[AbstractDetector], Type[AbstractPrinter]]) -> str: if cls.WIKI == "": return cls.HELP return f"[{cls.HELP}]({cls.WIKI})" @@ -127,7 +135,7 @@ def output_to_markdown(detector_classes, printer_classes, filter_wiki): idx = idx + 1 -def get_level(l): +def get_level(l: str) -> int: tab = l.count("\t") + 1 if l.replace("\t", "").startswith(" -"): tab = tab + 1 @@ -136,7 +144,7 @@ def get_level(l): return tab -def convert_result_to_markdown(txt): +def convert_result_to_markdown(txt: str) -> str: # -1 to remove the last \n lines = txt[0:-1].split("\n") ret = [] @@ -154,16 +162,21 @@ def convert_result_to_markdown(txt): return "".join(ret) -def output_results_to_markdown(all_results, checklistlimit: str): +def output_results_to_markdown(all_results: List[Dict], checklistlimit: str) -> None: checks = defaultdict(list) - info = defaultdict(dict) - for results in all_results: - checks[results["check"]].append(results) - info[results["check"]] = {"impact": results["impact"], "confidence": results["confidence"]} + info: Dict = defaultdict(dict) + for results_ in all_results: + checks[results_["check"]].append(results_) + info[results_["check"]] = { + "impact": results_["impact"], + "confidence": results_["confidence"], + } print("Summary") - for check in checks: - print(f" - [{check}](#{check}) ({len(checks[check])} results) ({info[check]['impact']})") + for check_ in checks: + print( + f" - [{check_}](#{check_}) ({len(checks[check_])} results) ({info[check_]['impact']})" + ) counter = 0 for (check, results) in checks.items(): @@ -185,8 +198,7 @@ def output_results_to_markdown(all_results, checklistlimit: str): print(f"**More results were found, check [{checklistlimit}]({checklistlimit})**") -def output_wiki(detector_classes, filter_wiki): - detectors_list = [] +def output_wiki(detector_classes: List[Type[AbstractDetector]], filter_wiki: str) -> None: # Sort by impact, confidence, and name detectors_list = sorted( @@ -223,7 +235,7 @@ def output_wiki(detector_classes, filter_wiki): print(recommendation) -def output_detectors(detector_classes): +def output_detectors(detector_classes: List[Type[AbstractDetector]]) -> None: detectors_list = [] for detector in detector_classes: argument = detector.ARGUMENT @@ -242,12 +254,15 @@ def output_detectors(detector_classes): ) idx = 1 for (argument, help_info, impact, confidence) in detectors_list: - table.add_row([idx, argument, help_info, classification_txt[impact], confidence]) + table.add_row([str(idx), argument, help_info, classification_txt[impact], confidence]) idx = idx + 1 print(table) -def output_detectors_json(detector_classes): # pylint: disable=too-many-locals +# pylint: disable=too-many-locals +def output_detectors_json( + detector_classes: List[Type[AbstractDetector]], +) -> List[Dict]: detectors_list = [] for detector in detector_classes: argument = detector.ARGUMENT @@ -307,7 +322,7 @@ def output_detectors_json(detector_classes): # pylint: disable=too-many-locals return table -def output_printers(printer_classes): +def output_printers(printer_classes: List[Type[AbstractPrinter]]) -> None: printers_list = [] for printer in printer_classes: argument = printer.ARGUMENT @@ -319,12 +334,12 @@ def output_printers(printer_classes): printers_list = sorted(printers_list, key=lambda element: (element[0])) idx = 1 for (argument, help_info) in printers_list: - table.add_row([idx, argument, help_info]) + table.add_row([str(idx), argument, help_info]) idx = idx + 1 print(table) -def output_printers_json(printer_classes): +def output_printers_json(printer_classes: List[Type[AbstractPrinter]]) -> List[Dict]: printers_list = [] for printer in printer_classes: argument = printer.ARGUMENT diff --git a/slither/utils/output.py b/slither/utils/output.py index 6296e35d3..5db6492db 100644 --- a/slither/utils/output.py +++ b/slither/utils/output.py @@ -4,7 +4,7 @@ import json import logging import zipfile from collections import OrderedDict -from typing import Optional, Dict, List, Union, Any, TYPE_CHECKING +from typing import Optional, Dict, List, Union, Any, TYPE_CHECKING, Type from zipfile import ZipFile from pkg_resources import require @@ -129,7 +129,7 @@ def _output_result_to_sarif( def output_to_sarif( - filename: Optional[str], results: Dict, detectors_classes: List["AbstractDetector"] + filename: Optional[str], results: Dict, detectors_classes: List[Type["AbstractDetector"]] ) -> None: """ diff --git a/slither/utils/output_capture.py b/slither/utils/output_capture.py index 5282afb91..aec170d7f 100644 --- a/slither/utils/output_capture.py +++ b/slither/utils/output_capture.py @@ -28,7 +28,7 @@ class StandardOutputCapture: original_logger_handlers = None @staticmethod - def enable(block_original=True): + def enable(block_original: bool = True) -> None: """ Redirects stdout and stderr to a capturable StringIO. :param block_original: If True, blocks all output to the original stream. If False, duplicates output. @@ -54,7 +54,7 @@ class StandardOutputCapture: root_logger.handlers = [logging.StreamHandler(sys.stderr)] @staticmethod - def disable(): + def disable() -> None: """ Disables redirection of stdout/stderr, if previously enabled. :return: None @@ -78,7 +78,7 @@ class StandardOutputCapture: StandardOutputCapture.original_logger_handlers = None @staticmethod - def get_stdout_output(): + def get_stdout_output() -> str: """ Obtains the output from the currently set stdout :return: Returns stdout output as a string @@ -87,7 +87,7 @@ class StandardOutputCapture: return sys.stdout.read() @staticmethod - def get_stderr_output(): + def get_stderr_output() -> str: """ Obtains the output from the currently set stderr :return: Returns stderr output as a string