Add more types hints

pull/1640/head
Feist Josselin 2 years ago
parent 4c976d5af5
commit b57be52818
  1. 6
      slither/core/slither_core.py
  2. 3
      slither/slither.py
  3. 2
      slither/tools/doctor/__main__.py
  4. 8
      slither/tools/erc_conformance/__main__.py
  5. 8
      slither/tools/erc_conformance/erc/erc1155.py
  6. 8
      slither/tools/erc_conformance/erc/erc20.py
  7. 45
      slither/tools/kspec_coverage/analysis.py
  8. 2
      slither/tools/mutator/__main__.py
  9. 8
      slither/tools/mutator/mutators/MIA.py
  10. 11
      slither/tools/mutator/mutators/MVIE.py
  11. 10
      slither/tools/mutator/mutators/MVIV.py
  12. 2
      slither/tools/mutator/utils/command_line.py
  13. 2
      slither/tools/similarity/__main__.py
  14. 11
      slither/tools/similarity/encode.py
  15. 3
      slither/tools/similarity/test.py

@ -443,7 +443,7 @@ class SlitherCore(Context):
return True
def load_previous_results(self):
def load_previous_results(self) -> None:
filename = self._previous_results_filename
try:
if os.path.isfile(filename):
@ -456,7 +456,7 @@ class SlitherCore(Context):
except json.decoder.JSONDecodeError:
logger.error(red(f"Impossible to decode {filename}. Consider removing the file"))
def write_results_to_hide(self):
def write_results_to_hide(self) -> None:
if not self._results_to_hide:
return
filename = self._previous_results_filename
@ -464,7 +464,7 @@ class SlitherCore(Context):
results = self._results_to_hide + self._previous_results
json.dump(results, f)
def save_results_to_hide(self, results: List[Dict]):
def save_results_to_hide(self, results: List[Dict]) -> None:
self._results_to_hide += results
def add_path_to_filter(self, path: str):

@ -11,6 +11,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi
from slither.exceptions import SlitherError
from slither.printers.abstract_printer import AbstractPrinter
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
from slither.utils.output import Output
logger = logging.getLogger("Slither")
logging.basicConfig()
@ -206,7 +207,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
self.write_results_to_hide()
return results
def run_printers(self):
def run_printers(self) -> List[Output]:
"""
:return: List of registered printers outputs.
"""

@ -26,7 +26,7 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
def main():
def main() -> None:
# log on stdout to keep output in order
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)

@ -1,10 +1,11 @@
import argparse
import logging
from collections import defaultdict
from typing import Any, Dict, List
from typing import Any, Dict, List, Callable, Optional
from crytic_compile import cryticparser
from slither import Slither
from slither.core.declarations import Contract
from slither.utils.erc import ERCS
from slither.utils.output import output_to_json
from .erc.ercs import generic_erc_checks
@ -24,7 +25,10 @@ logger.addHandler(ch)
logger.handlers[0].setFormatter(formatter)
logger.propagate = False
ADDITIONAL_CHECKS = {"ERC20": check_erc20, "ERC1155": check_erc1155}
ADDITIONAL_CHECKS: Dict[str, Callable[[Contract, Dict[str, List]], Dict[str, List]]] = {
"ERC20": check_erc20,
"ERC1155": check_erc1155,
}
def parse_args() -> argparse.Namespace:

@ -1,12 +1,14 @@
import logging
from typing import Dict, List, Optional
from slither.core.declarations import Contract
from slither.slithir.operations import EventCall
from slither.utils import output
logger = logging.getLogger("Slither-conformance")
def events_safeBatchTransferFrom(contract, ret):
def events_safeBatchTransferFrom(contract: Contract, ret: Dict[str, List]) -> None:
function = contract.get_function_from_signature(
"safeBatchTransferFrom(address,address,uint256[],uint256[],bytes)"
)
@ -44,7 +46,9 @@ def events_safeBatchTransferFrom(contract, ret):
)
def check_erc1155(contract, ret, explored=None):
def check_erc1155(
contract: Contract, ret: Dict[str, List], explored: Optional[bool] = None
) -> Dict[str, List]:
if explored is None:
explored = set()

@ -1,11 +1,13 @@
import logging
from typing import Dict, List, Optional
from slither.core.declarations import Contract
from slither.utils import output
logger = logging.getLogger("Slither-conformance")
def approval_race_condition(contract, ret):
def approval_race_condition(contract: Contract, ret: Dict[str, List]) -> None:
increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)")
if not increaseAllowance:
@ -27,7 +29,9 @@ def approval_race_condition(contract, ret):
)
def check_erc20(contract, ret, explored=None):
def check_erc20(
contract: Contract, ret: Dict[str, List], explored: Optional[bool] = None
) -> Dict[str, List]:
if explored is None:
explored = set()

@ -1,8 +1,12 @@
import re
import logging
from typing import Set, Tuple
from argparse import Namespace
from typing import Set, Tuple, List, Dict, Union, Optional, Callable
from slither.core.declarations import Function
from slither import Slither
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import Function, FunctionContract
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.variable import Variable
from slither.utils.colors import yellow, green, red
from slither.utils import output
@ -54,13 +58,15 @@ def _get_all_covered_kspec_functions(target: str) -> Set[Tuple[str, str]]:
return covered_functions
def _get_slither_functions(slither):
def _get_slither_functions(
slither: SlitherCompilationUnit,
) -> Dict[Tuple[str, str], Union[FunctionContract, StateVariable]]:
# Use contract == contract_declarer to avoid dupplicate
all_functions_declared = [
all_functions_declared: List[Union[FunctionContract, StateVariable]] = [
f
for f in slither.functions
if (
f.contract == f.contract_declarer
(isinstance(f, FunctionContract) and f.contract == f.contract_declarer)
and f.is_implemented
and not f.is_constructor
and not f.is_constructor_variables
@ -79,7 +85,12 @@ def _get_slither_functions(slither):
return slither_functions
def _generate_output(kspec, message, color, generate_json):
def _generate_output(
kspec: List[Union[FunctionContract, StateVariable]],
message: str,
color: Callable[[str], str],
generate_json: bool,
) -> Optional[Dict]:
info = ""
for function in kspec:
info += f"{message} {function.contract.name}.{function.full_name}\n"
@ -94,7 +105,9 @@ def _generate_output(kspec, message, color, generate_json):
return None
def _generate_output_unresolved(kspec, message, color, generate_json):
def _generate_output_unresolved(
kspec: Set[Tuple[str, str]], message: str, color: Callable[[str], str], generate_json: bool
) -> Optional[Dict]:
info = ""
for contract, function in kspec:
info += f"{message} {contract}.{function}\n"
@ -107,17 +120,19 @@ def _generate_output_unresolved(kspec, message, color, generate_json):
return None
def _run_coverage_analysis(args, slither, kspec_functions):
def _run_coverage_analysis(
args, slither: SlitherCompilationUnit, kspec_functions: Set[Tuple[str, str]]
) -> None:
# Collect all slither functions
slither_functions = _get_slither_functions(slither)
# Determine which klab specs were not resolved.
slither_functions_set = set(slither_functions)
kspec_functions_resolved = kspec_functions & slither_functions_set
kspec_functions_unresolved = kspec_functions - kspec_functions_resolved
kspec_functions_unresolved: Set[Tuple[str, str]] = kspec_functions - kspec_functions_resolved
kspec_missing = []
kspec_present = []
kspec_missing: List[Union[FunctionContract, StateVariable]] = []
kspec_present: List[Union[FunctionContract, StateVariable]] = []
for slither_func_desc in sorted(slither_functions_set):
slither_func = slither_functions[slither_func_desc]
@ -130,13 +145,13 @@ def _run_coverage_analysis(args, slither, kspec_functions):
logger.info("## Check for functions coverage")
json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json)
json_kspec_missing_functions = _generate_output(
[f for f in kspec_missing if isinstance(f, Function)],
[f for f in kspec_missing if isinstance(f, FunctionContract)],
"[ ] (Missing function)",
red,
args.json,
)
json_kspec_missing_variables = _generate_output(
[f for f in kspec_missing if isinstance(f, Variable)],
[f for f in kspec_missing if isinstance(f, StateVariable)],
"[ ] (Missing variable)",
yellow,
args.json,
@ -159,11 +174,11 @@ def _run_coverage_analysis(args, slither, kspec_functions):
)
def run_analysis(args, slither, kspec_arg):
def run_analysis(args: Namespace, slither: SlitherCompilationUnit, kspec_arg: str) -> None:
# Get all of our kspec'd functions (tuple(contract_name, function_name)).
if "," in kspec_arg:
kspecs = kspec_arg.split(",")
kspec_functions = set()
kspec_functions: Set[Tuple[str, str]] = set()
for kspec in kspecs:
kspec_functions |= _get_all_covered_kspec_functions(kspec)
else:

@ -72,7 +72,7 @@ class ListMutators(argparse.Action): # pylint: disable=too-few-public-methods
###################################################################################
def main():
def main() -> None:
args = parse_args()

@ -1,3 +1,5 @@
from typing import Dict
from slither.core.cfg.node import NodeType
from slither.formatters.utils.patches import create_patch
from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass
@ -9,13 +11,13 @@ class MIA(AbstractMutator): # pylint: disable=too-few-public-methods
FAULTCLASS = FaultClass.Checking
FAULTNATURE = FaultNature.Missing
def _mutate(self):
def _mutate(self) -> Dict:
result = {}
result: Dict = {}
for contract in self.slither.contracts:
for function in contract.functions_declared + contract.modifiers_declared:
for function in contract.functions_declared + list(contract.modifiers_declared):
for node in function.nodes:
if node.type == NodeType.IF:

@ -1,4 +1,7 @@
from typing import Dict
from slither.core.expressions import Literal
from slither.core.variables.variable import Variable
from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass
from slither.tools.mutator.utils.generic_patching import remove_assignement
@ -9,10 +12,10 @@ class MVIE(AbstractMutator): # pylint: disable=too-few-public-methods
FAULTCLASS = FaultClass.Assignement
FAULTNATURE = FaultNature.Missing
def _mutate(self):
result = {}
def _mutate(self) -> Dict:
result: Dict = {}
variable: Variable
for contract in self.slither.contracts:
# Create fault for state variables declaration
@ -25,7 +28,7 @@ class MVIE(AbstractMutator): # pylint: disable=too-few-public-methods
if not isinstance(variable.expression, Literal):
remove_assignement(variable, contract, result)
for function in contract.functions_declared + contract.modifiers_declared:
for function in contract.functions_declared + list(contract.modifiers_declared):
for variable in function.local_variables:
if variable.initialized and not isinstance(variable.expression, Literal):
remove_assignement(variable, contract, result)

@ -1,4 +1,7 @@
from typing import Dict
from slither.core.expressions import Literal
from slither.core.variables.variable import Variable
from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass
from slither.tools.mutator.utils.generic_patching import remove_assignement
@ -9,9 +12,10 @@ class MVIV(AbstractMutator): # pylint: disable=too-few-public-methods
FAULTCLASS = FaultClass.Assignement
FAULTNATURE = FaultNature.Missing
def _mutate(self):
def _mutate(self) -> Dict:
result = {}
result: Dict = {}
variable: Variable
for contract in self.slither.contracts:
@ -25,7 +29,7 @@ class MVIV(AbstractMutator): # pylint: disable=too-few-public-methods
if isinstance(variable.expression, Literal):
remove_assignement(variable, contract, result)
for function in contract.functions_declared + contract.modifiers_declared:
for function in contract.functions_declared + list(contract.modifiers_declared):
for variable in function.local_variables:
if variable.initialized and isinstance(variable.expression, Literal):
remove_assignement(variable, contract, result)

@ -1,7 +1,7 @@
from slither.utils.myprettytable import MyPrettyTable
def output_mutators(mutators_classes):
def output_mutators(mutators_classes: List[Type[AbstractMutator]]) -> None:
mutators_list = []
for detector in mutators_classes:
argument = detector.NAME

@ -7,7 +7,7 @@ import sys
from crytic_compile import cryticparser
from slither.tools.similarity.info import info
from slither.tools.similarity.test import test
from slither.tools.similarity import test
from slither.tools.similarity.train import train
from slither.tools.similarity.plot import plot

@ -1,5 +1,6 @@
import logging
import os
from typing import Optional, Tuple, List
from slither import Slither
from slither.core.declarations import (
@ -60,7 +61,7 @@ slither_logger = logging.getLogger("Slither")
slither_logger.setLevel(logging.CRITICAL)
def parse_target(target):
def parse_target(target: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
if target is None:
return None, None
@ -68,9 +69,9 @@ def parse_target(target):
if len(parts) == 1:
return None, parts[0]
if len(parts) == 2:
return parts
return parts[0], parts[1]
simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'")
return None
return None, None
def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs):
@ -88,7 +89,9 @@ def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs):
return r
def load_contracts(dirname, ext=None, nsamples=None):
def load_contracts(
dirname: str, ext: Optional[str] = None, nsamples: Optional[int] = None
) -> List[str]:
r = []
walk = list(os.walk(dirname))
for x, y, files in walk:

@ -2,6 +2,7 @@ import logging
import operator
import sys
import traceback
from argparse import Namespace
from slither.tools.similarity.encode import encode_contract, load_and_encode, parse_target
from slither.tools.similarity.model import load_model
@ -10,7 +11,7 @@ from slither.tools.similarity.similarity import similarity
logger = logging.getLogger("Slither-simil")
def test(args):
def est(args: Namespace) -> None:
try:
model = args.model

Loading…
Cancel
Save