Add more types hints

pull/1640/head
Feist Josselin 2 years ago
parent 4c976d5af5
commit b57be52818
  1. 6
      slither/core/slither_core.py
  2. 3
      slither/slither.py
  3. 2
      slither/tools/doctor/__main__.py
  4. 8
      slither/tools/erc_conformance/__main__.py
  5. 8
      slither/tools/erc_conformance/erc/erc1155.py
  6. 8
      slither/tools/erc_conformance/erc/erc20.py
  7. 45
      slither/tools/kspec_coverage/analysis.py
  8. 2
      slither/tools/mutator/__main__.py
  9. 8
      slither/tools/mutator/mutators/MIA.py
  10. 11
      slither/tools/mutator/mutators/MVIE.py
  11. 10
      slither/tools/mutator/mutators/MVIV.py
  12. 2
      slither/tools/mutator/utils/command_line.py
  13. 2
      slither/tools/similarity/__main__.py
  14. 11
      slither/tools/similarity/encode.py
  15. 3
      slither/tools/similarity/test.py

@ -443,7 +443,7 @@ class SlitherCore(Context):
return True return True
def load_previous_results(self): def load_previous_results(self) -> None:
filename = self._previous_results_filename filename = self._previous_results_filename
try: try:
if os.path.isfile(filename): if os.path.isfile(filename):
@ -456,7 +456,7 @@ class SlitherCore(Context):
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
logger.error(red(f"Impossible to decode {filename}. Consider removing the file")) logger.error(red(f"Impossible to decode {filename}. Consider removing the file"))
def write_results_to_hide(self): def write_results_to_hide(self) -> None:
if not self._results_to_hide: if not self._results_to_hide:
return return
filename = self._previous_results_filename filename = self._previous_results_filename
@ -464,7 +464,7 @@ class SlitherCore(Context):
results = self._results_to_hide + self._previous_results results = self._results_to_hide + self._previous_results
json.dump(results, f) json.dump(results, f)
def save_results_to_hide(self, results: List[Dict]): def save_results_to_hide(self, results: List[Dict]) -> None:
self._results_to_hide += results self._results_to_hide += results
def add_path_to_filter(self, path: str): def add_path_to_filter(self, path: str):

@ -11,6 +11,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
from slither.exceptions import SlitherError from slither.exceptions import SlitherError
from slither.printers.abstract_printer import AbstractPrinter from slither.printers.abstract_printer import AbstractPrinter
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
from slither.utils.output import Output
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
logging.basicConfig() logging.basicConfig()
@ -206,7 +207,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
self.write_results_to_hide() self.write_results_to_hide()
return results return results
def run_printers(self): def run_printers(self) -> List[Output]:
""" """
:return: List of registered printers outputs. :return: List of registered printers outputs.
""" """

@ -26,7 +26,7 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
# log on stdout to keep output in order # log on stdout to keep output in order
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)

@ -1,10 +1,11 @@
import argparse import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List from typing import Any, Dict, List, Callable, Optional
from crytic_compile import cryticparser from crytic_compile import cryticparser
from slither import Slither from slither import Slither
from slither.core.declarations import Contract
from slither.utils.erc import ERCS from slither.utils.erc import ERCS
from slither.utils.output import output_to_json from slither.utils.output import output_to_json
from .erc.ercs import generic_erc_checks from .erc.ercs import generic_erc_checks
@ -24,7 +25,10 @@ logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter) logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
ADDITIONAL_CHECKS = {"ERC20": check_erc20, "ERC1155": check_erc1155} ADDITIONAL_CHECKS: Dict[str, Callable[[Contract, Dict[str, List]], Dict[str, List]]] = {
"ERC20": check_erc20,
"ERC1155": check_erc1155,
}
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:

@ -1,12 +1,14 @@
import logging import logging
from typing import Dict, List, Optional
from slither.core.declarations import Contract
from slither.slithir.operations import EventCall from slither.slithir.operations import EventCall
from slither.utils import output from slither.utils import output
logger = logging.getLogger("Slither-conformance") logger = logging.getLogger("Slither-conformance")
def events_safeBatchTransferFrom(contract, ret): def events_safeBatchTransferFrom(contract: Contract, ret: Dict[str, List]) -> None:
function = contract.get_function_from_signature( function = contract.get_function_from_signature(
"safeBatchTransferFrom(address,address,uint256[],uint256[],bytes)" "safeBatchTransferFrom(address,address,uint256[],uint256[],bytes)"
) )
@ -44,7 +46,9 @@ def events_safeBatchTransferFrom(contract, ret):
) )
def check_erc1155(contract, ret, explored=None): def check_erc1155(
contract: Contract, ret: Dict[str, List], explored: Optional[bool] = None
) -> Dict[str, List]:
if explored is None: if explored is None:
explored = set() explored = set()

@ -1,11 +1,13 @@
import logging import logging
from typing import Dict, List, Optional
from slither.core.declarations import Contract
from slither.utils import output from slither.utils import output
logger = logging.getLogger("Slither-conformance") logger = logging.getLogger("Slither-conformance")
def approval_race_condition(contract, ret): def approval_race_condition(contract: Contract, ret: Dict[str, List]) -> None:
increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)") increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)")
if not increaseAllowance: if not increaseAllowance:
@ -27,7 +29,9 @@ def approval_race_condition(contract, ret):
) )
def check_erc20(contract, ret, explored=None): def check_erc20(
contract: Contract, ret: Dict[str, List], explored: Optional[bool] = None
) -> Dict[str, List]:
if explored is None: if explored is None:
explored = set() explored = set()

@ -1,8 +1,12 @@
import re import re
import logging import logging
from typing import Set, Tuple from argparse import Namespace
from typing import Set, Tuple, List, Dict, Union, Optional, Callable
from slither.core.declarations import Function from slither import Slither
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import Function, FunctionContract
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.utils.colors import yellow, green, red from slither.utils.colors import yellow, green, red
from slither.utils import output from slither.utils import output
@ -54,13 +58,15 @@ def _get_all_covered_kspec_functions(target: str) -> Set[Tuple[str, str]]:
return covered_functions return covered_functions
def _get_slither_functions(slither): def _get_slither_functions(
slither: SlitherCompilationUnit,
) -> Dict[Tuple[str, str], Union[FunctionContract, StateVariable]]:
# Use contract == contract_declarer to avoid dupplicate # Use contract == contract_declarer to avoid dupplicate
all_functions_declared = [ all_functions_declared: List[Union[FunctionContract, StateVariable]] = [
f f
for f in slither.functions for f in slither.functions
if ( if (
f.contract == f.contract_declarer (isinstance(f, FunctionContract) and f.contract == f.contract_declarer)
and f.is_implemented and f.is_implemented
and not f.is_constructor and not f.is_constructor
and not f.is_constructor_variables and not f.is_constructor_variables
@ -79,7 +85,12 @@ def _get_slither_functions(slither):
return slither_functions return slither_functions
def _generate_output(kspec, message, color, generate_json): def _generate_output(
kspec: List[Union[FunctionContract, StateVariable]],
message: str,
color: Callable[[str], str],
generate_json: bool,
) -> Optional[Dict]:
info = "" info = ""
for function in kspec: for function in kspec:
info += f"{message} {function.contract.name}.{function.full_name}\n" info += f"{message} {function.contract.name}.{function.full_name}\n"
@ -94,7 +105,9 @@ def _generate_output(kspec, message, color, generate_json):
return None return None
def _generate_output_unresolved(kspec, message, color, generate_json): def _generate_output_unresolved(
kspec: Set[Tuple[str, str]], message: str, color: Callable[[str], str], generate_json: bool
) -> Optional[Dict]:
info = "" info = ""
for contract, function in kspec: for contract, function in kspec:
info += f"{message} {contract}.{function}\n" info += f"{message} {contract}.{function}\n"
@ -107,17 +120,19 @@ def _generate_output_unresolved(kspec, message, color, generate_json):
return None return None
def _run_coverage_analysis(args, slither, kspec_functions): def _run_coverage_analysis(
args, slither: SlitherCompilationUnit, kspec_functions: Set[Tuple[str, str]]
) -> None:
# Collect all slither functions # Collect all slither functions
slither_functions = _get_slither_functions(slither) slither_functions = _get_slither_functions(slither)
# Determine which klab specs were not resolved. # Determine which klab specs were not resolved.
slither_functions_set = set(slither_functions) slither_functions_set = set(slither_functions)
kspec_functions_resolved = kspec_functions & slither_functions_set kspec_functions_resolved = kspec_functions & slither_functions_set
kspec_functions_unresolved = kspec_functions - kspec_functions_resolved kspec_functions_unresolved: Set[Tuple[str, str]] = kspec_functions - kspec_functions_resolved
kspec_missing = [] kspec_missing: List[Union[FunctionContract, StateVariable]] = []
kspec_present = [] kspec_present: List[Union[FunctionContract, StateVariable]] = []
for slither_func_desc in sorted(slither_functions_set): for slither_func_desc in sorted(slither_functions_set):
slither_func = slither_functions[slither_func_desc] slither_func = slither_functions[slither_func_desc]
@ -130,13 +145,13 @@ def _run_coverage_analysis(args, slither, kspec_functions):
logger.info("## Check for functions coverage") logger.info("## Check for functions coverage")
json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json) json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json)
json_kspec_missing_functions = _generate_output( json_kspec_missing_functions = _generate_output(
[f for f in kspec_missing if isinstance(f, Function)], [f for f in kspec_missing if isinstance(f, FunctionContract)],
"[ ] (Missing function)", "[ ] (Missing function)",
red, red,
args.json, args.json,
) )
json_kspec_missing_variables = _generate_output( json_kspec_missing_variables = _generate_output(
[f for f in kspec_missing if isinstance(f, Variable)], [f for f in kspec_missing if isinstance(f, StateVariable)],
"[ ] (Missing variable)", "[ ] (Missing variable)",
yellow, yellow,
args.json, args.json,
@ -159,11 +174,11 @@ def _run_coverage_analysis(args, slither, kspec_functions):
) )
def run_analysis(args, slither, kspec_arg): def run_analysis(args: Namespace, slither: SlitherCompilationUnit, kspec_arg: str) -> None:
# Get all of our kspec'd functions (tuple(contract_name, function_name)). # Get all of our kspec'd functions (tuple(contract_name, function_name)).
if "," in kspec_arg: if "," in kspec_arg:
kspecs = kspec_arg.split(",") kspecs = kspec_arg.split(",")
kspec_functions = set() kspec_functions: Set[Tuple[str, str]] = set()
for kspec in kspecs: for kspec in kspecs:
kspec_functions |= _get_all_covered_kspec_functions(kspec) kspec_functions |= _get_all_covered_kspec_functions(kspec)
else: else:

@ -72,7 +72,7 @@ class ListMutators(argparse.Action): # pylint: disable=too-few-public-methods
################################################################################### ###################################################################################
def main(): def main() -> None:
args = parse_args() args = parse_args()

@ -1,3 +1,5 @@
from typing import Dict
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType
from slither.formatters.utils.patches import create_patch from slither.formatters.utils.patches import create_patch
from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass
@ -9,13 +11,13 @@ class MIA(AbstractMutator): # pylint: disable=too-few-public-methods
FAULTCLASS = FaultClass.Checking FAULTCLASS = FaultClass.Checking
FAULTNATURE = FaultNature.Missing FAULTNATURE = FaultNature.Missing
def _mutate(self): def _mutate(self) -> Dict:
result = {} result: Dict = {}
for contract in self.slither.contracts: for contract in self.slither.contracts:
for function in contract.functions_declared + contract.modifiers_declared: for function in contract.functions_declared + list(contract.modifiers_declared):
for node in function.nodes: for node in function.nodes:
if node.type == NodeType.IF: if node.type == NodeType.IF:

@ -1,4 +1,7 @@
from typing import Dict
from slither.core.expressions import Literal from slither.core.expressions import Literal
from slither.core.variables.variable import Variable
from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass
from slither.tools.mutator.utils.generic_patching import remove_assignement from slither.tools.mutator.utils.generic_patching import remove_assignement
@ -9,10 +12,10 @@ class MVIE(AbstractMutator): # pylint: disable=too-few-public-methods
FAULTCLASS = FaultClass.Assignement FAULTCLASS = FaultClass.Assignement
FAULTNATURE = FaultNature.Missing FAULTNATURE = FaultNature.Missing
def _mutate(self): def _mutate(self) -> Dict:
result = {}
result: Dict = {}
variable: Variable
for contract in self.slither.contracts: for contract in self.slither.contracts:
# Create fault for state variables declaration # Create fault for state variables declaration
@ -25,7 +28,7 @@ class MVIE(AbstractMutator): # pylint: disable=too-few-public-methods
if not isinstance(variable.expression, Literal): if not isinstance(variable.expression, Literal):
remove_assignement(variable, contract, result) remove_assignement(variable, contract, result)
for function in contract.functions_declared + contract.modifiers_declared: for function in contract.functions_declared + list(contract.modifiers_declared):
for variable in function.local_variables: for variable in function.local_variables:
if variable.initialized and not isinstance(variable.expression, Literal): if variable.initialized and not isinstance(variable.expression, Literal):
remove_assignement(variable, contract, result) remove_assignement(variable, contract, result)

@ -1,4 +1,7 @@
from typing import Dict
from slither.core.expressions import Literal from slither.core.expressions import Literal
from slither.core.variables.variable import Variable
from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass
from slither.tools.mutator.utils.generic_patching import remove_assignement from slither.tools.mutator.utils.generic_patching import remove_assignement
@ -9,9 +12,10 @@ class MVIV(AbstractMutator): # pylint: disable=too-few-public-methods
FAULTCLASS = FaultClass.Assignement FAULTCLASS = FaultClass.Assignement
FAULTNATURE = FaultNature.Missing FAULTNATURE = FaultNature.Missing
def _mutate(self): def _mutate(self) -> Dict:
result = {} result: Dict = {}
variable: Variable
for contract in self.slither.contracts: for contract in self.slither.contracts:
@ -25,7 +29,7 @@ class MVIV(AbstractMutator): # pylint: disable=too-few-public-methods
if isinstance(variable.expression, Literal): if isinstance(variable.expression, Literal):
remove_assignement(variable, contract, result) remove_assignement(variable, contract, result)
for function in contract.functions_declared + contract.modifiers_declared: for function in contract.functions_declared + list(contract.modifiers_declared):
for variable in function.local_variables: for variable in function.local_variables:
if variable.initialized and isinstance(variable.expression, Literal): if variable.initialized and isinstance(variable.expression, Literal):
remove_assignement(variable, contract, result) remove_assignement(variable, contract, result)

@ -1,7 +1,7 @@
from slither.utils.myprettytable import MyPrettyTable from slither.utils.myprettytable import MyPrettyTable
def output_mutators(mutators_classes): def output_mutators(mutators_classes: List[Type[AbstractMutator]]) -> None:
mutators_list = [] mutators_list = []
for detector in mutators_classes: for detector in mutators_classes:
argument = detector.NAME argument = detector.NAME

@ -7,7 +7,7 @@ import sys
from crytic_compile import cryticparser from crytic_compile import cryticparser
from slither.tools.similarity.info import info from slither.tools.similarity.info import info
from slither.tools.similarity.test import test from slither.tools.similarity import test
from slither.tools.similarity.train import train from slither.tools.similarity.train import train
from slither.tools.similarity.plot import plot from slither.tools.similarity.plot import plot

@ -1,5 +1,6 @@
import logging import logging
import os import os
from typing import Optional, Tuple, List
from slither import Slither from slither import Slither
from slither.core.declarations import ( from slither.core.declarations import (
@ -60,7 +61,7 @@ slither_logger = logging.getLogger("Slither")
slither_logger.setLevel(logging.CRITICAL) slither_logger.setLevel(logging.CRITICAL)
def parse_target(target): def parse_target(target: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
if target is None: if target is None:
return None, None return None, None
@ -68,9 +69,9 @@ def parse_target(target):
if len(parts) == 1: if len(parts) == 1:
return None, parts[0] return None, parts[0]
if len(parts) == 2: if len(parts) == 2:
return parts return parts[0], parts[1]
simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'")
return None return None, None
def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs): def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs):
@ -88,7 +89,9 @@ def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs):
return r return r
def load_contracts(dirname, ext=None, nsamples=None): def load_contracts(
dirname: str, ext: Optional[str] = None, nsamples: Optional[int] = None
) -> List[str]:
r = [] r = []
walk = list(os.walk(dirname)) walk = list(os.walk(dirname))
for x, y, files in walk: for x, y, files in walk:

@ -2,6 +2,7 @@ import logging
import operator import operator
import sys import sys
import traceback import traceback
from argparse import Namespace
from slither.tools.similarity.encode import encode_contract, load_and_encode, parse_target from slither.tools.similarity.encode import encode_contract, load_and_encode, parse_target
from slither.tools.similarity.model import load_model from slither.tools.similarity.model import load_model
@ -10,7 +11,7 @@ from slither.tools.similarity.similarity import similarity
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
def test(args): def est(args: Namespace) -> None:
try: try:
model = args.model model = args.model

Loading…
Cancel
Save