Merge pull request #1388 from crytic/dev-types

Improve types
pull/1395/head
Feist Josselin 2 years ago committed by GitHub
commit 9764c54eff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      plugin_example/slither_my_plugin/__init__.py
  2. 96
      slither/__main__.py
  3. 14
      slither/core/declarations/structure.py
  4. 4
      slither/core/expressions/tuple_expression.py
  5. 8
      slither/core/source_mapping/source_mapping.py
  6. 12
      slither/slither.py
  7. 8
      slither/slithir/operations/call.py
  8. 8
      slither/slithir/tmp_operations/tmp_call.py
  9. 12
      slither/slithir/tmp_operations/tmp_new_elementary_type.py
  10. 4
      slither/tools/demo/__main__.py
  11. 9
      slither/tools/erc_conformance/__main__.py
  12. 15
      slither/tools/erc_conformance/erc/ercs.py
  13. 4
      slither/tools/flattening/__main__.py
  14. 4
      slither/tools/kspec_coverage/__main__.py
  15. 4
      slither/tools/kspec_coverage/kspec_coverage.py
  16. 13
      slither/tools/mutator/__main__.py
  17. 7
      slither/tools/possible_paths/__main__.py
  18. 45
      slither/tools/possible_paths/possible_paths.py
  19. 19
      slither/tools/properties/__main__.py
  20. 8
      slither/tools/properties/platforms/truffle.py
  21. 2
      slither/tools/properties/utils.py
  22. 4
      slither/tools/similarity/__main__.py
  23. 2
      slither/tools/similarity/encode.py
  24. 3
      slither/tools/similarity/info.py
  25. 3
      slither/tools/similarity/plot.py
  26. 3
      slither/tools/similarity/train.py
  27. 4
      slither/tools/slither_format/__main__.py
  28. 18
      slither/tools/slither_format/slither_format.py
  29. 77
      slither/tools/upgradeability/__main__.py
  30. 20
      slither/tools/upgradeability/utils/command_line.py
  31. 57
      slither/utils/command_line.py
  32. 4
      slither/utils/output.py
  33. 8
      slither/utils/output_capture.py
  34. 2
      tests/check-upgradeability/test_10.txt
  35. 2
      tests/check-upgradeability/test_2.txt
  36. 2
      tests/check-upgradeability/test_3.txt
  37. 2
      tests/check-upgradeability/test_4.txt

@ -1,8 +1,13 @@
from typing import Tuple, List, Type
from slither_my_plugin.detectors.example import Example from slither_my_plugin.detectors.example import Example
from slither.detectors.abstract_detector import AbstractDetector
from slither.printers.abstract_printer import AbstractPrinter
def make_plugin(): def make_plugin() -> Tuple[List[Type[AbstractDetector]], List[Type[AbstractPrinter]]]:
plugin_detectors = [Example] plugin_detectors = [Example]
plugin_printers = [] plugin_printers: List[Type[AbstractPrinter]] = []
return plugin_detectors, plugin_printers return plugin_detectors, plugin_printers

@ -10,11 +10,11 @@ import os
import pstats import pstats
import sys import sys
import traceback import traceback
from typing import Tuple, Optional, List, Dict from typing import Tuple, Optional, List, Dict, Type, Union, Any, Sequence
from pkg_resources import iter_entry_points, require from pkg_resources import iter_entry_points, require
from crytic_compile import cryticparser from crytic_compile import cryticparser, CryticCompile
from crytic_compile.platform.standard import generate_standard_export from crytic_compile.platform.standard import generate_standard_export
from crytic_compile.platform.etherscan import SUPPORTED_NETWORK from crytic_compile.platform.etherscan import SUPPORTED_NETWORK
from crytic_compile import compile_all, is_supported from crytic_compile import compile_all, is_supported
@ -55,10 +55,10 @@ logger = logging.getLogger("Slither")
def process_single( def process_single(
target: str, target: Union[str, CryticCompile],
args: argparse.Namespace, args: argparse.Namespace,
detector_classes: List[AbstractDetector], detector_classes: List[Type[AbstractDetector]],
printer_classes: List[AbstractPrinter], printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]: ) -> Tuple[Slither, List[Dict], List[Dict], int]:
""" """
The core high-level code for running Slither static analysis. The core high-level code for running Slither static analysis.
@ -80,8 +80,8 @@ def process_single(
def process_all( def process_all(
target: str, target: str,
args: argparse.Namespace, args: argparse.Namespace,
detector_classes: List[AbstractDetector], detector_classes: List[Type[AbstractDetector]],
printer_classes: List[AbstractPrinter], printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[List[Slither], List[Dict], List[Dict], int]: ) -> Tuple[List[Slither], List[Dict], List[Dict], int]:
compilations = compile_all(target, **vars(args)) compilations = compile_all(target, **vars(args))
slither_instances = [] slither_instances = []
@ -109,8 +109,8 @@ def process_all(
def _process( def _process(
slither: Slither, slither: Slither,
detector_classes: List[AbstractDetector], detector_classes: List[Type[AbstractDetector]],
printer_classes: List[AbstractPrinter], printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]: ) -> Tuple[Slither, List[Dict], List[Dict], int]:
for detector_cls in detector_classes: for detector_cls in detector_classes:
slither.register_detector(detector_cls) slither.register_detector(detector_cls)
@ -137,13 +137,14 @@ def _process(
return slither, results_detectors, results_printers, analyzed_contracts_count return slither, results_detectors, results_printers, analyzed_contracts_count
# TODO: delete me?
def process_from_asts( def process_from_asts(
filenames: List[str], filenames: List[str],
args: argparse.Namespace, args: argparse.Namespace,
detector_classes: List[AbstractDetector], detector_classes: List[Type[AbstractDetector]],
printer_classes: List[AbstractPrinter], printer_classes: List[Type[AbstractPrinter]],
): ) -> Tuple[Slither, List[Dict], List[Dict], int]:
all_contracts = [] all_contracts: List[str] = []
for filename in filenames: for filename in filenames:
with open(filename, encoding="utf8") as file_open: with open(filename, encoding="utf8") as file_open:
@ -162,13 +163,15 @@ def process_from_asts(
################################################################################### ###################################################################################
def get_detectors_and_printers(): def get_detectors_and_printers() -> Tuple[
List[Type[AbstractDetector]], List[Type[AbstractPrinter]]
]:
detectors = [getattr(all_detectors, name) for name in dir(all_detectors)] detectors_ = [getattr(all_detectors, name) for name in dir(all_detectors)]
detectors = [d for d in detectors if inspect.isclass(d) and issubclass(d, AbstractDetector)] detectors = [d for d in detectors_ if inspect.isclass(d) and issubclass(d, AbstractDetector)]
printers = [getattr(all_printers, name) for name in dir(all_printers)] printers_ = [getattr(all_printers, name) for name in dir(all_printers)]
printers = [p for p in printers if inspect.isclass(p) and issubclass(p, AbstractPrinter)] printers = [p for p in printers_ if inspect.isclass(p) and issubclass(p, AbstractPrinter)]
# Handle plugins! # Handle plugins!
for entry_point in iter_entry_points(group="slither_analyzer.plugin", name=None): for entry_point in iter_entry_points(group="slither_analyzer.plugin", name=None):
@ -194,8 +197,8 @@ def get_detectors_and_printers():
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
def choose_detectors( def choose_detectors(
args: argparse.Namespace, all_detector_classes: List[AbstractDetector] args: argparse.Namespace, all_detector_classes: List[Type[AbstractDetector]]
) -> List[AbstractDetector]: ) -> List[Type[AbstractDetector]]:
# If detectors are specified, run only these ones # If detectors are specified, run only these ones
detectors_to_run = [] detectors_to_run = []
@ -245,8 +248,8 @@ def choose_detectors(
def choose_printers( def choose_printers(
args: argparse.Namespace, all_printer_classes: List[AbstractPrinter] args: argparse.Namespace, all_printer_classes: List[Type[AbstractPrinter]]
) -> List[AbstractPrinter]: ) -> List[Type[AbstractPrinter]]:
printers_to_run = [] printers_to_run = []
# disable default printer # disable default printer
@ -273,13 +276,16 @@ def choose_printers(
################################################################################### ###################################################################################
def parse_filter_paths(args): def parse_filter_paths(args: argparse.Namespace) -> List[str]:
if args.filter_paths: if args.filter_paths:
return args.filter_paths.split(",") return args.filter_paths.split(",")
return [] return []
def parse_args(detector_classes, printer_classes): # pylint: disable=too-many-statements # pylint: disable=too-many-statements
def parse_args(
detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]]
) -> argparse.Namespace:
usage = "slither target [flag]\n" usage = "slither target [flag]\n"
usage += "\ntarget can be:\n" usage += "\ntarget can be:\n"
@ -622,7 +628,9 @@ class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods
class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
detectors, _ = get_detectors_and_printers() detectors, _ = get_detectors_and_printers()
detector_types_json = output_detectors_json(detectors) detector_types_json = output_detectors_json(detectors)
print(json.dumps(detector_types_json)) print(json.dumps(detector_types_json))
@ -630,22 +638,38 @@ class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-meth
class ListPrinters(argparse.Action): # pylint: disable=too-few-public-methods class ListPrinters(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
_, printers = get_detectors_and_printers() _, printers = get_detectors_and_printers()
output_printers(printers) output_printers(printers)
parser.exit() parser.exit()
class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, args, values, option_string=None): def __call__(
self,
parser: Any,
args: Any,
values: Optional[Union[str, Sequence[Any]]],
option_string: Any = None,
) -> None:
detectors, printers = get_detectors_and_printers() detectors, printers = get_detectors_and_printers()
assert isinstance(values, str)
output_to_markdown(detectors, printers, values) output_to_markdown(detectors, printers, values)
parser.exit() parser.exit()
class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, args, values, option_string=None): def __call__(
self,
parser: Any,
args: Any,
values: Optional[Union[str, Sequence[Any]]],
option_string: Any = None,
) -> None:
detectors, _ = get_detectors_and_printers() detectors, _ = get_detectors_and_printers()
assert isinstance(values, str)
output_wiki(detectors, values) output_wiki(detectors, values)
parser.exit() parser.exit()
@ -678,7 +702,7 @@ class FormatterCryticCompile(logging.Formatter):
################################################################################### ###################################################################################
def main(): def main() -> None:
# Codebase with complex domninators can lead to a lot of SSA recursive call # Codebase with complex domninators can lead to a lot of SSA recursive call
sys.setrecursionlimit(1500) sys.setrecursionlimit(1500)
@ -689,8 +713,9 @@ def main():
# pylint: disable=too-many-statements,too-many-branches,too-many-locals # pylint: disable=too-many-statements,too-many-branches,too-many-locals
def main_impl( def main_impl(
all_detector_classes: List[AbstractDetector], all_printer_classes: List[AbstractPrinter] all_detector_classes: List[Type[AbstractDetector]],
): all_printer_classes: List[Type[AbstractPrinter]],
) -> None:
""" """
:param all_detector_classes: A list of all detectors that can be included/excluded. :param all_detector_classes: A list of all detectors that can be included/excluded.
:param all_printer_classes: A list of all printers that can be included. :param all_printer_classes: A list of all printers that can be included.
@ -756,8 +781,8 @@ def main_impl(
crytic_compile_error.propagate = False crytic_compile_error.propagate = False
crytic_compile_error.setLevel(logging.INFO) crytic_compile_error.setLevel(logging.INFO)
results_detectors = [] results_detectors: List[Dict] = []
results_printers = [] results_printers: List[Dict] = []
try: try:
filename = args.filename filename = args.filename
@ -806,6 +831,7 @@ def main_impl(
if "compilations" in args.json_types: if "compilations" in args.json_types:
compilation_results = [] compilation_results = []
for slither_instance in slither_instances: for slither_instance in slither_instances:
assert slither_instance.crytic_compile
compilation_results.append( compilation_results.append(
generate_standard_export(slither_instance.crytic_compile) generate_standard_export(slither_instance.crytic_compile)
) )
@ -856,7 +882,7 @@ def main_impl(
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
output_error = traceback.format_exc() output_error = traceback.format_exc()
logging.error(traceback.print_exc()) traceback.print_exc()
logging.error(f"Error in {args.filename}") # pylint: disable=logging-fstring-interpolation logging.error(f"Error in {args.filename}") # pylint: disable=logging-fstring-interpolation
logging.error(output_error) logging.error(output_error)
@ -879,7 +905,7 @@ def main_impl(
if outputting_zip: if outputting_zip:
output_to_zip(args.zip, output_error, json_results, args.zip_type) output_to_zip(args.zip, output_error, json_results, args.zip_type)
if args.perf: if args.perf and cp:
cp.disable() cp.disable()
stats = pstats.Stats(cp).sort_stats("cumtime") stats = pstats.Stats(cp).sort_stats("cumtime")
stats.print_stats() stats.print_stats()

@ -8,10 +8,10 @@ if TYPE_CHECKING:
class Structure(SourceMapping): class Structure(SourceMapping):
def __init__(self, compilation_unit: "SlitherCompilationUnit"): def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None:
super().__init__() super().__init__()
self._name: Optional[str] = None self._name: Optional[str] = None
self._canonical_name = None self._canonical_name: Optional[str] = None
self._elems: Dict[str, "StructureVariable"] = {} self._elems: Dict[str, "StructureVariable"] = {}
# Name of the elements in the order of declaration # Name of the elements in the order of declaration
self._elems_ordered: List[str] = [] self._elems_ordered: List[str] = []
@ -19,25 +19,27 @@ class Structure(SourceMapping):
@property @property
def canonical_name(self) -> str: def canonical_name(self) -> str:
assert self._canonical_name
return self._canonical_name return self._canonical_name
@canonical_name.setter @canonical_name.setter
def canonical_name(self, name: str): def canonical_name(self, name: str) -> None:
self._canonical_name = name self._canonical_name = name
@property @property
def name(self) -> str: def name(self) -> str:
assert self._name
return self._name return self._name
@name.setter @name.setter
def name(self, new_name: str): def name(self, new_name: str) -> None:
self._name = new_name self._name = new_name
@property @property
def elems(self) -> Dict[str, "StructureVariable"]: def elems(self) -> Dict[str, "StructureVariable"]:
return self._elems return self._elems
def add_elem_in_order(self, s: str): def add_elem_in_order(self, s: str) -> None:
self._elems_ordered.append(s) self._elems_ordered.append(s)
@property @property
@ -47,5 +49,5 @@ class Structure(SourceMapping):
ret.append(self._elems[e]) ret.append(self._elems[e])
return ret return ret
def __str__(self): def __str__(self) -> str:
return self.name return self.name

@ -4,7 +4,7 @@ from slither.core.expressions.expression import Expression
class TupleExpression(Expression): class TupleExpression(Expression):
def __init__(self, expressions): def __init__(self, expressions: List[Expression]) -> None:
assert all(isinstance(x, Expression) for x in expressions if x) assert all(isinstance(x, Expression) for x in expressions if x)
super().__init__() super().__init__()
self._expressions = expressions self._expressions = expressions
@ -13,6 +13,6 @@ class TupleExpression(Expression):
def expressions(self) -> List[Expression]: def expressions(self) -> List[Expression]:
return self._expressions return self._expressions
def __str__(self): def __str__(self) -> str:
expressions_str = [str(e) for e in self.expressions] expressions_str = [str(e) for e in self.expressions]
return "(" + ",".join(expressions_str) + ")" return "(" + ",".join(expressions_str) + ")"

@ -162,13 +162,15 @@ def _convert_source_mapping(
class SourceMapping(Context, metaclass=ABCMeta): class SourceMapping(Context, metaclass=ABCMeta):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
# self._source_mapping: Optional[Dict] = None # self._source_mapping: Optional[Dict] = None
self.source_mapping: Source = Source() self.source_mapping: Source = Source()
self.references: List[Source] = [] self.references: List[Source] = []
def set_offset(self, offset: Union["Source", str], compilation_unit: "SlitherCompilationUnit"): def set_offset(
self, offset: Union["Source", str], compilation_unit: "SlitherCompilationUnit"
) -> None:
if isinstance(offset, Source): if isinstance(offset, Source):
self.source_mapping.start = offset.start self.source_mapping.start = offset.start
self.source_mapping.length = offset.length self.source_mapping.length = offset.length
@ -184,6 +186,6 @@ class SourceMapping(Context, metaclass=ABCMeta):
def add_reference_from_raw_source( def add_reference_from_raw_source(
self, offset: str, compilation_unit: "SlitherCompilationUnit" self, offset: str, compilation_unit: "SlitherCompilationUnit"
): ) -> None:
s = _convert_source_mapping(offset, compilation_unit) s = _convert_source_mapping(offset, compilation_unit)
self.references.append(s) self.references.append(s)

@ -1,5 +1,5 @@
import logging import logging
from typing import Union, List, ValuesView from typing import Union, List, ValuesView, Type, Dict
from crytic_compile import CryticCompile, InvalidCompilation from crytic_compile import CryticCompile, InvalidCompilation
@ -19,7 +19,9 @@ logger_detector = logging.getLogger("Detectors")
logger_printer = logging.getLogger("Printers") logger_printer = logging.getLogger("Printers")
def _check_common_things(thing_name, cls, base_cls, instances_list): def _check_common_things(
thing_name: str, cls: Type, base_cls: Type, instances_list: List[Type[AbstractDetector]]
) -> None:
if not issubclass(cls, base_cls) or cls is base_cls: if not issubclass(cls, base_cls) or cls is base_cls:
raise Exception( raise Exception(
@ -178,7 +180,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
def detectors_optimization(self): def detectors_optimization(self):
return [d for d in self.detectors if d.IMPACT == DetectorClassification.OPTIMIZATION] return [d for d in self.detectors if d.IMPACT == DetectorClassification.OPTIMIZATION]
def register_detector(self, detector_class): def register_detector(self, detector_class: Type[AbstractDetector]) -> None:
""" """
:param detector_class: Class inheriting from `AbstractDetector`. :param detector_class: Class inheriting from `AbstractDetector`.
""" """
@ -188,7 +190,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
instance = detector_class(compilation_unit, self, logger_detector) instance = detector_class(compilation_unit, self, logger_detector)
self._detectors.append(instance) self._detectors.append(instance)
def register_printer(self, printer_class): def register_printer(self, printer_class: Type[AbstractPrinter]) -> None:
""" """
:param printer_class: Class inheriting from `AbstractPrinter`. :param printer_class: Class inheriting from `AbstractPrinter`.
""" """
@ -197,7 +199,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
instance = printer_class(self, logger_printer) instance = printer_class(self, logger_printer)
self._printers.append(instance) self._printers.append(instance)
def run_detectors(self): def run_detectors(self) -> List[Dict]:
""" """
:return: List of registered detectors results. :return: List of registered detectors results.
""" """

@ -1,8 +1,10 @@
from typing import Optional, List
from slither.slithir.operations.operation import Operation from slither.slithir.operations.operation import Operation
class Call(Operation): class Call(Operation):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._arguments = [] self._arguments = []
@ -14,14 +16,14 @@ class Call(Operation):
def arguments(self, v): def arguments(self, v):
self._arguments = v self._arguments = v
def can_reenter(self, _callstack=None): # pylint: disable=no-self-use def can_reenter(self, _callstack: Optional[List] = None) -> bool: # pylint: disable=no-self-use
""" """
Must be called after slithIR analysis pass Must be called after slithIR analysis pass
:return: bool :return: bool
""" """
return False return False
def can_send_eth(self): # pylint: disable=no-self-use def can_send_eth(self) -> bool: # pylint: disable=no-self-use
""" """
Must be called after slithIR analysis pass Must be called after slithIR analysis pass
:return: bool :return: bool

@ -63,14 +63,14 @@ class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attribu
def call_id(self): def call_id(self):
return self._callid return self._callid
@property
def read(self):
return [self.called]
@call_id.setter @call_id.setter
def call_id(self, c): def call_id(self, c):
self._callid = c self._callid = c
@property
def read(self):
return [self.called]
@property @property
def called(self): def called(self):
return self._called return self._called

@ -1,21 +1,23 @@
from typing import List
from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.elementary_type import ElementaryType
class TmpNewElementaryType(OperationWithLValue): class TmpNewElementaryType(OperationWithLValue):
def __init__(self, new_type, lvalue): def __init__(self, new_type: ElementaryType, lvalue):
assert isinstance(new_type, ElementaryType) assert isinstance(new_type, ElementaryType)
super().__init__() super().__init__()
self._type = new_type self._type: ElementaryType = new_type
self._lvalue = lvalue self._lvalue = lvalue
@property @property
def read(self): def read(self) -> List:
return [] return []
@property @property
def type(self): def type(self) -> ElementaryType:
return self._type return self._type
def __str__(self): def __str__(self) -> str:
return f"{self.lvalue} = new {self._type}" return f"{self.lvalue} = new {self._type}"

@ -9,7 +9,7 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither-demo") logger = logging.getLogger("Slither-demo")
def parse_args(): def parse_args() -> argparse.Namespace:
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
@ -26,7 +26,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
args = parse_args() args = parse_args()
# Perform slither analysis on the given filename # Perform slither analysis on the given filename

@ -1,6 +1,7 @@
import argparse import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List
from crytic_compile import cryticparser from crytic_compile import cryticparser
from slither import Slither from slither import Slither
@ -26,7 +27,7 @@ logger.propagate = False
ADDITIONAL_CHECKS = {"ERC20": check_erc20, "ERC1155": check_erc1155} ADDITIONAL_CHECKS = {"ERC20": check_erc20, "ERC1155": check_erc1155}
def parse_args(): def parse_args() -> argparse.Namespace:
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
@ -63,20 +64,20 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def _log_error(err, args): def _log_error(err: Any, args: argparse.Namespace) -> None:
if args.json: if args.json:
output_to_json(args.json, str(err), {"upgradeability-check": []}) output_to_json(args.json, str(err), {"upgradeability-check": []})
logger.error(err) logger.error(err)
def main(): def main() -> None:
args = parse_args() args = parse_args()
# Perform slither analysis on the given filename # Perform slither analysis on the given filename
slither = Slither(args.project, **vars(args)) slither = Slither(args.project, **vars(args))
ret = defaultdict(list) ret: Dict[str, List] = defaultdict(list)
if args.erc.upper() in ERCS: if args.erc.upper() in ERCS:

@ -1,7 +1,10 @@
import logging import logging
from typing import Dict, List, Optional, Set
from slither.core.declarations import Contract
from slither.slithir.operations import EventCall from slither.slithir.operations import EventCall
from slither.utils import output from slither.utils import output
from slither.utils.erc import ERC, ERC_EVENT
from slither.utils.type import ( from slither.utils.type import (
export_nested_types_from_variable, export_nested_types_from_variable,
export_return_type_from_variable, export_return_type_from_variable,
@ -11,7 +14,7 @@ logger = logging.getLogger("Slither-conformance")
# pylint: disable=too-many-locals,too-many-branches,too-many-statements # pylint: disable=too-many-locals,too-many-branches,too-many-statements
def _check_signature(erc_function, contract, ret): def _check_signature(erc_function: ERC, contract: Contract, ret: Dict) -> None:
name = erc_function.name name = erc_function.name
parameters = erc_function.parameters parameters = erc_function.parameters
return_type = erc_function.return_type return_type = erc_function.return_type
@ -146,7 +149,7 @@ def _check_signature(erc_function, contract, ret):
ret["missing_event_emmited"].append(missing_event_emmited.data) ret["missing_event_emmited"].append(missing_event_emmited.data)
def _check_events(erc_event, contract, ret): def _check_events(erc_event: ERC_EVENT, contract: Contract, ret: Dict[str, List]) -> None:
name = erc_event.name name = erc_event.name
parameters = erc_event.parameters parameters = erc_event.parameters
indexes = erc_event.indexes indexes = erc_event.indexes
@ -180,7 +183,13 @@ def _check_events(erc_event, contract, ret):
ret["missing_event_index"].append(missing_event_index.data) ret["missing_event_index"].append(missing_event_index.data)
def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None): def generic_erc_checks(
contract: Contract,
erc_functions: List[ERC],
erc_events: List[ERC_EVENT],
ret: Dict[str, List],
explored: Optional[Set[Contract]] = None,
) -> None:
if explored is None: if explored is None:
explored = set() explored = set()

@ -18,7 +18,7 @@ logger = logging.getLogger("Slither")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
def parse_args(): def parse_args() -> argparse.Namespace:
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
@ -106,7 +106,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
args = parse_args() args = parse_args()
slither = Slither(args.filename, **vars(args)) slither = Slither(args.filename, **vars(args))

@ -16,7 +16,7 @@ logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
def parse_args(): def parse_args() -> argparse.Namespace:
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
@ -56,7 +56,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
# ------------------------------ # ------------------------------
# Usage: slither-kspec-coverage contract kspec # Usage: slither-kspec-coverage contract kspec
# Example: slither-kspec-coverage contract.sol kspec.md # Example: slither-kspec-coverage contract.sol kspec.md

@ -1,8 +1,10 @@
import argparse
from slither.tools.kspec_coverage.analysis import run_analysis from slither.tools.kspec_coverage.analysis import run_analysis
from slither import Slither from slither import Slither
def kspec_coverage(args): def kspec_coverage(args: argparse.Namespace) -> None:
contract = args.contract contract = args.contract
kspec = args.kspec kspec = args.kspec

@ -2,6 +2,7 @@ import argparse
import inspect import inspect
import logging import logging
import sys import sys
from typing import Type, List, Any
from crytic_compile import cryticparser from crytic_compile import cryticparser
@ -22,7 +23,7 @@ logger.setLevel(logging.INFO)
################################################################################### ###################################################################################
def parse_args(): def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Experimental smart contract mutator. Based on https://arxiv.org/abs/2006.11597", description="Experimental smart contract mutator. Based on https://arxiv.org/abs/2006.11597",
usage="slither-mutate target", usage="slither-mutate target",
@ -48,14 +49,16 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def _get_mutators(): def _get_mutators() -> List[Type[AbstractMutator]]:
detectors = [getattr(all_mutators, name) for name in dir(all_mutators)] detectors_ = [getattr(all_mutators, name) for name in dir(all_mutators)]
detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractMutator)] detectors = [c for c in detectors_ if inspect.isclass(c) and issubclass(c, AbstractMutator)]
return detectors return detectors
class ListMutators(argparse.Action): # pylint: disable=too-few-public-methods class ListMutators(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
checks = _get_mutators() checks = _get_mutators()
output_mutators(checks) output_mutators(checks)
parser.exit() parser.exit()

@ -5,6 +5,7 @@ from argparse import ArgumentParser, Namespace
from crytic_compile import cryticparser from crytic_compile import cryticparser
from slither import Slither from slither import Slither
from slither.core.declarations import FunctionContract
from slither.utils.colors import red from slither.utils.colors import red
from slither.tools.possible_paths.possible_paths import ( from slither.tools.possible_paths.possible_paths import (
find_target_paths, find_target_paths,
@ -58,7 +59,11 @@ def main() -> None:
# Print out all target functions. # Print out all target functions.
print("Target functions:") print("Target functions:")
for target in targets: for target in targets:
print(f"- {target.contract_declarer.name}.{target.full_name}") if isinstance(target, FunctionContract):
print(f"- {target.contract_declarer.name}.{target.full_name}")
else:
pass
# TODO implement me
print("\n") print("\n")
# Obtain all paths which reach the target functions. # Obtain all paths which reach the target functions.

@ -1,8 +1,15 @@
from typing import List, Tuple, Union, Optional, Set
from slither import Slither
from slither.core.declarations import Function, FunctionContract
from slither.core.slither_core import SlitherCore
class ResolveFunctionException(Exception): class ResolveFunctionException(Exception):
pass pass
def resolve_function(slither, contract_name, function_name): def resolve_function(slither: SlitherCore, contract_name: str, function_name: str) -> Function:
""" """
Resolves a function instance, given a contract name and function. Resolves a function instance, given a contract name and function.
:param contract_name: The name of the contract the function is declared in. :param contract_name: The name of the contract the function is declared in.
@ -32,7 +39,9 @@ def resolve_function(slither, contract_name, function_name):
return target_function return target_function
def resolve_functions(slither, functions): def resolve_functions(
slither: Slither, functions: List[Union[str, Tuple[str, str]]]
) -> List[Function]:
""" """
Resolves the provided function descriptors. Resolves the provided function descriptors.
:param functions: A list of tuples (contract_name, function_name) or str (of form "ContractName.FunctionName") :param functions: A list of tuples (contract_name, function_name) or str (of form "ContractName.FunctionName")
@ -40,7 +49,7 @@ def resolve_functions(slither, functions):
:return: Returns a list of resolved functions. :return: Returns a list of resolved functions.
""" """
# Create the resolved list. # Create the resolved list.
resolved = [] resolved: List[Function] = []
# Verify that the provided argument is a list. # Verify that the provided argument is a list.
if not isinstance(functions, list): if not isinstance(functions, list):
@ -72,24 +81,31 @@ def resolve_functions(slither, functions):
return resolved return resolved
def all_function_definitions(function): def all_function_definitions(function: Function) -> List[Function]:
""" """
Obtains a list of representing this function and any base definitions Obtains a list of representing this function and any base definitions
:param function: The function to obtain all definitions at and beneath. :param function: The function to obtain all definitions at and beneath.
:return: Returns a list composed of the provided function definition and any base definitions. :return: Returns a list composed of the provided function definition and any base definitions.
""" """
return [function] + [ # TODO implement me
if not isinstance(function, FunctionContract):
return []
ret: List[Function] = [function]
ret += [
f f
for c in function.contract.inheritance for c in function.contract.inheritance
for f in c.functions_and_modifiers_declared for f in c.functions_and_modifiers_declared
if f.full_name == function.full_name if f.full_name == function.full_name
] ]
return ret
def __find_target_paths(slither, target_function, current_path=None): def __find_target_paths(
slither: SlitherCore, target_function: Function, current_path: Optional[List[Function]] = None
) -> Set[Tuple[Function, ...]]:
current_path = current_path if current_path else [] current_path = current_path if current_path else []
# Create our results list # Create our results list
results = set() results: Set[Tuple[Function, ...]] = set()
# Add our current function to the path. # Add our current function to the path.
current_path = [target_function] + current_path current_path = [target_function] + current_path
@ -106,9 +122,12 @@ def __find_target_paths(slither, target_function, current_path=None):
continue continue
# Find all function calls in this function (except for low level) # Find all function calls in this function (except for low level)
called_functions = [f for (_, f) in function.high_level_calls + function.library_calls] called_functions_list = [
called_functions += function.internal_calls f for (_, f) in function.high_level_calls if isinstance(f, Function)
called_functions = set(called_functions) ]
called_functions_list += [f for (_, f) in function.library_calls]
called_functions_list += [f for f in function.internal_calls if isinstance(f, Function)]
called_functions = set(called_functions_list)
# If any of our target functions are reachable from this function, it's a result. # If any of our target functions are reachable from this function, it's a result.
if all_target_functions.intersection(called_functions): if all_target_functions.intersection(called_functions):
@ -123,14 +142,16 @@ def __find_target_paths(slither, target_function, current_path=None):
return results return results
def find_target_paths(slither, target_functions): def find_target_paths(
slither: SlitherCore, target_functions: List[Function]
) -> Set[Tuple[Function, ...]]:
""" """
Obtains all functions which can lead to any of the target functions being called. Obtains all functions which can lead to any of the target functions being called.
:param target_functions: The functions we are interested in reaching. :param target_functions: The functions we are interested in reaching.
:return: Returns a list of all functions which can reach any of the target_functions. :return: Returns a list of all functions which can reach any of the target_functions.
""" """
# Create our results list # Create our results list
results = set() results: Set[Tuple[Function, ...]] = set()
# Loop for each target function # Loop for each target function
for target_function in target_functions: for target_function in target_functions:

@ -1,6 +1,7 @@
import argparse import argparse
import logging import logging
import sys import sys
from typing import Any
from crytic_compile import cryticparser from crytic_compile import cryticparser
@ -26,7 +27,7 @@ logger.handlers[0].setFormatter(formatter)
logger.propagate = False logger.propagate = False
def _all_scenarios(): def _all_scenarios() -> str:
txt = "\n" txt = "\n"
txt += "#################### ERC20 ####################\n" txt += "#################### ERC20 ####################\n"
for k, value in ERC20_PROPERTIES.items(): for k, value in ERC20_PROPERTIES.items():
@ -35,29 +36,33 @@ def _all_scenarios():
return txt return txt
def _all_properties(): def _all_properties() -> MyPrettyTable:
table = MyPrettyTable(["Num", "Description", "Scenario"]) table = MyPrettyTable(["Num", "Description", "Scenario"])
idx = 0 idx = 0
for scenario, value in ERC20_PROPERTIES.items(): for scenario, value in ERC20_PROPERTIES.items():
for prop in value.properties: for prop in value.properties:
table.add_row([idx, prop.description, scenario]) table.add_row([str(idx), prop.description, scenario])
idx = idx + 1 idx = idx + 1
return table return table
class ListScenarios(argparse.Action): # pylint: disable=too-few-public-methods class ListScenarios(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
logger.info(_all_scenarios()) logger.info(_all_scenarios())
parser.exit() parser.exit()
class ListProperties(argparse.Action): # pylint: disable=too-few-public-methods class ListProperties(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
logger.info(_all_properties()) logger.info(_all_properties())
parser.exit() parser.exit()
def parse_args(): def parse_args() -> argparse.Namespace:
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
@ -120,7 +125,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
args = parse_args() args = parse_args()
# Perform slither analysis on the given filename # Perform slither analysis on the given filename

@ -15,7 +15,7 @@ PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_")
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
def _extract_caller(p: PropertyCaller): def _extract_caller(p: PropertyCaller) -> List[str]:
if p == PropertyCaller.OWNER: if p == PropertyCaller.OWNER:
return ["owner"] return ["owner"]
if p == PropertyCaller.SENDER: if p == PropertyCaller.SENDER:
@ -28,7 +28,7 @@ def _extract_caller(p: PropertyCaller):
return ["user"] return ["user"]
def _helpers(): def _helpers() -> str:
""" """
Generate two functions: Generate two functions:
- catchRevertThrowReturnFalse: check if the call revert/throw or return false - catchRevertThrowReturnFalse: check if the call revert/throw or return false
@ -75,7 +75,7 @@ def generate_unit_test( # pylint: disable=too-many-arguments,too-many-branches
output_dir: Path, output_dir: Path,
addresses: Addresses, addresses: Addresses,
assert_message: str = "", assert_message: str = "",
): ) -> Path:
""" """
Generate unit tests files Generate unit tests files
:param test_contract: :param test_contract:
@ -134,7 +134,7 @@ def generate_unit_test( # pylint: disable=too-many-arguments,too-many-branches
return output_dir return output_dir
def generate_migration(test_contract: str, output_dir: Path, owner_address: str): def generate_migration(test_contract: str, output_dir: Path, owner_address: str) -> None:
""" """
Generate migration file Generate migration file
:param test_contract: :param test_contract:

@ -12,7 +12,7 @@ def write_file(
content: str, content: str,
allow_overwrite: bool = True, allow_overwrite: bool = True,
discard_if_exist: bool = False, discard_if_exist: bool = False,
): ) -> None:
""" """
Write the content into output_dir/filename Write the content into output_dir/filename
:param output_dir: :param output_dir:

@ -17,7 +17,7 @@ logger = logging.getLogger("Slither-simil")
modes = ["info", "test", "train", "plot"] modes = ["info", "test", "train", "plot"]
def parse_args(): def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector" description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector"
) )
@ -78,7 +78,7 @@ def parse_args():
################################################################################### ###################################################################################
def main(): def main() -> None:
args = parse_args() args = parse_args()
default_log = logging.INFO default_log = logging.INFO

@ -74,7 +74,7 @@ def parse_target(target):
return None return None
def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs):
r = {} r = {}
if infile.endswith(".npz"): if infile.endswith(".npz"):
r = load_cache(infile, nsamples=nsamples) r = load_cache(infile, nsamples=nsamples)

@ -1,3 +1,4 @@
import argparse
import logging import logging
import sys import sys
import os.path import os.path
@ -10,7 +11,7 @@ logging.basicConfig()
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
def info(args): def info(args: argparse.Namespace) -> None:
try: try:

@ -1,3 +1,4 @@
import argparse
import logging import logging
import random import random
import sys import sys
@ -23,7 +24,7 @@ except ImportError:
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
def plot(args): # pylint: disable=too-many-locals def plot(args: argparse.Namespace) -> None: # pylint: disable=too-many-locals
if decomposition is None or plt is None: if decomposition is None or plt is None:
logger.error( logger.error(

@ -1,3 +1,4 @@
import argparse
import logging import logging
import os import os
import sys import sys
@ -10,7 +11,7 @@ from slither.tools.similarity.model import train_unsupervised
logger = logging.getLogger("Slither-simil") logger = logging.getLogger("Slither-simil")
def train(args): # pylint: disable=too-many-locals def train(args: argparse.Namespace) -> None: # pylint: disable=too-many-locals
try: try:
last_data_train_filename = "last_data_train.txt" last_data_train_filename = "last_data_train.txt"

@ -23,7 +23,7 @@ available_detectors = [
] ]
def parse_args(): def parse_args() -> argparse.Namespace:
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
@ -90,7 +90,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
# ------------------------------ # ------------------------------
# Usage: python3 -m slither_format filename # Usage: python3 -m slither_format filename
# Example: python3 -m slither_format contract.sol # Example: python3 -m slither_format contract.sol

@ -1,5 +1,9 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Type, List, Dict
from slither import Slither
from slither.detectors.abstract_detector import AbstractDetector
from slither.detectors.variables.unused_state_variables import UnusedStateVars from slither.detectors.variables.unused_state_variables import UnusedStateVars
from slither.detectors.attributes.incorrect_solc import IncorrectSolc from slither.detectors.attributes.incorrect_solc import IncorrectSolc
from slither.detectors.attributes.constant_pragma import ConstantPragma from slither.detectors.attributes.constant_pragma import ConstantPragma
@ -13,7 +17,7 @@ from slither.utils.colors import yellow
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("Slither.Format") logger = logging.getLogger("Slither.Format")
all_detectors = { all_detectors: Dict[str, Type[AbstractDetector]] = {
"unused-state": UnusedStateVars, "unused-state": UnusedStateVars,
"solc-version": IncorrectSolc, "solc-version": IncorrectSolc,
"pragma": ConstantPragma, "pragma": ConstantPragma,
@ -25,7 +29,7 @@ all_detectors = {
} }
def slither_format(slither, **kwargs): # pylint: disable=too-many-locals def slither_format(slither: Slither, **kwargs: Dict) -> None: # pylint: disable=too-many-locals
"""' """'
Keyword Args: Keyword Args:
detectors_to_run (str): Comma-separated list of detectors, defaults to all detectors_to_run (str): Comma-separated list of detectors, defaults to all
@ -85,9 +89,11 @@ def slither_format(slither, **kwargs): # pylint: disable=too-many-locals
################################################################################### ###################################################################################
def choose_detectors(detectors_to_run, detectors_to_exclude): def choose_detectors(
detectors_to_run: str, detectors_to_exclude: str
) -> List[Type[AbstractDetector]]:
# If detectors are specified, run only these ones # If detectors are specified, run only these ones
cls_detectors_to_run = [] cls_detectors_to_run: List[Type[AbstractDetector]] = []
exclude = detectors_to_exclude.split(",") exclude = detectors_to_exclude.split(",")
if detectors_to_run == "all": if detectors_to_run == "all":
for key, detector in all_detectors.items(): for key, detector in all_detectors.items():
@ -114,7 +120,7 @@ def choose_detectors(detectors_to_run, detectors_to_exclude):
################################################################################### ###################################################################################
def print_patches(number_of_slither_results, patches): def print_patches(number_of_slither_results: int, patches: Dict) -> None:
logger.info("Number of Slither results: " + str(number_of_slither_results)) logger.info("Number of Slither results: " + str(number_of_slither_results))
number_of_patches = 0 number_of_patches = 0
for file in patches: for file in patches:
@ -130,7 +136,7 @@ def print_patches(number_of_slither_results, patches):
logger.info("Location end: " + str(patch["end"])) logger.info("Location end: " + str(patch["end"]))
def print_patches_json(number_of_slither_results, patches): def print_patches_json(number_of_slither_results: int, patches: Dict) -> None:
print("{", end="") print("{", end="")
print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",') print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",')
print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",") print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",")

@ -3,10 +3,13 @@ import inspect
import json import json
import logging import logging
import sys import sys
from typing import List, Any, Type, Dict, Tuple, Union, Sequence, Optional
from crytic_compile import cryticparser from crytic_compile import cryticparser
from slither import Slither from slither import Slither
from slither.core.declarations import Contract
from slither.exceptions import SlitherException from slither.exceptions import SlitherException
from slither.utils.colors import red from slither.utils.colors import red
from slither.utils.output import output_to_json from slither.utils.output import output_to_json
@ -24,7 +27,7 @@ logger: logging.Logger = logging.getLogger("Slither")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
def parse_args(): def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.", description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.",
usage="slither-check-upgradeability contract.sol ContractName", usage="slither-check-upgradeability contract.sol ContractName",
@ -93,21 +96,27 @@ def parse_args():
################################################################################### ###################################################################################
def _get_checks(): def _get_checks() -> List[Type[AbstractCheck]]:
detectors = [getattr(all_checks, name) for name in dir(all_checks)] detectors_ = [getattr(all_checks, name) for name in dir(all_checks)]
detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)] detectors: List[Type[AbstractCheck]] = [
c for c in detectors_ if inspect.isclass(c) and issubclass(c, AbstractCheck)
]
return detectors return detectors
class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
checks = _get_checks() checks = _get_checks()
output_detectors(checks) output_detectors(checks)
parser.exit() parser.exit()
class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
checks = _get_checks() checks = _get_checks()
detector_types_json = output_detectors_json(checks) detector_types_json = output_detectors_json(checks)
print(json.dumps(detector_types_json)) print(json.dumps(detector_types_json))
@ -116,48 +125,64 @@ class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-meth
class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods
def __call__( def __call__(
self, parser, args, values, option_string=None self,
): # pylint: disable=signature-differs parser: Any,
args: Any,
values: Optional[Union[str, Sequence[Any]]],
option_string: Any = None,
) -> None: # pylint: disable=signature-differs
checks = _get_checks() checks = _get_checks()
assert isinstance(values, str)
output_to_markdown(checks, values) output_to_markdown(checks, values)
parser.exit() parser.exit()
class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods
def __call__( def __call__(
self, parser, args, values, option_string=None self,
): # pylint: disable=signature-differs parser: Any,
args: Any,
values: Optional[Union[str, Sequence[Any]]],
option_string: Any = None,
) -> Any: # pylint: disable=signature-differs
checks = _get_checks() checks = _get_checks()
assert isinstance(values, str)
output_wiki(checks, values) output_wiki(checks, values)
parser.exit() parser.exit()
def _run_checks(detectors): def _run_checks(detectors: List[AbstractCheck]) -> List[Dict]:
results = [d.check() for d in detectors] results_ = [d.check() for d in detectors]
results = [r for r in results if r] results_ = [r for r in results_ if r]
results = [item for sublist in results for item in sublist] # flatten results = [item for sublist in results_ for item in sublist] # flatten
return results return results
def _checks_on_contract(detectors, contract): def _checks_on_contract(
detectors = [ detectors: List[Type[AbstractCheck]], contract: Contract
) -> Tuple[List[Dict], int]:
detectors_ = [
d(logger, contract) d(logger, contract)
for d in detectors for d in detectors
if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2) if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2)
] ]
return _run_checks(detectors), len(detectors) return _run_checks(detectors_), len(detectors_)
def _checks_on_contract_update(detectors, contract_v1, contract_v2): def _checks_on_contract_update(
detectors = [ detectors: List[Type[AbstractCheck]], contract_v1: Contract, contract_v2: Contract
) -> Tuple[List[Dict], int]:
detectors_ = [
d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2 d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2
] ]
return _run_checks(detectors), len(detectors) return _run_checks(detectors_), len(detectors_)
def _checks_on_contract_and_proxy(detectors, contract, proxy): def _checks_on_contract_and_proxy(
detectors = [d(logger, contract, proxy=proxy) for d in detectors if d.REQUIRE_PROXY] detectors: List[Type[AbstractCheck]], contract: Contract, proxy: Contract
return _run_checks(detectors), len(detectors) ) -> Tuple[List[Dict], int]:
detectors_ = [d(logger, contract, proxy=proxy) for d in detectors if d.REQUIRE_PROXY]
return _run_checks(detectors_), len(detectors_)
# endregion # endregion
@ -168,8 +193,8 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy):
################################################################################### ###################################################################################
# pylint: disable=too-many-statements,too-many-branches,too-many-locals # pylint: disable=too-many-statements,too-many-branches,too-many-locals
def main(): def main() -> None:
json_results = { json_results: Dict = {
"proxy-present": False, "proxy-present": False,
"contract_v2-present": False, "contract_v2-present": False,
"detectors": [], "detectors": [],
@ -254,7 +279,7 @@ def main():
number_detectors_run += number_detectors number_detectors_run += number_detectors
# If there is a V2, we run the contract-only check on the V2 # If there is a V2, we run the contract-only check on the V2
detectors_results, _ = _checks_on_contract(detectors, v2_contract) detectors_results, number_detectors = _checks_on_contract(detectors, v2_contract)
json_results["detectors"] += detectors_results json_results["detectors"] += detectors_results
number_detectors_run += number_detectors number_detectors_run += number_detectors

@ -1,8 +1,10 @@
from slither.tools.upgradeability.checks.abstract_checks import classification_txt from typing import List, Union, Dict, Type
from slither.tools.upgradeability.checks.abstract_checks import classification_txt, AbstractCheck
from slither.utils.myprettytable import MyPrettyTable from slither.utils.myprettytable import MyPrettyTable
def output_wiki(detector_classes, filter_wiki): def output_wiki(detector_classes: List[Type[AbstractCheck]], filter_wiki: str) -> None:
# Sort by impact, confidence, and name # Sort by impact, confidence, and name
detectors_list = sorted( detectors_list = sorted(
detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT) detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT)
@ -31,7 +33,7 @@ def output_wiki(detector_classes, filter_wiki):
print(recommendation) print(recommendation)
def output_detectors(detector_classes): def output_detectors(detector_classes: List[Type[AbstractCheck]]) -> None:
detectors_list = [] detectors_list = []
for detector in detector_classes: for detector in detector_classes:
argument = detector.ARGUMENT argument = detector.ARGUMENT
@ -48,7 +50,7 @@ def output_detectors(detector_classes):
for (argument, help_info, impact, proxy, v2) in detectors_list: for (argument, help_info, impact, proxy, v2) in detectors_list:
table.add_row( table.add_row(
[ [
idx, str(idx),
argument, argument,
help_info, help_info,
classification_txt[impact], classification_txt[impact],
@ -60,8 +62,8 @@ def output_detectors(detector_classes):
print(table) print(table)
def output_to_markdown(detector_classes, _filter_wiki): def output_to_markdown(detector_classes: List[Type[AbstractCheck]], _filter_wiki: str) -> None:
def extract_help(cls): def extract_help(cls: AbstractCheck) -> str:
if cls.WIKI == "": if cls.WIKI == "":
return cls.HELP return cls.HELP
return f"[{cls.HELP}]({cls.WIKI})" return f"[{cls.HELP}]({cls.WIKI})"
@ -85,7 +87,9 @@ def output_to_markdown(detector_classes, _filter_wiki):
idx = idx + 1 idx = idx + 1
def output_detectors_json(detector_classes): def output_detectors_json(
detector_classes: List[Type[AbstractCheck]],
) -> List[Dict[str, Union[str, int]]]:
detectors_list = [] detectors_list = []
for detector in detector_classes: for detector in detector_classes:
argument = detector.ARGUMENT argument = detector.ARGUMENT
@ -110,7 +114,7 @@ def output_detectors_json(detector_classes):
# Sort by impact, confidence, and name # Sort by impact, confidence, and name
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1 idx = 1
table = [] table: List[Dict[str, Union[str, int]]] = []
for ( for (
argument, argument,
help_info, help_info,

@ -1,13 +1,17 @@
import argparse
import json import json
import os import os
import re import re
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Type, Union
from crytic_compile.cryticparser.defaults import ( from crytic_compile.cryticparser.defaults import (
DEFAULTS_FLAG_IN_CONFIG as DEFAULTS_FLAG_IN_CONFIG_CRYTIC_COMPILE, DEFAULTS_FLAG_IN_CONFIG as DEFAULTS_FLAG_IN_CONFIG_CRYTIC_COMPILE,
) )
from slither.detectors.abstract_detector import classification_txt from slither.detectors.abstract_detector import classification_txt, AbstractDetector
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.colors import yellow, red from slither.utils.colors import yellow, red
from slither.utils.myprettytable import MyPrettyTable from slither.utils.myprettytable import MyPrettyTable
@ -54,7 +58,7 @@ defaults_flag_in_config = {
} }
def read_config_file(args): def read_config_file(args: argparse.Namespace) -> None:
# No config file was provided as an argument # No config file was provided as an argument
if args.config_file is None: if args.config_file is None:
# Check wether the default config file is present # Check wether the default config file is present
@ -83,8 +87,12 @@ def read_config_file(args):
logger.error(yellow("Falling back to the default settings...")) logger.error(yellow("Falling back to the default settings..."))
def output_to_markdown(detector_classes, printer_classes, filter_wiki): def output_to_markdown(
def extract_help(cls): detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
filter_wiki: str,
) -> None:
def extract_help(cls: Union[Type[AbstractDetector], Type[AbstractPrinter]]) -> str:
if cls.WIKI == "": if cls.WIKI == "":
return cls.HELP return cls.HELP
return f"[{cls.HELP}]({cls.WIKI})" return f"[{cls.HELP}]({cls.WIKI})"
@ -127,7 +135,7 @@ def output_to_markdown(detector_classes, printer_classes, filter_wiki):
idx = idx + 1 idx = idx + 1
def get_level(l): def get_level(l: str) -> int:
tab = l.count("\t") + 1 tab = l.count("\t") + 1
if l.replace("\t", "").startswith(" -"): if l.replace("\t", "").startswith(" -"):
tab = tab + 1 tab = tab + 1
@ -136,7 +144,7 @@ def get_level(l):
return tab return tab
def convert_result_to_markdown(txt): def convert_result_to_markdown(txt: str) -> str:
# -1 to remove the last \n # -1 to remove the last \n
lines = txt[0:-1].split("\n") lines = txt[0:-1].split("\n")
ret = [] ret = []
@ -154,16 +162,21 @@ def convert_result_to_markdown(txt):
return "".join(ret) return "".join(ret)
def output_results_to_markdown(all_results, checklistlimit: str): def output_results_to_markdown(all_results: List[Dict], checklistlimit: str) -> None:
checks = defaultdict(list) checks = defaultdict(list)
info = defaultdict(dict) info: Dict = defaultdict(dict)
for results in all_results: for results_ in all_results:
checks[results["check"]].append(results) checks[results_["check"]].append(results_)
info[results["check"]] = {"impact": results["impact"], "confidence": results["confidence"]} info[results_["check"]] = {
"impact": results_["impact"],
"confidence": results_["confidence"],
}
print("Summary") print("Summary")
for check in checks: for check_ in checks:
print(f" - [{check}](#{check}) ({len(checks[check])} results) ({info[check]['impact']})") print(
f" - [{check_}](#{check_}) ({len(checks[check_])} results) ({info[check_]['impact']})"
)
counter = 0 counter = 0
for (check, results) in checks.items(): for (check, results) in checks.items():
@ -185,8 +198,7 @@ def output_results_to_markdown(all_results, checklistlimit: str):
print(f"**More results were found, check [{checklistlimit}]({checklistlimit})**") print(f"**More results were found, check [{checklistlimit}]({checklistlimit})**")
def output_wiki(detector_classes, filter_wiki): def output_wiki(detector_classes: List[Type[AbstractDetector]], filter_wiki: str) -> None:
detectors_list = []
# Sort by impact, confidence, and name # Sort by impact, confidence, and name
detectors_list = sorted( detectors_list = sorted(
@ -223,7 +235,7 @@ def output_wiki(detector_classes, filter_wiki):
print(recommendation) print(recommendation)
def output_detectors(detector_classes): def output_detectors(detector_classes: List[Type[AbstractDetector]]) -> None:
detectors_list = [] detectors_list = []
for detector in detector_classes: for detector in detector_classes:
argument = detector.ARGUMENT argument = detector.ARGUMENT
@ -242,12 +254,15 @@ def output_detectors(detector_classes):
) )
idx = 1 idx = 1
for (argument, help_info, impact, confidence) in detectors_list: for (argument, help_info, impact, confidence) in detectors_list:
table.add_row([idx, argument, help_info, classification_txt[impact], confidence]) table.add_row([str(idx), argument, help_info, classification_txt[impact], confidence])
idx = idx + 1 idx = idx + 1
print(table) print(table)
def output_detectors_json(detector_classes): # pylint: disable=too-many-locals # pylint: disable=too-many-locals
def output_detectors_json(
detector_classes: List[Type[AbstractDetector]],
) -> List[Dict]:
detectors_list = [] detectors_list = []
for detector in detector_classes: for detector in detector_classes:
argument = detector.ARGUMENT argument = detector.ARGUMENT
@ -307,7 +322,7 @@ def output_detectors_json(detector_classes): # pylint: disable=too-many-locals
return table return table
def output_printers(printer_classes): def output_printers(printer_classes: List[Type[AbstractPrinter]]) -> None:
printers_list = [] printers_list = []
for printer in printer_classes: for printer in printer_classes:
argument = printer.ARGUMENT argument = printer.ARGUMENT
@ -319,12 +334,12 @@ def output_printers(printer_classes):
printers_list = sorted(printers_list, key=lambda element: (element[0])) printers_list = sorted(printers_list, key=lambda element: (element[0]))
idx = 1 idx = 1
for (argument, help_info) in printers_list: for (argument, help_info) in printers_list:
table.add_row([idx, argument, help_info]) table.add_row([str(idx), argument, help_info])
idx = idx + 1 idx = idx + 1
print(table) print(table)
def output_printers_json(printer_classes): def output_printers_json(printer_classes: List[Type[AbstractPrinter]]) -> List[Dict]:
printers_list = [] printers_list = []
for printer in printer_classes: for printer in printer_classes:
argument = printer.ARGUMENT argument = printer.ARGUMENT

@ -4,7 +4,7 @@ import json
import logging import logging
import zipfile import zipfile
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Dict, List, Union, Any, TYPE_CHECKING from typing import Optional, Dict, List, Union, Any, TYPE_CHECKING, Type
from zipfile import ZipFile from zipfile import ZipFile
from pkg_resources import require from pkg_resources import require
@ -129,7 +129,7 @@ def _output_result_to_sarif(
def output_to_sarif( def output_to_sarif(
filename: Optional[str], results: Dict, detectors_classes: List["AbstractDetector"] filename: Optional[str], results: Dict, detectors_classes: List[Type["AbstractDetector"]]
) -> None: ) -> None:
""" """

@ -28,7 +28,7 @@ class StandardOutputCapture:
original_logger_handlers = None original_logger_handlers = None
@staticmethod @staticmethod
def enable(block_original=True): def enable(block_original: bool = True) -> None:
""" """
Redirects stdout and stderr to a capturable StringIO. Redirects stdout and stderr to a capturable StringIO.
:param block_original: If True, blocks all output to the original stream. If False, duplicates output. :param block_original: If True, blocks all output to the original stream. If False, duplicates output.
@ -54,7 +54,7 @@ class StandardOutputCapture:
root_logger.handlers = [logging.StreamHandler(sys.stderr)] root_logger.handlers = [logging.StreamHandler(sys.stderr)]
@staticmethod @staticmethod
def disable(): def disable() -> None:
""" """
Disables redirection of stdout/stderr, if previously enabled. Disables redirection of stdout/stderr, if previously enabled.
:return: None :return: None
@ -78,7 +78,7 @@ class StandardOutputCapture:
StandardOutputCapture.original_logger_handlers = None StandardOutputCapture.original_logger_handlers = None
@staticmethod @staticmethod
def get_stdout_output(): def get_stdout_output() -> str:
""" """
Obtains the output from the currently set stdout Obtains the output from the currently set stdout
:return: Returns stdout output as a string :return: Returns stdout output as a string
@ -87,7 +87,7 @@ class StandardOutputCapture:
return sys.stdout.read() return sys.stdout.read()
@staticmethod @staticmethod
def get_stderr_output(): def get_stderr_output() -> str:
""" """
Obtains the output from the currently set stderr Obtains the output from the currently set stderr
:return: Returns stderr output as a string :return: Returns stderr output as a string

@ -10,4 +10,4 @@ Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-
INFO:Slither: INFO:Slither:
Initializable contract not found, the contract does not follow a standard initalization schema. Initializable contract not found, the contract does not follow a standard initalization schema.
Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing
INFO:Slither:4 findings, 18 detectors run INFO:Slither:4 findings, 21 detectors run

@ -4,4 +4,4 @@ Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initiali
INFO:Slither: INFO:Slither:
Initializable contract not found, the contract does not follow a standard initalization schema. Initializable contract not found, the contract does not follow a standard initalization schema.
Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing
INFO:Slither:2 findings, 22 detectors run INFO:Slither:2 findings, 25 detectors run

@ -20,4 +20,4 @@ Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-va
INFO:Slither: INFO:Slither:
Initializable contract not found, the contract does not follow a standard initalization schema. Initializable contract not found, the contract does not follow a standard initalization schema.
Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing
INFO:Slither:6 findings, 22 detectors run INFO:Slither:6 findings, 25 detectors run

@ -17,4 +17,4 @@ Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-va
INFO:Slither: INFO:Slither:
Initializable contract not found, the contract does not follow a standard initalization schema. Initializable contract not found, the contract does not follow a standard initalization schema.
Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing Reference: https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing
INFO:Slither:5 findings, 22 detectors run INFO:Slither:5 findings, 25 detectors run

Loading…
Cancel
Save