Add more types

pull/1388/head
Josselin Feist 2 years ago
parent 719e4e98d8
commit aa4a57d050
  1. 9
      plugin_example/slither_my_plugin/__init__.py
  2. 96
      slither/__main__.py
  3. 14
      slither/core/declarations/structure.py
  4. 4
      slither/core/expressions/tuple_expression.py
  5. 8
      slither/core/source_mapping/source_mapping.py
  6. 12
      slither/slither.py
  7. 8
      slither/slithir/operations/call.py
  8. 8
      slither/slithir/tmp_operations/tmp_call.py
  9. 12
      slither/slithir/tmp_operations/tmp_new_elementary_type.py
  10. 4
      slither/tools/demo/__main__.py
  11. 9
      slither/tools/erc_conformance/__main__.py
  12. 15
      slither/tools/erc_conformance/erc/ercs.py
  13. 4
      slither/tools/flattening/__main__.py
  14. 4
      slither/tools/kspec_coverage/__main__.py
  15. 4
      slither/tools/kspec_coverage/kspec_coverage.py
  16. 13
      slither/tools/mutator/__main__.py
  17. 7
      slither/tools/possible_paths/__main__.py
  18. 45
      slither/tools/possible_paths/possible_paths.py
  19. 19
      slither/tools/properties/__main__.py
  20. 8
      slither/tools/properties/platforms/truffle.py
  21. 2
      slither/tools/properties/utils.py
  22. 4
      slither/tools/similarity/__main__.py
  23. 2
      slither/tools/similarity/encode.py
  24. 3
      slither/tools/similarity/info.py
  25. 3
      slither/tools/similarity/plot.py
  26. 3
      slither/tools/similarity/train.py
  27. 4
      slither/tools/slither_format/__main__.py
  28. 18
      slither/tools/slither_format/slither_format.py
  29. 75
      slither/tools/upgradeability/__main__.py
  30. 20
      slither/tools/upgradeability/utils/command_line.py
  31. 57
      slither/utils/command_line.py
  32. 4
      slither/utils/output.py
  33. 8
      slither/utils/output_capture.py

@ -1,8 +1,13 @@
from typing import Tuple, List, Type
from slither_my_plugin.detectors.example import Example
from slither.detectors.abstract_detector import AbstractDetector
from slither.printers.abstract_printer import AbstractPrinter
def make_plugin():
def make_plugin() -> Tuple[List[Type[AbstractDetector]], List[Type[AbstractPrinter]]]:
plugin_detectors = [Example]
plugin_printers = []
plugin_printers: List[Type[AbstractPrinter]] = []
return plugin_detectors, plugin_printers

@ -10,11 +10,11 @@ import os
import pstats
import sys
import traceback
from typing import Tuple, Optional, List, Dict
from typing import Tuple, Optional, List, Dict, Type, Union, Any, Sequence
from pkg_resources import iter_entry_points, require
from crytic_compile import cryticparser
from crytic_compile import cryticparser, CryticCompile
from crytic_compile.platform.standard import generate_standard_export
from crytic_compile.platform.etherscan import SUPPORTED_NETWORK
from crytic_compile import compile_all, is_supported
@ -55,10 +55,10 @@ logger = logging.getLogger("Slither")
def process_single(
target: str,
target: Union[str, CryticCompile],
args: argparse.Namespace,
detector_classes: List[AbstractDetector],
printer_classes: List[AbstractPrinter],
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
"""
The core high-level code for running Slither static analysis.
@ -80,8 +80,8 @@ def process_single(
def process_all(
target: str,
args: argparse.Namespace,
detector_classes: List[AbstractDetector],
printer_classes: List[AbstractPrinter],
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[List[Slither], List[Dict], List[Dict], int]:
compilations = compile_all(target, **vars(args))
slither_instances = []
@ -109,8 +109,8 @@ def process_all(
def _process(
slither: Slither,
detector_classes: List[AbstractDetector],
printer_classes: List[AbstractPrinter],
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
for detector_cls in detector_classes:
slither.register_detector(detector_cls)
@ -137,13 +137,14 @@ def _process(
return slither, results_detectors, results_printers, analyzed_contracts_count
# TODO: delete me?
def process_from_asts(
filenames: List[str],
args: argparse.Namespace,
detector_classes: List[AbstractDetector],
printer_classes: List[AbstractPrinter],
):
all_contracts = []
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
all_contracts: List[str] = []
for filename in filenames:
with open(filename, encoding="utf8") as file_open:
@ -162,13 +163,15 @@ def process_from_asts(
###################################################################################
def get_detectors_and_printers():
def get_detectors_and_printers() -> Tuple[
List[Type[AbstractDetector]], List[Type[AbstractPrinter]]
]:
detectors = [getattr(all_detectors, name) for name in dir(all_detectors)]
detectors = [d for d in detectors if inspect.isclass(d) and issubclass(d, AbstractDetector)]
detectors_ = [getattr(all_detectors, name) for name in dir(all_detectors)]
detectors = [d for d in detectors_ if inspect.isclass(d) and issubclass(d, AbstractDetector)]
printers = [getattr(all_printers, name) for name in dir(all_printers)]
printers = [p for p in printers if inspect.isclass(p) and issubclass(p, AbstractPrinter)]
printers_ = [getattr(all_printers, name) for name in dir(all_printers)]
printers = [p for p in printers_ if inspect.isclass(p) and issubclass(p, AbstractPrinter)]
# Handle plugins!
for entry_point in iter_entry_points(group="slither_analyzer.plugin", name=None):
@ -194,8 +197,8 @@ def get_detectors_and_printers():
# pylint: disable=too-many-branches
def choose_detectors(
args: argparse.Namespace, all_detector_classes: List[AbstractDetector]
) -> List[AbstractDetector]:
args: argparse.Namespace, all_detector_classes: List[Type[AbstractDetector]]
) -> List[Type[AbstractDetector]]:
# If detectors are specified, run only these ones
detectors_to_run = []
@ -245,8 +248,8 @@ def choose_detectors(
def choose_printers(
args: argparse.Namespace, all_printer_classes: List[AbstractPrinter]
) -> List[AbstractPrinter]:
args: argparse.Namespace, all_printer_classes: List[Type[AbstractPrinter]]
) -> List[Type[AbstractPrinter]]:
printers_to_run = []
# disable default printer
@ -273,13 +276,16 @@ def choose_printers(
###################################################################################
def parse_filter_paths(args):
def parse_filter_paths(args: argparse.Namespace) -> List[str]:
if args.filter_paths:
return args.filter_paths.split(",")
return []
def parse_args(detector_classes, printer_classes): # pylint: disable=too-many-statements
# pylint: disable=too-many-statements
def parse_args(
detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]]
) -> argparse.Namespace:
usage = "slither target [flag]\n"
usage += "\ntarget can be:\n"
@ -622,7 +628,9 @@ class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods
class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
detectors, _ = get_detectors_and_printers()
detector_types_json = output_detectors_json(detectors)
print(json.dumps(detector_types_json))
@ -630,22 +638,38 @@ class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-meth
class ListPrinters(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
_, printers = get_detectors_and_printers()
output_printers(printers)
parser.exit()
class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, args, values, option_string=None):
def __call__(
self,
parser: Any,
args: Any,
values: Optional[Union[str, Sequence[Any]]],
option_string: Any = None,
) -> None:
detectors, printers = get_detectors_and_printers()
assert isinstance(values, str)
output_to_markdown(detectors, printers, values)
parser.exit()
class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, args, values, option_string=None):
def __call__(
self,
parser: Any,
args: Any,
values: Optional[Union[str, Sequence[Any]]],
option_string: Any = None,
) -> None:
detectors, _ = get_detectors_and_printers()
assert isinstance(values, str)
output_wiki(detectors, values)
parser.exit()
@ -678,7 +702,7 @@ class FormatterCryticCompile(logging.Formatter):
###################################################################################
def main():
def main() -> None:
# Codebase with complex domninators can lead to a lot of SSA recursive call
sys.setrecursionlimit(1500)
@ -689,8 +713,9 @@ def main():
# pylint: disable=too-many-statements,too-many-branches,too-many-locals
def main_impl(
all_detector_classes: List[AbstractDetector], all_printer_classes: List[AbstractPrinter]
):
all_detector_classes: List[Type[AbstractDetector]],
all_printer_classes: List[Type[AbstractPrinter]],
) -> None:
"""
:param all_detector_classes: A list of all detectors that can be included/excluded.
:param all_printer_classes: A list of all printers that can be included.
@ -756,8 +781,8 @@ def main_impl(
crytic_compile_error.propagate = False
crytic_compile_error.setLevel(logging.INFO)
results_detectors = []
results_printers = []
results_detectors: List[Dict] = []
results_printers: List[Dict] = []
try:
filename = args.filename
@ -806,6 +831,7 @@ def main_impl(
if "compilations" in args.json_types:
compilation_results = []
for slither_instance in slither_instances:
assert slither_instance.crytic_compile
compilation_results.append(
generate_standard_export(slither_instance.crytic_compile)
)
@ -856,7 +882,7 @@ def main_impl(
except Exception: # pylint: disable=broad-except
output_error = traceback.format_exc()
logging.error(traceback.print_exc())
traceback.print_exc()
logging.error(f"Error in {args.filename}") # pylint: disable=logging-fstring-interpolation
logging.error(output_error)
@ -879,7 +905,7 @@ def main_impl(
if outputting_zip:
output_to_zip(args.zip, output_error, json_results, args.zip_type)
if args.perf:
if args.perf and cp:
cp.disable()
stats = pstats.Stats(cp).sort_stats("cumtime")
stats.print_stats()

@ -8,10 +8,10 @@ if TYPE_CHECKING:
class Structure(SourceMapping):
def __init__(self, compilation_unit: "SlitherCompilationUnit"):
def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None:
super().__init__()
self._name: Optional[str] = None
self._canonical_name = None
self._canonical_name: Optional[str] = None
self._elems: Dict[str, "StructureVariable"] = {}
# Name of the elements in the order of declaration
self._elems_ordered: List[str] = []
@ -19,25 +19,27 @@ class Structure(SourceMapping):
@property
def canonical_name(self) -> str:
assert self._canonical_name
return self._canonical_name
@canonical_name.setter
def canonical_name(self, name: str):
def canonical_name(self, name: str) -> None:
self._canonical_name = name
@property
def name(self) -> str:
assert self._name
return self._name
@name.setter
def name(self, new_name: str):
def name(self, new_name: str) -> None:
self._name = new_name
@property
def elems(self) -> Dict[str, "StructureVariable"]:
return self._elems
def add_elem_in_order(self, s: str):
def add_elem_in_order(self, s: str) -> None:
self._elems_ordered.append(s)
@property
@ -47,5 +49,5 @@ class Structure(SourceMapping):
ret.append(self._elems[e])
return ret
def __str__(self):
def __str__(self) -> str:
return self.name

@ -4,7 +4,7 @@ from slither.core.expressions.expression import Expression
class TupleExpression(Expression):
def __init__(self, expressions):
def __init__(self, expressions: List[Expression]) -> None:
assert all(isinstance(x, Expression) for x in expressions if x)
super().__init__()
self._expressions = expressions
@ -13,6 +13,6 @@ class TupleExpression(Expression):
def expressions(self) -> List[Expression]:
return self._expressions
def __str__(self):
def __str__(self) -> str:
expressions_str = [str(e) for e in self.expressions]
return "(" + ",".join(expressions_str) + ")"

@ -162,13 +162,15 @@ def _convert_source_mapping(
class SourceMapping(Context, metaclass=ABCMeta):
def __init__(self):
def __init__(self) -> None:
super().__init__()
# self._source_mapping: Optional[Dict] = None
self.source_mapping: Source = Source()
self.references: List[Source] = []
def set_offset(self, offset: Union["Source", str], compilation_unit: "SlitherCompilationUnit"):
def set_offset(
self, offset: Union["Source", str], compilation_unit: "SlitherCompilationUnit"
) -> None:
if isinstance(offset, Source):
self.source_mapping.start = offset.start
self.source_mapping.length = offset.length
@ -184,6 +186,6 @@ class SourceMapping(Context, metaclass=ABCMeta):
def add_reference_from_raw_source(
self, offset: str, compilation_unit: "SlitherCompilationUnit"
):
) -> None:
s = _convert_source_mapping(offset, compilation_unit)
self.references.append(s)

@ -1,5 +1,5 @@
import logging
from typing import Union, List, ValuesView
from typing import Union, List, ValuesView, Type, Dict
from crytic_compile import CryticCompile, InvalidCompilation
@ -19,7 +19,9 @@ logger_detector = logging.getLogger("Detectors")
logger_printer = logging.getLogger("Printers")
def _check_common_things(thing_name, cls, base_cls, instances_list):
def _check_common_things(
thing_name: str, cls: Type, base_cls: Type, instances_list: List[Type[AbstractDetector]]
) -> None:
if not issubclass(cls, base_cls) or cls is base_cls:
raise Exception(
@ -178,7 +180,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
def detectors_optimization(self):
return [d for d in self.detectors if d.IMPACT == DetectorClassification.OPTIMIZATION]
def register_detector(self, detector_class):
def register_detector(self, detector_class: Type[AbstractDetector]) -> None:
"""
:param detector_class: Class inheriting from `AbstractDetector`.
"""
@ -188,7 +190,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
instance = detector_class(compilation_unit, self, logger_detector)
self._detectors.append(instance)
def register_printer(self, printer_class):
def register_printer(self, printer_class: Type[AbstractPrinter]) -> None:
"""
:param printer_class: Class inheriting from `AbstractPrinter`.
"""
@ -197,7 +199,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes
instance = printer_class(self, logger_printer)
self._printers.append(instance)
def run_detectors(self):
def run_detectors(self) -> List[Dict]:
"""
:return: List of registered detectors results.
"""

@ -1,8 +1,10 @@
from typing import Optional, List
from slither.slithir.operations.operation import Operation
class Call(Operation):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._arguments = []
@ -14,14 +16,14 @@ class Call(Operation):
def arguments(self, v):
self._arguments = v
def can_reenter(self, _callstack=None): # pylint: disable=no-self-use
def can_reenter(self, _callstack: Optional[List] = None) -> bool: # pylint: disable=no-self-use
"""
Must be called after slithIR analysis pass
:return: bool
"""
return False
def can_send_eth(self): # pylint: disable=no-self-use
def can_send_eth(self) -> bool: # pylint: disable=no-self-use
"""
Must be called after slithIR analysis pass
:return: bool

@ -63,14 +63,14 @@ class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attribu
def call_id(self):
return self._callid
@property
def read(self):
return [self.called]
@call_id.setter
def call_id(self, c):
self._callid = c
@property
def read(self):
return [self.called]
@property
def called(self):
return self._called

@ -1,21 +1,23 @@
from typing import List
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.solidity_types.elementary_type import ElementaryType
class TmpNewElementaryType(OperationWithLValue):
def __init__(self, new_type, lvalue):
def __init__(self, new_type: ElementaryType, lvalue):
assert isinstance(new_type, ElementaryType)
super().__init__()
self._type = new_type
self._type: ElementaryType = new_type
self._lvalue = lvalue
@property
def read(self):
def read(self) -> List:
return []
@property
def type(self):
def type(self) -> ElementaryType:
return self._type
def __str__(self):
def __str__(self) -> str:
return f"{self.lvalue} = new {self._type}"

@ -9,7 +9,7 @@ logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither-demo")
def parse_args():
def parse_args() -> argparse.Namespace:
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
@ -26,7 +26,7 @@ def parse_args():
return parser.parse_args()
def main():
def main() -> None:
args = parse_args()
# Perform slither analysis on the given filename

@ -1,6 +1,7 @@
import argparse
import logging
from collections import defaultdict
from typing import Any, Dict, List
from crytic_compile import cryticparser
from slither import Slither
@ -26,7 +27,7 @@ logger.propagate = False
ADDITIONAL_CHECKS = {"ERC20": check_erc20, "ERC1155": check_erc1155}
def parse_args():
def parse_args() -> argparse.Namespace:
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
@ -63,20 +64,20 @@ def parse_args():
return parser.parse_args()
def _log_error(err, args):
def _log_error(err: Any, args: argparse.Namespace) -> None:
if args.json:
output_to_json(args.json, str(err), {"upgradeability-check": []})
logger.error(err)
def main():
def main() -> None:
args = parse_args()
# Perform slither analysis on the given filename
slither = Slither(args.project, **vars(args))
ret = defaultdict(list)
ret: Dict[str, List] = defaultdict(list)
if args.erc.upper() in ERCS:

@ -1,7 +1,10 @@
import logging
from typing import Dict, List, Optional, Set
from slither.core.declarations import Contract
from slither.slithir.operations import EventCall
from slither.utils import output
from slither.utils.erc import ERC, ERC_EVENT
from slither.utils.type import (
export_nested_types_from_variable,
export_return_type_from_variable,
@ -11,7 +14,7 @@ logger = logging.getLogger("Slither-conformance")
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def _check_signature(erc_function, contract, ret):
def _check_signature(erc_function: ERC, contract: Contract, ret: Dict) -> None:
name = erc_function.name
parameters = erc_function.parameters
return_type = erc_function.return_type
@ -146,7 +149,7 @@ def _check_signature(erc_function, contract, ret):
ret["missing_event_emmited"].append(missing_event_emmited.data)
def _check_events(erc_event, contract, ret):
def _check_events(erc_event: ERC_EVENT, contract: Contract, ret: Dict[str, List]) -> None:
name = erc_event.name
parameters = erc_event.parameters
indexes = erc_event.indexes
@ -180,7 +183,13 @@ def _check_events(erc_event, contract, ret):
ret["missing_event_index"].append(missing_event_index.data)
def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None):
def generic_erc_checks(
contract: Contract,
erc_functions: List[ERC],
erc_events: List[ERC_EVENT],
ret: Dict[str, List],
explored: Optional[Set[Contract]] = None,
) -> None:
if explored is None:
explored = set()

@ -18,7 +18,7 @@ logger = logging.getLogger("Slither")
logger.setLevel(logging.INFO)
def parse_args():
def parse_args() -> argparse.Namespace:
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
@ -106,7 +106,7 @@ def parse_args():
return parser.parse_args()
def main():
def main() -> None:
args = parse_args()
slither = Slither(args.filename, **vars(args))

@ -16,7 +16,7 @@ logger.handlers[0].setFormatter(formatter)
logger.propagate = False
def parse_args():
def parse_args() -> argparse.Namespace:
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
@ -56,7 +56,7 @@ def parse_args():
return parser.parse_args()
def main():
def main() -> None:
# ------------------------------
# Usage: slither-kspec-coverage contract kspec
# Example: slither-kspec-coverage contract.sol kspec.md

@ -1,8 +1,10 @@
import argparse
from slither.tools.kspec_coverage.analysis import run_analysis
from slither import Slither
def kspec_coverage(args):
def kspec_coverage(args: argparse.Namespace) -> None:
contract = args.contract
kspec = args.kspec

@ -2,6 +2,7 @@ import argparse
import inspect
import logging
import sys
from typing import Type, List, Any
from crytic_compile import cryticparser
@ -22,7 +23,7 @@ logger.setLevel(logging.INFO)
###################################################################################
def parse_args():
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Experimental smart contract mutator. Based on https://arxiv.org/abs/2006.11597",
usage="slither-mutate target",
@ -48,14 +49,16 @@ def parse_args():
return parser.parse_args()
def _get_mutators():
detectors = [getattr(all_mutators, name) for name in dir(all_mutators)]
detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractMutator)]
def _get_mutators() -> List[Type[AbstractMutator]]:
detectors_ = [getattr(all_mutators, name) for name in dir(all_mutators)]
detectors = [c for c in detectors_ if inspect.isclass(c) and issubclass(c, AbstractMutator)]
return detectors
class ListMutators(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
checks = _get_mutators()
output_mutators(checks)
parser.exit()

@ -5,6 +5,7 @@ from argparse import ArgumentParser, Namespace
from crytic_compile import cryticparser
from slither import Slither
from slither.core.declarations import FunctionContract
from slither.utils.colors import red
from slither.tools.possible_paths.possible_paths import (
find_target_paths,
@ -58,7 +59,11 @@ def main() -> None:
# Print out all target functions.
print("Target functions:")
for target in targets:
print(f"- {target.contract_declarer.name}.{target.full_name}")
if isinstance(target, FunctionContract):
print(f"- {target.contract_declarer.name}.{target.full_name}")
else:
pass
# TODO implement me
print("\n")
# Obtain all paths which reach the target functions.

@ -1,8 +1,15 @@
from typing import List, Tuple, Union, Optional, Set
from slither import Slither
from slither.core.declarations import Function, FunctionContract
from slither.core.slither_core import SlitherCore
class ResolveFunctionException(Exception):
pass
def resolve_function(slither, contract_name, function_name):
def resolve_function(slither: SlitherCore, contract_name: str, function_name: str) -> Function:
"""
Resolves a function instance, given a contract name and function.
:param contract_name: The name of the contract the function is declared in.
@ -32,7 +39,9 @@ def resolve_function(slither, contract_name, function_name):
return target_function
def resolve_functions(slither, functions):
def resolve_functions(
slither: Slither, functions: List[Union[str, Tuple[str, str]]]
) -> List[Function]:
"""
Resolves the provided function descriptors.
:param functions: A list of tuples (contract_name, function_name) or str (of form "ContractName.FunctionName")
@ -40,7 +49,7 @@ def resolve_functions(slither, functions):
:return: Returns a list of resolved functions.
"""
# Create the resolved list.
resolved = []
resolved: List[Function] = []
# Verify that the provided argument is a list.
if not isinstance(functions, list):
@ -72,24 +81,31 @@ def resolve_functions(slither, functions):
return resolved
def all_function_definitions(function):
def all_function_definitions(function: Function) -> List[Function]:
"""
Obtains a list of representing this function and any base definitions
:param function: The function to obtain all definitions at and beneath.
:return: Returns a list composed of the provided function definition and any base definitions.
"""
return [function] + [
# TODO implement me
if not isinstance(function, FunctionContract):
return []
ret: List[Function] = [function]
ret += [
f
for c in function.contract.inheritance
for f in c.functions_and_modifiers_declared
if f.full_name == function.full_name
]
return ret
def __find_target_paths(slither, target_function, current_path=None):
def __find_target_paths(
slither: SlitherCore, target_function: Function, current_path: Optional[List[Function]] = None
) -> Set[Tuple[Function, ...]]:
current_path = current_path if current_path else []
# Create our results list
results = set()
results: Set[Tuple[Function, ...]] = set()
# Add our current function to the path.
current_path = [target_function] + current_path
@ -106,9 +122,12 @@ def __find_target_paths(slither, target_function, current_path=None):
continue
# Find all function calls in this function (except for low level)
called_functions = [f for (_, f) in function.high_level_calls + function.library_calls]
called_functions += function.internal_calls
called_functions = set(called_functions)
called_functions_list = [
f for (_, f) in function.high_level_calls if isinstance(f, Function)
]
called_functions_list += [f for (_, f) in function.library_calls]
called_functions_list += [f for f in function.internal_calls if isinstance(f, Function)]
called_functions = set(called_functions_list)
# If any of our target functions are reachable from this function, it's a result.
if all_target_functions.intersection(called_functions):
@ -123,14 +142,16 @@ def __find_target_paths(slither, target_function, current_path=None):
return results
def find_target_paths(slither, target_functions):
def find_target_paths(
slither: SlitherCore, target_functions: List[Function]
) -> Set[Tuple[Function, ...]]:
"""
Obtains all functions which can lead to any of the target functions being called.
:param target_functions: The functions we are interested in reaching.
:return: Returns a list of all functions which can reach any of the target_functions.
"""
# Create our results list
results = set()
results: Set[Tuple[Function, ...]] = set()
# Loop for each target function
for target_function in target_functions:

@ -1,6 +1,7 @@
import argparse
import logging
import sys
from typing import Any
from crytic_compile import cryticparser
@ -26,7 +27,7 @@ logger.handlers[0].setFormatter(formatter)
logger.propagate = False
def _all_scenarios():
def _all_scenarios() -> str:
txt = "\n"
txt += "#################### ERC20 ####################\n"
for k, value in ERC20_PROPERTIES.items():
@ -35,29 +36,33 @@ def _all_scenarios():
return txt
def _all_properties():
def _all_properties() -> MyPrettyTable:
table = MyPrettyTable(["Num", "Description", "Scenario"])
idx = 0
for scenario, value in ERC20_PROPERTIES.items():
for prop in value.properties:
table.add_row([idx, prop.description, scenario])
table.add_row([str(idx), prop.description, scenario])
idx = idx + 1
return table
class ListScenarios(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
logger.info(_all_scenarios())
parser.exit()
class ListProperties(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
logger.info(_all_properties())
parser.exit()
def parse_args():
def parse_args() -> argparse.Namespace:
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
@ -120,7 +125,7 @@ def parse_args():
return parser.parse_args()
def main():
def main() -> None:
args = parse_args()
# Perform slither analysis on the given filename

@ -15,7 +15,7 @@ PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_")
logger = logging.getLogger("Slither")
def _extract_caller(p: PropertyCaller):
def _extract_caller(p: PropertyCaller) -> List[str]:
if p == PropertyCaller.OWNER:
return ["owner"]
if p == PropertyCaller.SENDER:
@ -28,7 +28,7 @@ def _extract_caller(p: PropertyCaller):
return ["user"]
def _helpers():
def _helpers() -> str:
"""
Generate two functions:
- catchRevertThrowReturnFalse: check if the call revert/throw or return false
@ -75,7 +75,7 @@ def generate_unit_test( # pylint: disable=too-many-arguments,too-many-branches
output_dir: Path,
addresses: Addresses,
assert_message: str = "",
):
) -> Path:
"""
Generate unit tests files
:param test_contract:
@ -134,7 +134,7 @@ def generate_unit_test( # pylint: disable=too-many-arguments,too-many-branches
return output_dir
def generate_migration(test_contract: str, output_dir: Path, owner_address: str):
def generate_migration(test_contract: str, output_dir: Path, owner_address: str) -> None:
"""
Generate migration file
:param test_contract:

@ -12,7 +12,7 @@ def write_file(
content: str,
allow_overwrite: bool = True,
discard_if_exist: bool = False,
):
) -> None:
"""
Write the content into output_dir/filename
:param output_dir:

@ -17,7 +17,7 @@ logger = logging.getLogger("Slither-simil")
modes = ["info", "test", "train", "plot"]
def parse_args():
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector"
)
@ -78,7 +78,7 @@ def parse_args():
###################################################################################
def main():
def main() -> None:
args = parse_args()
default_log = logging.INFO

@ -74,7 +74,7 @@ def parse_target(target):
return None
def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs):
def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs):
r = {}
if infile.endswith(".npz"):
r = load_cache(infile, nsamples=nsamples)

@ -1,3 +1,4 @@
import argparse
import logging
import sys
import os.path
@ -10,7 +11,7 @@ logging.basicConfig()
logger = logging.getLogger("Slither-simil")
def info(args):
def info(args: argparse.Namespace) -> None:
try:

@ -1,3 +1,4 @@
import argparse
import logging
import random
import sys
@ -23,7 +24,7 @@ except ImportError:
logger = logging.getLogger("Slither-simil")
def plot(args): # pylint: disable=too-many-locals
def plot(args: argparse.Namespace) -> None: # pylint: disable=too-many-locals
if decomposition is None or plt is None:
logger.error(

@ -1,3 +1,4 @@
import argparse
import logging
import os
import sys
@ -10,7 +11,7 @@ from slither.tools.similarity.model import train_unsupervised
logger = logging.getLogger("Slither-simil")
def train(args): # pylint: disable=too-many-locals
def train(args: argparse.Namespace) -> None: # pylint: disable=too-many-locals
try:
last_data_train_filename = "last_data_train.txt"

@ -23,7 +23,7 @@ available_detectors = [
]
def parse_args():
def parse_args() -> argparse.Namespace:
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
@ -90,7 +90,7 @@ def parse_args():
return parser.parse_args()
def main():
def main() -> None:
# ------------------------------
# Usage: python3 -m slither_format filename
# Example: python3 -m slither_format contract.sol

@ -1,5 +1,9 @@
import logging
from pathlib import Path
from typing import Type, List, Dict
from slither import Slither
from slither.detectors.abstract_detector import AbstractDetector
from slither.detectors.variables.unused_state_variables import UnusedStateVars
from slither.detectors.attributes.incorrect_solc import IncorrectSolc
from slither.detectors.attributes.constant_pragma import ConstantPragma
@ -13,7 +17,7 @@ from slither.utils.colors import yellow
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("Slither.Format")
all_detectors = {
all_detectors: Dict[str, Type[AbstractDetector]] = {
"unused-state": UnusedStateVars,
"solc-version": IncorrectSolc,
"pragma": ConstantPragma,
@ -25,7 +29,7 @@ all_detectors = {
}
def slither_format(slither, **kwargs): # pylint: disable=too-many-locals
def slither_format(slither: Slither, **kwargs: Dict) -> None: # pylint: disable=too-many-locals
"""'
Keyword Args:
detectors_to_run (str): Comma-separated list of detectors, defaults to all
@ -85,9 +89,11 @@ def slither_format(slither, **kwargs): # pylint: disable=too-many-locals
###################################################################################
def choose_detectors(detectors_to_run, detectors_to_exclude):
def choose_detectors(
detectors_to_run: str, detectors_to_exclude: str
) -> List[Type[AbstractDetector]]:
# If detectors are specified, run only these ones
cls_detectors_to_run = []
cls_detectors_to_run: List[Type[AbstractDetector]] = []
exclude = detectors_to_exclude.split(",")
if detectors_to_run == "all":
for key, detector in all_detectors.items():
@ -114,7 +120,7 @@ def choose_detectors(detectors_to_run, detectors_to_exclude):
###################################################################################
def print_patches(number_of_slither_results, patches):
def print_patches(number_of_slither_results: int, patches: Dict) -> None:
logger.info("Number of Slither results: " + str(number_of_slither_results))
number_of_patches = 0
for file in patches:
@ -130,7 +136,7 @@ def print_patches(number_of_slither_results, patches):
logger.info("Location end: " + str(patch["end"]))
def print_patches_json(number_of_slither_results, patches):
def print_patches_json(number_of_slither_results: int, patches: Dict) -> None:
print("{", end="")
print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",')
print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",")

@ -3,10 +3,13 @@ import inspect
import json
import logging
import sys
from typing import List, Any, Type, Dict, Tuple, Union, Sequence, Optional
from crytic_compile import cryticparser
from slither import Slither
from slither.core.declarations import Contract
from slither.exceptions import SlitherException
from slither.utils.colors import red
from slither.utils.output import output_to_json
@ -24,7 +27,7 @@ logger: logging.Logger = logging.getLogger("Slither")
logger.setLevel(logging.INFO)
def parse_args():
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.",
usage="slither-check-upgradeability contract.sol ContractName",
@ -93,21 +96,27 @@ def parse_args():
###################################################################################
def _get_checks():
detectors = [getattr(all_checks, name) for name in dir(all_checks)]
detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)]
def _get_checks() -> List[Type[AbstractCheck]]:
detectors_ = [getattr(all_checks, name) for name in dir(all_checks)]
detectors: List[Type[AbstractCheck]] = [
c for c in detectors_ if inspect.isclass(c) and issubclass(c, AbstractCheck)
]
return detectors
class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
checks = _get_checks()
output_detectors(checks)
parser.exit()
class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs
def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
checks = _get_checks()
detector_types_json = output_detectors_json(checks)
print(json.dumps(detector_types_json))
@ -116,48 +125,64 @@ class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-meth
class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(
self, parser, args, values, option_string=None
): # pylint: disable=signature-differs
self,
parser: Any,
args: Any,
values: Optional[Union[str, Sequence[Any]]],
option_string: Any = None,
) -> None: # pylint: disable=signature-differs
checks = _get_checks()
assert isinstance(values, str)
output_to_markdown(checks, values)
parser.exit()
class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(
self, parser, args, values, option_string=None
): # pylint: disable=signature-differs
self,
parser: Any,
args: Any,
values: Optional[Union[str, Sequence[Any]]],
option_string: Any = None,
) -> Any: # pylint: disable=signature-differs
checks = _get_checks()
assert isinstance(values, str)
output_wiki(checks, values)
parser.exit()
def _run_checks(detectors):
results = [d.check() for d in detectors]
results = [r for r in results if r]
results = [item for sublist in results for item in sublist] # flatten
def _run_checks(detectors: List[AbstractCheck]) -> List[Dict]:
results_ = [d.check() for d in detectors]
results_ = [r for r in results_ if r]
results = [item for sublist in results_ for item in sublist] # flatten
return results
def _checks_on_contract(detectors, contract):
detectors = [
def _checks_on_contract(
detectors: List[Type[AbstractCheck]], contract: Contract
) -> Tuple[List[Dict], int]:
detectors_ = [
d(logger, contract)
for d in detectors
if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2)
]
return _run_checks(detectors), len(detectors)
return _run_checks(detectors_), len(detectors)
def _checks_on_contract_update(detectors, contract_v1, contract_v2):
detectors = [
def _checks_on_contract_update(
detectors: List[Type[AbstractCheck]], contract_v1: Contract, contract_v2: Contract
) -> Tuple[List[Dict], int]:
detectors_ = [
d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2
]
return _run_checks(detectors), len(detectors)
return _run_checks(detectors_), len(detectors)
def _checks_on_contract_and_proxy(detectors, contract, proxy):
detectors = [d(logger, contract, proxy=proxy) for d in detectors if d.REQUIRE_PROXY]
return _run_checks(detectors), len(detectors)
def _checks_on_contract_and_proxy(
detectors: List[Type[AbstractCheck]], contract: Contract, proxy: Contract
) -> Tuple[List[Dict], int]:
detectors_ = [d(logger, contract, proxy=proxy) for d in detectors if d.REQUIRE_PROXY]
return _run_checks(detectors_), len(detectors)
# endregion
@ -168,8 +193,8 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy):
###################################################################################
# pylint: disable=too-many-statements,too-many-branches,too-many-locals
def main():
json_results = {
def main() -> None:
json_results: Dict = {
"proxy-present": False,
"contract_v2-present": False,
"detectors": [],

@ -1,8 +1,10 @@
from slither.tools.upgradeability.checks.abstract_checks import classification_txt
from typing import List, Union, Dict, Type
from slither.tools.upgradeability.checks.abstract_checks import classification_txt, AbstractCheck
from slither.utils.myprettytable import MyPrettyTable
def output_wiki(detector_classes, filter_wiki):
def output_wiki(detector_classes: List[Type[AbstractCheck]], filter_wiki: str) -> None:
# Sort by impact, confidence, and name
detectors_list = sorted(
detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT)
@ -31,7 +33,7 @@ def output_wiki(detector_classes, filter_wiki):
print(recommendation)
def output_detectors(detector_classes):
def output_detectors(detector_classes: List[Type[AbstractCheck]]) -> None:
detectors_list = []
for detector in detector_classes:
argument = detector.ARGUMENT
@ -48,7 +50,7 @@ def output_detectors(detector_classes):
for (argument, help_info, impact, proxy, v2) in detectors_list:
table.add_row(
[
idx,
str(idx),
argument,
help_info,
classification_txt[impact],
@ -60,8 +62,8 @@ def output_detectors(detector_classes):
print(table)
def output_to_markdown(detector_classes, _filter_wiki):
def extract_help(cls):
def output_to_markdown(detector_classes: List[Type[AbstractCheck]], _filter_wiki: str) -> None:
def extract_help(cls: AbstractCheck) -> str:
if cls.WIKI == "":
return cls.HELP
return f"[{cls.HELP}]({cls.WIKI})"
@ -85,7 +87,9 @@ def output_to_markdown(detector_classes, _filter_wiki):
idx = idx + 1
def output_detectors_json(detector_classes):
def output_detectors_json(
detector_classes: List[Type[AbstractCheck]],
) -> List[Dict[str, Union[str, int]]]:
detectors_list = []
for detector in detector_classes:
argument = detector.ARGUMENT
@ -110,7 +114,7 @@ def output_detectors_json(detector_classes):
# Sort by impact, confidence, and name
detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0]))
idx = 1
table = []
table: List[Dict[str, Union[str, int]]] = []
for (
argument,
help_info,

@ -1,13 +1,17 @@
import argparse
import json
import os
import re
import logging
from collections import defaultdict
from typing import Dict, List, Type, Union
from crytic_compile.cryticparser.defaults import (
DEFAULTS_FLAG_IN_CONFIG as DEFAULTS_FLAG_IN_CONFIG_CRYTIC_COMPILE,
)
from slither.detectors.abstract_detector import classification_txt
from slither.detectors.abstract_detector import classification_txt, AbstractDetector
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.colors import yellow, red
from slither.utils.myprettytable import MyPrettyTable
@ -54,7 +58,7 @@ defaults_flag_in_config = {
}
def read_config_file(args):
def read_config_file(args: argparse.Namespace) -> None:
# No config file was provided as an argument
if args.config_file is None:
# Check wether the default config file is present
@ -83,8 +87,12 @@ def read_config_file(args):
logger.error(yellow("Falling back to the default settings..."))
def output_to_markdown(detector_classes, printer_classes, filter_wiki):
def extract_help(cls):
def output_to_markdown(
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
filter_wiki: str,
) -> None:
def extract_help(cls: Union[Type[AbstractDetector], Type[AbstractPrinter]]) -> str:
if cls.WIKI == "":
return cls.HELP
return f"[{cls.HELP}]({cls.WIKI})"
@ -127,7 +135,7 @@ def output_to_markdown(detector_classes, printer_classes, filter_wiki):
idx = idx + 1
def get_level(l):
def get_level(l: str) -> int:
tab = l.count("\t") + 1
if l.replace("\t", "").startswith(" -"):
tab = tab + 1
@ -136,7 +144,7 @@ def get_level(l):
return tab
def convert_result_to_markdown(txt):
def convert_result_to_markdown(txt: str) -> str:
# -1 to remove the last \n
lines = txt[0:-1].split("\n")
ret = []
@ -154,16 +162,21 @@ def convert_result_to_markdown(txt):
return "".join(ret)
def output_results_to_markdown(all_results, checklistlimit: str):
def output_results_to_markdown(all_results: List[Dict], checklistlimit: str) -> None:
checks = defaultdict(list)
info = defaultdict(dict)
for results in all_results:
checks[results["check"]].append(results)
info[results["check"]] = {"impact": results["impact"], "confidence": results["confidence"]}
info: Dict = defaultdict(dict)
for results_ in all_results:
checks[results_["check"]].append(results_)
info[results_["check"]] = {
"impact": results_["impact"],
"confidence": results_["confidence"],
}
print("Summary")
for check in checks:
print(f" - [{check}](#{check}) ({len(checks[check])} results) ({info[check]['impact']})")
for check_ in checks:
print(
f" - [{check_}](#{check_}) ({len(checks[check_])} results) ({info[check_]['impact']})"
)
counter = 0
for (check, results) in checks.items():
@ -185,8 +198,7 @@ def output_results_to_markdown(all_results, checklistlimit: str):
print(f"**More results were found, check [{checklistlimit}]({checklistlimit})**")
def output_wiki(detector_classes, filter_wiki):
detectors_list = []
def output_wiki(detector_classes: List[Type[AbstractDetector]], filter_wiki: str) -> None:
# Sort by impact, confidence, and name
detectors_list = sorted(
@ -223,7 +235,7 @@ def output_wiki(detector_classes, filter_wiki):
print(recommendation)
def output_detectors(detector_classes):
def output_detectors(detector_classes: List[Type[AbstractDetector]]) -> None:
detectors_list = []
for detector in detector_classes:
argument = detector.ARGUMENT
@ -242,12 +254,15 @@ def output_detectors(detector_classes):
)
idx = 1
for (argument, help_info, impact, confidence) in detectors_list:
table.add_row([idx, argument, help_info, classification_txt[impact], confidence])
table.add_row([str(idx), argument, help_info, classification_txt[impact], confidence])
idx = idx + 1
print(table)
def output_detectors_json(detector_classes): # pylint: disable=too-many-locals
# pylint: disable=too-many-locals
def output_detectors_json(
detector_classes: List[Type[AbstractDetector]],
) -> List[Dict]:
detectors_list = []
for detector in detector_classes:
argument = detector.ARGUMENT
@ -307,7 +322,7 @@ def output_detectors_json(detector_classes): # pylint: disable=too-many-locals
return table
def output_printers(printer_classes):
def output_printers(printer_classes: List[Type[AbstractPrinter]]) -> None:
printers_list = []
for printer in printer_classes:
argument = printer.ARGUMENT
@ -319,12 +334,12 @@ def output_printers(printer_classes):
printers_list = sorted(printers_list, key=lambda element: (element[0]))
idx = 1
for (argument, help_info) in printers_list:
table.add_row([idx, argument, help_info])
table.add_row([str(idx), argument, help_info])
idx = idx + 1
print(table)
def output_printers_json(printer_classes):
def output_printers_json(printer_classes: List[Type[AbstractPrinter]]) -> List[Dict]:
printers_list = []
for printer in printer_classes:
argument = printer.ARGUMENT

@ -4,7 +4,7 @@ import json
import logging
import zipfile
from collections import OrderedDict
from typing import Optional, Dict, List, Union, Any, TYPE_CHECKING
from typing import Optional, Dict, List, Union, Any, TYPE_CHECKING, Type
from zipfile import ZipFile
from pkg_resources import require
@ -129,7 +129,7 @@ def _output_result_to_sarif(
def output_to_sarif(
filename: Optional[str], results: Dict, detectors_classes: List["AbstractDetector"]
filename: Optional[str], results: Dict, detectors_classes: List[Type["AbstractDetector"]]
) -> None:
"""

@ -28,7 +28,7 @@ class StandardOutputCapture:
original_logger_handlers = None
@staticmethod
def enable(block_original=True):
def enable(block_original: bool = True) -> None:
"""
Redirects stdout and stderr to a capturable StringIO.
:param block_original: If True, blocks all output to the original stream. If False, duplicates output.
@ -54,7 +54,7 @@ class StandardOutputCapture:
root_logger.handlers = [logging.StreamHandler(sys.stderr)]
@staticmethod
def disable():
def disable() -> None:
"""
Disables redirection of stdout/stderr, if previously enabled.
:return: None
@ -78,7 +78,7 @@ class StandardOutputCapture:
StandardOutputCapture.original_logger_handlers = None
@staticmethod
def get_stdout_output():
def get_stdout_output() -> str:
"""
Obtains the output from the currently set stdout
:return: Returns stdout output as a string
@ -87,7 +87,7 @@ class StandardOutputCapture:
return sys.stdout.read()
@staticmethod
def get_stderr_output():
def get_stderr_output() -> str:
"""
Obtains the output from the currently set stderr
:return: Returns stderr output as a string

Loading…
Cancel
Save