Add more types hints

pull/1666/head
Feist Josselin 2 years ago
parent 20a79519f3
commit 371e3cbe47
  1. 63
      slither/__main__.py
  2. 88
      slither/analyses/data_dependency/data_dependency.py
  3. 89
      slither/core/cfg/node.py
  4. 9
      slither/core/children/child_contract.py
  5. 11
      slither/core/children/child_event.py
  6. 11
      slither/core/children/child_expression.py
  7. 9
      slither/core/children/child_function.py
  8. 9
      slither/core/children/child_inheritance.py
  9. 11
      slither/core/children/child_node.py
  10. 9
      slither/core/children/child_structure.py
  11. 51
      slither/core/declarations/contract.py
  12. 4
      slither/core/declarations/custom_error.py
  13. 7
      slither/core/declarations/custom_error_contract.py
  14. 2
      slither/core/declarations/custom_error_top_level.py
  15. 7
      slither/core/declarations/function.py
  16. 12
      slither/core/declarations/solidity_variables.py
  17. 1
      slither/core/dominators/utils.py
  18. 2
      slither/core/expressions/assignment_operation.py
  19. 4
      slither/core/slither_core.py
  20. 4
      slither/core/solidity_types/array_type.py
  21. 6
      slither/core/solidity_types/elementary_type.py
  22. 4
      slither/core/solidity_types/mapping_type.py
  23. 4
      slither/core/solidity_types/type_alias.py
  24. 6
      slither/core/solidity_types/type_information.py
  25. 6
      slither/core/source_mapping/source_mapping.py
  26. 2
      slither/core/variables/event_variable.py
  27. 6
      slither/core/variables/variable.py
  28. 4
      slither/detectors/abstract_detector.py
  29. 13
      slither/detectors/assembly/shift_parameter_mixup.py
  30. 6
      slither/detectors/attributes/const_functions_asm.py
  31. 7
      slither/detectors/compiler_bugs/array_by_reference.py
  32. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20.py
  33. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py
  34. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py
  35. 10
      slither/detectors/functions/arbitrary_send_eth.py
  36. 3
      slither/detectors/statements/array_length_assignment.py
  37. 8
      slither/detectors/statements/assembly.py
  38. 7
      slither/slithir/operations/call.py
  39. 20
      slither/slithir/operations/high_level_call.py
  40. 26
      slither/slithir/operations/index.py
  41. 17
      slither/slithir/operations/library_call.py
  42. 6
      slither/slithir/operations/low_level_call.py
  43. 18
      slither/slithir/operations/lvalue.py
  44. 12
      slither/slithir/operations/member.py
  45. 11
      slither/slithir/operations/new_contract.py
  46. 11
      slither/slithir/operations/solidity_call.py
  47. 20
      slither/slithir/utils/utils.py
  48. 7
      slither/tools/mutator/__main__.py
  49. 11
      slither/tools/mutator/mutators/abstract_mutator.py

@ -66,7 +66,7 @@ def process_single(
args: argparse.Namespace, args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]], detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]], printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]: ) -> Tuple[Slither, List[Dict], List[Output], int]:
""" """
The core high-level code for running Slither static analysis. The core high-level code for running Slither static analysis.
@ -89,7 +89,7 @@ def process_all(
args: argparse.Namespace, args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]], detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]], printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[List[Slither], List[Dict], List[Dict], int]: ) -> Tuple[List[Slither], List[Dict], List[Output], int]:
compilations = compile_all(target, **vars(args)) compilations = compile_all(target, **vars(args))
slither_instances = [] slither_instances = []
results_detectors = [] results_detectors = []
@ -144,23 +144,6 @@ def _process(
return slither, results_detectors, results_printers, analyzed_contracts_count return slither, results_detectors, results_printers, analyzed_contracts_count
# TODO: delete me?
def process_from_asts(
filenames: List[str],
args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
all_contracts: List[str] = []
for filename in filenames:
with open(filename, encoding="utf8") as file_open:
contract_loaded = json.load(file_open)
all_contracts.append(contract_loaded["ast"])
return process_single(all_contracts, args, detector_classes, printer_classes)
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -608,9 +591,6 @@ def parse_args(
default=False, default=False,
) )
# if the json is splitted in different files
parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False)
# Disable the throw/catch on partial analyses # Disable the throw/catch on partial analyses
parser.add_argument( parser.add_argument(
"--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False "--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False
@ -626,7 +606,7 @@ def parse_args(
args.filter_paths = parse_filter_paths(args) args.filter_paths = parse_filter_paths(args)
# Verify our json-type output is valid # Verify our json-type output is valid
args.json_types = set(args.json_types.split(",")) args.json_types = set(args.json_types.split(",")) # type:ignore
for json_type in args.json_types: for json_type in args.json_types:
if json_type not in JSON_OUTPUT_TYPES: if json_type not in JSON_OUTPUT_TYPES:
raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.') raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.')
@ -697,14 +677,14 @@ class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods
class FormatterCryticCompile(logging.Formatter): class FormatterCryticCompile(logging.Formatter):
def format(self, record): def format(self, record: logging.LogRecord) -> str:
# for i, msg in enumerate(record.msg): # for i, msg in enumerate(record.msg):
if record.msg.startswith("Compilation warnings/errors on "): if record.msg.startswith("Compilation warnings/errors on "):
txt = record.args[1] txt = record.args[1] # type:ignore
txt = txt.split("\n") txt = txt.split("\n") # type:ignore
txt = [red(x) if "Error" in x else x for x in txt] txt = [red(x) if "Error" in x else x for x in txt]
txt = "\n".join(txt) txt = "\n".join(txt)
record.args = (record.args[0], txt) record.args = (record.args[0], txt) # type:ignore
return super().format(record) return super().format(record)
@ -747,7 +727,7 @@ def main_impl(
set_colorization_enabled(False if args.disable_color else sys.stdout.isatty()) set_colorization_enabled(False if args.disable_color else sys.stdout.isatty())
# Define some variables for potential JSON output # Define some variables for potential JSON output
json_results = {} json_results: Dict[str, Any] = {}
output_error = None output_error = None
outputting_json = args.json is not None outputting_json = args.json is not None
outputting_json_stdout = args.json == "-" outputting_json_stdout = args.json == "-"
@ -796,7 +776,7 @@ def main_impl(
crytic_compile_error.setLevel(logging.INFO) crytic_compile_error.setLevel(logging.INFO)
results_detectors: List[Dict] = [] results_detectors: List[Dict] = []
results_printers: List[Dict] = [] results_printers: List[Output] = []
try: try:
filename = args.filename filename = args.filename
@ -809,26 +789,17 @@ def main_impl(
number_contracts = 0 number_contracts = 0
slither_instances = [] slither_instances = []
if args.splitted: for filename in filenames:
( (
slither_instance, slither_instance,
results_detectors, results_detectors_tmp,
results_printers, results_printers_tmp,
number_contracts, number_contracts_tmp,
) = process_from_asts(filenames, args, detector_classes, printer_classes) ) = process_single(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp
results_printers += results_printers_tmp
slither_instances.append(slither_instance) slither_instances.append(slither_instance)
else:
for filename in filenames:
(
slither_instance,
results_detectors_tmp,
results_printers_tmp,
number_contracts_tmp,
) = process_single(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp
results_printers += results_printers_tmp
slither_instances.append(slither_instance)
# Rely on CryticCompile to discern the underlying type of compilations. # Rely on CryticCompile to discern the underlying type of compilations.
else: else:

@ -4,6 +4,7 @@
from collections import defaultdict from collections import defaultdict
from typing import Union, Set, Dict, TYPE_CHECKING from typing import Union, Set, Dict, TYPE_CHECKING
from slither.core.cfg.node import Node
from slither.core.declarations import ( from slither.core.declarations import (
Contract, Contract,
Enum, Enum,
@ -12,6 +13,7 @@ from slither.core.declarations import (
SolidityVariable, SolidityVariable,
SolidityVariableComposed, SolidityVariableComposed,
Structure, Structure,
FunctionContract,
) )
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.top_level_variable import TopLevelVariable
@ -40,25 +42,37 @@ if TYPE_CHECKING:
Variable_types = Union[Variable, SolidityVariable] Variable_types = Union[Variable, SolidityVariable]
# TODO refactor the data deps to be better suited for top level function object
# Right now we allow to pass a node to ease the API, but we need something
# better
# The deps propagation for top level elements is also not working as expected
Context_types_API = Union[Contract, Function, Node]
Context_types = Union[Contract, Function] Context_types = Union[Contract, Function]
def is_dependent( def is_dependent(
variable: Variable_types, variable: Variable_types,
source: Variable_types, source: Variable_types,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> bool: ) -> bool:
""" """
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args: Args:
variable (Variable) variable (Variable)
source (Variable) source (Variable)
context (Contract|Function) context (Contract|Function|Node).
only_unprotected (bool): True only unprotected function are considered only_unprotected (bool): True only unprotected function are considered
Returns: Returns:
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
if variable == source: if variable == source:
@ -76,10 +90,13 @@ def is_dependent(
def is_dependent_ssa( def is_dependent_ssa(
variable: Variable_types, variable: Variable_types,
source: Variable_types, source: Variable_types,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> bool: ) -> bool:
""" """
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args: Args:
variable (Variable) variable (Variable)
taint (Variable) taint (Variable)
@ -88,7 +105,10 @@ def is_dependent_ssa(
Returns: Returns:
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
context_dict = context.context context_dict = context.context
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
@ -112,11 +132,14 @@ GENERIC_TAINT = {
def is_tainted( def is_tainted(
variable: Variable_types, variable: Variable_types,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
ignore_generic_taint: bool = False, ignore_generic_taint: bool = False,
) -> bool: ) -> bool:
""" """
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args: Args:
variable variable
context (Contract|Function) context (Contract|Function)
@ -124,7 +147,10 @@ def is_tainted(
Returns: Returns:
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
@ -139,11 +165,14 @@ def is_tainted(
def is_tainted_ssa( def is_tainted_ssa(
variable: Variable_types, variable: Variable_types,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
ignore_generic_taint: bool = False, ignore_generic_taint: bool = False,
): ) -> bool:
""" """
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args: Args:
variable variable
context (Contract|Function) context (Contract|Function)
@ -151,7 +180,10 @@ def is_tainted_ssa(
Returns: Returns:
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
@ -166,18 +198,23 @@ def is_tainted_ssa(
def get_dependencies( def get_dependencies(
variable: Variable_types, variable: Variable_types,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> Set[Variable]: ) -> Set[Variable]:
""" """
Return the variables for which `variable` depends on. Return the variables for which `variable` depends on.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param variable: The target :param variable: The target
:param context: Either a function (interprocedural) or a contract (inter transactional) :param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions :param only_unprotected: True if consider only protected functions
:return: set(Variable) :return: set(Variable)
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if only_unprotected: if only_unprotected:
return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set()) return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set())
@ -185,16 +222,21 @@ def get_dependencies(
def get_all_dependencies( def get_all_dependencies(
context: Context_types, only_unprotected: bool = False context: Context_types_API, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]: ) -> Dict[Variable, Set[Variable]]:
""" """
Return the dictionary of dependencies. Return the dictionary of dependencies.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param context: Either a function (interprocedural) or a contract (inter transactional) :param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions :param only_unprotected: True if consider only protected functions
:return: Dict(Variable, set(Variable)) :return: Dict(Variable, set(Variable))
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if only_unprotected: if only_unprotected:
return context.context[KEY_NON_SSA_UNPROTECTED] return context.context[KEY_NON_SSA_UNPROTECTED]
@ -203,18 +245,23 @@ def get_all_dependencies(
def get_dependencies_ssa( def get_dependencies_ssa(
variable: Variable_types, variable: Variable_types,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> Set[Variable]: ) -> Set[Variable]:
""" """
Return the variables for which `variable` depends on (SSA version). Return the variables for which `variable` depends on (SSA version).
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param variable: The target (must be SSA variable) :param variable: The target (must be SSA variable)
:param context: Either a function (interprocedural) or a contract (inter transactional) :param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions :param only_unprotected: True if consider only protected functions
:return: set(Variable) :return: set(Variable)
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if only_unprotected: if only_unprotected:
return context.context[KEY_SSA_UNPROTECTED].get(variable, set()) return context.context[KEY_SSA_UNPROTECTED].get(variable, set())
@ -222,16 +269,21 @@ def get_dependencies_ssa(
def get_all_dependencies_ssa( def get_all_dependencies_ssa(
context: Context_types, only_unprotected: bool = False context: Context_types_API, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]: ) -> Dict[Variable, Set[Variable]]:
""" """
Return the dictionary of dependencies. Return the dictionary of dependencies.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param context: Either a function (interprocedural) or a contract (inter transactional) :param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions :param only_unprotected: True if consider only protected functions
:return: Dict(Variable, set(Variable)) :return: Dict(Variable, set(Variable))
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if only_unprotected: if only_unprotected:
return context.context[KEY_SSA_UNPROTECTED] return context.context[KEY_SSA_UNPROTECTED]

@ -6,7 +6,7 @@ from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING
from slither.all_exceptions import SlitherException from slither.all_exceptions import SlitherException
from slither.core.children.child_function import ChildFunction from slither.core.children.child_function import ChildFunction
from slither.core.declarations import Contract, Function from slither.core.declarations import Contract, Function, FunctionContract
from slither.core.declarations.solidity_variables import ( from slither.core.declarations.solidity_variables import (
SolidityVariable, SolidityVariable,
SolidityFunction, SolidityFunction,
@ -33,6 +33,7 @@ from slither.slithir.operations import (
Return, Return,
Operation, Operation,
) )
from slither.slithir.utils.utils import RVALUE
from slither.slithir.variables import ( from slither.slithir.variables import (
Constant, Constant,
LocalIRVariable, LocalIRVariable,
@ -146,12 +147,12 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._node_id: int = node_id self._node_id: int = node_id
self._vars_written: List[Variable] = [] self._vars_written: List[Variable] = []
self._vars_read: List[Variable] = [] self._vars_read: List[Union[Variable, SolidityVariable]] = []
self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_written: List["SlithIRVariable"] = []
self._ssa_vars_read: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = []
self._internal_calls: List["Function"] = [] self._internal_calls: List[Union["Function", "SolidityFunction"]] = []
self._solidity_calls: List[SolidityFunction] = [] self._solidity_calls: List[SolidityFunction] = []
self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls
self._library_calls: List["LibraryCallType"] = [] self._library_calls: List["LibraryCallType"] = []
@ -172,7 +173,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._local_vars_read: List[LocalVariable] = [] self._local_vars_read: List[LocalVariable] = []
self._local_vars_written: List[LocalVariable] = [] self._local_vars_written: List[LocalVariable] = []
self._slithir_vars: Set["SlithIRVariable"] = set() # non SSA self._slithir_vars: Set[
Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]
] = set() # non SSA
self._ssa_local_vars_read: List[LocalIRVariable] = [] self._ssa_local_vars_read: List[LocalIRVariable] = []
self._ssa_local_vars_written: List[LocalIRVariable] = [] self._ssa_local_vars_written: List[LocalIRVariable] = []
@ -213,7 +216,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._node_type return self._node_type
@type.setter @type.setter
def type(self, new_type: NodeType): def type(self, new_type: NodeType) -> None:
self._node_type = new_type self._node_type = new_type
@property @property
@ -232,7 +235,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
################################################################################### ###################################################################################
@property @property
def variables_read(self) -> List[Variable]: def variables_read(self) -> List[Union[Variable, SolidityVariable]]:
""" """
list(Variable): Variables read (local/state/solidity) list(Variable): Variables read (local/state/solidity)
""" """
@ -285,11 +288,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._expression_vars_read return self._expression_vars_read
@variables_read_as_expression.setter @variables_read_as_expression.setter
def variables_read_as_expression(self, exprs: List[Expression]): def variables_read_as_expression(self, exprs: List[Expression]) -> None:
self._expression_vars_read = exprs self._expression_vars_read = exprs
@property @property
def slithir_variables(self) -> List["SlithIRVariable"]: def slithir_variables(
self,
) -> List[Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]]:
return list(self._slithir_vars) return list(self._slithir_vars)
@property @property
@ -339,7 +344,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._expression_vars_written return self._expression_vars_written
@variables_written_as_expression.setter @variables_written_as_expression.setter
def variables_written_as_expression(self, exprs: List[Expression]): def variables_written_as_expression(self, exprs: List[Expression]) -> None:
self._expression_vars_written = exprs self._expression_vars_written = exprs
# endregion # endregion
@ -399,7 +404,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._external_calls_as_expressions return self._external_calls_as_expressions
@external_calls_as_expressions.setter @external_calls_as_expressions.setter
def external_calls_as_expressions(self, exprs: List[Expression]): def external_calls_as_expressions(self, exprs: List[Expression]) -> None:
self._external_calls_as_expressions = exprs self._external_calls_as_expressions = exprs
@property @property
@ -410,7 +415,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._internal_calls_as_expressions return self._internal_calls_as_expressions
@internal_calls_as_expressions.setter @internal_calls_as_expressions.setter
def internal_calls_as_expressions(self, exprs: List[Expression]): def internal_calls_as_expressions(self, exprs: List[Expression]) -> None:
self._internal_calls_as_expressions = exprs self._internal_calls_as_expressions = exprs
@property @property
@ -418,10 +423,10 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return list(self._expression_calls) return list(self._expression_calls)
@calls_as_expression.setter @calls_as_expression.setter
def calls_as_expression(self, exprs: List[Expression]): def calls_as_expression(self, exprs: List[Expression]) -> None:
self._expression_calls = exprs self._expression_calls = exprs
def can_reenter(self, callstack=None) -> bool: def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
""" """
Check if the node can re-enter Check if the node can re-enter
Do not consider CREATE as potential re-enter, but check if the Do not consider CREATE as potential re-enter, but check if the
@ -567,7 +572,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
""" """
self._fathers.append(father) self._fathers.append(father)
def set_fathers(self, fathers: List["Node"]): def set_fathers(self, fathers: List["Node"]) -> None:
"""Set the father nodes """Set the father nodes
Args: Args:
@ -663,20 +668,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._irs_ssa return self._irs_ssa
@irs_ssa.setter @irs_ssa.setter
def irs_ssa(self, irs): def irs_ssa(self, irs: List[Operation]) -> None:
self._irs_ssa = irs self._irs_ssa = irs
def add_ssa_ir(self, ir: Operation) -> None: def add_ssa_ir(self, ir: Operation) -> None:
""" """
Use to place phi operation Use to place phi operation
""" """
ir.set_node(self) ir.set_node(self) # type: ignore
self._irs_ssa.append(ir) self._irs_ssa.append(ir)
def slithir_generation(self) -> None: def slithir_generation(self) -> None:
if self.expression: if self.expression:
expression = self.expression expression = self.expression
self._irs = convert_expression(expression, self) self._irs = convert_expression(expression, self) # type:ignore
self._find_read_write_call() self._find_read_write_call()
@ -713,7 +718,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._dominators return self._dominators
@dominators.setter @dominators.setter
def dominators(self, dom: Set["Node"]): def dominators(self, dom: Set["Node"]) -> None:
self._dominators = dom self._dominators = dom
@property @property
@ -725,7 +730,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._immediate_dominator return self._immediate_dominator
@immediate_dominator.setter @immediate_dominator.setter
def immediate_dominator(self, idom: "Node"): def immediate_dominator(self, idom: "Node") -> None:
self._immediate_dominator = idom self._immediate_dominator = idom
@property @property
@ -737,7 +742,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._dominance_frontier return self._dominance_frontier
@dominance_frontier.setter @dominance_frontier.setter
def dominance_frontier(self, doms: Set["Node"]): def dominance_frontier(self, doms: Set["Node"]) -> None:
""" """
Returns: Returns:
set(Node) set(Node)
@ -789,6 +794,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None: def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None:
if variable.name not in self._phi_origins_local_variables: if variable.name not in self._phi_origins_local_variables:
assert variable.name
self._phi_origins_local_variables[variable.name] = (variable, set()) self._phi_origins_local_variables[variable.name] = (variable, set())
(v, nodes) = self._phi_origins_local_variables[variable.name] (v, nodes) = self._phi_origins_local_variables[variable.name]
assert v == variable assert v == variable
@ -827,7 +833,8 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
if isinstance(ir, OperationWithLValue): if isinstance(ir, OperationWithLValue):
var = ir.lvalue var = ir.lvalue
if var and self._is_valid_slithir_var(var): if var and self._is_valid_slithir_var(var):
self._slithir_vars.add(var) # The type is checked by is_valid_slithir_var
self._slithir_vars.add(var) # type: ignore
if not isinstance(ir, (Phi, Index, Member)): if not isinstance(ir, (Phi, Index, Member)):
self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)] self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)]
@ -835,8 +842,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
if isinstance(var, ReferenceVariable): if isinstance(var, ReferenceVariable):
self._vars_read.append(var.points_to_origin) self._vars_read.append(var.points_to_origin)
elif isinstance(ir, (Member, Index)): elif isinstance(ir, (Member, Index)):
# TODO investigate types for member variable left
var = ir.variable_left if isinstance(ir, Member) else ir.variable_right var = ir.variable_left if isinstance(ir, Member) else ir.variable_right
if self._is_non_slithir_var(var): if var and self._is_non_slithir_var(var):
self._vars_read.append(var) self._vars_read.append(var)
if isinstance(var, ReferenceVariable): if isinstance(var, ReferenceVariable):
origin = var.points_to_origin origin = var.points_to_origin
@ -860,14 +868,21 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._internal_calls.append(ir.function) self._internal_calls.append(ir.function)
if isinstance(ir, LowLevelCall): if isinstance(ir, LowLevelCall):
assert isinstance(ir.destination, (Variable, SolidityVariable)) assert isinstance(ir.destination, (Variable, SolidityVariable))
self._low_level_calls.append((ir.destination, ir.function_name.value)) self._low_level_calls.append((ir.destination, str(ir.function_name.value)))
elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall):
# Todo investigate this if condition
# It does seem right to compare against a contract
# This might need a refactoring
if isinstance(ir.destination.type, Contract): if isinstance(ir.destination.type, Contract):
self._high_level_calls.append((ir.destination.type, ir.function)) self._high_level_calls.append((ir.destination.type, ir.function))
elif ir.destination == SolidityVariable("this"): elif ir.destination == SolidityVariable("this"):
self._high_level_calls.append((self.function.contract, ir.function)) func = self.function
# Can't use this in a top level function
assert isinstance(func, FunctionContract)
self._high_level_calls.append((func.contract, ir.function))
else: else:
try: try:
# Todo this part needs more tests and documentation
self._high_level_calls.append((ir.destination.type.type, ir.function)) self._high_level_calls.append((ir.destination.type.type, ir.function))
except AttributeError as error: except AttributeError as error:
# pylint: disable=raise-missing-from # pylint: disable=raise-missing-from
@ -883,7 +898,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._vars_read = list(set(self._vars_read)) self._vars_read = list(set(self._vars_read))
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)] self._solidity_vars_read = [
v_ for v_ in self._vars_read if isinstance(v_, SolidityVariable)
]
self._vars_written = list(set(self._vars_written)) self._vars_written = list(set(self._vars_written))
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
@ -895,12 +912,15 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
@staticmethod @staticmethod
def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]: def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]:
non_ssa_var: Optional[Union[StateVariable, LocalVariable]]
if isinstance(v, StateIRVariable): if isinstance(v, StateIRVariable):
contract = v.contract contract = v.contract
assert v.name
non_ssa_var = contract.get_state_variable_from_name(v.name) non_ssa_var = contract.get_state_variable_from_name(v.name)
return non_ssa_var return non_ssa_var
assert isinstance(v, LocalIRVariable) assert isinstance(v, LocalIRVariable)
function = v.function function = v.function
assert v.name
non_ssa_var = function.get_local_variable_from_name(v.name) non_ssa_var = function.get_local_variable_from_name(v.name)
return non_ssa_var return non_ssa_var
@ -921,10 +941,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._ssa_vars_read.append(origin) self._ssa_vars_read.append(origin)
elif isinstance(ir, (Member, Index)): elif isinstance(ir, (Member, Index)):
if isinstance(ir.variable_right, (StateIRVariable, LocalIRVariable)): variable_right: RVALUE = ir.variable_right
self._ssa_vars_read.append(ir.variable_right) if isinstance(variable_right, (StateIRVariable, LocalIRVariable)):
if isinstance(ir.variable_right, ReferenceVariable): self._ssa_vars_read.append(variable_right)
origin = ir.variable_right.points_to_origin if isinstance(variable_right, ReferenceVariable):
origin = variable_right.points_to_origin
if isinstance(origin, (StateIRVariable, LocalIRVariable)): if isinstance(origin, (StateIRVariable, LocalIRVariable)):
self._ssa_vars_read.append(origin) self._ssa_vars_read.append(origin)
@ -944,20 +965,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)]
self._ssa_vars_written = list(set(self._ssa_vars_written)) self._ssa_vars_written = list(set(self._ssa_vars_written))
self._ssa_state_vars_written = [ self._ssa_state_vars_written = [
v for v in self._ssa_vars_written if isinstance(v, StateVariable) v for v in self._ssa_vars_written if v and isinstance(v, StateIRVariable)
] ]
self._ssa_local_vars_written = [ self._ssa_local_vars_written = [
v for v in self._ssa_vars_written if isinstance(v, LocalVariable) v for v in self._ssa_vars_written if v and isinstance(v, LocalIRVariable)
] ]
vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read]
vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written]
self._vars_read += [v for v in vars_read if v not in self._vars_read] self._vars_read += [v_ for v_ in vars_read if v_ and v_ not in self._vars_read]
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
self._vars_written += [v for v in vars_written if v not in self._vars_written] self._vars_written += [v_ for v_ in vars_written if v_ and v_ not in self._vars_written]
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
@ -974,7 +995,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
additional_info += " " + str(self.expression) additional_info += " " + str(self.expression)
elif self.variable_declaration: elif self.variable_declaration:
additional_info += " " + str(self.variable_declaration) additional_info += " " + str(self.variable_declaration)
txt = self._node_type.value + additional_info txt = str(self._node_type.value) + additional_info
return txt return txt

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
@ -9,11 +9,14 @@ if TYPE_CHECKING:
class ChildContract(SourceMapping): class ChildContract(SourceMapping):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._contract = None # TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._contract: Optional["Contract"] = None
def set_contract(self, contract: "Contract") -> None: def set_contract(self, contract: "Contract") -> None:
self._contract = contract self._contract = contract
@property @property
def contract(self) -> "Contract": def contract(self) -> "Contract":
return self._contract return self._contract # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Event from slither.core.declarations import Event
@ -7,11 +7,14 @@ if TYPE_CHECKING:
class ChildEvent: class ChildEvent:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._event = None # TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._event: Optional["Event"] = None
def set_event(self, event: "Event"): def set_event(self, event: "Event") -> None:
self._event = event self._event = event
@property @property
def event(self) -> "Event": def event(self) -> "Event":
return self._event return self._event # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
@ -8,11 +8,16 @@ if TYPE_CHECKING:
class ChildExpression: class ChildExpression:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._expression = None # TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._expression: Optional[Union["Expression", "Operation"]] = None
def set_expression(self, expression: Union["Expression", "Operation"]) -> None: def set_expression(self, expression: Union["Expression", "Operation"]) -> None:
# TODO investigate when this can be an operation?
# It was auto generated during an AST or detectors tests
self._expression = expression self._expression = expression
@property @property
def expression(self) -> Union["Expression", "Operation"]: def expression(self) -> Union["Expression", "Operation"]:
return self._expression return self._expression # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Function from slither.core.declarations import Function
@ -7,11 +7,14 @@ if TYPE_CHECKING:
class ChildFunction: class ChildFunction:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._function = None # TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._function: Optional["Function"] = None
def set_function(self, function: "Function") -> None: def set_function(self, function: "Function") -> None:
self._function = function self._function = function
@property @property
def function(self) -> "Function": def function(self) -> "Function":
return self._function return self._function # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Contract from slither.core.declarations import Contract
@ -7,11 +7,14 @@ if TYPE_CHECKING:
class ChildInheritance: class ChildInheritance:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._contract_declarer = None # TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._contract_declarer: Optional["Contract"] = None
def set_contract_declarer(self, contract: "Contract") -> None: def set_contract_declarer(self, contract: "Contract") -> None:
self._contract_declarer = contract self._contract_declarer = contract
@property @property
def contract_declarer(self) -> "Contract": def contract_declarer(self) -> "Contract":
return self._contract_declarer return self._contract_declarer # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
@ -9,14 +9,17 @@ if TYPE_CHECKING:
class ChildNode: class ChildNode:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._node = None # TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._node: Optional["Node"] = None
def set_node(self, node: "Node") -> None: def set_node(self, node: "Node") -> None:
self._node = node self._node = node
@property @property
def node(self) -> "Node": def node(self) -> "Node":
return self._node return self._node # type:ignore
@property @property
def function(self) -> "Function": def function(self) -> "Function":
@ -24,7 +27,7 @@ class ChildNode:
@property @property
def contract(self) -> "Contract": def contract(self) -> "Contract":
return self.node.function.contract return self.node.function.contract # type: ignore
@property @property
def compilation_unit(self) -> "SlitherCompilationUnit": def compilation_unit(self) -> "SlitherCompilationUnit":

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Structure from slither.core.declarations import Structure
@ -7,11 +7,14 @@ if TYPE_CHECKING:
class ChildStructure: class ChildStructure:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._structure = None # TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._structure: Optional["Structure"] = None
def set_structure(self, structure: "Structure") -> None: def set_structure(self, structure: "Structure") -> None:
self._structure = structure self._structure = structure
@property @property
def structure(self) -> "Structure": def structure(self) -> "Structure":
return self._structure return self._structure # type: ignore

@ -81,7 +81,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
# The only str is "*" # The only str is "*"
self._using_for: Dict[Union[str, Type], List[Type]] = {} self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._using_for_complete: Dict[Union[str, Type], List[Type]] = None self._using_for_complete: Optional[Dict[Union[str, Type], List[Type]]] = None
self._kind: Optional[str] = None self._kind: Optional[str] = None
self._is_interface: bool = False self._is_interface: bool = False
self._is_library: bool = False self._is_library: bool = False
@ -275,7 +275,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive
""" """
def _merge_using_for(uf1, uf2): def _merge_using_for(uf1: Dict, uf2: Dict) -> Dict:
result = {**uf1, **uf2} result = {**uf1, **uf2}
for key, value in result.items(): for key, value in result.items():
if key in uf1 and key in uf2: if key in uf1 and key in uf2:
@ -524,14 +524,14 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
""" """
return list(self._functions.values()) return list(self._functions.values())
def available_functions_as_dict(self) -> Dict[str, "FunctionContract"]: def available_functions_as_dict(self) -> Dict[str, "Function"]:
if self._available_functions_as_dict is None: if self._available_functions_as_dict is None:
self._available_functions_as_dict = { self._available_functions_as_dict = {
f.full_name: f for f in self._functions.values() if not f.is_shadowed f.full_name: f for f in self._functions.values() if not f.is_shadowed
} }
return self._available_functions_as_dict return self._available_functions_as_dict
def add_function(self, func: "FunctionContract"): def add_function(self, func: "FunctionContract") -> None:
self._functions[func.canonical_name] = func self._functions[func.canonical_name] = func
def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None: def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None:
@ -699,7 +699,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
list(Contract): Return the list of contracts derived from self list(Contract): Return the list of contracts derived from self
""" """
candidates = self.compilation_unit.contracts candidates = self.compilation_unit.contracts
return [c for c in candidates if self in c.inheritance] return [c for c in candidates if self in c.inheritance] # type: ignore
# endregion # endregion
################################################################################### ###################################################################################
@ -855,7 +855,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
""" """
return next((e for e in self.enums if e.name == enum_name), None) return next((e for e in self.enums if e.name == enum_name), None)
def get_enum_from_canonical_name(self, enum_name) -> Optional["Enum"]: def get_enum_from_canonical_name(self, enum_name: str) -> Optional["Enum"]:
""" """
Return an enum from a canonical name Return an enum from a canonical name
Args: Args:
@ -956,7 +956,9 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def get_summary(self, include_shadowed=True) -> Tuple[str, List[str], List[str], List, List]: def get_summary(
self, include_shadowed: bool = True
) -> Tuple[str, List[str], List[str], List, List]:
"""Return the function summary """Return the function summary
:param include_shadowed: boolean to indicate if shadowed functions should be included (default True) :param include_shadowed: boolean to indicate if shadowed functions should be included (default True)
@ -1209,7 +1211,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
@property @property
def is_test(self) -> bool: def is_test(self) -> bool:
return is_test_contract(self) or self.is_truffle_migration return is_test_contract(self) or self.is_truffle_migration # type: ignore
# endregion # endregion
################################################################################### ###################################################################################
@ -1219,7 +1221,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
def update_read_write_using_ssa(self) -> None: def update_read_write_using_ssa(self) -> None:
for function in self.functions + self.modifiers: for function in self.functions + list(self.modifiers):
function.update_read_write_using_ssa() function.update_read_write_using_ssa()
# endregion # endregion
@ -1254,7 +1256,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_upgradeable return self._is_upgradeable
@is_upgradeable.setter @is_upgradeable.setter
def is_upgradeable(self, upgradeable: bool): def is_upgradeable(self, upgradeable: bool) -> None:
self._is_upgradeable = upgradeable self._is_upgradeable = upgradeable
@property @property
@ -1283,7 +1285,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_upgradeable_proxy return self._is_upgradeable_proxy
@is_upgradeable_proxy.setter @is_upgradeable_proxy.setter
def is_upgradeable_proxy(self, upgradeable_proxy: bool): def is_upgradeable_proxy(self, upgradeable_proxy: bool) -> None:
self._is_upgradeable_proxy = upgradeable_proxy self._is_upgradeable_proxy = upgradeable_proxy
@property @property
@ -1291,7 +1293,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._upgradeable_version return self._upgradeable_version
@upgradeable_version.setter @upgradeable_version.setter
def upgradeable_version(self, version_name: str): def upgradeable_version(self, version_name: str) -> None:
self._upgradeable_version = version_name self._upgradeable_version = version_name
# endregion # endregion
@ -1310,7 +1312,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_incorrectly_parsed return self._is_incorrectly_parsed
@is_incorrectly_constructed.setter @is_incorrectly_constructed.setter
def is_incorrectly_constructed(self, incorrect: bool): def is_incorrectly_constructed(self, incorrect: bool) -> None:
self._is_incorrectly_parsed = incorrect self._is_incorrectly_parsed = incorrect
def add_constructor_variables(self) -> None: def add_constructor_variables(self) -> None:
@ -1322,8 +1324,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
constructor_variable = FunctionContract(self.compilation_unit) constructor_variable = FunctionContract(self.compilation_unit)
constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES)
constructor_variable.set_contract(self) constructor_variable.set_contract(self) # type: ignore
constructor_variable.set_contract_declarer(self) constructor_variable.set_contract_declarer(self) # type: ignore
constructor_variable.set_visibility("internal") constructor_variable.set_visibility("internal")
# For now, source mapping of the constructor variable is the whole contract # For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping # Could be improved with a targeted source mapping
@ -1354,8 +1356,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
constructor_variable.set_function_type( constructor_variable.set_function_type(
FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES
) )
constructor_variable.set_contract(self) constructor_variable.set_contract(self) # type: ignore
constructor_variable.set_contract_declarer(self) constructor_variable.set_contract_declarer(self) # type: ignore
constructor_variable.set_visibility("internal") constructor_variable.set_visibility("internal")
# For now, source mapping of the constructor variable is the whole contract # For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping # Could be improved with a targeted source mapping
@ -1436,22 +1438,23 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
all_ssa_state_variables_instances[v.canonical_name] = new_var all_ssa_state_variables_instances[v.canonical_name] = new_var
self._initial_state_variables.append(new_var) self._initial_state_variables.append(new_var)
for func in self.functions + self.modifiers: for func in self.functions + list(self.modifiers):
func.generate_slithir_ssa(all_ssa_state_variables_instances) func.generate_slithir_ssa(all_ssa_state_variables_instances)
def fix_phi(self) -> None: def fix_phi(self) -> None:
last_state_variables_instances = {} last_state_variables_instances: Dict[str, List["StateVariable"]] = {}
initial_state_variables_instances = {} initial_state_variables_instances: Dict[str, "StateVariable"] = {}
for v in self._initial_state_variables: for v in self._initial_state_variables:
last_state_variables_instances[v.canonical_name] = [] last_state_variables_instances[v.canonical_name] = []
initial_state_variables_instances[v.canonical_name] = v initial_state_variables_instances[v.canonical_name] = v
for func in self.functions + self.modifiers: for func in self.functions + list(self.modifiers):
result = func.get_last_ssa_state_variables_instances() result = func.get_last_ssa_state_variables_instances()
for variable_name, instances in result.items(): for variable_name, instances in result.items():
# TODO: investigate the next operation
last_state_variables_instances[variable_name] += instances last_state_variables_instances[variable_name] += instances
for func in self.functions + self.modifiers: for func in self.functions + list(self.modifiers):
func.fix_phi(last_state_variables_instances, initial_state_variables_instances) func.fix_phi(last_state_variables_instances, initial_state_variables_instances)
# endregion # endregion
@ -1461,7 +1464,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def __eq__(self, other: SourceMapping) -> bool: def __eq__(self, other: Any) -> bool:
if isinstance(other, str): if isinstance(other, str):
return other == self.name return other == self.name
return NotImplemented return NotImplemented
@ -1475,6 +1478,6 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self.name return self.name
def __hash__(self) -> int: def __hash__(self) -> int:
return self._id return self._id # type:ignore
# endregion # endregion

@ -51,7 +51,7 @@ class CustomError(SourceMapping):
return str(t) return str(t)
@property @property
def solidity_signature(self) -> Optional[str]: def solidity_signature(self) -> str:
""" """
Return a signature following the Solidity Standard Return a signature following the Solidity Standard
Contract and converted into address Contract and converted into address
@ -63,7 +63,7 @@ class CustomError(SourceMapping):
# (set_solidity_sig was not called before find_variable) # (set_solidity_sig was not called before find_variable)
if self._solidity_signature is None: if self._solidity_signature is None:
raise ValueError("Custom Error not yet built") raise ValueError("Custom Error not yet built")
return self._solidity_signature return self._solidity_signature # type: ignore
def set_solidity_sig(self) -> None: def set_solidity_sig(self) -> None:
""" """

@ -1,9 +1,14 @@
from typing import TYPE_CHECKING
from slither.core.children.child_contract import ChildContract from slither.core.children.child_contract import ChildContract
from slither.core.declarations.custom_error import CustomError from slither.core.declarations.custom_error import CustomError
if TYPE_CHECKING:
from slither.core.declarations import Contract
class CustomErrorContract(CustomError, ChildContract): class CustomErrorContract(CustomError, ChildContract):
def is_declared_by(self, contract): def is_declared_by(self, contract: "Contract") -> bool:
""" """
Check if the element is declared by the contract Check if the element is declared by the contract
:param contract: :param contract:

@ -9,6 +9,6 @@ if TYPE_CHECKING:
class CustomErrorTopLevel(CustomError, TopLevel): class CustomErrorTopLevel(CustomError, TopLevel):
def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None:
super().__init__(compilation_unit) super().__init__(compilation_unit)
self.file_scope: "FileScope" = scope self.file_scope: "FileScope" = scope

@ -47,7 +47,6 @@ if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope from slither.core.scope.scope import FileScope
from slither.slithir.variables.state_variable import StateIRVariable from slither.slithir.variables.state_variable import StateIRVariable
from slither.core.declarations.function_contract import FunctionContract
LOGGER = logging.getLogger("Function") LOGGER = logging.getLogger("Function")
ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"])
@ -298,7 +297,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def contains_assembly(self, c: bool): def contains_assembly(self, c: bool):
self._contains_assembly = c self._contains_assembly = c
def can_reenter(self, callstack: Optional[List["FunctionContract"]] = None) -> bool: def can_reenter(self, callstack: Optional[List[Union["Function", "Variable"]]] = None) -> bool:
""" """
Check if the function can re-enter Check if the function can re-enter
Follow internal calls. Follow internal calls.
@ -1720,8 +1719,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def fix_phi( def fix_phi(
self, self,
last_state_variables_instances: Dict[str, List["StateIRVariable"]], last_state_variables_instances: Dict[str, List["StateVariable"]],
initial_state_variables_instances: Dict[str, "StateIRVariable"], initial_state_variables_instances: Dict[str, "StateVariable"],
) -> None: ) -> None:
from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.operations import InternalCall, PhiCallback
from slither.slithir.variables import Constant, StateIRVariable from slither.slithir.variables import Constant, StateIRVariable

@ -82,7 +82,7 @@ SOLIDITY_FUNCTIONS: Dict[str, List[str]] = {
} }
def solidity_function_signature(name): def solidity_function_signature(name: str) -> str:
""" """
Return the function signature (containing the return value) Return the function signature (containing the return value)
It is useful if a solidity function is used as a pointer It is useful if a solidity function is used as a pointer
@ -106,7 +106,7 @@ class SolidityVariable(SourceMapping):
assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset")) assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset"))
@property @property
def state_variable(self): def state_variable(self) -> str:
if self._name.endswith("_slot"): if self._name.endswith("_slot"):
return self._name[:-5] return self._name[:-5]
if self._name.endswith("_offset"): if self._name.endswith("_offset"):
@ -125,7 +125,7 @@ class SolidityVariable(SourceMapping):
def __str__(self) -> str: def __str__(self) -> str:
return self._name return self._name
def __eq__(self, other: SourceMapping) -> bool: def __eq__(self, other: Any) -> bool:
return self.__class__ == other.__class__ and self.name == other.name return self.__class__ == other.__class__ and self.name == other.name
def __hash__(self) -> int: def __hash__(self) -> int:
@ -182,13 +182,13 @@ class SolidityFunction(SourceMapping):
return self._return_type return self._return_type
@return_type.setter @return_type.setter
def return_type(self, r: List[Union[TypeInformation, ElementaryType]]): def return_type(self, r: List[Union[TypeInformation, ElementaryType]]) -> None:
self._return_type = r self._return_type = r
def __str__(self) -> str: def __str__(self) -> str:
return self._name return self._name
def __eq__(self, other: "SolidityFunction") -> bool: def __eq__(self, other: Any) -> bool:
return self.__class__ == other.__class__ and self.name == other.name return self.__class__ == other.__class__ and self.name == other.name
def __hash__(self) -> int: def __hash__(self) -> int:
@ -201,7 +201,7 @@ class SolidityCustomRevert(SolidityFunction):
self._custom_error = custom_error self._custom_error = custom_error
self._return_type: List[Union[TypeInformation, ElementaryType]] = [] self._return_type: List[Union[TypeInformation, ElementaryType]] = []
def __eq__(self, other: Union["SolidityCustomRevert", SolidityFunction]) -> bool: def __eq__(self, other: Any) -> bool:
return ( return (
self.__class__ == other.__class__ self.__class__ == other.__class__
and self.name == other.name and self.name == other.name

@ -95,4 +95,5 @@ def compute_dominance_frontier(nodes: List["Node"]) -> None:
runner.dominance_frontier = runner.dominance_frontier.union({node}) runner.dominance_frontier = runner.dominance_frontier.union({node})
while runner != node.immediate_dominator: while runner != node.immediate_dominator:
runner.dominance_frontier = runner.dominance_frontier.union({node}) runner.dominance_frontier = runner.dominance_frontier.union({node})
assert runner.immediate_dominator
runner = runner.immediate_dominator runner = runner.immediate_dominator

@ -91,7 +91,7 @@ class AssignmentOperation(ExpressionTyped):
super().__init__() super().__init__()
left_expression.set_lvalue() left_expression.set_lvalue()
self._expressions = [left_expression, right_expression] self._expressions = [left_expression, right_expression]
self._type: Optional["AssignmentOperationType"] = expression_type self._type: AssignmentOperationType = expression_type
self._expression_return_type: Optional["Type"] = expression_return_type self._expression_return_type: Optional["Type"] = expression_return_type
@property @property

@ -482,8 +482,8 @@ class SlitherCore(Context):
################################################################################### ###################################################################################
@property @property
def crytic_compile(self) -> Optional[CryticCompile]: def crytic_compile(self) -> CryticCompile:
return self._crytic_compile return self._crytic_compile # type: ignore
# endregion # endregion
################################################################################### ###################################################################################

@ -4,11 +4,11 @@ from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.visitors.expression.constants_folding import ConstantFolding from slither.visitors.expression.constants_folding import ConstantFolding
from slither.core.expressions.literal import Literal from slither.core.expressions.literal import Literal
from slither.core.solidity_types.elementary_type import ElementaryType
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.binary_operation import BinaryOperation
from slither.core.expressions.identifier import Identifier from slither.core.expressions.identifier import Identifier
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.solidity_types.function_type import FunctionType from slither.core.solidity_types.function_type import FunctionType
from slither.core.solidity_types.type_alias import TypeAliasTopLevel from slither.core.solidity_types.type_alias import TypeAliasTopLevel
@ -22,7 +22,7 @@ class ArrayType(Type):
assert isinstance(t, Type) assert isinstance(t, Type)
if length: if length:
if isinstance(length, int): if isinstance(length, int):
length = Literal(length, "uint256") length = Literal(length, ElementaryType("uint256"))
assert isinstance(length, Expression) assert isinstance(length, Expression)
super().__init__() super().__init__()
self._type: Type = t self._type: Type = t

@ -1,5 +1,5 @@
import itertools import itertools
from typing import Tuple from typing import Tuple, Optional, Any
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
@ -176,7 +176,7 @@ class ElementaryType(Type):
return self.type return self.type
@property @property
def size(self) -> int: def size(self) -> Optional[int]:
""" """
Return the size in bits Return the size in bits
Return None if the size is not known Return None if the size is not known
@ -219,7 +219,7 @@ class ElementaryType(Type):
def __str__(self) -> str: def __str__(self) -> str:
return self._type return self._type
def __eq__(self, other) -> bool: def __eq__(self, other: Any) -> bool:
if not isinstance(other, ElementaryType): if not isinstance(other, ElementaryType):
return False return False
return self.type == other.type return self.type == other.type

@ -1,4 +1,4 @@
from typing import Union, Tuple, TYPE_CHECKING from typing import Union, Tuple, TYPE_CHECKING, Any
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
@ -38,7 +38,7 @@ class MappingType(Type):
def __str__(self) -> str: def __str__(self) -> str:
return f"mapping({str(self._from)} => {str(self._to)})" return f"mapping({str(self._from)} => {str(self._to)})"
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
if not isinstance(other, MappingType): if not isinstance(other, MappingType):
return False return False
return self.type_from == other.type_from and self.type_to == other.type_to return self.type_from == other.type_from and self.type_to == other.type_to

@ -40,7 +40,7 @@ class TypeAlias(Type):
class TypeAliasTopLevel(TypeAlias, TopLevel): class TypeAliasTopLevel(TypeAlias, TopLevel):
def __init__(self, underlying_type: Type, name: str, scope: "FileScope") -> None: def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None:
super().__init__(underlying_type, name) super().__init__(underlying_type, name)
self.file_scope: "FileScope" = scope self.file_scope: "FileScope" = scope
@ -49,7 +49,7 @@ class TypeAliasTopLevel(TypeAlias, TopLevel):
class TypeAliasContract(TypeAlias, ChildContract): class TypeAliasContract(TypeAlias, ChildContract):
def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: def __init__(self, underlying_type: ElementaryType, name: str, contract: "Contract") -> None:
super().__init__(underlying_type, name) super().__init__(underlying_type, name)
self._contract: "Contract" = contract self._contract: "Contract" = contract

@ -1,4 +1,4 @@
from typing import Union, TYPE_CHECKING, Tuple from typing import Union, TYPE_CHECKING, Tuple, Any
from slither.core.solidity_types import ElementaryType from slither.core.solidity_types import ElementaryType
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
@ -40,10 +40,10 @@ class TypeInformation(Type):
def is_dynamic(self) -> bool: def is_dynamic(self) -> bool:
raise NotImplementedError raise NotImplementedError
def __str__(self): def __str__(self) -> str:
return f"type({self.type.name})" return f"type({self.type.name})"
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeInformation): if not isinstance(other, TypeInformation):
return False return False
return self.type == other.type return self.type == other.type

@ -1,6 +1,6 @@
import re import re
from abc import ABCMeta from abc import ABCMeta
from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional, Any
from Crypto.Hash import SHA1 from Crypto.Hash import SHA1
from crytic_compile.utils.naming import Filename from crytic_compile.utils.naming import Filename
@ -102,10 +102,10 @@ class Source:
filename_short: str = self.filename.short if self.filename.short else "" filename_short: str = self.filename.short if self.filename.short else ""
return f"{filename_short}{lines}" return f"{filename_short}{lines}"
def __hash__(self): def __hash__(self) -> int:
return hash(str(self)) return hash(str(self))
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)): if not isinstance(other, type(self)):
return NotImplemented return NotImplemented
return ( return (

@ -16,5 +16,5 @@ class EventVariable(ChildEvent, Variable):
return self._indexed return self._indexed
@indexed.setter @indexed.setter
def indexed(self, is_indexed: bool): def indexed(self, is_indexed: bool) -> bool:
self._indexed = is_indexed self._indexed = is_indexed

@ -55,7 +55,7 @@ class Variable(SourceMapping):
return self._initialized return self._initialized
@initialized.setter @initialized.setter
def initialized(self, is_init: bool): def initialized(self, is_init: bool) -> None:
self._initialized = is_init self._initialized = is_init
@property @property
@ -73,7 +73,7 @@ class Variable(SourceMapping):
return self._name return self._name
@name.setter @name.setter
def name(self, name): def name(self, name: str) -> None:
self._name = name self._name = name
@property @property
@ -89,7 +89,7 @@ class Variable(SourceMapping):
return self._is_constant return self._is_constant
@is_constant.setter @is_constant.setter
def is_constant(self, is_cst: bool): def is_constant(self, is_cst: bool) -> None:
self._is_constant = is_cst self._is_constant = is_cst
@property @property

@ -59,6 +59,8 @@ ALL_SOLC_VERSIONS_06 = make_solc_versions(6, 0, 12)
ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6) ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6)
# No VERSIONS_08 as it is still in dev # No VERSIONS_08 as it is still in dev
DETECTOR_INFO = Union[str, List[Union[str, SupportedOutput]]]
class AbstractDetector(metaclass=abc.ABCMeta): class AbstractDetector(metaclass=abc.ABCMeta):
ARGUMENT = "" # run the detector with slither.py --ARGUMENT ARGUMENT = "" # run the detector with slither.py --ARGUMENT
@ -251,7 +253,7 @@ class AbstractDetector(metaclass=abc.ABCMeta):
def generate_result( def generate_result(
self, self,
info: Union[str, List[Union[str, SupportedOutput]]], info: DETECTOR_INFO,
additional_fields: Optional[Dict] = None, additional_fields: Optional[Dict] = None,
) -> Output: ) -> Output:
output = Output( output = Output(

@ -1,5 +1,9 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import Binary, BinaryType from slither.slithir.operations import Binary, BinaryType
from slither.slithir.variables import Constant from slither.slithir.variables import Constant
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
@ -49,7 +53,12 @@ The shift statement will right-shift the constant 8 by `a` bits"""
BinaryType.RIGHT_SHIFT, BinaryType.RIGHT_SHIFT,
]: ]:
if isinstance(ir.variable_left, Constant): if isinstance(ir.variable_left, Constant):
info = [f, " contains an incorrect shift operation: ", node, "\n"] info: DETECTOR_INFO = [
f,
" contains an incorrect shift operation: ",
node,
"\n",
]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)

@ -7,6 +7,7 @@ from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_04,
DETECTOR_INFO,
) )
from slither.formatters.attributes.const_functions import custom_format from slither.formatters.attributes.const_functions import custom_format
from slither.utils.output import Output from slither.utils.output import Output
@ -73,7 +74,10 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
if f.contains_assembly: if f.contains_assembly:
attr = "view" if f.view else "pure" attr = "view" if f.view else "pure"
info = [f, f" is declared {attr} but contains assembly code\n"] info: DETECTOR_INFO = [
f,
f" is declared {attr} but contains assembly code\n",
]
res = self.generate_result(info, {"contains_assembly": True}) res = self.generate_result(info, {"contains_assembly": True})
results.append(res) results.append(res)

@ -105,7 +105,12 @@ As a result, Bob's usage of the contract is incorrect."""
write to the array unsuccessfully. write to the array unsuccessfully.
""" """
# Define our resulting array. # Define our resulting array.
results = [] results: List[
Union[
Tuple[Node, StateVariable, FunctionContract],
Tuple[Node, LocalVariable, FunctionContract],
]
] = []
# Verify we have functions in our list to check for. # Verify we have functions in our list to check for.
if not array_modifying_funcs: if not array_modifying_funcs:

@ -61,12 +61,12 @@ class ArbitrarySendErc20:
is_dependent( is_dependent(
ir.arguments[0], ir.arguments[0],
SolidityVariableComposed("msg.sender"), SolidityVariableComposed("msg.sender"),
node.function.contract, node,
) )
or is_dependent( or is_dependent(
ir.arguments[0], ir.arguments[0],
SolidityVariable("this"), SolidityVariable("this"),
node.function.contract, node,
) )
) )
): ):
@ -79,12 +79,12 @@ class ArbitrarySendErc20:
is_dependent( is_dependent(
ir.arguments[1], ir.arguments[1],
SolidityVariableComposed("msg.sender"), SolidityVariableComposed("msg.sender"),
node.function.contract, node,
) )
or is_dependent( or is_dependent(
ir.arguments[1], ir.arguments[1],
SolidityVariable("this"), SolidityVariable("this"),
node.function.contract, node,
) )
) )
): ):

@ -1,5 +1,9 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
from .arbitrary_send_erc20 import ArbitrarySendErc20 from .arbitrary_send_erc20 import ArbitrarySendErc20
@ -38,7 +42,7 @@ Use `msg.sender` as `from` in transferFrom.
arbitrary_sends.detect() arbitrary_sends.detect()
for node in arbitrary_sends.no_permit_results: for node in arbitrary_sends.no_permit_results:
func = node.function func = node.function
info = [func, " uses arbitrary from in transferFrom: ", node, "\n"] info: DETECTOR_INFO = [func, " uses arbitrary from in transferFrom: ", node, "\n"]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -1,5 +1,9 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
from .arbitrary_send_erc20 import ArbitrarySendErc20 from .arbitrary_send_erc20 import ArbitrarySendErc20
@ -41,7 +45,7 @@ Ensure that the underlying ERC20 token correctly implements a permit function.
arbitrary_sends.detect() arbitrary_sends.detect()
for node in arbitrary_sends.permit_results: for node in arbitrary_sends.permit_results:
func = node.function func = node.function
info = [ info: DETECTOR_INFO = [
func, func,
" uses arbitrary from in transferFrom in combination with permit: ", " uses arbitrary from in transferFrom in combination with permit: ",
node, node,

@ -39,6 +39,10 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
ret: List[Node] = [] ret: List[Node] = []
for node in func.nodes: for node in func.nodes:
func = node.function
deps_target: Union[Contract, Function] = (
func.contract if isinstance(func, FunctionContract) else func
)
for ir in node.irs: for ir in node.irs:
if isinstance(ir, SolidityCall): if isinstance(ir, SolidityCall):
if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"): if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"):
@ -49,7 +53,7 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
if is_dependent( if is_dependent(
ir.variable_right, ir.variable_right,
SolidityVariableComposed("msg.sender"), SolidityVariableComposed("msg.sender"),
func.contract, deps_target,
): ):
return False return False
if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)):
@ -64,11 +68,11 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
if is_dependent( if is_dependent(
ir.call_value, ir.call_value,
SolidityVariableComposed("msg.value"), SolidityVariableComposed("msg.value"),
func.contract, node,
): ):
continue continue
if is_tainted(ir.destination, func.contract): if is_tainted(ir.destination, node):
ret.append(node) ret.append(node)
return ret return ret

@ -7,6 +7,7 @@ from slither.detectors.abstract_detector import (
DetectorClassification, DetectorClassification,
ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_04,
ALL_SOLC_VERSIONS_05, ALL_SOLC_VERSIONS_05,
DETECTOR_INFO,
) )
from slither.core.cfg.node import Node, NodeType from slither.core.cfg.node import Node, NodeType
from slither.slithir.operations import Assignment, Length from slither.slithir.operations import Assignment, Length
@ -120,7 +121,7 @@ Otherwise, thoroughly review the contract to ensure a user-controlled variable c
for contract in self.contracts: for contract in self.contracts:
array_length_assignments = detect_array_length_assignment(contract) array_length_assignments = detect_array_length_assignment(contract)
if array_length_assignments: if array_length_assignments:
contract_info = [ contract_info: DETECTOR_INFO = [
contract, contract,
" contract sets array length with a user-controlled value:\n", " contract sets array length with a user-controlled value:\n",
] ]

@ -6,7 +6,11 @@ from typing import List, Tuple
from slither.core.cfg.node import Node, NodeType from slither.core.cfg.node import Node, NodeType
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -52,7 +56,7 @@ class Assembly(AbstractDetector):
for c in self.contracts: for c in self.contracts:
values = self.detect_assembly(c) values = self.detect_assembly(c)
for func, nodes in values: for func, nodes in values:
info = [func, " uses assembly\n"] info: DETECTOR_INFO = [func, " uses assembly\n"]
# sort the nodes to get deterministic results # sort the nodes to get deterministic results
nodes.sort(key=lambda x: x.node_id) nodes.sort(key=lambda x: x.node_id)

@ -1,5 +1,7 @@
from typing import Optional, List from typing import Optional, List, Union
from slither.core.declarations import Function
from slither.core.variables import Variable
from slither.slithir.operations.operation import Operation from slither.slithir.operations.operation import Operation
@ -16,7 +18,8 @@ class Call(Operation):
def arguments(self, v): def arguments(self, v):
self._arguments = v self._arguments = v
def can_reenter(self, _callstack: Optional[List] = None) -> bool: # pylint: disable=no-self-use # pylint: disable=no-self-use
def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
""" """
Must be called after slithIR analysis pass Must be called after slithIR analysis pass
:return: bool :return: bool

@ -1,5 +1,6 @@
from typing import List, Optional, Union from typing import List, Optional, Union
from slither.core.declarations import Contract
from slither.slithir.operations.call import Call from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
@ -32,7 +33,8 @@ class HighLevelCall(Call, OperationWithLValue):
assert is_valid_lvalue(result) or result is None assert is_valid_lvalue(result) or result is None
self._check_destination(destination) self._check_destination(destination)
super().__init__() super().__init__()
self._destination = destination # Contract is only possible for library call, which inherits from highlevelcall
self._destination: Union[Variable, SolidityVariable, Contract] = destination # type: ignore
self._function_name = function_name self._function_name = function_name
self._nbr_arguments = nbr_arguments self._nbr_arguments = nbr_arguments
self._type_call = type_call self._type_call = type_call
@ -44,8 +46,9 @@ class HighLevelCall(Call, OperationWithLValue):
self._call_gas = None self._call_gas = None
# Development function, to be removed once the code is stable # Development function, to be removed once the code is stable
# It is ovveride by LbraryCall # It is overridden by LibraryCall
def _check_destination(self, destination: SourceMapping) -> None: # pylint: disable=no-self-use # pylint: disable=no-self-use
def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None:
assert isinstance(destination, (Variable, SolidityVariable)) assert isinstance(destination, (Variable, SolidityVariable))
@property @property
@ -79,7 +82,14 @@ class HighLevelCall(Call, OperationWithLValue):
return [x for x in all_read if x] + [self.destination] return [x for x in all_read if x] + [self.destination]
@property @property
def destination(self) -> SourceMapping: def destination(self) -> Union[Variable, SolidityVariable, Contract]:
"""
Return a variable or a solidityVariable
Contract is only possible for LibraryCall
Returns:
"""
return self._destination return self._destination
@property @property
@ -116,7 +126,7 @@ class HighLevelCall(Call, OperationWithLValue):
return True return True
return False return False
def can_reenter(self, callstack: None = None) -> bool: def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
""" """
Must be called after slithIR analysis pass Must be called after slithIR analysis pass
For Solidity > 0.5, filter access to public variables and constant/pure/view For Solidity > 0.5, filter access to public variables and constant/pure/view

@ -1,20 +1,20 @@
from typing import List, Union from typing import List, Union
from slither.core.declarations import SolidityVariableComposed from slither.core.declarations import SolidityVariableComposed
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue
from slither.slithir.variables.reference import ReferenceVariable
from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.slithir.variables.reference_ssa import ReferenceVariableSSA from slither.slithir.operations.lvalue import OperationWithLValue
from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, RVALUE, LVALUE
from slither.slithir.variables.reference import ReferenceVariable
class Index(OperationWithLValue): class Index(OperationWithLValue):
def __init__( def __init__(
self, self,
result: Union[ReferenceVariable, ReferenceVariableSSA], result: ReferenceVariable,
left_variable: Variable, left_variable: Variable,
right_variable: SourceMapping, right_variable: RVALUE,
index_type: Union[ElementaryType, str], index_type: Union[ElementaryType, str],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -25,23 +25,23 @@ class Index(OperationWithLValue):
assert isinstance(result, ReferenceVariable) assert isinstance(result, ReferenceVariable)
self._variables = [left_variable, right_variable] self._variables = [left_variable, right_variable]
self._type = index_type self._type = index_type
self._lvalue = result self._lvalue: ReferenceVariable = result
@property @property
def read(self) -> List[SourceMapping]: def read(self) -> List[SourceMapping]:
return list(self.variables) return list(self.variables)
@property @property
def variables(self) -> List[SourceMapping]: def variables(self) -> List[Union[LVALUE, RVALUE, SolidityVariableComposed]]:
return self._variables return self._variables # type: ignore
@property @property
def variable_left(self) -> Variable: def variable_left(self) -> Union[LVALUE, SolidityVariableComposed]:
return self._variables[0] return self._variables[0] # type: ignore
@property @property
def variable_right(self) -> SourceMapping: def variable_right(self) -> RVALUE:
return self._variables[1] return self._variables[1] # type: ignore
@property @property
def index_type(self) -> Union[ElementaryType, str]: def index_type(self) -> Union[ElementaryType, str]:

@ -1,4 +1,7 @@
from slither.core.declarations import Function from typing import Union, Optional, List
from slither.core.declarations import Function, SolidityVariable
from slither.core.variables import Variable
from slither.slithir.operations.high_level_call import HighLevelCall from slither.slithir.operations.high_level_call import HighLevelCall
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
@ -9,10 +12,10 @@ class LibraryCall(HighLevelCall):
""" """
# Development function, to be removed once the code is stable # Development function, to be removed once the code is stable
def _check_destination(self, destination: Contract) -> None: def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None:
assert isinstance(destination, Contract) assert isinstance(destination, Contract)
def can_reenter(self, callstack: None = None) -> bool: def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
""" """
Must be called after slithIR analysis pass Must be called after slithIR analysis pass
:return: bool :return: bool
@ -20,11 +23,11 @@ class LibraryCall(HighLevelCall):
if self.is_static_call(): if self.is_static_call():
return False return False
# In case of recursion, return False # In case of recursion, return False
callstack = [] if callstack is None else callstack callstack_local = [] if callstack is None else callstack
if self.function in callstack: if self.function in callstack_local:
return False return False
callstack = callstack + [self.function] callstack_local = callstack_local + [self.function]
return self.function.can_reenter(callstack) return self.function.can_reenter(callstack_local)
def __str__(self): def __str__(self):
gas = "" gas = ""

@ -1,4 +1,6 @@
from typing import List, Union from typing import List, Union, Optional
from slither.core.declarations import Function
from slither.slithir.operations.call import Call from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
@ -74,7 +76,7 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta
# remove None # remove None
return self._unroll([x for x in all_read if x]) return self._unroll([x for x in all_read if x])
def can_reenter(self, _callstack: None = None) -> bool: def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
""" """
Must be called after slithIR analysis pass Must be called after slithIR analysis pass
:return: bool :return: bool

@ -1,4 +1,6 @@
from typing import Any, List from typing import Any, List, Optional
from slither.core.variables import Variable
from slither.slithir.operations.operation import Operation from slither.slithir.operations.operation import Operation
@ -10,16 +12,16 @@ class OperationWithLValue(Operation):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._lvalue = None self._lvalue: Optional[Variable] = None
@property @property
def lvalue(self): def lvalue(self) -> Optional[Variable]:
return self._lvalue return self._lvalue
@property
def used(self) -> List[Any]:
return self.read + [self.lvalue]
@lvalue.setter @lvalue.setter
def lvalue(self, lvalue): def lvalue(self, lvalue: Variable) -> None:
self._lvalue = lvalue self._lvalue = lvalue
@property
def used(self) -> List[Optional[Any]]:
return self.read + [self.lvalue]

@ -5,7 +5,7 @@ from slither.core.declarations.enum import Enum
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.solidity_types import ElementaryType from slither.core.solidity_types import ElementaryType
from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.lvalue import OperationWithLValue
from slither.slithir.utils.utils import is_valid_rvalue from slither.slithir.utils.utils import is_valid_rvalue, RVALUE
from slither.slithir.variables.constant import Constant from slither.slithir.variables.constant import Constant
from slither.slithir.variables.reference import ReferenceVariable from slither.slithir.variables.reference import ReferenceVariable
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
@ -39,7 +39,9 @@ class Member(OperationWithLValue):
assert isinstance(variable_right, Constant) assert isinstance(variable_right, Constant)
assert isinstance(result, ReferenceVariable) assert isinstance(result, ReferenceVariable)
super().__init__() super().__init__()
self._variable_left = variable_left self._variable_left: Union[
RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType
] = variable_left
self._variable_right = variable_right self._variable_right = variable_right
self._lvalue = result self._lvalue = result
self._gas = None self._gas = None
@ -50,7 +52,11 @@ class Member(OperationWithLValue):
return [self.variable_left, self.variable_right] return [self.variable_left, self.variable_right]
@property @property
def variable_left(self) -> SourceMapping: def variable_left(
self,
) -> Union[
RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType
]:
return self._variable_left return self._variable_left
@property @property

@ -1,11 +1,13 @@
from typing import Optional, Any, List, Union from typing import Optional, Any, List, Union
from slither.core.declarations import Function
from slither.core.declarations.contract import Contract
from slither.core.variables import Variable
from slither.slithir.operations import Call, OperationWithLValue from slither.slithir.operations import Call, OperationWithLValue
from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.utils.utils import is_valid_lvalue
from slither.slithir.variables.constant import Constant from slither.slithir.variables.constant import Constant
from slither.core.declarations.contract import Contract
from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.temporary import TemporaryVariable
from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA
from slither.core.declarations.function_contract import FunctionContract
class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes
@ -58,6 +60,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan
def contract_created(self) -> Contract: def contract_created(self) -> Contract:
contract_name = self.contract_name contract_name = self.contract_name
contract_instance = self.node.file_scope.get_contract_from_name(contract_name) contract_instance = self.node.file_scope.get_contract_from_name(contract_name)
assert contract_instance
return contract_instance return contract_instance
################################################################################### ###################################################################################
@ -66,7 +69,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def can_reenter(self, callstack: Optional[List[FunctionContract]] = None) -> bool: def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
""" """
Must be called after slithIR analysis pass Must be called after slithIR analysis pass
For Solidity > 0.5, filter access to public variables and constant/pure/view For Solidity > 0.5, filter access to public variables and constant/pure/view
@ -92,7 +95,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan
# endregion # endregion
def __str__(self): def __str__(self) -> str:
options = "" options = ""
if self.call_value: if self.call_value:
options = f"value:{self.call_value} " options = f"value:{self.call_value} "

@ -1,15 +1,16 @@
from typing import Any, List, Union from typing import Any, List, Union
from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction
from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.children.child_node import ChildNode from slither.core.children.child_node import ChildNode
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.elementary_type import ElementaryType
from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue
class SolidityCall(Call, OperationWithLValue): class SolidityCall(Call, OperationWithLValue):
def __init__( def __init__(
self, self,
function: Union[SolidityCustomRevert, SolidityFunction], function: SolidityFunction,
nbr_arguments: int, nbr_arguments: int,
result: ChildNode, result: ChildNode,
type_call: Union[str, List[ElementaryType]], type_call: Union[str, List[ElementaryType]],
@ -26,7 +27,7 @@ class SolidityCall(Call, OperationWithLValue):
return self._unroll(self.arguments) return self._unroll(self.arguments)
@property @property
def function(self) -> Union[SolidityCustomRevert, SolidityFunction]: def function(self) -> SolidityFunction:
return self._function return self._function
@property @property

@ -1,3 +1,5 @@
from typing import Union
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
@ -10,6 +12,24 @@ from slither.slithir.variables.reference import ReferenceVariable
from slither.slithir.variables.tuple import TupleVariable from slither.slithir.variables.tuple import TupleVariable
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
RVALUE = Union[
StateVariable,
LocalVariable,
TopLevelVariable,
TemporaryVariable,
Constant,
SolidityVariable,
ReferenceVariable,
]
LVALUE = Union[
StateVariable,
LocalVariable,
TemporaryVariable,
ReferenceVariable,
TupleVariable,
]
def is_valid_rvalue(v: SourceMapping) -> bool: def is_valid_rvalue(v: SourceMapping) -> bool:
return isinstance( return isinstance(

@ -79,9 +79,10 @@ def main() -> None:
print(args.codebase) print(args.codebase)
sl = Slither(args.codebase, **vars(args)) sl = Slither(args.codebase, **vars(args))
for M in _get_mutators(): for compilation_unit in sl.compilation_units:
m = M(sl) for M in _get_mutators():
m.mutate() m = M(compilation_unit)
m.mutate()
# endregion # endregion

@ -3,7 +3,7 @@ import logging
from enum import Enum from enum import Enum
from typing import Optional, Dict from typing import Optional, Dict
from slither import Slither from slither.core.compilation_unit import SlitherCompilationUnit
from slither.formatters.utils.patches import apply_patch, create_diff from slither.formatters.utils.patches import apply_patch, create_diff
logger = logging.getLogger("Slither") logger = logging.getLogger("Slither")
@ -34,8 +34,11 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public-
FAULTCLASS = FaultClass.Undefined FAULTCLASS = FaultClass.Undefined
FAULTNATURE = FaultNature.Undefined FAULTNATURE = FaultNature.Undefined
def __init__(self, slither: Slither, rate: int = 10, seed: Optional[int] = None): def __init__(
self.slither = slither self, compilation_unit: SlitherCompilationUnit, rate: int = 10, seed: Optional[int] = None
):
self.compilation_unit = compilation_unit
self.slither = compilation_unit.core
self.seed = seed self.seed = seed
self.rate = rate self.rate = rate
@ -87,7 +90,7 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public-
continue continue
for patch in patches: for patch in patches:
patched_txt, offset = apply_patch(patched_txt, patch, offset) patched_txt, offset = apply_patch(patched_txt, patch, offset)
diff = create_diff(self.slither, original_txt, patched_txt, file) diff = create_diff(self.compilation_unit, original_txt, patched_txt, file)
if not diff: if not diff:
logger.info(f"Impossible to generate patch; empty {patches}") logger.info(f"Impossible to generate patch; empty {patches}")
print(diff) print(diff)

Loading…
Cancel
Save