Add more type hints

pull/906/head
Josselin 3 years ago
parent dae40de4bb
commit a1f6efe388
  1. 30
      slither/analyses/data_dependency/data_dependency.py
  2. 4
      slither/core/context/context.py
  3. 2
      slither/core/declarations/function.py
  4. 3
      slither/core/declarations/function_contract.py
  5. 3
      slither/core/declarations/function_top_level.py
  6. 10
      slither/core/expressions/assignment_operation.py
  7. 2
      slither/core/expressions/expression.py
  8. 2
      slither/core/source_mapping/source_mapping.py
  9. 64
      slither/detectors/abstract_detector.py
  10. 16
      slither/detectors/functions/arbitrary_send.py
  11. 18
      slither/detectors/operations/bad_prng.py
  12. 19
      slither/detectors/shadowing/abstract.py
  13. 24
      slither/formatters/utils/patches.py
  14. 7
      slither/tools/kspec_coverage/analysis.py
  15. 8
      slither/tools/mutator/mutators/abstract_mutator.py
  16. 8
      slither/tools/possible_paths/__main__.py
  17. 2
      slither/tools/upgradeability/__main__.py
  18. 37
      slither/tools/upgradeability/checks/abstract_checks.py
  19. 4
      slither/utils/myprettytable.py
  20. 1
      slither/utils/output.py

@ -37,7 +37,12 @@ if TYPE_CHECKING:
################################################################################### ###################################################################################
def is_dependent(variable, source, context, only_unprotected=False): def is_dependent(
variable: Variable,
source: Variable,
context: Union[Contract, Function],
only_unprotected: bool = False,
) -> bool:
""" """
Args: Args:
variable (Variable) variable (Variable)
@ -52,17 +57,22 @@ def is_dependent(variable, source, context, only_unprotected=False):
return False return False
if variable == source: if variable == source:
return True return True
context = context.context context_dict = context.context
if only_unprotected: if only_unprotected:
return ( return (
variable in context[KEY_NON_SSA_UNPROTECTED] variable in context_dict[KEY_NON_SSA_UNPROTECTED]
and source in context[KEY_NON_SSA_UNPROTECTED][variable] and source in context_dict[KEY_NON_SSA_UNPROTECTED][variable]
) )
return variable in context[KEY_NON_SSA] and source in context[KEY_NON_SSA][variable] return variable in context_dict[KEY_NON_SSA] and source in context_dict[KEY_NON_SSA][variable]
def is_dependent_ssa(variable, source, context, only_unprotected=False): def is_dependent_ssa(
variable: Variable,
source: Variable,
context: Union[Contract, Function],
only_unprotected: bool = False,
) -> bool:
""" """
Args: Args:
variable (Variable) variable (Variable)
@ -73,17 +83,17 @@ def is_dependent_ssa(variable, source, context, only_unprotected=False):
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function))
context = context.context context_dict = context.context
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
if variable == source: if variable == source:
return True return True
if only_unprotected: if only_unprotected:
return ( return (
variable in context[KEY_SSA_UNPROTECTED] variable in context_dict[KEY_SSA_UNPROTECTED]
and source in context[KEY_SSA_UNPROTECTED][variable] and source in context_dict[KEY_SSA_UNPROTECTED][variable]
) )
return variable in context[KEY_SSA] and source in context[KEY_SSA][variable] return variable in context_dict[KEY_SSA] and source in context_dict[KEY_SSA][variable]
GENERIC_TAINT = { GENERIC_TAINT = {

@ -3,9 +3,9 @@ from typing import Dict
class Context: # pylint: disable=too-few-public-methods class Context: # pylint: disable=too-few-public-methods
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
self._context = {"MEMBERS": defaultdict(None)} self._context: Dict = {"MEMBERS": defaultdict(None)}
@property @property
def context(self) -> Dict: def context(self) -> Dict:

@ -104,7 +104,7 @@ def _filter_state_variables_written(expressions: List["Expression"]):
return ret return ret
class Function(metaclass=ABCMeta): # pylint: disable=too-many-public-methods class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-public-methods
""" """
Function class Function class
""" """

@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Tuple
from slither.core.children.child_contract import ChildContract from slither.core.children.child_contract import ChildContract
from slither.core.children.child_inheritance import ChildInheritance from slither.core.children.child_inheritance import ChildInheritance
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.declarations import Function from slither.core.declarations import Function
# pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines # pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines
@ -14,7 +13,7 @@ if TYPE_CHECKING:
from slither.core.declarations import Contract from slither.core.declarations import Contract
class FunctionContract(Function, ChildContract, ChildInheritance, SourceMapping): class FunctionContract(Function, ChildContract, ChildInheritance):
@property @property
def canonical_name(self) -> str: def canonical_name(self) -> str:
""" """

@ -5,10 +5,9 @@ from typing import List, Tuple
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
from slither.core.source_mapping.source_mapping import SourceMapping
class FunctionTopLevel(Function, TopLevel, SourceMapping): class FunctionTopLevel(Function, TopLevel):
@property @property
def canonical_name(self) -> str: def canonical_name(self) -> str:
""" """

@ -26,7 +26,7 @@ class AssignmentOperationType(Enum):
ASSIGN_MODULO = 10 # %= ASSIGN_MODULO = 10 # %=
@staticmethod @staticmethod
def get_type(operation_type: "AssignmentOperationType"): def get_type(operation_type: str) -> "AssignmentOperationType":
if operation_type == "=": if operation_type == "=":
return AssignmentOperationType.ASSIGN return AssignmentOperationType.ASSIGN
if operation_type == "|=": if operation_type == "|=":
@ -52,7 +52,7 @@ class AssignmentOperationType(Enum):
raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type))
def __str__(self): def __str__(self) -> str:
if self == AssignmentOperationType.ASSIGN: if self == AssignmentOperationType.ASSIGN:
return "=" return "="
if self == AssignmentOperationType.ASSIGN_OR: if self == AssignmentOperationType.ASSIGN_OR:
@ -91,7 +91,7 @@ class AssignmentOperation(ExpressionTyped):
super().__init__() super().__init__()
left_expression.set_lvalue() left_expression.set_lvalue()
self._expressions = [left_expression, right_expression] self._expressions = [left_expression, right_expression]
self._type: Optional["Type"] = expression_type self._type: Optional["AssignmentOperationType"] = expression_type
self._expression_return_type: Optional["Type"] = expression_return_type self._expression_return_type: Optional["Type"] = expression_return_type
@property @property
@ -111,8 +111,8 @@ class AssignmentOperation(ExpressionTyped):
return self._expressions[1] return self._expressions[1]
@property @property
def type(self) -> Optional["Type"]: def type(self) -> Optional["AssignmentOperationType"]:
return self._type return self._type
def __str__(self): def __str__(self) -> str:
return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right) return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right)

@ -10,5 +10,5 @@ class Expression(SourceMapping):
def is_lvalue(self) -> bool: def is_lvalue(self) -> bool:
return self._is_lvalue return self._is_lvalue
def set_lvalue(self): def set_lvalue(self) -> None:
self._is_lvalue = True self._is_lvalue = True

@ -8,7 +8,7 @@ if TYPE_CHECKING:
class SourceMapping(Context): class SourceMapping(Context):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
# TODO create a namedtuple for the source mapping rather than a dict # TODO create a namedtuple for the source mapping rather than a dict
self._source_mapping: Optional[Dict] = None self._source_mapping: Optional[Dict] = None

@ -1,6 +1,8 @@
import abc import abc
import re import re
from typing import Optional, List, TYPE_CHECKING from functools import partial
from logging import Logger
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Callable
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import Contract from slither.core.declarations import Contract
@ -8,7 +10,7 @@ from slither.utils.colors import green, yellow, red
from slither.formatters.exceptions import FormatImpossible from slither.formatters.exceptions import FormatImpossible
from slither.formatters.utils.patches import apply_patch, create_diff from slither.formatters.utils.patches import apply_patch, create_diff
from slither.utils.comparable_enum import ComparableEnum from slither.utils.comparable_enum import ComparableEnum
from slither.utils.output import Output from slither.utils.output import Output, SupportedOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from slither import Slither from slither import Slither
@ -25,8 +27,10 @@ class DetectorClassification(ComparableEnum):
INFORMATIONAL = 3 INFORMATIONAL = 3
OPTIMIZATION = 4 OPTIMIZATION = 4
UNIMPLEMENTED = 999
classification_colors = {
classification_colors: Dict[DetectorClassification, Callable[[str], str]] = {
DetectorClassification.INFORMATIONAL: green, DetectorClassification.INFORMATIONAL: green,
DetectorClassification.OPTIMIZATION: green, DetectorClassification.OPTIMIZATION: green,
DetectorClassification.LOW: green, DetectorClassification.LOW: green,
@ -46,8 +50,8 @@ classification_txt = {
class AbstractDetector(metaclass=abc.ABCMeta): class AbstractDetector(metaclass=abc.ABCMeta):
ARGUMENT = "" # run the detector with slither.py --ARGUMENT ARGUMENT = "" # run the detector with slither.py --ARGUMENT
HELP = "" # help information HELP = "" # help information
IMPACT: Optional[DetectorClassification] = None IMPACT: DetectorClassification = DetectorClassification.UNIMPLEMENTED
CONFIDENCE: Optional[DetectorClassification] = None CONFIDENCE: DetectorClassification = DetectorClassification.UNIMPLEMENTED
WIKI = "" WIKI = ""
@ -58,7 +62,9 @@ class AbstractDetector(metaclass=abc.ABCMeta):
STANDARD_JSON = True STANDARD_JSON = True
def __init__(self, compilation_unit: SlitherCompilationUnit, slither, logger): def __init__(
self, compilation_unit: SlitherCompilationUnit, slither: "Slither", logger: Logger
):
self.compilation_unit: SlitherCompilationUnit = compilation_unit self.compilation_unit: SlitherCompilationUnit = compilation_unit
self.contracts: List[Contract] = compilation_unit.contracts self.contracts: List[Contract] = compilation_unit.contracts
self.slither: "Slither" = slither self.slither: "Slither" = slither
@ -130,32 +136,25 @@ class AbstractDetector(metaclass=abc.ABCMeta):
"CONFIDENCE is not initialized {}".format(self.__class__.__name__) "CONFIDENCE is not initialized {}".format(self.__class__.__name__)
) )
def _log(self, info): def _log(self, info: str) -> None:
if self.logger: if self.logger:
self.logger.info(self.color(info)) self.logger.info(self.color(info))
@abc.abstractmethod @abc.abstractmethod
def _detect(self): def _detect(self) -> List[Output]:
"""TODO Documentation""" """TODO Documentation"""
return [] return []
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
def detect(self): def detect(self) -> List[Dict]:
results = [] results: List[Dict] = []
# only keep valid result, and remove dupplicate # only keep valid result, and remove dupplicate
# Keep only dictionaries # Keep only dictionaries
for r in [r.data for r in self._detect()]: for r in [output.data for output in self._detect()]:
if self.compilation_unit.core.valid_result(r) and r not in results: if self.compilation_unit.core.valid_result(r) and r not in results:
results.append(r) results.append(r)
if results: if results and self.logger:
if self.logger: self._log_result(results)
info = "\n"
for idx, result in enumerate(results):
if self.slither.triage_mode:
info += "{}: ".format(idx)
info += result["description"]
info += "Reference: {}".format(self.WIKI)
self._log(info)
if self.compilation_unit.core.generate_patches: if self.compilation_unit.core.generate_patches:
for result in results: for result in results:
try: try:
@ -205,20 +204,24 @@ class AbstractDetector(metaclass=abc.ABCMeta):
if indexes.endswith("]"): if indexes.endswith("]"):
indexes = indexes[:-1] indexes = indexes[:-1]
try: try:
indexes = [int(i) for i in indexes.split(",")] indexes_converted = [int(i) for i in indexes.split(",")]
self.slither.save_results_to_hide( self.slither.save_results_to_hide(
[r for (idx, r) in enumerate(results) if idx in indexes] [r for (idx, r) in enumerate(results) if idx in indexes_converted]
) )
return [r for (idx, r) in enumerate(results) if idx not in indexes] return [r for (idx, r) in enumerate(results) if idx not in indexes_converted]
except ValueError: except ValueError:
self.logger.error(yellow("Malformed input. Example of valid input: 0,1,2,3")) self.logger.error(yellow("Malformed input. Example of valid input: 0,1,2,3"))
return results return results
@property @property
def color(self): def color(self) -> Callable[[str], str]:
return classification_colors[self.IMPACT] return classification_colors[self.IMPACT]
def generate_result(self, info, additional_fields=None): def generate_result(
self,
info: Union[str, List[Union[str, SupportedOutput]]],
additional_fields: Optional[Dict] = None,
) -> Output:
output = Output( output = Output(
info, info,
additional_fields, additional_fields,
@ -233,6 +236,15 @@ class AbstractDetector(metaclass=abc.ABCMeta):
return output return output
@staticmethod @staticmethod
def _format(_compilation_unit: SlitherCompilationUnit, _result): def _format(_compilation_unit: SlitherCompilationUnit, _result: Dict) -> None:
"""Implement format""" """Implement format"""
return return
def _log_result(self, results: List[Dict]) -> None:
info = "\n"
for idx, result in enumerate(results):
if self.slither.triage_mode:
info += "{}: ".format(idx)
info += result["description"]
info += "Reference: {}".format(self.WIKI)
self._log(info)

@ -9,7 +9,10 @@
TODO: dont report if the value is tainted by msg.value TODO: dont report if the value is tainted by msg.value
""" """
from slither.core.declarations import Function from typing import List
from slither.core.cfg.node import Node
from slither.core.declarations import Function, Contract
from slither.analyses.data_dependency.data_dependency import is_tainted, is_dependent from slither.analyses.data_dependency.data_dependency import is_tainted, is_dependent
from slither.core.declarations.solidity_variables import ( from slither.core.declarations.solidity_variables import (
SolidityFunction, SolidityFunction,
@ -27,11 +30,14 @@ from slither.slithir.operations import (
# pylint: disable=too-many-nested-blocks,too-many-branches # pylint: disable=too-many-nested-blocks,too-many-branches
def arbitrary_send(func): from slither.utils.output import Output
def arbitrary_send(func: Function):
if func.is_protected(): if func.is_protected():
return [] return []
ret = [] ret: List[Node] = []
for node in func.nodes: for node in func.nodes:
for ir in node.irs: for ir in node.irs:
if isinstance(ir, SolidityCall): if isinstance(ir, SolidityCall):
@ -68,7 +74,7 @@ def arbitrary_send(func):
return ret return ret
def detect_arbitrary_send(contract): def detect_arbitrary_send(contract: Contract):
""" """
Detect arbitrary send Detect arbitrary send
Args: Args:
@ -114,7 +120,7 @@ Bob calls `setDestination` and `withdraw`. As a result he withdraws the contract
WIKI_RECOMMENDATION = "Ensure that an arbitrary user cannot withdraw unauthorized funds." WIKI_RECOMMENDATION = "Ensure that an arbitrary user cannot withdraw unauthorized funds."
def _detect(self): def _detect(self) -> List[Output]:
"""""" """"""
results = [] results = []

@ -2,18 +2,24 @@
Module detecting bad PRNG due to the use of block.timestamp, now or blockhash (block.blockhash) as a source of randomness Module detecting bad PRNG due to the use of block.timestamp, now or blockhash (block.blockhash) as a source of randomness
""" """
from typing import List, Tuple
from slither.analyses.data_dependency.data_dependency import is_dependent_ssa from slither.analyses.data_dependency.data_dependency import is_dependent_ssa
from slither.core.cfg.node import Node
from slither.core.declarations import Function, Contract
from slither.core.declarations.solidity_variables import ( from slither.core.declarations.solidity_variables import (
SolidityVariable, SolidityVariable,
SolidityFunction, SolidityFunction,
SolidityVariableComposed, SolidityVariableComposed,
) )
from slither.core.variables.variable import Variable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import BinaryType, Binary from slither.slithir.operations import BinaryType, Binary
from slither.slithir.operations import SolidityCall from slither.slithir.operations import SolidityCall
from slither.utils.output import Output, AllSupportedOutput
def collect_return_values_of_bad_PRNG_functions(f): def collect_return_values_of_bad_PRNG_functions(f: Function) -> List:
""" """
Return the return-values of calls to blockhash() Return the return-values of calls to blockhash()
Args: Args:
@ -33,7 +39,7 @@ def collect_return_values_of_bad_PRNG_functions(f):
return values_returned return values_returned
def contains_bad_PRNG_sources(func, blockhash_ret_values): def contains_bad_PRNG_sources(func: Function, blockhash_ret_values: List[Variable]) -> List[Node]:
""" """
Check if any node in function has a modulus operator and the first operand is dependent on block.timestamp, now or blockhash() Check if any node in function has a modulus operator and the first operand is dependent on block.timestamp, now or blockhash()
Returns: Returns:
@ -57,7 +63,7 @@ def contains_bad_PRNG_sources(func, blockhash_ret_values):
return list(ret) return list(ret)
def detect_bad_PRNG(contract): def detect_bad_PRNG(contract: Contract) -> List[Tuple[Function, List[Node]]]:
""" """
Args: Args:
contract (Contract) contract (Contract)
@ -67,7 +73,7 @@ def detect_bad_PRNG(contract):
blockhash_ret_values = [] blockhash_ret_values = []
for f in contract.functions: for f in contract.functions:
blockhash_ret_values += collect_return_values_of_bad_PRNG_functions(f) blockhash_ret_values += collect_return_values_of_bad_PRNG_functions(f)
ret = [] ret: List[Tuple[Function, List[Node]]] = []
for f in contract.functions: for f in contract.functions:
bad_prng_nodes = contains_bad_PRNG_sources(f, blockhash_ret_values) bad_prng_nodes = contains_bad_PRNG_sources(f, blockhash_ret_values)
if bad_prng_nodes: if bad_prng_nodes:
@ -110,7 +116,7 @@ As a result, Eve wins the game."""
"Do not use `block.timestamp`, `now` or `blockhash` as a source of randomness" "Do not use `block.timestamp`, `now` or `blockhash` as a source of randomness"
) )
def _detect(self): def _detect(self) -> List[Output]:
"""Detect bad PRNG due to the use of block.timestamp, now or blockhash (block.blockhash) as a source of randomness""" """Detect bad PRNG due to the use of block.timestamp, now or blockhash (block.blockhash) as a source of randomness"""
results = [] results = []
for c in self.compilation_unit.contracts_derived: for c in self.compilation_unit.contracts_derived:
@ -118,7 +124,7 @@ As a result, Eve wins the game."""
for func, nodes in values: for func, nodes in values:
for node in nodes: for node in nodes:
info = [func, ' uses a weak PRNG: "', node, '" \n'] info: List[AllSupportedOutput] = [func, ' uses a weak PRNG: "', node, '" \n']
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -2,19 +2,24 @@
Module detecting shadowing variables on abstract contract Module detecting shadowing variables on abstract contract
Recursively check the called functions Recursively check the called functions
""" """
from typing import List
from slither.core.declarations import Contract
from slither.core.variables.state_variable import StateVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.utils.output import Output, AllSupportedOutput
def detect_shadowing(contract): def detect_shadowing(contract: Contract) -> List[List[StateVariable]]:
ret = [] ret: List[List[StateVariable]] = []
variables_fathers = [] variables_fathers = []
for father in contract.inheritance: for father in contract.inheritance:
if all(not f.is_implemented for f in father.functions + father.modifiers): if all(not f.is_implemented for f in father.functions + list(father.modifiers)):
variables_fathers += father.state_variables_declared variables_fathers += father.state_variables_declared
var: StateVariable
for var in contract.state_variables_declared: for var in contract.state_variables_declared:
shadow = [v for v in variables_fathers if v.name == var.name] shadow: List[StateVariable] = [v for v in variables_fathers if v.name == var.name]
if shadow: if shadow:
ret.append([var] + shadow) ret.append([var] + shadow)
return ret return ret
@ -51,7 +56,7 @@ contract DerivedContract is BaseContract{
WIKI_RECOMMENDATION = "Remove the state variable shadowing." WIKI_RECOMMENDATION = "Remove the state variable shadowing."
def _detect(self): def _detect(self) -> List[Output]:
"""Detect shadowing """Detect shadowing
Recursively visit the calls Recursively visit the calls
@ -59,14 +64,14 @@ contract DerivedContract is BaseContract{
list: {'vuln', 'filename,'contract','func', 'shadow'} list: {'vuln', 'filename,'contract','func', 'shadow'}
""" """
results = [] results: List[Output] = []
for contract in self.contracts: for contract in self.contracts:
shadowing = detect_shadowing(contract) shadowing = detect_shadowing(contract)
if shadowing: if shadowing:
for all_variables in shadowing: for all_variables in shadowing:
shadow = all_variables[0] shadow = all_variables[0]
variables = all_variables[1:] variables = all_variables[1:]
info = [shadow, " shadows:\n"] info: List[AllSupportedOutput] = [shadow, " shadows:\n"]
for var in variables: for var in variables:
info += ["\t- ", var, "\n"] info += ["\t- ", var, "\n"]

@ -1,9 +1,19 @@
import os import os
import difflib import difflib
from typing import Dict, Tuple, Union
from collections import defaultdict from collections import defaultdict
from slither.core.compilation_unit import SlitherCompilationUnit
def create_patch(result, file, start, end, old_str, new_str): # pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments
def create_patch(
result: Dict,
file: str,
start: int,
end: int,
old_str: Union[str, bytes],
new_str: Union[str, bytes],
) -> None:
if isinstance(old_str, bytes): if isinstance(old_str, bytes):
old_str = old_str.decode("utf8") old_str = old_str.decode("utf8")
if isinstance(new_str, bytes): if isinstance(new_str, bytes):
@ -15,7 +25,7 @@ def create_patch(result, file, start, end, old_str, new_str): # pylint: disable
result["patches"][file].append(p) result["patches"][file].append(p)
def apply_patch(original_txt, patch, offset): def apply_patch(original_txt: bytes, patch: Dict, offset: int) -> Tuple[bytes, int]:
patched_txt = original_txt[: int(patch["start"] + offset)] patched_txt = original_txt[: int(patch["start"] + offset)]
patched_txt += patch["new_string"].encode("utf8") patched_txt += patch["new_string"].encode("utf8")
patched_txt += original_txt[int(patch["end"] + offset) :] patched_txt += original_txt[int(patch["end"] + offset) :]
@ -25,9 +35,11 @@ def apply_patch(original_txt, patch, offset):
return patched_txt, patch_length_diff + offset return patched_txt, patch_length_diff + offset
def create_diff(slither, original_txt, patched_txt, filename): def create_diff(
if slither.crytic_compile: compilation_unit: SlitherCompilationUnit, original_txt: bytes, patched_txt: bytes, filename: str
relative_path = slither.crytic_compile.filename_lookup(filename).relative ) -> str:
if compilation_unit.crytic_compile:
relative_path = compilation_unit.crytic_compile.filename_lookup(filename).relative
relative_path = os.path.join(".", relative_path) relative_path = os.path.join(".", relative_path)
else: else:
relative_path = filename relative_path = filename

@ -1,5 +1,6 @@
import re import re
import logging import logging
from typing import Set, Tuple
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
@ -13,13 +14,13 @@ logger = logging.getLogger("Slither.kspec")
# pylint: disable=anomalous-backslash-in-string # pylint: disable=anomalous-backslash-in-string
def _refactor_type(targeted_type): def _refactor_type(targeted_type: str) -> str:
return {"uint": "uint256", "int": "int256"}.get(targeted_type, targeted_type) return {"uint": "uint256", "int": "int256"}.get(targeted_type, targeted_type)
def _get_all_covered_kspec_functions(target): def _get_all_covered_kspec_functions(target: str) -> Set[Tuple[str, str]]:
# Create a set of our discovered functions which are covered # Create a set of our discovered functions which are covered
covered_functions = set() covered_functions: Set[Tuple[str, str]] = set()
BEHAVIOUR_PATTERN = re.compile("behaviour\s+(\S+)\s+of\s+(\S+)") BEHAVIOUR_PATTERN = re.compile("behaviour\s+(\S+)\s+of\s+(\S+)")
INTERFACE_PATTERN = re.compile("interface\s+([^\r\n]+)") INTERFACE_PATTERN = re.compile("interface\s+([^\r\n]+)")

@ -69,14 +69,14 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public-
"""TODO Documentation""" """TODO Documentation"""
return dict() return dict()
def mutate(self): def mutate(self) -> None:
patches = self._mutate() all_patches = self._mutate()
for file in patches["patches"]: for file in all_patches["patches"]:
original_txt = self.slither.source_code[file].encode("utf8") original_txt = self.slither.source_code[file].encode("utf8")
patched_txt = original_txt patched_txt = original_txt
offset = 0 offset = 0
patches = patches["patches"][file] patches = all_patches["patches"][file]
patches.sort(key=lambda x: x["start"]) patches.sort(key=lambda x: x["start"])
if not all(patches[i]["end"] <= patches[i + 1]["end"] for i in range(len(patches) - 1)): if not all(patches[i]["end"] <= patches[i + 1]["end"] for i in range(len(patches) - 1)):
logger.info(f"Impossible to generate patch; patches collisions: {patches}") logger.info(f"Impossible to generate patch; patches collisions: {patches}")

@ -2,6 +2,8 @@ import argparse
import sys import sys
import logging import logging
from argparse import ArgumentParser
from crytic_compile import cryticparser from crytic_compile import cryticparser
from slither import Slither from slither import Slither
from slither.utils.colors import red from slither.utils.colors import red
@ -15,12 +17,12 @@ logging.basicConfig()
logging.getLogger("Slither").setLevel(logging.INFO) logging.getLogger("Slither").setLevel(logging.INFO)
def parse_args(): def parse_args() -> argparse.Namespace:
""" """
Parse the underlying arguments for the program. Parse the underlying arguments for the program.
:return: Returns the arguments for the program. :return: Returns the arguments for the program.
""" """
parser = argparse.ArgumentParser( parser: ArgumentParser = argparse.ArgumentParser(
description="PossiblePaths", description="PossiblePaths",
usage="possible_paths.py filename [contract.function targets]", usage="possible_paths.py filename [contract.function targets]",
) )
@ -36,7 +38,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
# ------------------------------ # ------------------------------
# PossiblePaths.py # PossiblePaths.py
# Usage: python3 possible_paths.py filename targets # Usage: python3 possible_paths.py filename targets

@ -20,7 +20,7 @@ from slither.tools.upgradeability.utils.command_line import (
) )
logging.basicConfig() logging.basicConfig()
logger = logging.getLogger("Slither") logger: logging.Logger = logging.getLogger("Slither")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)

@ -1,9 +1,11 @@
import abc import abc
from typing import Optional from logging import Logger
from typing import Optional, List, Dict, Union, Callable
from slither.core.declarations import Contract
from slither.utils.colors import green, yellow, red from slither.utils.colors import green, yellow, red
from slither.utils.comparable_enum import ComparableEnum from slither.utils.comparable_enum import ComparableEnum
from slither.utils.output import Output from slither.utils.output import Output, SupportedOutput
class IncorrectCheckInitialization(Exception): class IncorrectCheckInitialization(Exception):
@ -15,9 +17,10 @@ class CheckClassification(ComparableEnum):
MEDIUM = 1 MEDIUM = 1
LOW = 2 LOW = 2
INFORMATIONAL = 3 INFORMATIONAL = 3
UNIMPLEMENTED = 999
classification_colors = { classification_colors: Dict[CheckClassification, Callable[[str], str]] = {
CheckClassification.INFORMATIONAL: green, CheckClassification.INFORMATIONAL: green,
CheckClassification.LOW: yellow, CheckClassification.LOW: yellow,
CheckClassification.MEDIUM: yellow, CheckClassification.MEDIUM: yellow,
@ -35,7 +38,7 @@ classification_txt = {
class AbstractCheck(metaclass=abc.ABCMeta): class AbstractCheck(metaclass=abc.ABCMeta):
ARGUMENT = "" ARGUMENT = ""
HELP = "" HELP = ""
IMPACT: Optional[CheckClassification] = None IMPACT: CheckClassification = CheckClassification.UNIMPLEMENTED
WIKI = "" WIKI = ""
@ -48,7 +51,13 @@ class AbstractCheck(metaclass=abc.ABCMeta):
REQUIRE_PROXY = False REQUIRE_PROXY = False
REQUIRE_CONTRACT_V2 = False REQUIRE_CONTRACT_V2 = False
def __init__(self, logger, contract, proxy=None, contract_v2=None): def __init__(
self,
logger: Logger,
contract: Contract,
proxy: Optional[Contract] = None,
contract_v2: Optional[Contract] = None,
) -> None:
self.logger = logger self.logger = logger
self.contract = contract self.contract = contract
self.proxy = proxy self.proxy = proxy
@ -120,14 +129,14 @@ class AbstractCheck(metaclass=abc.ABCMeta):
) )
@abc.abstractmethod @abc.abstractmethod
def _check(self): def _check(self) -> List[Output]:
"""TODO Documentation""" """TODO Documentation"""
return [] return []
def check(self): def check(self) -> List[Dict]:
all_results = self._check() all_outputs = self._check()
# Keep only dictionaries # Keep only dictionaries
all_results = [r.data for r in all_results] all_results = [r.data for r in all_outputs]
if all_results: if all_results:
if self.logger: if self.logger:
info = "\n" info = "\n"
@ -137,7 +146,11 @@ class AbstractCheck(metaclass=abc.ABCMeta):
self._log(info) self._log(info)
return all_results return all_results
def generate_result(self, info, additional_fields=None): def generate_result(
self,
info: Union[str, List[Union[str, SupportedOutput]]],
additional_fields: Optional[Dict] = None,
) -> Output:
output = Output( output = Output(
info, additional_fields, markdown_root=self.contract.compilation_unit.core.markdown_root info, additional_fields, markdown_root=self.contract.compilation_unit.core.markdown_root
) )
@ -146,10 +159,10 @@ class AbstractCheck(metaclass=abc.ABCMeta):
return output return output
def _log(self, info): def _log(self, info: str) -> None:
if self.logger: if self.logger:
self.logger.info(self.color(info)) self.logger.info(self.color(info))
@property @property
def color(self): def color(self) -> Callable[[str], str]:
return classification_colors[self.IMPACT] return classification_colors[self.IMPACT]

@ -8,7 +8,7 @@ class MyPrettyTable:
self._field_names = field_names self._field_names = field_names
self._rows: List = [] self._rows: List = []
def add_row(self, row): def add_row(self, row: List[str]) -> None:
self._rows.append(row) self._rows.append(row)
def to_pretty_table(self) -> PrettyTable: def to_pretty_table(self) -> PrettyTable:
@ -20,5 +20,5 @@ class MyPrettyTable:
def to_json(self) -> Dict: def to_json(self) -> Dict:
return {"fields_names": self._field_names, "rows": self._rows} return {"fields_names": self._field_names, "rows": self._rows}
def __str__(self): def __str__(self) -> str:
return str(self.to_pretty_table()) return str(self.to_pretty_table())

@ -215,6 +215,7 @@ def _create_parent_element(element):
SupportedOutput = Union[Variable, Contract, Function, Enum, Event, Structure, Pragma, Node] SupportedOutput = Union[Variable, Contract, Function, Enum, Event, Structure, Pragma, Node]
AllSupportedOutput = Union[str, SupportedOutput]
class Output: class Output:

Loading…
Cancel
Save