diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index 66b8fc430..cba86e56e 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -443,7 +443,7 @@ class SlitherCore(Context): return True - def load_previous_results(self): + def load_previous_results(self) -> None: filename = self._previous_results_filename try: if os.path.isfile(filename): @@ -456,7 +456,7 @@ class SlitherCore(Context): except json.decoder.JSONDecodeError: logger.error(red(f"Impossible to decode {filename}. Consider removing the file")) - def write_results_to_hide(self): + def write_results_to_hide(self) -> None: if not self._results_to_hide: return filename = self._previous_results_filename @@ -464,7 +464,7 @@ class SlitherCore(Context): results = self._results_to_hide + self._previous_results json.dump(results, f) - def save_results_to_hide(self, results: List[Dict]): + def save_results_to_hide(self, results: List[Dict]) -> None: self._results_to_hide += results def add_path_to_filter(self, path: str): diff --git a/slither/slither.py b/slither/slither.py index 45d99906f..93cf16394 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -11,6 +11,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi from slither.exceptions import SlitherError from slither.printers.abstract_printer import AbstractPrinter from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc +from slither.utils.output import Output logger = logging.getLogger("Slither") logging.basicConfig() @@ -206,7 +207,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes self.write_results_to_hide() return results - def run_printers(self): + def run_printers(self) -> List[Output]: """ :return: List of registered printers outputs. """ diff --git a/slither/tools/doctor/__main__.py b/slither/tools/doctor/__main__.py index b9b4c5497..f401781a7 100644 --- a/slither/tools/doctor/__main__.py +++ b/slither/tools/doctor/__main__.py @@ -26,7 +26,7 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def main(): +def main() -> None: # log on stdout to keep output in order logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) diff --git a/slither/tools/erc_conformance/__main__.py b/slither/tools/erc_conformance/__main__.py index ef594a7c6..4f0c19f80 100644 --- a/slither/tools/erc_conformance/__main__.py +++ b/slither/tools/erc_conformance/__main__.py @@ -1,10 +1,11 @@ import argparse import logging from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Dict, List, Callable, Optional from crytic_compile import cryticparser from slither import Slither +from slither.core.declarations import Contract from slither.utils.erc import ERCS from slither.utils.output import output_to_json from .erc.ercs import generic_erc_checks @@ -24,7 +25,10 @@ logger.addHandler(ch) logger.handlers[0].setFormatter(formatter) logger.propagate = False -ADDITIONAL_CHECKS = {"ERC20": check_erc20, "ERC1155": check_erc1155} +ADDITIONAL_CHECKS: Dict[str, Callable[[Contract, Dict[str, List]], Dict[str, List]]] = { + "ERC20": check_erc20, + "ERC1155": check_erc1155, +} def parse_args() -> argparse.Namespace: diff --git a/slither/tools/erc_conformance/erc/erc1155.py b/slither/tools/erc_conformance/erc/erc1155.py index fceb4e242..34bb18bfb 100644 --- a/slither/tools/erc_conformance/erc/erc1155.py +++ b/slither/tools/erc_conformance/erc/erc1155.py @@ -1,12 +1,14 @@ import logging +from typing import Dict, List, Optional +from slither.core.declarations import Contract from slither.slithir.operations import EventCall from slither.utils import output logger = logging.getLogger("Slither-conformance") -def events_safeBatchTransferFrom(contract, ret): +def events_safeBatchTransferFrom(contract: Contract, ret: Dict[str, List]) -> None: function = contract.get_function_from_signature( "safeBatchTransferFrom(address,address,uint256[],uint256[],bytes)" ) @@ -44,7 +46,9 @@ def events_safeBatchTransferFrom(contract, ret): ) -def check_erc1155(contract, ret, explored=None): +def check_erc1155( + contract: Contract, ret: Dict[str, List], explored: Optional[bool] = None +) -> Dict[str, List]: if explored is None: explored = set() diff --git a/slither/tools/erc_conformance/erc/erc20.py b/slither/tools/erc_conformance/erc/erc20.py index 720b08322..6ee243515 100644 --- a/slither/tools/erc_conformance/erc/erc20.py +++ b/slither/tools/erc_conformance/erc/erc20.py @@ -1,11 +1,13 @@ import logging +from typing import Dict, List, Optional +from slither.core.declarations import Contract from slither.utils import output logger = logging.getLogger("Slither-conformance") -def approval_race_condition(contract, ret): +def approval_race_condition(contract: Contract, ret: Dict[str, List]) -> None: increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)") if not increaseAllowance: @@ -27,7 +29,9 @@ def approval_race_condition(contract, ret): ) -def check_erc20(contract, ret, explored=None): +def check_erc20( + contract: Contract, ret: Dict[str, List], explored: Optional[bool] = None +) -> Dict[str, List]: if explored is None: explored = set() diff --git a/slither/tools/kspec_coverage/analysis.py b/slither/tools/kspec_coverage/analysis.py index 3d513d22f..514bfa58f 100755 --- a/slither/tools/kspec_coverage/analysis.py +++ b/slither/tools/kspec_coverage/analysis.py @@ -1,8 +1,12 @@ import re import logging -from typing import Set, Tuple +from argparse import Namespace +from typing import Set, Tuple, List, Dict, Union, Optional, Callable -from slither.core.declarations import Function +from slither import Slither +from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.declarations import Function, FunctionContract +from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.utils.colors import yellow, green, red from slither.utils import output @@ -54,13 +58,15 @@ def _get_all_covered_kspec_functions(target: str) -> Set[Tuple[str, str]]: return covered_functions -def _get_slither_functions(slither): +def _get_slither_functions( + slither: SlitherCompilationUnit, +) -> Dict[Tuple[str, str], Union[FunctionContract, StateVariable]]: # Use contract == contract_declarer to avoid dupplicate - all_functions_declared = [ + all_functions_declared: List[Union[FunctionContract, StateVariable]] = [ f for f in slither.functions if ( - f.contract == f.contract_declarer + (isinstance(f, FunctionContract) and f.contract == f.contract_declarer) and f.is_implemented and not f.is_constructor and not f.is_constructor_variables @@ -79,7 +85,12 @@ def _get_slither_functions(slither): return slither_functions -def _generate_output(kspec, message, color, generate_json): +def _generate_output( + kspec: List[Union[FunctionContract, StateVariable]], + message: str, + color: Callable[[str], str], + generate_json: bool, +) -> Optional[Dict]: info = "" for function in kspec: info += f"{message} {function.contract.name}.{function.full_name}\n" @@ -94,7 +105,9 @@ def _generate_output(kspec, message, color, generate_json): return None -def _generate_output_unresolved(kspec, message, color, generate_json): +def _generate_output_unresolved( + kspec: Set[Tuple[str, str]], message: str, color: Callable[[str], str], generate_json: bool +) -> Optional[Dict]: info = "" for contract, function in kspec: info += f"{message} {contract}.{function}\n" @@ -107,17 +120,19 @@ def _generate_output_unresolved(kspec, message, color, generate_json): return None -def _run_coverage_analysis(args, slither, kspec_functions): +def _run_coverage_analysis( + args, slither: SlitherCompilationUnit, kspec_functions: Set[Tuple[str, str]] +) -> None: # Collect all slither functions slither_functions = _get_slither_functions(slither) # Determine which klab specs were not resolved. slither_functions_set = set(slither_functions) kspec_functions_resolved = kspec_functions & slither_functions_set - kspec_functions_unresolved = kspec_functions - kspec_functions_resolved + kspec_functions_unresolved: Set[Tuple[str, str]] = kspec_functions - kspec_functions_resolved - kspec_missing = [] - kspec_present = [] + kspec_missing: List[Union[FunctionContract, StateVariable]] = [] + kspec_present: List[Union[FunctionContract, StateVariable]] = [] for slither_func_desc in sorted(slither_functions_set): slither_func = slither_functions[slither_func_desc] @@ -130,13 +145,13 @@ def _run_coverage_analysis(args, slither, kspec_functions): logger.info("## Check for functions coverage") json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json) json_kspec_missing_functions = _generate_output( - [f for f in kspec_missing if isinstance(f, Function)], + [f for f in kspec_missing if isinstance(f, FunctionContract)], "[ ] (Missing function)", red, args.json, ) json_kspec_missing_variables = _generate_output( - [f for f in kspec_missing if isinstance(f, Variable)], + [f for f in kspec_missing if isinstance(f, StateVariable)], "[ ] (Missing variable)", yellow, args.json, @@ -159,11 +174,11 @@ def _run_coverage_analysis(args, slither, kspec_functions): ) -def run_analysis(args, slither, kspec_arg): +def run_analysis(args: Namespace, slither: SlitherCompilationUnit, kspec_arg: str) -> None: # Get all of our kspec'd functions (tuple(contract_name, function_name)). if "," in kspec_arg: kspecs = kspec_arg.split(",") - kspec_functions = set() + kspec_functions: Set[Tuple[str, str]] = set() for kspec in kspecs: kspec_functions |= _get_all_covered_kspec_functions(kspec) else: diff --git a/slither/tools/mutator/__main__.py b/slither/tools/mutator/__main__.py index 78b86d681..27e396d0b 100644 --- a/slither/tools/mutator/__main__.py +++ b/slither/tools/mutator/__main__.py @@ -72,7 +72,7 @@ class ListMutators(argparse.Action): # pylint: disable=too-few-public-methods ################################################################################### -def main(): +def main() -> None: args = parse_args() diff --git a/slither/tools/mutator/mutators/MIA.py b/slither/tools/mutator/mutators/MIA.py index 54ca0ec1c..405888f8b 100644 --- a/slither/tools/mutator/mutators/MIA.py +++ b/slither/tools/mutator/mutators/MIA.py @@ -1,3 +1,5 @@ +from typing import Dict + from slither.core.cfg.node import NodeType from slither.formatters.utils.patches import create_patch from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass @@ -9,13 +11,13 @@ class MIA(AbstractMutator): # pylint: disable=too-few-public-methods FAULTCLASS = FaultClass.Checking FAULTNATURE = FaultNature.Missing - def _mutate(self): + def _mutate(self) -> Dict: - result = {} + result: Dict = {} for contract in self.slither.contracts: - for function in contract.functions_declared + contract.modifiers_declared: + for function in contract.functions_declared + list(contract.modifiers_declared): for node in function.nodes: if node.type == NodeType.IF: diff --git a/slither/tools/mutator/mutators/MVIE.py b/slither/tools/mutator/mutators/MVIE.py index 8f8cc11bf..a16a8252e 100644 --- a/slither/tools/mutator/mutators/MVIE.py +++ b/slither/tools/mutator/mutators/MVIE.py @@ -1,4 +1,7 @@ +from typing import Dict + from slither.core.expressions import Literal +from slither.core.variables.variable import Variable from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass from slither.tools.mutator.utils.generic_patching import remove_assignement @@ -9,10 +12,10 @@ class MVIE(AbstractMutator): # pylint: disable=too-few-public-methods FAULTCLASS = FaultClass.Assignement FAULTNATURE = FaultNature.Missing - def _mutate(self): - - result = {} + def _mutate(self) -> Dict: + result: Dict = {} + variable: Variable for contract in self.slither.contracts: # Create fault for state variables declaration @@ -25,7 +28,7 @@ class MVIE(AbstractMutator): # pylint: disable=too-few-public-methods if not isinstance(variable.expression, Literal): remove_assignement(variable, contract, result) - for function in contract.functions_declared + contract.modifiers_declared: + for function in contract.functions_declared + list(contract.modifiers_declared): for variable in function.local_variables: if variable.initialized and not isinstance(variable.expression, Literal): remove_assignement(variable, contract, result) diff --git a/slither/tools/mutator/mutators/MVIV.py b/slither/tools/mutator/mutators/MVIV.py index dac34da28..d4a7c5486 100644 --- a/slither/tools/mutator/mutators/MVIV.py +++ b/slither/tools/mutator/mutators/MVIV.py @@ -1,4 +1,7 @@ +from typing import Dict + from slither.core.expressions import Literal +from slither.core.variables.variable import Variable from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass from slither.tools.mutator.utils.generic_patching import remove_assignement @@ -9,9 +12,10 @@ class MVIV(AbstractMutator): # pylint: disable=too-few-public-methods FAULTCLASS = FaultClass.Assignement FAULTNATURE = FaultNature.Missing - def _mutate(self): + def _mutate(self) -> Dict: - result = {} + result: Dict = {} + variable: Variable for contract in self.slither.contracts: @@ -25,7 +29,7 @@ class MVIV(AbstractMutator): # pylint: disable=too-few-public-methods if isinstance(variable.expression, Literal): remove_assignement(variable, contract, result) - for function in contract.functions_declared + contract.modifiers_declared: + for function in contract.functions_declared + list(contract.modifiers_declared): for variable in function.local_variables: if variable.initialized and isinstance(variable.expression, Literal): remove_assignement(variable, contract, result) diff --git a/slither/tools/mutator/utils/command_line.py b/slither/tools/mutator/utils/command_line.py index 9799fd488..fdef145f3 100644 --- a/slither/tools/mutator/utils/command_line.py +++ b/slither/tools/mutator/utils/command_line.py @@ -1,7 +1,7 @@ from slither.utils.myprettytable import MyPrettyTable -def output_mutators(mutators_classes): +def output_mutators(mutators_classes: List[Type[AbstractMutator]]) -> None: mutators_list = [] for detector in mutators_classes: argument = detector.NAME diff --git a/slither/tools/similarity/__main__.py b/slither/tools/similarity/__main__.py index 86673fccd..00caa3fbd 100755 --- a/slither/tools/similarity/__main__.py +++ b/slither/tools/similarity/__main__.py @@ -7,7 +7,7 @@ import sys from crytic_compile import cryticparser from slither.tools.similarity.info import info -from slither.tools.similarity.test import test +from slither.tools.similarity import test from slither.tools.similarity.train import train from slither.tools.similarity.plot import plot diff --git a/slither/tools/similarity/encode.py b/slither/tools/similarity/encode.py index d08086282..48700ec4a 100644 --- a/slither/tools/similarity/encode.py +++ b/slither/tools/similarity/encode.py @@ -1,5 +1,6 @@ import logging import os +from typing import Optional, Tuple, List from slither import Slither from slither.core.declarations import ( @@ -60,7 +61,7 @@ slither_logger = logging.getLogger("Slither") slither_logger.setLevel(logging.CRITICAL) -def parse_target(target): +def parse_target(target: Optional[str]) -> Tuple[Optional[str], Optional[str]]: if target is None: return None, None @@ -68,9 +69,9 @@ def parse_target(target): if len(parts) == 1: return None, parts[0] if len(parts) == 2: - return parts + return parts[0], parts[1] simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") - return None + return None, None def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs): @@ -88,7 +89,9 @@ def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs): return r -def load_contracts(dirname, ext=None, nsamples=None): +def load_contracts( + dirname: str, ext: Optional[str] = None, nsamples: Optional[int] = None +) -> List[str]: r = [] walk = list(os.walk(dirname)) for x, y, files in walk: diff --git a/slither/tools/similarity/test.py b/slither/tools/similarity/test.py index 76229d5bf..145e50321 100755 --- a/slither/tools/similarity/test.py +++ b/slither/tools/similarity/test.py @@ -2,6 +2,7 @@ import logging import operator import sys import traceback +from argparse import Namespace from slither.tools.similarity.encode import encode_contract, load_and_encode, parse_target from slither.tools.similarity.model import load_model @@ -10,7 +11,7 @@ from slither.tools.similarity.similarity import similarity logger = logging.getLogger("Slither-simil") -def test(args): +def est(args: Namespace) -> None: try: model = args.model