Merge pull request #1666 from crytic/dev-more-types

Add more types hints
pull/1346/merge
Feist Josselin 2 years ago committed by GitHub
commit 346d3b63f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      examples/scripts/data_dependency.py
  2. 1
      examples/scripts/variable_in_condition.py
  3. 67
      slither/__main__.py
  4. 146
      slither/analyses/data_dependency/data_dependency.py
  5. 29
      slither/analyses/write/are_variables_written.py
  6. 100
      slither/core/cfg/node.py
  7. 19
      slither/core/children/child_contract.py
  8. 17
      slither/core/children/child_event.py
  9. 18
      slither/core/children/child_expression.py
  10. 17
      slither/core/children/child_function.py
  11. 17
      slither/core/children/child_inheritance.py
  12. 31
      slither/core/children/child_node.py
  13. 17
      slither/core/children/child_structure.py
  14. 17
      slither/core/compilation_unit.py
  15. 76
      slither/core/declarations/contract.py
  16. 29
      slither/core/declarations/contract_level.py
  17. 10
      slither/core/declarations/custom_error.py
  18. 12
      slither/core/declarations/custom_error_contract.py
  19. 2
      slither/core/declarations/custom_error_top_level.py
  20. 4
      slither/core/declarations/enum_contract.py
  21. 4
      slither/core/declarations/event.py
  22. 7
      slither/core/declarations/function.py
  23. 29
      slither/core/declarations/function_contract.py
  24. 12
      slither/core/declarations/solidity_variables.py
  25. 4
      slither/core/declarations/structure_contract.py
  26. 6
      slither/core/declarations/top_level.py
  27. 3
      slither/core/declarations/using_for_top_level.py
  28. 1
      slither/core/dominators/utils.py
  29. 5
      slither/core/expressions/assignment_operation.py
  30. 3
      slither/core/expressions/binary_operation.py
  31. 8
      slither/core/expressions/call_expression.py
  32. 2
      slither/core/expressions/conditional_expression.py
  33. 20
      slither/core/expressions/expression_typed.py
  34. 76
      slither/core/expressions/identifier.py
  35. 19
      slither/core/expressions/index_access.py
  36. 4
      slither/core/expressions/literal.py
  37. 3
      slither/core/expressions/member_access.py
  38. 11
      slither/core/expressions/type_conversion.py
  39. 3
      slither/core/expressions/unary_operation.py
  40. 10
      slither/core/slither_core.py
  41. 15
      slither/core/solidity_types/array_type.py
  42. 6
      slither/core/solidity_types/elementary_type.py
  43. 4
      slither/core/solidity_types/mapping_type.py
  44. 8
      slither/core/solidity_types/type_alias.py
  45. 6
      slither/core/solidity_types/type_information.py
  46. 6
      slither/core/source_mapping/source_mapping.py
  47. 5
      slither/core/variables/event_variable.py
  48. 17
      slither/core/variables/local_variable.py
  49. 4
      slither/core/variables/state_variable.py
  50. 19
      slither/core/variables/structure_variable.py
  51. 18
      slither/core/variables/variable.py
  52. 4
      slither/detectors/abstract_detector.py
  53. 13
      slither/detectors/assembly/shift_parameter_mixup.py
  54. 12
      slither/detectors/attributes/const_functions_asm.py
  55. 9
      slither/detectors/attributes/const_functions_state.py
  56. 15
      slither/detectors/attributes/constant_pragma.py
  57. 8
      slither/detectors/attributes/incorrect_solc.py
  58. 8
      slither/detectors/attributes/locked_ether.py
  59. 8
      slither/detectors/attributes/unimplemented_interface.py
  60. 20
      slither/detectors/compiler_bugs/array_by_reference.py
  61. 11
      slither/detectors/compiler_bugs/enum_conversion.py
  62. 11
      slither/detectors/compiler_bugs/multiple_constructor_schemes.py
  63. 3
      slither/detectors/compiler_bugs/reused_base_constructor.py
  64. 9
      slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py
  65. 32
      slither/detectors/compiler_bugs/storage_signed_integer_array.py
  66. 5
      slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py
  67. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20.py
  68. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py
  69. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py
  70. 8
      slither/detectors/erc/erc20/incorrect_erc20_interface.py
  71. 12
      slither/detectors/erc/incorrect_erc721_interface.py
  72. 8
      slither/detectors/examples/backdoor.py
  73. 15
      slither/detectors/functions/arbitrary_send_eth.py
  74. 8
      slither/detectors/functions/cyclomatic_complexity.py
  75. 8
      slither/detectors/functions/dead_code.py
  76. 12
      slither/detectors/functions/modifier.py
  77. 8
      slither/detectors/functions/permit_domain_signature_collision.py
  78. 8
      slither/detectors/functions/protected_variable.py
  79. 8
      slither/detectors/functions/suicidal.py
  80. 14
      slither/detectors/functions/unimplemented.py
  81. 7
      slither/detectors/naming_convention/naming_convention.py
  82. 9
      slither/detectors/operations/bad_prng.py
  83. 29
      slither/detectors/operations/block_timestamp.py
  84. 8
      slither/detectors/operations/low_level_calls.py
  85. 8
      slither/detectors/operations/missing_events_access_control.py
  86. 8
      slither/detectors/operations/missing_events_arithmetic.py
  87. 8
      slither/detectors/operations/missing_zero_address_validation.py
  88. 8
      slither/detectors/operations/unused_return_values.py
  89. 8
      slither/detectors/operations/void_constructor.py
  90. 8
      slither/detectors/reentrancy/token.py
  91. 8
      slither/detectors/shadowing/builtin_symbols.py
  92. 10
      slither/detectors/shadowing/local.py
  93. 8
      slither/detectors/shadowing/state.py
  94. 13
      slither/detectors/slither/name_reused.py
  95. 8
      slither/detectors/source/rtlo.py
  96. 16
      slither/detectors/statements/array_length_assignment.py
  97. 8
      slither/detectors/statements/assembly.py
  98. 17
      slither/detectors/statements/assert_state_change.py
  99. 8
      slither/detectors/statements/boolean_constant_equality.py
  100. 8
      slither/detectors/statements/boolean_constant_misuse.py
  101. Some files were not shown because too many files have changed in this diff Show More

@ -18,6 +18,8 @@ assert len(contracts) == 1
contract = contracts[0] contract = contracts[0]
destination = contract.get_state_variable_from_name("destination") destination = contract.get_state_variable_from_name("destination")
source = contract.get_state_variable_from_name("source") source = contract.get_state_variable_from_name("source")
assert source
assert destination
print(f"{source} is dependent of {destination}: {is_dependent(source, destination, contract)}") print(f"{source} is dependent of {destination}: {is_dependent(source, destination, contract)}")
assert not is_dependent(source, destination, contract) assert not is_dependent(source, destination, contract)
@ -47,9 +49,11 @@ print(f"{destination} is tainted {is_tainted(destination, contract)}")
assert is_tainted(destination, contract) assert is_tainted(destination, contract)
destination_indirect_1 = contract.get_state_variable_from_name("destination_indirect_1") destination_indirect_1 = contract.get_state_variable_from_name("destination_indirect_1")
assert destination_indirect_1
print(f"{destination_indirect_1} is tainted {is_tainted(destination_indirect_1, contract)}") print(f"{destination_indirect_1} is tainted {is_tainted(destination_indirect_1, contract)}")
assert is_tainted(destination_indirect_1, contract) assert is_tainted(destination_indirect_1, contract)
destination_indirect_2 = contract.get_state_variable_from_name("destination_indirect_2") destination_indirect_2 = contract.get_state_variable_from_name("destination_indirect_2")
assert destination_indirect_2
print(f"{destination_indirect_2} is tainted {is_tainted(destination_indirect_2, contract)}") print(f"{destination_indirect_2} is tainted {is_tainted(destination_indirect_2, contract)}")
assert is_tainted(destination_indirect_2, contract) assert is_tainted(destination_indirect_2, contract)
@ -88,6 +92,8 @@ contract = contracts[0]
contract_derived = slither.get_contract_from_name("Derived")[0] contract_derived = slither.get_contract_from_name("Derived")[0]
destination = contract.get_state_variable_from_name("destination") destination = contract.get_state_variable_from_name("destination")
source = contract.get_state_variable_from_name("source") source = contract.get_state_variable_from_name("source")
assert destination
assert source
print(f"{destination} is dependent of {source}: {is_dependent(destination, source, contract)}") print(f"{destination} is dependent of {source}: {is_dependent(destination, source, contract)}")
assert not is_dependent(destination, source, contract) assert not is_dependent(destination, source, contract)

@ -14,6 +14,7 @@ assert len(contracts) == 1
contract = contracts[0] contract = contracts[0]
# Get the variable # Get the variable
var_a = contract.get_state_variable_from_name("a") var_a = contract.get_state_variable_from_name("a")
assert var_a
# Get the functions reading the variable # Get the functions reading the variable
functions_reading_a = contract.get_functions_reading_from_variable(var_a) functions_reading_a = contract.get_functions_reading_from_variable(var_a)

@ -66,7 +66,7 @@ def process_single(
args: argparse.Namespace, args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]], detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]], printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]: ) -> Tuple[Slither, List[Dict], List[Output], int]:
""" """
The core high-level code for running Slither static analysis. The core high-level code for running Slither static analysis.
@ -86,7 +86,7 @@ def process_all(
args: argparse.Namespace, args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]], detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]], printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[List[Slither], List[Dict], List[Dict], int]: ) -> Tuple[List[Slither], List[Dict], List[Output], int]:
compilations = compile_all(target, **vars(args)) compilations = compile_all(target, **vars(args))
slither_instances = [] slither_instances = []
results_detectors = [] results_detectors = []
@ -141,23 +141,6 @@ def _process(
return slither, results_detectors, results_printers, analyzed_contracts_count return slither, results_detectors, results_printers, analyzed_contracts_count
# TODO: delete me?
def process_from_asts(
filenames: List[str],
args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
all_contracts: List[str] = []
for filename in filenames:
with open(filename, encoding="utf8") as file_open:
contract_loaded = json.load(file_open)
all_contracts.append(contract_loaded["ast"])
return process_single(all_contracts, args, detector_classes, printer_classes)
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -605,9 +588,6 @@ def parse_args(
default=False, default=False,
) )
# if the json is splitted in different files
parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False)
# Disable the throw/catch on partial analyses # Disable the throw/catch on partial analyses
parser.add_argument( parser.add_argument(
"--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False "--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False
@ -623,7 +603,7 @@ def parse_args(
args.filter_paths = parse_filter_paths(args) args.filter_paths = parse_filter_paths(args)
# Verify our json-type output is valid # Verify our json-type output is valid
args.json_types = set(args.json_types.split(",")) args.json_types = set(args.json_types.split(",")) # type:ignore
for json_type in args.json_types: for json_type in args.json_types:
if json_type not in JSON_OUTPUT_TYPES: if json_type not in JSON_OUTPUT_TYPES:
raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.') raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.')
@ -632,7 +612,9 @@ def parse_args(
class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods
def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs def __call__(
self, parser: Any, *args: Any, **kwargs: Any
) -> None: # pylint: disable=signature-differs
detectors, _ = get_detectors_and_printers() detectors, _ = get_detectors_and_printers()
output_detectors(detectors) output_detectors(detectors)
parser.exit() parser.exit()
@ -694,14 +676,14 @@ class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods
class FormatterCryticCompile(logging.Formatter): class FormatterCryticCompile(logging.Formatter):
def format(self, record): def format(self, record: logging.LogRecord) -> str:
# for i, msg in enumerate(record.msg): # for i, msg in enumerate(record.msg):
if record.msg.startswith("Compilation warnings/errors on "): if record.msg.startswith("Compilation warnings/errors on "):
txt = record.args[1] txt = record.args[1] # type:ignore
txt = txt.split("\n") txt = txt.split("\n") # type:ignore
txt = [red(x) if "Error" in x else x for x in txt] txt = [red(x) if "Error" in x else x for x in txt]
txt = "\n".join(txt) txt = "\n".join(txt)
record.args = (record.args[0], txt) record.args = (record.args[0], txt) # type:ignore
return super().format(record) return super().format(record)
@ -744,7 +726,7 @@ def main_impl(
set_colorization_enabled(False if args.disable_color else sys.stdout.isatty()) set_colorization_enabled(False if args.disable_color else sys.stdout.isatty())
# Define some variables for potential JSON output # Define some variables for potential JSON output
json_results = {} json_results: Dict[str, Any] = {}
output_error = None output_error = None
outputting_json = args.json is not None outputting_json = args.json is not None
outputting_json_stdout = args.json == "-" outputting_json_stdout = args.json == "-"
@ -793,7 +775,7 @@ def main_impl(
crytic_compile_error.setLevel(logging.INFO) crytic_compile_error.setLevel(logging.INFO)
results_detectors: List[Dict] = [] results_detectors: List[Dict] = []
results_printers: List[Dict] = [] results_printers: List[Output] = []
try: try:
filename = args.filename filename = args.filename
@ -806,26 +788,17 @@ def main_impl(
number_contracts = 0 number_contracts = 0
slither_instances = [] slither_instances = []
if args.splitted: for filename in filenames:
( (
slither_instance, slither_instance,
results_detectors, results_detectors_tmp,
results_printers, results_printers_tmp,
number_contracts, number_contracts_tmp,
) = process_from_asts(filenames, args, detector_classes, printer_classes) ) = process_single(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp
results_printers += results_printers_tmp
slither_instances.append(slither_instance) slither_instances.append(slither_instance)
else:
for filename in filenames:
(
slither_instance,
results_detectors_tmp,
results_printers_tmp,
number_contracts_tmp,
) = process_single(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp
results_printers += results_printers_tmp
slither_instances.append(slither_instance)
# Rely on CryticCompile to discern the underlying type of compilations. # Rely on CryticCompile to discern the underlying type of compilations.
else: else:

@ -2,8 +2,9 @@
Compute the data depenency between all the SSA variables Compute the data depenency between all the SSA variables
""" """
from collections import defaultdict from collections import defaultdict
from typing import Union, Set, Dict, TYPE_CHECKING from typing import Union, Set, Dict, TYPE_CHECKING, List
from slither.core.cfg.node import Node
from slither.core.declarations import ( from slither.core.declarations import (
Contract, Contract,
Enum, Enum,
@ -12,11 +13,14 @@ from slither.core.declarations import (
SolidityVariable, SolidityVariable,
SolidityVariableComposed, SolidityVariableComposed,
Structure, Structure,
FunctionContract,
) )
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.solidity_types.type import Type
from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.top_level_variable import TopLevelVariable
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation
from slither.slithir.utils.utils import LVALUE
from slither.slithir.variables import ( from slither.slithir.variables import (
Constant, Constant,
LocalIRVariable, LocalIRVariable,
@ -26,12 +30,11 @@ from slither.slithir.variables import (
TemporaryVariableSSA, TemporaryVariableSSA,
TupleVariableSSA, TupleVariableSSA,
) )
from slither.core.solidity_types.type import Type from slither.slithir.variables.variable import SlithIRVariable
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region User APIs # region User APIs
@ -39,26 +42,39 @@ if TYPE_CHECKING:
################################################################################### ###################################################################################
Variable_types = Union[Variable, SolidityVariable] SUPPORTED_TYPES = Union[Variable, SolidityVariable]
# TODO refactor the data deps to be better suited for top level function object
# Right now we allow to pass a node to ease the API, but we need something
# better
# The deps propagation for top level elements is also not working as expected
Context_types_API = Union[Contract, Function, Node]
Context_types = Union[Contract, Function] Context_types = Union[Contract, Function]
def is_dependent( def is_dependent(
variable: Variable_types, variable: SUPPORTED_TYPES,
source: Variable_types, source: SUPPORTED_TYPES,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> bool: ) -> bool:
""" """
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args: Args:
variable (Variable) variable (Variable)
source (Variable) source (Variable)
context (Contract|Function) context (Contract|Function|Node).
only_unprotected (bool): True only unprotected function are considered only_unprotected (bool): True only unprotected function are considered
Returns: Returns:
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
if variable == source: if variable == source:
@ -74,12 +90,15 @@ def is_dependent(
def is_dependent_ssa( def is_dependent_ssa(
variable: Variable_types, variable: SUPPORTED_TYPES,
source: Variable_types, source: SUPPORTED_TYPES,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> bool: ) -> bool:
""" """
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args: Args:
variable (Variable) variable (Variable)
taint (Variable) taint (Variable)
@ -88,7 +107,10 @@ def is_dependent_ssa(
Returns: Returns:
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
context_dict = context.context context_dict = context.context
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
@ -111,12 +133,15 @@ GENERIC_TAINT = {
def is_tainted( def is_tainted(
variable: Variable_types, variable: SUPPORTED_TYPES,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
ignore_generic_taint: bool = False, ignore_generic_taint: bool = False,
) -> bool: ) -> bool:
""" """
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args: Args:
variable variable
context (Contract|Function) context (Contract|Function)
@ -124,7 +149,10 @@ def is_tainted(
Returns: Returns:
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
@ -138,12 +166,15 @@ def is_tainted(
def is_tainted_ssa( def is_tainted_ssa(
variable: Variable_types, variable: SUPPORTED_TYPES,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
ignore_generic_taint: bool = False, ignore_generic_taint: bool = False,
): ) -> bool:
""" """
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args: Args:
variable variable
context (Contract|Function) context (Contract|Function)
@ -151,7 +182,10 @@ def is_tainted_ssa(
Returns: Returns:
bool bool
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if isinstance(variable, Constant): if isinstance(variable, Constant):
return False return False
@ -165,19 +199,24 @@ def is_tainted_ssa(
def get_dependencies( def get_dependencies(
variable: Variable_types, variable: SUPPORTED_TYPES,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> Set[Variable]: ) -> Set[Variable]:
""" """
Return the variables for which `variable` depends on. Return the variables for which `variable` depends on.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param variable: The target :param variable: The target
:param context: Either a function (interprocedural) or a contract (inter transactional) :param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions :param only_unprotected: True if consider only protected functions
:return: set(Variable) :return: set(Variable)
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if only_unprotected: if only_unprotected:
return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set()) return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set())
@ -185,16 +224,21 @@ def get_dependencies(
def get_all_dependencies( def get_all_dependencies(
context: Context_types, only_unprotected: bool = False context: Context_types_API, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]: ) -> Dict[Variable, Set[Variable]]:
""" """
Return the dictionary of dependencies. Return the dictionary of dependencies.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param context: Either a function (interprocedural) or a contract (inter transactional) :param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions :param only_unprotected: True if consider only protected functions
:return: Dict(Variable, set(Variable)) :return: Dict(Variable, set(Variable))
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if only_unprotected: if only_unprotected:
return context.context[KEY_NON_SSA_UNPROTECTED] return context.context[KEY_NON_SSA_UNPROTECTED]
@ -202,19 +246,24 @@ def get_all_dependencies(
def get_dependencies_ssa( def get_dependencies_ssa(
variable: Variable_types, variable: SUPPORTED_TYPES,
context: Context_types, context: Context_types_API,
only_unprotected: bool = False, only_unprotected: bool = False,
) -> Set[Variable]: ) -> Set[Variable]:
""" """
Return the variables for which `variable` depends on (SSA version). Return the variables for which `variable` depends on (SSA version).
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param variable: The target (must be SSA variable) :param variable: The target (must be SSA variable)
:param context: Either a function (interprocedural) or a contract (inter transactional) :param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions :param only_unprotected: True if consider only protected functions
:return: set(Variable) :return: set(Variable)
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if only_unprotected: if only_unprotected:
return context.context[KEY_SSA_UNPROTECTED].get(variable, set()) return context.context[KEY_SSA_UNPROTECTED].get(variable, set())
@ -222,16 +271,21 @@ def get_dependencies_ssa(
def get_all_dependencies_ssa( def get_all_dependencies_ssa(
context: Context_types, only_unprotected: bool = False context: Context_types_API, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]: ) -> Dict[Variable, Set[Variable]]:
""" """
Return the dictionary of dependencies. Return the dictionary of dependencies.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param context: Either a function (interprocedural) or a contract (inter transactional) :param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions :param only_unprotected: True if consider only protected functions
:return: Dict(Variable, set(Variable)) :return: Dict(Variable, set(Variable))
""" """
assert isinstance(context, (Contract, Function)) assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool) assert isinstance(only_unprotected, bool)
if only_unprotected: if only_unprotected:
return context.context[KEY_SSA_UNPROTECTED] return context.context[KEY_SSA_UNPROTECTED]
@ -341,13 +395,9 @@ def transitive_close_dependencies(
while changed: while changed:
changed = False changed = False
to_add = defaultdict(set) to_add = defaultdict(set)
[ # pylint: disable=expression-not-assigned for key, items in context.context[context_key].items():
[ for item in items & keys:
to_add[key].update(context.context[context_key][item] - {key} - items) to_add[key].update(context.context[context_key][item] - {key} - items)
for item in items & keys
]
for key, items in context.context[context_key].items()
]
for k, v in to_add.items(): for k, v in to_add.items():
# Because we dont have any check on the update operation # Because we dont have any check on the update operation
# We might update an empty set with an empty set # We might update an empty set with an empty set
@ -366,20 +416,20 @@ def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_prote
function.context[KEY_SSA][lvalue] = set() function.context[KEY_SSA][lvalue] = set()
if not is_protected: if not is_protected:
function.context[KEY_SSA_UNPROTECTED][lvalue] = set() function.context[KEY_SSA_UNPROTECTED][lvalue] = set()
read: Union[List[Union[LVALUE, SolidityVariableComposed]], List[SlithIRVariable]]
if isinstance(ir, Index): if isinstance(ir, Index):
read = [ir.variable_left] read = [ir.variable_left]
elif isinstance(ir, InternalCall): elif isinstance(ir, InternalCall) and ir.function:
read = ir.function.return_values_ssa read = ir.function.return_values_ssa
else: else:
read = ir.read read = ir.read
# pylint: disable=expression-not-assigned for v in read:
[function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)] if not isinstance(v, Constant):
function.context[KEY_SSA][lvalue].add(v)
if not is_protected: if not is_protected:
[ for v in read:
function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) if not isinstance(v, Constant):
for v in read function.context[KEY_SSA_UNPROTECTED][lvalue].add(v)
if not isinstance(v, Constant)
]
def compute_dependency_function(function: Function) -> None: def compute_dependency_function(function: Function) -> None:
@ -407,7 +457,7 @@ def compute_dependency_function(function: Function) -> None:
) )
def convert_variable_to_non_ssa(v: Variable_types) -> Variable_types: def convert_variable_to_non_ssa(v: SUPPORTED_TYPES) -> SUPPORTED_TYPES:
if isinstance( if isinstance(
v, v,
( (
@ -438,10 +488,10 @@ def convert_variable_to_non_ssa(v: Variable_types) -> Variable_types:
def convert_to_non_ssa( def convert_to_non_ssa(
data_depencies: Dict[Variable_types, Set[Variable_types]] data_depencies: Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]]
) -> Dict[Variable_types, Set[Variable_types]]: ) -> Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]]:
# Need to create new set() as its changed during iteration # Need to create new set() as its changed during iteration
ret: Dict[Variable_types, Set[Variable_types]] = {} ret: Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]] = {}
for (k, values) in data_depencies.items(): for (k, values) in data_depencies.items():
var = convert_variable_to_non_ssa(k) var = convert_variable_to_non_ssa(k)
if not var in ret: if not var in ret:

@ -2,10 +2,10 @@
Detect if all the given variables are written in all the paths of the function Detect if all the given variables are written in all the paths of the function
""" """
from collections import defaultdict from collections import defaultdict
from typing import Dict, Set, List from typing import Dict, Set, List, Any, Optional
from slither.core.cfg.node import NodeType, Node from slither.core.cfg.node import NodeType, Node
from slither.core.declarations import SolidityFunction from slither.core.declarations import SolidityFunction, Function
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.slithir.operations import ( from slither.slithir.operations import (
Index, Index,
@ -18,7 +18,7 @@ from slither.slithir.variables import ReferenceVariable, TemporaryVariable
class State: # pylint: disable=too-few-public-methods class State: # pylint: disable=too-few-public-methods
def __init__(self): def __init__(self) -> None:
# Map node -> list of variables set # Map node -> list of variables set
# Were each variables set represents a configuration of a path # Were each variables set represents a configuration of a path
# If two paths lead to the exact same set of variables written, we dont need to explore both # If two paths lead to the exact same set of variables written, we dont need to explore both
@ -34,11 +34,11 @@ class State: # pylint: disable=too-few-public-methods
# pylint: disable=too-many-branches # pylint: disable=too-many-branches
def _visit( def _visit(
node: Node, node: Optional[Node],
state: State, state: State,
variables_written: Set[Variable], variables_written: Set[Variable],
variables_to_write: List[Variable], variables_to_write: List[Variable],
): ) -> List[Variable]:
""" """
Explore all the nodes to look for values not written when the node's function return Explore all the nodes to look for values not written when the node's function return
Fixpoint reaches if no new written variables are found Fixpoint reaches if no new written variables are found
@ -51,6 +51,8 @@ def _visit(
refs = {} refs = {}
variables_written = set(variables_written) variables_written = set(variables_written)
if not node:
return []
for ir in node.irs: for ir in node.irs:
if isinstance(ir, SolidityCall): if isinstance(ir, SolidityCall):
# TODO convert the revert to a THROW node # TODO convert the revert to a THROW node
@ -70,17 +72,20 @@ def _visit(
if ir.lvalue and not isinstance(ir.lvalue, (TemporaryVariable, ReferenceVariable)): if ir.lvalue and not isinstance(ir.lvalue, (TemporaryVariable, ReferenceVariable)):
variables_written.add(ir.lvalue) variables_written.add(ir.lvalue)
lvalue = ir.lvalue lvalue: Any = ir.lvalue
while isinstance(lvalue, ReferenceVariable): while isinstance(lvalue, ReferenceVariable):
if lvalue not in refs: if lvalue not in refs:
break break
if refs[lvalue] and not isinstance( refs_lvalues = refs[lvalue]
refs[lvalue], (TemporaryVariable, ReferenceVariable) if (
refs_lvalues
and isinstance(refs_lvalues, Variable)
and not isinstance(refs_lvalues, (TemporaryVariable, ReferenceVariable))
): ):
variables_written.add(refs[lvalue]) variables_written.add(refs_lvalues)
lvalue = refs[lvalue] lvalue = refs_lvalues
ret = [] ret: List[Variable] = []
if not node.sons and node.type not in [NodeType.THROW, NodeType.RETURN]: if not node.sons and node.type not in [NodeType.THROW, NodeType.RETURN]:
ret += [v for v in variables_to_write if v not in variables_written] ret += [v for v in variables_to_write if v not in variables_written]
@ -96,7 +101,7 @@ def _visit(
return ret return ret
def are_variables_written(function, variables_to_write): def are_variables_written(function: Function, variables_to_write: List[Variable]) -> List[Variable]:
""" """
Return the list of variable that are not written at the end of the function Return the list of variable that are not written at the end of the function

@ -5,8 +5,7 @@ from enum import Enum
from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING
from slither.all_exceptions import SlitherException from slither.all_exceptions import SlitherException
from slither.core.children.child_function import ChildFunction from slither.core.declarations import Contract, Function, FunctionContract
from slither.core.declarations import Contract, Function
from slither.core.declarations.solidity_variables import ( from slither.core.declarations.solidity_variables import (
SolidityVariable, SolidityVariable,
SolidityFunction, SolidityFunction,
@ -33,6 +32,7 @@ from slither.slithir.operations import (
Return, Return,
Operation, Operation,
) )
from slither.slithir.utils.utils import RVALUE
from slither.slithir.variables import ( from slither.slithir.variables import (
Constant, Constant,
LocalIRVariable, LocalIRVariable,
@ -106,7 +106,7 @@ class NodeType(Enum):
# I am not sure why, but pylint reports a lot of "no-member" issue that are not real (Josselin) # I am not sure why, but pylint reports a lot of "no-member" issue that are not real (Josselin)
# pylint: disable=no-member # pylint: disable=no-member
class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-methods class Node(SourceMapping): # pylint: disable=too-many-public-methods
""" """
Node class Node class
@ -146,12 +146,12 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._node_id: int = node_id self._node_id: int = node_id
self._vars_written: List[Variable] = [] self._vars_written: List[Variable] = []
self._vars_read: List[Variable] = [] self._vars_read: List[Union[Variable, SolidityVariable]] = []
self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_written: List["SlithIRVariable"] = []
self._ssa_vars_read: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = []
self._internal_calls: List["Function"] = [] self._internal_calls: List[Union["Function", "SolidityFunction"]] = []
self._solidity_calls: List[SolidityFunction] = [] self._solidity_calls: List[SolidityFunction] = []
self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls
self._library_calls: List["LibraryCallType"] = [] self._library_calls: List["LibraryCallType"] = []
@ -172,7 +172,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._local_vars_read: List[LocalVariable] = [] self._local_vars_read: List[LocalVariable] = []
self._local_vars_written: List[LocalVariable] = [] self._local_vars_written: List[LocalVariable] = []
self._slithir_vars: Set["SlithIRVariable"] = set() # non SSA self._slithir_vars: Set[
Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]
] = set() # non SSA
self._ssa_local_vars_read: List[LocalIRVariable] = [] self._ssa_local_vars_read: List[LocalIRVariable] = []
self._ssa_local_vars_written: List[LocalIRVariable] = [] self._ssa_local_vars_written: List[LocalIRVariable] = []
@ -189,6 +191,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self.scope: Union["Scope", "Function"] = scope self.scope: Union["Scope", "Function"] = scope
self.file_scope: "FileScope" = file_scope self.file_scope: "FileScope" = file_scope
self._function: Optional["Function"] = None
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -213,7 +216,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._node_type return self._node_type
@type.setter @type.setter
def type(self, new_type: NodeType): def type(self, new_type: NodeType) -> None:
self._node_type = new_type self._node_type = new_type
@property @property
@ -224,6 +227,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return True return True
return False return False
def set_function(self, function: "Function") -> None:
self._function = function
@property
def function(self) -> "Function":
return self._function
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -232,7 +242,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
################################################################################### ###################################################################################
@property @property
def variables_read(self) -> List[Variable]: def variables_read(self) -> List[Union[Variable, SolidityVariable]]:
""" """
list(Variable): Variables read (local/state/solidity) list(Variable): Variables read (local/state/solidity)
""" """
@ -285,11 +295,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._expression_vars_read return self._expression_vars_read
@variables_read_as_expression.setter @variables_read_as_expression.setter
def variables_read_as_expression(self, exprs: List[Expression]): def variables_read_as_expression(self, exprs: List[Expression]) -> None:
self._expression_vars_read = exprs self._expression_vars_read = exprs
@property @property
def slithir_variables(self) -> List["SlithIRVariable"]: def slithir_variables(
self,
) -> List[Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]]:
return list(self._slithir_vars) return list(self._slithir_vars)
@property @property
@ -339,7 +351,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._expression_vars_written return self._expression_vars_written
@variables_written_as_expression.setter @variables_written_as_expression.setter
def variables_written_as_expression(self, exprs: List[Expression]): def variables_written_as_expression(self, exprs: List[Expression]) -> None:
self._expression_vars_written = exprs self._expression_vars_written = exprs
# endregion # endregion
@ -399,7 +411,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._external_calls_as_expressions return self._external_calls_as_expressions
@external_calls_as_expressions.setter @external_calls_as_expressions.setter
def external_calls_as_expressions(self, exprs: List[Expression]): def external_calls_as_expressions(self, exprs: List[Expression]) -> None:
self._external_calls_as_expressions = exprs self._external_calls_as_expressions = exprs
@property @property
@ -410,7 +422,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._internal_calls_as_expressions return self._internal_calls_as_expressions
@internal_calls_as_expressions.setter @internal_calls_as_expressions.setter
def internal_calls_as_expressions(self, exprs: List[Expression]): def internal_calls_as_expressions(self, exprs: List[Expression]) -> None:
self._internal_calls_as_expressions = exprs self._internal_calls_as_expressions = exprs
@property @property
@ -418,10 +430,10 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return list(self._expression_calls) return list(self._expression_calls)
@calls_as_expression.setter @calls_as_expression.setter
def calls_as_expression(self, exprs: List[Expression]): def calls_as_expression(self, exprs: List[Expression]) -> None:
self._expression_calls = exprs self._expression_calls = exprs
def can_reenter(self, callstack=None) -> bool: def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
""" """
Check if the node can re-enter Check if the node can re-enter
Do not consider CREATE as potential re-enter, but check if the Do not consider CREATE as potential re-enter, but check if the
@ -567,7 +579,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
""" """
self._fathers.append(father) self._fathers.append(father)
def set_fathers(self, fathers: List["Node"]): def set_fathers(self, fathers: List["Node"]) -> None:
"""Set the father nodes """Set the father nodes
Args: Args:
@ -663,20 +675,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._irs_ssa return self._irs_ssa
@irs_ssa.setter @irs_ssa.setter
def irs_ssa(self, irs): def irs_ssa(self, irs: List[Operation]) -> None:
self._irs_ssa = irs self._irs_ssa = irs
def add_ssa_ir(self, ir: Operation) -> None: def add_ssa_ir(self, ir: Operation) -> None:
""" """
Use to place phi operation Use to place phi operation
""" """
ir.set_node(self) ir.set_node(self) # type: ignore
self._irs_ssa.append(ir) self._irs_ssa.append(ir)
def slithir_generation(self) -> None: def slithir_generation(self) -> None:
if self.expression: if self.expression:
expression = self.expression expression = self.expression
self._irs = convert_expression(expression, self) self._irs = convert_expression(expression, self) # type:ignore
self._find_read_write_call() self._find_read_write_call()
@ -713,7 +725,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._dominators return self._dominators
@dominators.setter @dominators.setter
def dominators(self, dom: Set["Node"]): def dominators(self, dom: Set["Node"]) -> None:
self._dominators = dom self._dominators = dom
@property @property
@ -725,7 +737,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._immediate_dominator return self._immediate_dominator
@immediate_dominator.setter @immediate_dominator.setter
def immediate_dominator(self, idom: "Node"): def immediate_dominator(self, idom: "Node") -> None:
self._immediate_dominator = idom self._immediate_dominator = idom
@property @property
@ -737,7 +749,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._dominance_frontier return self._dominance_frontier
@dominance_frontier.setter @dominance_frontier.setter
def dominance_frontier(self, doms: Set["Node"]): def dominance_frontier(self, doms: Set["Node"]) -> None:
""" """
Returns: Returns:
set(Node) set(Node)
@ -789,6 +801,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None: def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None:
if variable.name not in self._phi_origins_local_variables: if variable.name not in self._phi_origins_local_variables:
assert variable.name
self._phi_origins_local_variables[variable.name] = (variable, set()) self._phi_origins_local_variables[variable.name] = (variable, set())
(v, nodes) = self._phi_origins_local_variables[variable.name] (v, nodes) = self._phi_origins_local_variables[variable.name]
assert v == variable assert v == variable
@ -827,7 +840,8 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
if isinstance(ir, OperationWithLValue): if isinstance(ir, OperationWithLValue):
var = ir.lvalue var = ir.lvalue
if var and self._is_valid_slithir_var(var): if var and self._is_valid_slithir_var(var):
self._slithir_vars.add(var) # The type is checked by is_valid_slithir_var
self._slithir_vars.add(var) # type: ignore
if not isinstance(ir, (Phi, Index, Member)): if not isinstance(ir, (Phi, Index, Member)):
self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)] self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)]
@ -835,8 +849,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
if isinstance(var, ReferenceVariable): if isinstance(var, ReferenceVariable):
self._vars_read.append(var.points_to_origin) self._vars_read.append(var.points_to_origin)
elif isinstance(ir, (Member, Index)): elif isinstance(ir, (Member, Index)):
# TODO investigate types for member variable left
var = ir.variable_left if isinstance(ir, Member) else ir.variable_right var = ir.variable_left if isinstance(ir, Member) else ir.variable_right
if self._is_non_slithir_var(var): if var and self._is_non_slithir_var(var):
self._vars_read.append(var) self._vars_read.append(var)
if isinstance(var, ReferenceVariable): if isinstance(var, ReferenceVariable):
origin = var.points_to_origin origin = var.points_to_origin
@ -860,14 +875,21 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._internal_calls.append(ir.function) self._internal_calls.append(ir.function)
if isinstance(ir, LowLevelCall): if isinstance(ir, LowLevelCall):
assert isinstance(ir.destination, (Variable, SolidityVariable)) assert isinstance(ir.destination, (Variable, SolidityVariable))
self._low_level_calls.append((ir.destination, ir.function_name.value)) self._low_level_calls.append((ir.destination, str(ir.function_name.value)))
elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall):
# Todo investigate this if condition
# It does seem right to compare against a contract
# This might need a refactoring
if isinstance(ir.destination.type, Contract): if isinstance(ir.destination.type, Contract):
self._high_level_calls.append((ir.destination.type, ir.function)) self._high_level_calls.append((ir.destination.type, ir.function))
elif ir.destination == SolidityVariable("this"): elif ir.destination == SolidityVariable("this"):
self._high_level_calls.append((self.function.contract, ir.function)) func = self.function
# Can't use this in a top level function
assert isinstance(func, FunctionContract)
self._high_level_calls.append((func.contract, ir.function))
else: else:
try: try:
# Todo this part needs more tests and documentation
self._high_level_calls.append((ir.destination.type.type, ir.function)) self._high_level_calls.append((ir.destination.type.type, ir.function))
except AttributeError as error: except AttributeError as error:
# pylint: disable=raise-missing-from # pylint: disable=raise-missing-from
@ -883,7 +905,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._vars_read = list(set(self._vars_read)) self._vars_read = list(set(self._vars_read))
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)] self._solidity_vars_read = [
v_ for v_ in self._vars_read if isinstance(v_, SolidityVariable)
]
self._vars_written = list(set(self._vars_written)) self._vars_written = list(set(self._vars_written))
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
@ -895,12 +919,15 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
@staticmethod @staticmethod
def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]: def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]:
non_ssa_var: Optional[Union[StateVariable, LocalVariable]]
if isinstance(v, StateIRVariable): if isinstance(v, StateIRVariable):
contract = v.contract contract = v.contract
assert v.name
non_ssa_var = contract.get_state_variable_from_name(v.name) non_ssa_var = contract.get_state_variable_from_name(v.name)
return non_ssa_var return non_ssa_var
assert isinstance(v, LocalIRVariable) assert isinstance(v, LocalIRVariable)
function = v.function function = v.function
assert v.name
non_ssa_var = function.get_local_variable_from_name(v.name) non_ssa_var = function.get_local_variable_from_name(v.name)
return non_ssa_var return non_ssa_var
@ -921,10 +948,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._ssa_vars_read.append(origin) self._ssa_vars_read.append(origin)
elif isinstance(ir, (Member, Index)): elif isinstance(ir, (Member, Index)):
if isinstance(ir.variable_right, (StateIRVariable, LocalIRVariable)): variable_right: RVALUE = ir.variable_right
self._ssa_vars_read.append(ir.variable_right) if isinstance(variable_right, (StateIRVariable, LocalIRVariable)):
if isinstance(ir.variable_right, ReferenceVariable): self._ssa_vars_read.append(variable_right)
origin = ir.variable_right.points_to_origin if isinstance(variable_right, ReferenceVariable):
origin = variable_right.points_to_origin
if isinstance(origin, (StateIRVariable, LocalIRVariable)): if isinstance(origin, (StateIRVariable, LocalIRVariable)):
self._ssa_vars_read.append(origin) self._ssa_vars_read.append(origin)
@ -944,20 +972,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)]
self._ssa_vars_written = list(set(self._ssa_vars_written)) self._ssa_vars_written = list(set(self._ssa_vars_written))
self._ssa_state_vars_written = [ self._ssa_state_vars_written = [
v for v in self._ssa_vars_written if isinstance(v, StateVariable) v for v in self._ssa_vars_written if v and isinstance(v, StateIRVariable)
] ]
self._ssa_local_vars_written = [ self._ssa_local_vars_written = [
v for v in self._ssa_vars_written if isinstance(v, LocalVariable) v for v in self._ssa_vars_written if v and isinstance(v, LocalIRVariable)
] ]
vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read]
vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written]
self._vars_read += [v for v in vars_read if v not in self._vars_read] self._vars_read += [v_ for v_ in vars_read if v_ and v_ not in self._vars_read]
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
self._vars_written += [v for v in vars_written if v not in self._vars_written] self._vars_written += [v_ for v_ in vars_written if v_ and v_ not in self._vars_written]
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
@ -974,7 +1002,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
additional_info += " " + str(self.expression) additional_info += " " + str(self.expression)
elif self.variable_declaration: elif self.variable_declaration:
additional_info += " " + str(self.variable_declaration) additional_info += " " + str(self.variable_declaration)
txt = self._node_type.value + additional_info txt = str(self._node_type.value) + additional_info
return txt return txt

@ -1,19 +0,0 @@
from typing import TYPE_CHECKING
from slither.core.source_mapping.source_mapping import SourceMapping
if TYPE_CHECKING:
from slither.core.declarations import Contract
class ChildContract(SourceMapping):
def __init__(self) -> None:
super().__init__()
self._contract = None
def set_contract(self, contract: "Contract") -> None:
self._contract = contract
@property
def contract(self) -> "Contract":
return self._contract

@ -1,17 +0,0 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Event
class ChildEvent:
def __init__(self) -> None:
super().__init__()
self._event = None
def set_event(self, event: "Event"):
self._event = event
@property
def event(self) -> "Event":
return self._event

@ -1,18 +0,0 @@
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from slither.core.expressions.expression import Expression
from slither.slithir.operations import Operation
class ChildExpression:
def __init__(self) -> None:
super().__init__()
self._expression = None
def set_expression(self, expression: Union["Expression", "Operation"]) -> None:
self._expression = expression
@property
def expression(self) -> Union["Expression", "Operation"]:
return self._expression

@ -1,17 +0,0 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Function
class ChildFunction:
def __init__(self) -> None:
super().__init__()
self._function = None
def set_function(self, function: "Function") -> None:
self._function = function
@property
def function(self) -> "Function":
return self._function

@ -1,17 +0,0 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Contract
class ChildInheritance:
def __init__(self) -> None:
super().__init__()
self._contract_declarer = None
def set_contract_declarer(self, contract: "Contract") -> None:
self._contract_declarer = contract
@property
def contract_declarer(self) -> "Contract":
return self._contract_declarer

@ -1,31 +0,0 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.cfg.node import Node
from slither.core.declarations import Function, Contract
class ChildNode:
def __init__(self) -> None:
super().__init__()
self._node = None
def set_node(self, node: "Node") -> None:
self._node = node
@property
def node(self) -> "Node":
return self._node
@property
def function(self) -> "Function":
return self.node.function
@property
def contract(self) -> "Contract":
return self.node.function.contract
@property
def compilation_unit(self) -> "SlitherCompilationUnit":
return self.node.compilation_unit

@ -1,17 +0,0 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from slither.core.declarations import Structure
class ChildStructure:
def __init__(self) -> None:
super().__init__()
self._structure = None
def set_structure(self, structure: "Structure") -> None:
self._structure = structure
@property
def structure(self) -> "Structure":
return self._structure

@ -57,7 +57,7 @@ class SlitherCompilationUnit(Context):
self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {}
self._contract_with_missing_inheritance = set() self._contract_with_missing_inheritance: Set[Contract] = set()
self._source_units: Dict[int, str] = {} self._source_units: Dict[int, str] = {}
@ -88,7 +88,8 @@ class SlitherCompilationUnit(Context):
@property @property
def solc_version(self) -> str: def solc_version(self) -> str:
return self._crytic_compile_compilation_unit.compiler_version.version # TODO: make version a non optional argument of compiler version in cc
return self._crytic_compile_compilation_unit.compiler_version.version # type:ignore
@property @property
def crytic_compile_compilation_unit(self) -> CompilationUnit: def crytic_compile_compilation_unit(self) -> CompilationUnit:
@ -162,13 +163,14 @@ class SlitherCompilationUnit(Context):
@property @property
def functions_and_modifiers(self) -> List[Function]: def functions_and_modifiers(self) -> List[Function]:
return self.functions + self.modifiers return self.functions + list(self.modifiers)
def propagate_function_calls(self) -> None: def propagate_function_calls(self) -> None:
for f in self.functions_and_modifiers: for f in self.functions_and_modifiers:
for node in f.nodes: for node in f.nodes:
for ir in node.irs_ssa: for ir in node.irs_ssa:
if isinstance(ir, InternalCall): if isinstance(ir, InternalCall):
assert ir.function
ir.function.add_reachable_from_node(node, ir) ir.function.add_reachable_from_node(node, ir)
# endregion # endregion
@ -181,8 +183,8 @@ class SlitherCompilationUnit(Context):
@property @property
def state_variables(self) -> List[StateVariable]: def state_variables(self) -> List[StateVariable]:
if self._all_state_variables is None: if self._all_state_variables is None:
state_variables = [c.state_variables for c in self.contracts] state_variabless = [c.state_variables for c in self.contracts]
state_variables = [item for sublist in state_variables for item in sublist] state_variables = [item for sublist in state_variabless for item in sublist]
self._all_state_variables = set(state_variables) self._all_state_variables = set(state_variables)
return list(self._all_state_variables) return list(self._all_state_variables)
@ -229,7 +231,7 @@ class SlitherCompilationUnit(Context):
################################################################################### ###################################################################################
@property @property
def contracts_with_missing_inheritance(self) -> Set: def contracts_with_missing_inheritance(self) -> Set[Contract]:
return self._contract_with_missing_inheritance return self._contract_with_missing_inheritance
# endregion # endregion
@ -266,6 +268,7 @@ class SlitherCompilationUnit(Context):
if var.is_constant or var.is_immutable: if var.is_constant or var.is_immutable:
continue continue
assert var.type
size, new_slot = var.type.storage_size size, new_slot = var.type.storage_size
if new_slot: if new_slot:
@ -285,7 +288,7 @@ class SlitherCompilationUnit(Context):
else: else:
offset += size offset += size
def storage_layout_of(self, contract, var) -> Tuple[int, int]: def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]:
return self._storage_layouts[contract.name][var.canonical_name] return self._storage_layouts[contract.name][var.canonical_name]
# endregion # endregion

@ -49,6 +49,9 @@ if TYPE_CHECKING:
LOGGER = logging.getLogger("Contract") LOGGER = logging.getLogger("Contract")
USING_FOR_KEY = Union[str, Type]
USING_FOR_ITEM = List[Union[Type, Function]]
class Contract(SourceMapping): # pylint: disable=too-many-public-methods class Contract(SourceMapping): # pylint: disable=too-many-public-methods
""" """
@ -80,8 +83,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._custom_errors: Dict[str, "CustomErrorContract"] = {} self._custom_errors: Dict[str, "CustomErrorContract"] = {}
# The only str is "*" # The only str is "*"
self._using_for: Dict[Union[str, Type], List[Type]] = {} self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {}
self._using_for_complete: Dict[Union[str, Type], List[Type]] = None self._using_for_complete: Optional[Dict[USING_FOR_KEY, USING_FOR_ITEM]] = None
self._kind: Optional[str] = None self._kind: Optional[str] = None
self._is_interface: bool = False self._is_interface: bool = False
self._is_library: bool = False self._is_library: bool = False
@ -126,7 +129,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._name return self._name
@name.setter @name.setter
def name(self, name: str): def name(self, name: str) -> None:
self._name = name self._name = name
@property @property
@ -136,7 +139,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._id return self._id
@id.setter @id.setter
def id(self, new_id): def id(self, new_id: int) -> None:
"""Unique id.""" """Unique id."""
self._id = new_id self._id = new_id
@ -149,7 +152,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._kind return self._kind
@contract_kind.setter @contract_kind.setter
def contract_kind(self, kind): def contract_kind(self, kind: str) -> None:
self._kind = kind self._kind = kind
@property @property
@ -157,7 +160,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_interface return self._is_interface
@is_interface.setter @is_interface.setter
def is_interface(self, is_interface: bool): def is_interface(self, is_interface: bool) -> None:
self._is_interface = is_interface self._is_interface = is_interface
@property @property
@ -165,7 +168,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_library return self._is_library
@is_library.setter @is_library.setter
def is_library(self, is_library: bool): def is_library(self, is_library: bool) -> None:
self._is_library = is_library self._is_library = is_library
@property @property
@ -302,16 +305,18 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
@property @property
def using_for(self) -> Dict[Union[str, Type], List[Type]]: def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
return self._using_for return self._using_for
@property @property
def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]: def using_for_complete(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
""" """
Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive
""" """
def _merge_using_for(uf1, uf2): def _merge_using_for(
uf1: Dict[USING_FOR_KEY, USING_FOR_ITEM], uf2: Dict[USING_FOR_KEY, USING_FOR_ITEM]
) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
result = {**uf1, **uf2} result = {**uf1, **uf2}
for key, value in result.items(): for key, value in result.items():
if key in uf1 and key in uf2: if key in uf1 and key in uf2:
@ -491,7 +496,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
) )
@property @property
def constructors(self) -> List["Function"]: def constructors(self) -> List["FunctionContract"]:
""" """
Return the list of constructors (including inherited) Return the list of constructors (including inherited)
""" """
@ -560,14 +565,14 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
""" """
return list(self._functions.values()) return list(self._functions.values())
def available_functions_as_dict(self) -> Dict[str, "FunctionContract"]: def available_functions_as_dict(self) -> Dict[str, "Function"]:
if self._available_functions_as_dict is None: if self._available_functions_as_dict is None:
self._available_functions_as_dict = { self._available_functions_as_dict = {
f.full_name: f for f in self._functions.values() if not f.is_shadowed f.full_name: f for f in self._functions.values() if not f.is_shadowed
} }
return self._available_functions_as_dict return self._available_functions_as_dict
def add_function(self, func: "FunctionContract"): def add_function(self, func: "FunctionContract") -> None:
self._functions[func.canonical_name] = func self._functions[func.canonical_name] = func
def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None: def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None:
@ -735,7 +740,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
list(Contract): Return the list of contracts derived from self list(Contract): Return the list of contracts derived from self
""" """
candidates = self.compilation_unit.contracts candidates = self.compilation_unit.contracts
return [c for c in candidates if self in c.inheritance] return [c for c in candidates if self in c.inheritance] # type: ignore
# endregion # endregion
################################################################################### ###################################################################################
@ -891,7 +896,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
""" """
return next((e for e in self.enums if e.name == enum_name), None) return next((e for e in self.enums if e.name == enum_name), None)
def get_enum_from_canonical_name(self, enum_name) -> Optional["Enum"]: def get_enum_from_canonical_name(self, enum_name: str) -> Optional["Enum"]:
""" """
Return an enum from a canonical name Return an enum from a canonical name
Args: Args:
@ -992,7 +997,9 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def get_summary(self, include_shadowed=True) -> Tuple[str, List[str], List[str], List, List]: def get_summary(
self, include_shadowed: bool = True
) -> Tuple[str, List[str], List[str], List, List]:
"""Return the function summary """Return the function summary
:param include_shadowed: boolean to indicate if shadowed functions should be included (default True) :param include_shadowed: boolean to indicate if shadowed functions should be included (default True)
@ -1245,7 +1252,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
@property @property
def is_test(self) -> bool: def is_test(self) -> bool:
return is_test_contract(self) or self.is_truffle_migration return is_test_contract(self) or self.is_truffle_migration # type: ignore
# endregion # endregion
################################################################################### ###################################################################################
@ -1255,7 +1262,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
def update_read_write_using_ssa(self) -> None: def update_read_write_using_ssa(self) -> None:
for function in self.functions + self.modifiers: for function in self.functions + list(self.modifiers):
function.update_read_write_using_ssa() function.update_read_write_using_ssa()
# endregion # endregion
@ -1290,7 +1297,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_upgradeable return self._is_upgradeable
@is_upgradeable.setter @is_upgradeable.setter
def is_upgradeable(self, upgradeable: bool): def is_upgradeable(self, upgradeable: bool) -> None:
self._is_upgradeable = upgradeable self._is_upgradeable = upgradeable
@property @property
@ -1319,7 +1326,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_upgradeable_proxy return self._is_upgradeable_proxy
@is_upgradeable_proxy.setter @is_upgradeable_proxy.setter
def is_upgradeable_proxy(self, upgradeable_proxy: bool): def is_upgradeable_proxy(self, upgradeable_proxy: bool) -> None:
self._is_upgradeable_proxy = upgradeable_proxy self._is_upgradeable_proxy = upgradeable_proxy
@property @property
@ -1327,7 +1334,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._upgradeable_version return self._upgradeable_version
@upgradeable_version.setter @upgradeable_version.setter
def upgradeable_version(self, version_name: str): def upgradeable_version(self, version_name: str) -> None:
self._upgradeable_version = version_name self._upgradeable_version = version_name
# endregion # endregion
@ -1346,7 +1353,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_incorrectly_parsed return self._is_incorrectly_parsed
@is_incorrectly_constructed.setter @is_incorrectly_constructed.setter
def is_incorrectly_constructed(self, incorrect: bool): def is_incorrectly_constructed(self, incorrect: bool) -> None:
self._is_incorrectly_parsed = incorrect self._is_incorrectly_parsed = incorrect
def add_constructor_variables(self) -> None: def add_constructor_variables(self) -> None:
@ -1358,8 +1365,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
constructor_variable = FunctionContract(self.compilation_unit) constructor_variable = FunctionContract(self.compilation_unit)
constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES)
constructor_variable.set_contract(self) constructor_variable.set_contract(self) # type: ignore
constructor_variable.set_contract_declarer(self) constructor_variable.set_contract_declarer(self) # type: ignore
constructor_variable.set_visibility("internal") constructor_variable.set_visibility("internal")
# For now, source mapping of the constructor variable is the whole contract # For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping # Could be improved with a targeted source mapping
@ -1390,8 +1397,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
constructor_variable.set_function_type( constructor_variable.set_function_type(
FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES
) )
constructor_variable.set_contract(self) constructor_variable.set_contract(self) # type: ignore
constructor_variable.set_contract_declarer(self) constructor_variable.set_contract_declarer(self) # type: ignore
constructor_variable.set_visibility("internal") constructor_variable.set_visibility("internal")
# For now, source mapping of the constructor variable is the whole contract # For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping # Could be improved with a targeted source mapping
@ -1472,22 +1479,23 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
all_ssa_state_variables_instances[v.canonical_name] = new_var all_ssa_state_variables_instances[v.canonical_name] = new_var
self._initial_state_variables.append(new_var) self._initial_state_variables.append(new_var)
for func in self.functions + self.modifiers: for func in self.functions + list(self.modifiers):
func.generate_slithir_ssa(all_ssa_state_variables_instances) func.generate_slithir_ssa(all_ssa_state_variables_instances)
def fix_phi(self) -> None: def fix_phi(self) -> None:
last_state_variables_instances = {} last_state_variables_instances: Dict[str, List["StateVariable"]] = {}
initial_state_variables_instances = {} initial_state_variables_instances: Dict[str, "StateVariable"] = {}
for v in self._initial_state_variables: for v in self._initial_state_variables:
last_state_variables_instances[v.canonical_name] = [] last_state_variables_instances[v.canonical_name] = []
initial_state_variables_instances[v.canonical_name] = v initial_state_variables_instances[v.canonical_name] = v
for func in self.functions + self.modifiers: for func in self.functions + list(self.modifiers):
result = func.get_last_ssa_state_variables_instances() result = func.get_last_ssa_state_variables_instances()
for variable_name, instances in result.items(): for variable_name, instances in result.items():
last_state_variables_instances[variable_name] += instances # TODO: investigate the next operation
last_state_variables_instances[variable_name] += list(instances)
for func in self.functions + self.modifiers: for func in self.functions + list(self.modifiers):
func.fix_phi(last_state_variables_instances, initial_state_variables_instances) func.fix_phi(last_state_variables_instances, initial_state_variables_instances)
# endregion # endregion
@ -1497,7 +1505,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
################################################################################### ###################################################################################
def __eq__(self, other: SourceMapping) -> bool: def __eq__(self, other: Any) -> bool:
if isinstance(other, str): if isinstance(other, str):
return other == self.name return other == self.name
return NotImplemented return NotImplemented
@ -1511,6 +1519,6 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self.name return self.name
def __hash__(self) -> int: def __hash__(self) -> int:
return self._id return self._id # type:ignore
# endregion # endregion

@ -0,0 +1,29 @@
from typing import TYPE_CHECKING, Optional
from slither.core.source_mapping.source_mapping import SourceMapping
if TYPE_CHECKING:
from slither.core.declarations import Contract
class ContractLevel(SourceMapping):
"""
This class is used to represent objects that are at the contract level
The opposite is TopLevel
"""
def __init__(self) -> None:
super().__init__()
# TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._contract: Optional["Contract"] = None
def set_contract(self, contract: "Contract") -> None:
self._contract = contract
@property
def contract(self) -> "Contract":
assert self._contract
return self._contract

@ -1,4 +1,4 @@
from typing import List, TYPE_CHECKING, Optional, Type, Union from typing import List, TYPE_CHECKING, Optional, Type
from slither.core.solidity_types import UserDefinedType from slither.core.solidity_types import UserDefinedType
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
@ -42,7 +42,7 @@ class CustomError(SourceMapping):
################################################################################### ###################################################################################
@staticmethod @staticmethod
def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) -> str: def _convert_type_for_solidity_signature(t: Optional[Type]) -> str:
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from slither.core.declarations import Contract from slither.core.declarations import Contract
@ -51,7 +51,7 @@ class CustomError(SourceMapping):
return str(t) return str(t)
@property @property
def solidity_signature(self) -> Optional[str]: def solidity_signature(self) -> str:
""" """
Return a signature following the Solidity Standard Return a signature following the Solidity Standard
Contract and converted into address Contract and converted into address
@ -63,7 +63,7 @@ class CustomError(SourceMapping):
# (set_solidity_sig was not called before find_variable) # (set_solidity_sig was not called before find_variable)
if self._solidity_signature is None: if self._solidity_signature is None:
raise ValueError("Custom Error not yet built") raise ValueError("Custom Error not yet built")
return self._solidity_signature return self._solidity_signature # type: ignore
def set_solidity_sig(self) -> None: def set_solidity_sig(self) -> None:
""" """
@ -72,7 +72,7 @@ class CustomError(SourceMapping):
Returns: Returns:
""" """
parameters = [x.type for x in self.parameters] parameters = [x.type for x in self.parameters if x.type]
self._full_name = self.name + "(" + ",".join(map(str, parameters)) + ")" self._full_name = self.name + "(" + ",".join(map(str, parameters)) + ")"
solidity_parameters = map(self._convert_type_for_solidity_signature, parameters) solidity_parameters = map(self._convert_type_for_solidity_signature, parameters)
self._solidity_signature = self.name + "(" + ",".join(solidity_parameters) + ")" self._solidity_signature = self.name + "(" + ",".join(solidity_parameters) + ")"

@ -1,9 +1,15 @@
from slither.core.children.child_contract import ChildContract from typing import TYPE_CHECKING
from slither.core.declarations.contract_level import ContractLevel
from slither.core.declarations.custom_error import CustomError from slither.core.declarations.custom_error import CustomError
if TYPE_CHECKING:
from slither.core.declarations import Contract
class CustomErrorContract(CustomError, ChildContract): class CustomErrorContract(CustomError, ContractLevel):
def is_declared_by(self, contract): def is_declared_by(self, contract: "Contract") -> bool:
""" """
Check if the element is declared by the contract Check if the element is declared by the contract
:param contract: :param contract:

@ -9,6 +9,6 @@ if TYPE_CHECKING:
class CustomErrorTopLevel(CustomError, TopLevel): class CustomErrorTopLevel(CustomError, TopLevel):
def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None:
super().__init__(compilation_unit) super().__init__(compilation_unit)
self.file_scope: "FileScope" = scope self.file_scope: "FileScope" = scope

@ -1,13 +1,13 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from slither.core.children.child_contract import ChildContract from slither.core.declarations.contract_level import ContractLevel
from slither.core.declarations import Enum from slither.core.declarations import Enum
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.declarations import Contract from slither.core.declarations import Contract
class EnumContract(Enum, ChildContract): class EnumContract(Enum, ContractLevel):
def is_declared_by(self, contract: "Contract") -> bool: def is_declared_by(self, contract: "Contract") -> bool:
""" """
Check if the element is declared by the contract Check if the element is declared by the contract

@ -1,6 +1,6 @@
from typing import List, Tuple, TYPE_CHECKING from typing import List, Tuple, TYPE_CHECKING
from slither.core.children.child_contract import ChildContract from slither.core.declarations.contract_level import ContractLevel
from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.event_variable import EventVariable from slither.core.variables.event_variable import EventVariable
@ -8,7 +8,7 @@ if TYPE_CHECKING:
from slither.core.declarations import Contract from slither.core.declarations import Contract
class Event(ChildContract, SourceMapping): class Event(ContractLevel, SourceMapping):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._name = None self._name = None

@ -47,7 +47,6 @@ if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope from slither.core.scope.scope import FileScope
from slither.slithir.variables.state_variable import StateIRVariable from slither.slithir.variables.state_variable import StateIRVariable
from slither.core.declarations.function_contract import FunctionContract
LOGGER = logging.getLogger("Function") LOGGER = logging.getLogger("Function")
ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"])
@ -298,7 +297,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def contains_assembly(self, c: bool): def contains_assembly(self, c: bool):
self._contains_assembly = c self._contains_assembly = c
def can_reenter(self, callstack: Optional[List["FunctionContract"]] = None) -> bool: def can_reenter(self, callstack: Optional[List[Union["Function", "Variable"]]] = None) -> bool:
""" """
Check if the function can re-enter Check if the function can re-enter
Follow internal calls. Follow internal calls.
@ -1720,8 +1719,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def fix_phi( def fix_phi(
self, self,
last_state_variables_instances: Dict[str, List["StateIRVariable"]], last_state_variables_instances: Dict[str, List["StateVariable"]],
initial_state_variables_instances: Dict[str, "StateIRVariable"], initial_state_variables_instances: Dict[str, "StateVariable"],
) -> None: ) -> None:
from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.operations import InternalCall, PhiCallback
from slither.slithir.variables import Constant, StateIRVariable from slither.slithir.variables import Constant, StateIRVariable

@ -1,10 +1,9 @@
""" """
Function module Function module
""" """
from typing import Dict, TYPE_CHECKING, List, Tuple from typing import Dict, TYPE_CHECKING, List, Tuple, Optional
from slither.core.children.child_contract import ChildContract from slither.core.declarations.contract_level import ContractLevel
from slither.core.children.child_inheritance import ChildInheritance
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.utils.code_complexity import compute_cyclomatic_complexity from slither.utils.code_complexity import compute_cyclomatic_complexity
@ -15,9 +14,31 @@ if TYPE_CHECKING:
from slither.core.declarations import Contract from slither.core.declarations import Contract
from slither.core.scope.scope import FileScope from slither.core.scope.scope import FileScope
from slither.slithir.variables.state_variable import StateIRVariable from slither.slithir.variables.state_variable import StateIRVariable
from slither.core.compilation_unit import SlitherCompilationUnit
class FunctionContract(Function, ChildContract, ChildInheritance): class FunctionContract(Function, ContractLevel):
def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None:
super().__init__(compilation_unit)
self._contract_declarer: Optional["Contract"] = None
def set_contract_declarer(self, contract: "Contract") -> None:
self._contract_declarer = contract
@property
def contract_declarer(self) -> "Contract":
"""
Return the contract where this function was declared. Only functions have both a contract, and contract_declarer
This is because we need to have separate representation of the function depending of the contract's context
For example a function calling super.f() will generate different IR depending on the current contract's inheritance
Returns:
The contract where this function was declared
"""
assert self._contract_declarer
return self._contract_declarer
@property @property
def canonical_name(self) -> str: def canonical_name(self) -> str:
""" """

@ -82,7 +82,7 @@ SOLIDITY_FUNCTIONS: Dict[str, List[str]] = {
} }
def solidity_function_signature(name): def solidity_function_signature(name: str) -> str:
""" """
Return the function signature (containing the return value) Return the function signature (containing the return value)
It is useful if a solidity function is used as a pointer It is useful if a solidity function is used as a pointer
@ -106,7 +106,7 @@ class SolidityVariable(SourceMapping):
assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset")) assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset"))
@property @property
def state_variable(self): def state_variable(self) -> str:
if self._name.endswith("_slot"): if self._name.endswith("_slot"):
return self._name[:-5] return self._name[:-5]
if self._name.endswith("_offset"): if self._name.endswith("_offset"):
@ -125,7 +125,7 @@ class SolidityVariable(SourceMapping):
def __str__(self) -> str: def __str__(self) -> str:
return self._name return self._name
def __eq__(self, other: SourceMapping) -> bool: def __eq__(self, other: Any) -> bool:
return self.__class__ == other.__class__ and self.name == other.name return self.__class__ == other.__class__ and self.name == other.name
def __hash__(self) -> int: def __hash__(self) -> int:
@ -182,13 +182,13 @@ class SolidityFunction(SourceMapping):
return self._return_type return self._return_type
@return_type.setter @return_type.setter
def return_type(self, r: List[Union[TypeInformation, ElementaryType]]): def return_type(self, r: List[Union[TypeInformation, ElementaryType]]) -> None:
self._return_type = r self._return_type = r
def __str__(self) -> str: def __str__(self) -> str:
return self._name return self._name
def __eq__(self, other: "SolidityFunction") -> bool: def __eq__(self, other: Any) -> bool:
return self.__class__ == other.__class__ and self.name == other.name return self.__class__ == other.__class__ and self.name == other.name
def __hash__(self) -> int: def __hash__(self) -> int:
@ -201,7 +201,7 @@ class SolidityCustomRevert(SolidityFunction):
self._custom_error = custom_error self._custom_error = custom_error
self._return_type: List[Union[TypeInformation, ElementaryType]] = [] self._return_type: List[Union[TypeInformation, ElementaryType]] = []
def __eq__(self, other: Union["SolidityCustomRevert", SolidityFunction]) -> bool: def __eq__(self, other: Any) -> bool:
return ( return (
self.__class__ == other.__class__ self.__class__ == other.__class__
and self.name == other.name and self.name == other.name

@ -1,8 +1,8 @@
from slither.core.children.child_contract import ChildContract from slither.core.declarations.contract_level import ContractLevel
from slither.core.declarations import Structure from slither.core.declarations import Structure
class StructureContract(Structure, ChildContract): class StructureContract(Structure, ContractLevel):
def is_declared_by(self, contract): def is_declared_by(self, contract):
""" """
Check if the element is declared by the contract Check if the element is declared by the contract

@ -2,4 +2,8 @@ from slither.core.source_mapping.source_mapping import SourceMapping
class TopLevel(SourceMapping): class TopLevel(SourceMapping):
pass """
This class is used to represent objects that are at the top level
The opposite is ContractLevel
"""

@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, List, Dict, Union from typing import TYPE_CHECKING, List, Dict, Union
from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
@ -14,5 +15,5 @@ class UsingForTopLevel(TopLevel):
self.file_scope: "FileScope" = scope self.file_scope: "FileScope" = scope
@property @property
def using_for(self) -> Dict[Union[str, Type], List[Type]]: def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]:
return self._using_for return self._using_for

@ -95,4 +95,5 @@ def compute_dominance_frontier(nodes: List["Node"]) -> None:
runner.dominance_frontier = runner.dominance_frontier.union({node}) runner.dominance_frontier = runner.dominance_frontier.union({node})
while runner != node.immediate_dominator: while runner != node.immediate_dominator:
runner.dominance_frontier = runner.dominance_frontier.union({node}) runner.dominance_frontier = runner.dominance_frontier.union({node})
assert runner.immediate_dominator
runner = runner.immediate_dominator runner = runner.immediate_dominator

@ -2,7 +2,6 @@ import logging
from enum import Enum from enum import Enum
from typing import Optional, TYPE_CHECKING, List from typing import Optional, TYPE_CHECKING, List
from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.exceptions import SlitherCoreError from slither.core.exceptions import SlitherCoreError
@ -78,7 +77,7 @@ class AssignmentOperationType(Enum):
raise SlitherCoreError(f"str: Unknown operation type {self})") raise SlitherCoreError(f"str: Unknown operation type {self})")
class AssignmentOperation(ExpressionTyped): class AssignmentOperation(Expression):
def __init__( def __init__(
self, self,
left_expression: Expression, left_expression: Expression,
@ -91,7 +90,7 @@ class AssignmentOperation(ExpressionTyped):
super().__init__() super().__init__()
left_expression.set_lvalue() left_expression.set_lvalue()
self._expressions = [left_expression, right_expression] self._expressions = [left_expression, right_expression]
self._type: Optional["AssignmentOperationType"] = expression_type self._type: AssignmentOperationType = expression_type
self._expression_return_type: Optional["Type"] = expression_return_type self._expression_return_type: Optional["Type"] = expression_return_type
@property @property

@ -2,7 +2,6 @@ import logging
from enum import Enum from enum import Enum
from typing import List from typing import List
from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.exceptions import SlitherCoreError from slither.core.exceptions import SlitherCoreError
@ -148,7 +147,7 @@ class BinaryOperationType(Enum):
raise SlitherCoreError(f"str: Unknown operation type {self})") raise SlitherCoreError(f"str: Unknown operation type {self})")
class BinaryOperation(ExpressionTyped): class BinaryOperation(Expression):
def __init__( def __init__(
self, self,
left_expression: Expression, left_expression: Expression,

@ -22,7 +22,7 @@ class CallExpression(Expression): # pylint: disable=too-many-instance-attribute
return self._value return self._value
@call_value.setter @call_value.setter
def call_value(self, v): def call_value(self, v: Optional[Expression]) -> None:
self._value = v self._value = v
@property @property
@ -30,15 +30,15 @@ class CallExpression(Expression): # pylint: disable=too-many-instance-attribute
return self._gas return self._gas
@call_gas.setter @call_gas.setter
def call_gas(self, gas): def call_gas(self, gas: Optional[Expression]) -> None:
self._gas = gas self._gas = gas
@property @property
def call_salt(self): def call_salt(self) -> Optional[Expression]:
return self._salt return self._salt
@call_salt.setter @call_salt.setter
def call_salt(self, salt): def call_salt(self, salt: Optional[Expression]) -> None:
self._salt = salt self._salt = salt
@property @property

@ -42,7 +42,7 @@ class ConditionalExpression(Expression):
def then_expression(self) -> Expression: def then_expression(self) -> Expression:
return self._then_expression return self._then_expression
def __str__(self): def __str__(self) -> str:
return ( return (
"if " "if "
+ str(self._if_expression) + str(self._if_expression)

@ -1,20 +0,0 @@
from typing import Optional, TYPE_CHECKING
from slither.core.expressions.expression import Expression
if TYPE_CHECKING:
from slither.core.solidity_types.type import Type
class ExpressionTyped(Expression): # pylint: disable=too-few-public-methods
def __init__(self) -> None:
super().__init__()
self._type: Optional["Type"] = None
@property
def type(self) -> Optional["Type"]:
return self._type
@type.setter
def type(self, new_type: "Type"):
self._type = new_type

@ -1,18 +1,80 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional, Union
from slither.core.declarations.contract_level import ContractLevel
from slither.core.declarations.top_level import TopLevel
from slither.core.expressions.expression import Expression
from slither.core.variables.variable import Variable
from slither.core.expressions.expression_typed import ExpressionTyped
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.variables.variable import Variable from slither.core.solidity_types.type import Type
from slither.core.declarations import Contract, SolidityVariable, SolidityFunction
from slither.solc_parsing.yul.evm_functions import YulBuiltin
class Identifier(ExpressionTyped): class Identifier(Expression):
def __init__(self, value) -> None: def __init__(
self,
value: Union[
Variable,
"TopLevel",
"ContractLevel",
"Contract",
"SolidityVariable",
"SolidityFunction",
"YulBuiltin",
],
) -> None:
super().__init__() super().__init__()
self._value: "Variable" = value
# pylint: disable=import-outside-toplevel
from slither.core.declarations import Contract, SolidityVariable, SolidityFunction
from slither.solc_parsing.yul.evm_functions import YulBuiltin
assert isinstance(
value,
(
Variable,
TopLevel,
ContractLevel,
Contract,
SolidityVariable,
SolidityFunction,
YulBuiltin,
),
)
self._value: Union[
Variable,
"TopLevel",
"ContractLevel",
"Contract",
"SolidityVariable",
"SolidityFunction",
"YulBuiltin",
] = value
self._type: Optional["Type"] = None
@property
def type(self) -> Optional["Type"]:
return self._type
@type.setter
def type(self, new_type: "Type") -> None:
self._type = new_type
@property @property
def value(self) -> "Variable": def value(
self,
) -> Union[
Variable,
"TopLevel",
"ContractLevel",
"Contract",
"SolidityVariable",
"SolidityFunction",
"YulBuiltin",
]:
return self._value return self._value
def __str__(self) -> str: def __str__(self) -> str:

@ -1,27 +1,18 @@
from typing import Union, List, TYPE_CHECKING from typing import Union, List
from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression
from slither.core.expressions.identifier import Identifier from slither.core.expressions.identifier import Identifier
from slither.core.expressions.literal import Literal from slither.core.expressions.literal import Literal
if TYPE_CHECKING: class IndexAccess(Expression):
from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type
class IndexAccess(ExpressionTyped):
def __init__( def __init__(
self, self,
left_expression: Union["IndexAccess", Identifier], left_expression: Union["IndexAccess", Identifier],
right_expression: Union[Literal, Identifier], right_expression: Union[Literal, Identifier],
index_type: str,
) -> None: ) -> None:
super().__init__() super().__init__()
self._expressions = [left_expression, right_expression] self._expressions = [left_expression, right_expression]
# TODO type of undexAccess is not always a Type
# assert isinstance(index_type, Type)
self._type: "Type" = index_type
@property @property
def expressions(self) -> List["Expression"]: def expressions(self) -> List["Expression"]:
@ -35,9 +26,5 @@ class IndexAccess(ExpressionTyped):
def expression_right(self) -> "Expression": def expression_right(self) -> "Expression":
return self._expressions[1] return self._expressions[1]
@property
def type(self) -> "Type":
return self._type
def __str__(self) -> str: def __str__(self) -> str:
return str(self.expression_left) + "[" + str(self.expression_right) + "]" return str(self.expression_left) + "[" + str(self.expression_right) + "]"

@ -1,4 +1,4 @@
from typing import Optional, Union, TYPE_CHECKING from typing import Optional, Union, TYPE_CHECKING, Any
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.solidity_types.elementary_type import Fixed, Int, Ufixed, Uint from slither.core.solidity_types.elementary_type import Fixed, Int, Ufixed, Uint
@ -47,7 +47,7 @@ class Literal(Expression):
# be sure to handle any character # be sure to handle any character
return str(self._value) return str(self._value)
def __eq__(self, other) -> bool: def __eq__(self, other: Any) -> bool:
if not isinstance(other, Literal): if not isinstance(other, Literal):
return False return False
return (self.value, self.subdenomination) == (other.value, other.subdenomination) return (self.value, self.subdenomination) == (other.value, other.subdenomination)

@ -1,10 +1,9 @@
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
class MemberAccess(ExpressionTyped): class MemberAccess(Expression):
def __init__(self, member_name: str, member_type: str, expression: Expression) -> None: def __init__(self, member_name: str, member_type: str, expression: Expression) -> None:
# assert isinstance(member_type, Type) # assert isinstance(member_type, Type)
# TODO member_type is not always a Type # TODO member_type is not always a Type

@ -1,6 +1,5 @@
from typing import Union, TYPE_CHECKING from typing import Union, TYPE_CHECKING
from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
@ -14,7 +13,7 @@ if TYPE_CHECKING:
from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.user_defined_type import UserDefinedType
class TypeConversion(ExpressionTyped): class TypeConversion(Expression):
def __init__( def __init__(
self, self,
expression: Union[ expression: Union[
@ -28,6 +27,14 @@ class TypeConversion(ExpressionTyped):
self._expression: Expression = expression self._expression: Expression = expression
self._type: Type = expression_type self._type: Type = expression_type
@property
def type(self) -> Type:
return self._type
@type.setter
def type(self, new_type: Type) -> None:
self._type = new_type
@property @property
def expression(self) -> Expression: def expression(self) -> Expression:
return self._expression return self._expression

@ -2,7 +2,6 @@ import logging
from typing import Union from typing import Union
from enum import Enum from enum import Enum
from slither.core.expressions.expression_typed import ExpressionTyped
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.exceptions import SlitherCoreError from slither.core.exceptions import SlitherCoreError
from slither.core.expressions.identifier import Identifier from slither.core.expressions.identifier import Identifier
@ -91,7 +90,7 @@ class UnaryOperationType(Enum):
raise SlitherCoreError(f"is_prefix: Unknown operation type {operation_type}") raise SlitherCoreError(f"is_prefix: Unknown operation type {operation_type}")
class UnaryOperation(ExpressionTyped): class UnaryOperation(Expression):
def __init__( def __init__(
self, self,
expression: Union[Literal, Identifier, IndexAccess, TupleExpression], expression: Union[Literal, Identifier, IndexAccess, TupleExpression],

@ -13,7 +13,7 @@ from typing import Optional, Dict, List, Set, Union, Tuple
from crytic_compile import CryticCompile from crytic_compile import CryticCompile
from crytic_compile.utils.naming import Filename from crytic_compile.utils.naming import Filename
from slither.core.children.child_contract import ChildContract from slither.core.declarations.contract_level import ContractLevel
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.context.context import Context from slither.core.context.context import Context
from slither.core.declarations import Contract, FunctionContract from slither.core.declarations import Contract, FunctionContract
@ -206,7 +206,7 @@ class SlitherCore(Context):
isinstance(thing, FunctionContract) isinstance(thing, FunctionContract)
and thing.contract_declarer == thing.contract and thing.contract_declarer == thing.contract
) )
or (isinstance(thing, ChildContract) and not isinstance(thing, FunctionContract)) or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract))
): ):
self._offset_to_objects[definition.filename][offset].add(thing) self._offset_to_objects[definition.filename][offset].add(thing)
@ -224,7 +224,7 @@ class SlitherCore(Context):
and thing.contract_declarer == thing.contract and thing.contract_declarer == thing.contract
) )
or ( or (
isinstance(thing, ChildContract) and not isinstance(thing, FunctionContract) isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)
) )
): ):
self._offset_to_objects[definition.filename][offset].add(thing) self._offset_to_objects[definition.filename][offset].add(thing)
@ -482,8 +482,8 @@ class SlitherCore(Context):
################################################################################### ###################################################################################
@property @property
def crytic_compile(self) -> Optional[CryticCompile]: def crytic_compile(self) -> CryticCompile:
return self._crytic_compile return self._crytic_compile # type: ignore
# endregion # endregion
################################################################################### ###################################################################################

@ -1,38 +1,37 @@
from typing import Union, Optional, Tuple, Any, TYPE_CHECKING from typing import Union, Optional, Tuple, Any, TYPE_CHECKING
from slither.core.expressions.expression import Expression from slither.core.expressions.expression import Expression
from slither.core.expressions.literal import Literal
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.visitors.expression.constants_folding import ConstantFolding from slither.visitors.expression.constants_folding import ConstantFolding
from slither.core.expressions.literal import Literal
if TYPE_CHECKING: if TYPE_CHECKING:
from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.binary_operation import BinaryOperation
from slither.core.expressions.identifier import Identifier from slither.core.expressions.identifier import Identifier
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.solidity_types.function_type import FunctionType
from slither.core.solidity_types.type_alias import TypeAliasTopLevel
class ArrayType(Type): class ArrayType(Type):
def __init__( def __init__(
self, self,
t: Union["TypeAliasTopLevel", "ArrayType", "FunctionType", "ElementaryType"], t: Type,
length: Optional[Union["Identifier", Literal, "BinaryOperation", int]], length: Optional[Union["Identifier", Literal, "BinaryOperation", int]],
) -> None: ) -> None:
assert isinstance(t, Type) assert isinstance(t, Type)
if length: if length:
if isinstance(length, int): if isinstance(length, int):
length = Literal(length, "uint256") length = Literal(length, ElementaryType("uint256"))
assert isinstance(length, Expression)
super().__init__() super().__init__()
self._type: Type = t self._type: Type = t
assert length is None or isinstance(length, Expression)
self._length: Optional[Expression] = length self._length: Optional[Expression] = length
if length: if length:
if not isinstance(length, Literal): if not isinstance(length, Literal):
cf = ConstantFolding(length, "uint256") cf = ConstantFolding(length, "uint256")
length = cf.result() length = cf.result()
self._length_value = length self._length_value: Optional[Literal] = length
else: else:
self._length_value = None self._length_value = None

@ -1,5 +1,5 @@
import itertools import itertools
from typing import Tuple from typing import Tuple, Optional, Any
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
@ -176,7 +176,7 @@ class ElementaryType(Type):
return self.type return self.type
@property @property
def size(self) -> int: def size(self) -> Optional[int]:
""" """
Return the size in bits Return the size in bits
Return None if the size is not known Return None if the size is not known
@ -219,7 +219,7 @@ class ElementaryType(Type):
def __str__(self) -> str: def __str__(self) -> str:
return self._type return self._type
def __eq__(self, other) -> bool: def __eq__(self, other: Any) -> bool:
if not isinstance(other, ElementaryType): if not isinstance(other, ElementaryType):
return False return False
return self.type == other.type return self.type == other.type

@ -1,4 +1,4 @@
from typing import Union, Tuple, TYPE_CHECKING from typing import Union, Tuple, TYPE_CHECKING, Any
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
@ -38,7 +38,7 @@ class MappingType(Type):
def __str__(self) -> str: def __str__(self) -> str:
return f"mapping({str(self._from)} => {str(self._to)})" return f"mapping({str(self._from)} => {str(self._to)})"
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
if not isinstance(other, MappingType): if not isinstance(other, MappingType):
return False return False
return self.type_from == other.type_from and self.type_to == other.type_to return self.type_from == other.type_from and self.type_to == other.type_to

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from slither.core.children.child_contract import ChildContract
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
from slither.core.declarations.contract_level import ContractLevel
from slither.core.solidity_types import Type, ElementaryType from slither.core.solidity_types import Type, ElementaryType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -40,7 +40,7 @@ class TypeAlias(Type):
class TypeAliasTopLevel(TypeAlias, TopLevel): class TypeAliasTopLevel(TypeAlias, TopLevel):
def __init__(self, underlying_type: Type, name: str, scope: "FileScope") -> None: def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None:
super().__init__(underlying_type, name) super().__init__(underlying_type, name)
self.file_scope: "FileScope" = scope self.file_scope: "FileScope" = scope
@ -48,8 +48,8 @@ class TypeAliasTopLevel(TypeAlias, TopLevel):
return self.name return self.name
class TypeAliasContract(TypeAlias, ChildContract): class TypeAliasContract(TypeAlias, ContractLevel):
def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: def __init__(self, underlying_type: ElementaryType, name: str, contract: "Contract") -> None:
super().__init__(underlying_type, name) super().__init__(underlying_type, name)
self._contract: "Contract" = contract self._contract: "Contract" = contract

@ -1,4 +1,4 @@
from typing import Union, TYPE_CHECKING, Tuple from typing import Union, TYPE_CHECKING, Tuple, Any
from slither.core.solidity_types import ElementaryType from slither.core.solidity_types import ElementaryType
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
@ -40,10 +40,10 @@ class TypeInformation(Type):
def is_dynamic(self) -> bool: def is_dynamic(self) -> bool:
raise NotImplementedError raise NotImplementedError
def __str__(self): def __str__(self) -> str:
return f"type({self.type.name})" return f"type({self.type.name})"
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeInformation): if not isinstance(other, TypeInformation):
return False return False
return self.type == other.type return self.type == other.type

@ -1,6 +1,6 @@
import re import re
from abc import ABCMeta from abc import ABCMeta
from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional, Any
from Crypto.Hash import SHA1 from Crypto.Hash import SHA1
from crytic_compile.utils.naming import Filename from crytic_compile.utils.naming import Filename
@ -98,10 +98,10 @@ class Source:
filename_short: str = self.filename.short if self.filename.short else "" filename_short: str = self.filename.short if self.filename.short else ""
return f"{filename_short}{lines}" return f"{filename_short}{lines}"
def __hash__(self): def __hash__(self) -> int:
return hash(str(self)) return hash(str(self))
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)): if not isinstance(other, type(self)):
return NotImplemented return NotImplemented
return ( return (

@ -1,8 +1,7 @@
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.core.children.child_event import ChildEvent
class EventVariable(ChildEvent, Variable): class EventVariable(Variable):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._indexed = False self._indexed = False
@ -16,5 +15,5 @@ class EventVariable(ChildEvent, Variable):
return self._indexed return self._indexed
@indexed.setter @indexed.setter
def indexed(self, is_indexed: bool): def indexed(self, is_indexed: bool) -> None:
self._indexed = is_indexed self._indexed = is_indexed

@ -1,7 +1,6 @@
from typing import Optional from typing import Optional, TYPE_CHECKING
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.core.children.child_function import ChildFunction
from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.array_type import ArrayType
from slither.core.solidity_types.mapping_type import MappingType from slither.core.solidity_types.mapping_type import MappingType
@ -9,11 +8,23 @@ from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.declarations.structure import Structure from slither.core.declarations.structure import Structure
if TYPE_CHECKING: # type: ignore
from slither.core.declarations import Function
class LocalVariable(ChildFunction, Variable):
class LocalVariable(Variable):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._location: Optional[str] = None self._location: Optional[str] = None
self._function: Optional["Function"] = None
def set_function(self, function: "Function") -> None:
self._function = function
@property
def function(self) -> "Function":
assert self._function
return self._function
def set_location(self, loc: str) -> None: def set_location(self, loc: str) -> None:
self._location = loc self._location = loc

@ -1,6 +1,6 @@
from typing import Optional, TYPE_CHECKING from typing import Optional, TYPE_CHECKING
from slither.core.children.child_contract import ChildContract from slither.core.declarations.contract_level import ContractLevel
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
if TYPE_CHECKING: if TYPE_CHECKING:
@ -8,7 +8,7 @@ if TYPE_CHECKING:
from slither.core.declarations import Contract from slither.core.declarations import Contract
class StateVariable(ChildContract, Variable): class StateVariable(ContractLevel, Variable):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._node_initialization: Optional["Node"] = None self._node_initialization: Optional["Node"] = None

@ -1,6 +1,19 @@
from typing import TYPE_CHECKING, Optional
from slither.core.variables.variable import Variable from slither.core.variables.variable import Variable
from slither.core.children.child_structure import ChildStructure
class StructureVariable(ChildStructure, Variable): if TYPE_CHECKING:
pass from slither.core.declarations import Structure
class StructureVariable(Variable):
def __init__(self) -> None:
super().__init__()
self._structure: Optional["Structure"] = None
def set_structure(self, structure: "Structure") -> None:
self._structure = structure
@property
def structure(self) -> "Structure":
return self._structure

@ -55,7 +55,7 @@ class Variable(SourceMapping):
return self._initialized return self._initialized
@initialized.setter @initialized.setter
def initialized(self, is_init: bool): def initialized(self, is_init: bool) -> None:
self._initialized = is_init self._initialized = is_init
@property @property
@ -73,23 +73,24 @@ class Variable(SourceMapping):
return self._name return self._name
@name.setter @name.setter
def name(self, name): def name(self, name: str) -> None:
self._name = name self._name = name
@property @property
def type(self) -> Optional[Union[Type, List[Type]]]: def type(self) -> Optional[Type]:
return self._type return self._type
@type.setter @type.setter
def type(self, types: Union[Type, List[Type]]): def type(self, new_type: Type) -> None:
self._type = types assert isinstance(new_type, Type)
self._type = new_type
@property @property
def is_constant(self) -> bool: def is_constant(self) -> bool:
return self._is_constant return self._is_constant
@is_constant.setter @is_constant.setter
def is_constant(self, is_cst: bool): def is_constant(self, is_cst: bool) -> None:
self._is_constant = is_cst self._is_constant = is_cst
@property @property
@ -159,8 +160,8 @@ class Variable(SourceMapping):
return ( return (
self.name, self.name,
[str(x) for x in export_nested_types_from_variable(self)], [str(x) for x in export_nested_types_from_variable(self)], # type: ignore
[str(x) for x in export_return_type_from_variable(self)], [str(x) for x in export_return_type_from_variable(self)], # type: ignore
) )
@property @property
@ -178,4 +179,5 @@ class Variable(SourceMapping):
return f'{name}({",".join(parameters)})' return f'{name}({",".join(parameters)})'
def __str__(self) -> str: def __str__(self) -> str:
assert self._name
return self._name return self._name

@ -59,6 +59,8 @@ ALL_SOLC_VERSIONS_06 = make_solc_versions(6, 0, 12)
ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6) ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6)
# No VERSIONS_08 as it is still in dev # No VERSIONS_08 as it is still in dev
DETECTOR_INFO = List[Union[str, SupportedOutput]]
class AbstractDetector(metaclass=abc.ABCMeta): class AbstractDetector(metaclass=abc.ABCMeta):
ARGUMENT = "" # run the detector with slither.py --ARGUMENT ARGUMENT = "" # run the detector with slither.py --ARGUMENT
@ -251,7 +253,7 @@ class AbstractDetector(metaclass=abc.ABCMeta):
def generate_result( def generate_result(
self, self,
info: Union[str, List[Union[str, SupportedOutput]]], info: DETECTOR_INFO,
additional_fields: Optional[Dict] = None, additional_fields: Optional[Dict] = None,
) -> Output: ) -> Output:
output = Output( output = Output(

@ -1,5 +1,9 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import Binary, BinaryType from slither.slithir.operations import Binary, BinaryType
from slither.slithir.variables import Constant from slither.slithir.variables import Constant
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
@ -49,7 +53,12 @@ The shift statement will right-shift the constant 8 by `a` bits"""
BinaryType.RIGHT_SHIFT, BinaryType.RIGHT_SHIFT,
]: ]:
if isinstance(ir.variable_left, Constant): if isinstance(ir.variable_left, Constant):
info = [f, " contains an incorrect shift operation: ", node, "\n"] info: DETECTOR_INFO = [
f,
" contains an incorrect shift operation: ",
node,
"\n",
]
json = self.generate_result(info) json = self.generate_result(info)
results.append(json) results.append(json)

@ -2,11 +2,14 @@
Module detecting constant functions Module detecting constant functions
Recursively check the called functions Recursively check the called functions
""" """
from typing import List from typing import List, Dict
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.detectors.abstract_detector import ( from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_04,
DETECTOR_INFO,
) )
from slither.formatters.attributes.const_functions import custom_format from slither.formatters.attributes.const_functions import custom_format
from slither.utils.output import Output from slither.utils.output import Output
@ -73,7 +76,10 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
if f.contains_assembly: if f.contains_assembly:
attr = "view" if f.view else "pure" attr = "view" if f.view else "pure"
info = [f, f" is declared {attr} but contains assembly code\n"] info: DETECTOR_INFO = [
f,
f" is declared {attr} but contains assembly code\n",
]
res = self.generate_result(info, {"contains_assembly": True}) res = self.generate_result(info, {"contains_assembly": True})
results.append(res) results.append(res)
@ -81,5 +87,5 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
return results return results
@staticmethod @staticmethod
def _format(comilation_unit, result): def _format(comilation_unit: SlitherCompilationUnit, result: Dict) -> None:
custom_format(comilation_unit, result) custom_format(comilation_unit, result)

@ -2,11 +2,14 @@
Module detecting constant functions Module detecting constant functions
Recursively check the called functions Recursively check the called functions
""" """
from typing import List from typing import List, Dict
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.detectors.abstract_detector import ( from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_04,
DETECTOR_INFO,
) )
from slither.formatters.attributes.const_functions import custom_format from slither.formatters.attributes.const_functions import custom_format
from slither.utils.output import Output from slither.utils.output import Output
@ -74,7 +77,7 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
if variables_written: if variables_written:
attr = "view" if f.view else "pure" attr = "view" if f.view else "pure"
info = [ info: DETECTOR_INFO = [
f, f,
f" is declared {attr} but changes state variables:\n", f" is declared {attr} but changes state variables:\n",
] ]
@ -89,5 +92,5 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
return results return results
@staticmethod @staticmethod
def _format(slither, result): def _format(slither: SlitherCompilationUnit, result: Dict) -> None:
custom_format(slither, result) custom_format(slither, result)

@ -1,9 +1,14 @@
""" """
Check that the same pragma is used in all the files Check that the same pragma is used in all the files
""" """
from typing import List from typing import List, Dict
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.core.compilation_unit import SlitherCompilationUnit
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.formatters.attributes.constant_pragma import custom_format from slither.formatters.attributes.constant_pragma import custom_format
from slither.utils.output import Output from slither.utils.output import Output
@ -31,7 +36,7 @@ class ConstantPragma(AbstractDetector):
versions = sorted(list(set(versions))) versions = sorted(list(set(versions)))
if len(versions) > 1: if len(versions) > 1:
info = ["Different versions of Solidity are used:\n"] info: DETECTOR_INFO = ["Different versions of Solidity are used:\n"]
info += [f"\t- Version used: {[str(v) for v in versions]}\n"] info += [f"\t- Version used: {[str(v) for v in versions]}\n"]
for p in sorted(pragma, key=lambda x: x.version): for p in sorted(pragma, key=lambda x: x.version):
@ -44,5 +49,5 @@ class ConstantPragma(AbstractDetector):
return results return results
@staticmethod @staticmethod
def _format(slither, result): def _format(slither: SlitherCompilationUnit, result: Dict) -> None:
custom_format(slither, result) custom_format(slither, result)

@ -5,7 +5,11 @@
import re import re
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.formatters.attributes.incorrect_solc import custom_format from slither.formatters.attributes.incorrect_solc import custom_format
from slither.utils.output import Output from slither.utils.output import Output
@ -141,7 +145,7 @@ Consider using the latest version of Solidity for testing."""
# If we found any disallowed pragmas, we output our findings. # If we found any disallowed pragmas, we output our findings.
if disallowed_pragmas: if disallowed_pragmas:
for (reason, p) in disallowed_pragmas: for (reason, p) in disallowed_pragmas:
info = ["Pragma version", p, f" {reason}\n"] info: DETECTOR_INFO = ["Pragma version", p, f" {reason}\n"]
json = self.generate_result(info) json = self.generate_result(info)

@ -4,7 +4,11 @@
from typing import List from typing import List
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import ( from slither.slithir.operations import (
HighLevelCall, HighLevelCall,
LowLevelCall, LowLevelCall,
@ -85,7 +89,7 @@ Every Ether sent to `Locked` will be lost."""
funcs_payable = [function for function in contract.functions if function.payable] funcs_payable = [function for function in contract.functions if function.payable]
if funcs_payable: if funcs_payable:
if self.do_no_send_ether(contract): if self.do_no_send_ether(contract):
info = ["Contract locking ether found:\n"] info: DETECTOR_INFO = ["Contract locking ether found:\n"]
info += ["\tContract ", contract, " has payable functions:\n"] info += ["\tContract ", contract, " has payable functions:\n"]
for function in funcs_payable: for function in funcs_payable:
info += ["\t - ", function, "\n"] info += ["\t - ", function, "\n"]

@ -5,7 +5,11 @@ Collect all the interfaces
Check for contracts which implement all interface functions but do not explicitly derive from those interfaces. Check for contracts which implement all interface functions but do not explicitly derive from those interfaces.
""" """
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.utils.output import Output from slither.utils.output import Output
@ -139,7 +143,7 @@ contract Something {
continue continue
intended_interfaces = self.detect_unimplemented_interface(contract, interfaces) intended_interfaces = self.detect_unimplemented_interface(contract, interfaces)
for interface in intended_interfaces: for interface in intended_interfaces:
info = [contract, " should inherit from ", interface, "\n"] info: DETECTOR_INFO = [contract, " should inherit from ", interface, "\n"]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)
return results return results

@ -2,7 +2,14 @@
Detects the passing of arrays located in memory to functions which expect to modify arrays via storage reference. Detects the passing of arrays located in memory to functions which expect to modify arrays via storage reference.
""" """
from typing import List, Set, Tuple, Union from typing import List, Set, Tuple, Union
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.core.declarations import Function
from slither.core.variables import Variable
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.array_type import ArrayType
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
@ -89,12 +96,7 @@ As a result, Bob's usage of the contract is incorrect."""
@staticmethod @staticmethod
def detect_calls_passing_ref_to_function( def detect_calls_passing_ref_to_function(
contracts: List[Contract], array_modifying_funcs: Set[FunctionContract] contracts: List[Contract], array_modifying_funcs: Set[FunctionContract]
) -> List[ ) -> List[Tuple[Node, Variable, Union[Function, Variable]]]:
Union[
Tuple[Node, StateVariable, FunctionContract],
Tuple[Node, LocalVariable, FunctionContract],
]
]:
""" """
Obtains all calls passing storage arrays by value to a function which cannot write to them successfully. Obtains all calls passing storage arrays by value to a function which cannot write to them successfully.
:param contracts: The collection of contracts to check for problematic calls in. :param contracts: The collection of contracts to check for problematic calls in.
@ -105,7 +107,7 @@ As a result, Bob's usage of the contract is incorrect."""
write to the array unsuccessfully. write to the array unsuccessfully.
""" """
# Define our resulting array. # Define our resulting array.
results = [] results: List[Tuple[Node, Variable, Union[Function, Variable]]] = []
# Verify we have functions in our list to check for. # Verify we have functions in our list to check for.
if not array_modifying_funcs: if not array_modifying_funcs:
@ -159,7 +161,7 @@ As a result, Bob's usage of the contract is incorrect."""
if problematic_calls: if problematic_calls:
for calling_node, affected_argument, invoked_function in problematic_calls: for calling_node, affected_argument, invoked_function in problematic_calls:
info = [ info: DETECTOR_INFO = [
calling_node.function, calling_node.function,
" passes array ", " passes array ",
affected_argument, affected_argument,

@ -10,6 +10,7 @@ from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
make_solc_versions, make_solc_versions,
DETECTOR_INFO,
) )
from slither.slithir.operations import TypeConversion from slither.slithir.operations import TypeConversion
from slither.core.declarations.enum import Enum from slither.core.declarations.enum import Enum
@ -73,10 +74,14 @@ Attackers can trigger unexpected behaviour by calling `bug(1)`."""
for c in self.compilation_unit.contracts: for c in self.compilation_unit.contracts:
ret = _detect_dangerous_enum_conversions(c) ret = _detect_dangerous_enum_conversions(c)
for node, var in ret: for node, var in ret:
func_info = [node, " has a dangerous enum conversion\n"] func_info: DETECTOR_INFO = [node, " has a dangerous enum conversion\n"]
# Output each node with the function info header as a separate result. # Output each node with the function info header as a separate result.
variable_info = ["\t- Variable: ", var, f" of type: {str(var.type)}\n"] variable_info: DETECTOR_INFO = [
node_info = ["\t- Enum conversion: ", node, "\n"] "\t- Variable: ",
var,
f" of type: {str(var.type)}\n",
]
node_info: DETECTOR_INFO = ["\t- Enum conversion: ", node, "\n"]
json = self.generate_result(func_info + variable_info + node_info) json = self.generate_result(func_info + variable_info + node_info)
results.append(json) results.append(json)

@ -1,6 +1,10 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -58,7 +62,10 @@ In Solidity [0.4.22](https://github.com/ethereum/solidity/releases/tag/v0.4.23),
# If there is more than one, we encountered the described issue occurring. # If there is more than one, we encountered the described issue occurring.
if constructors and len(constructors) > 1: if constructors and len(constructors) > 1:
info = [contract, " contains multiple constructors in the same contract:\n"] info: DETECTOR_INFO = [
contract,
" contains multiple constructors in the same contract:\n",
]
for constructor in constructors: for constructor in constructors:
info += ["\t- ", constructor, "\n"] info += ["\t- ", constructor, "\n"]

@ -6,6 +6,7 @@ from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_04,
DETECTOR_INFO,
) )
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
@ -151,7 +152,7 @@ The constructor of `A` is called multiple times in `D` and `E`:
continue continue
# Generate data to output. # Generate data to output.
info = [ info: DETECTOR_INFO = [
contract, contract,
" gives base constructor ", " gives base constructor ",
base_constructor, base_constructor,

@ -6,6 +6,7 @@ from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
make_solc_versions, make_solc_versions,
DETECTOR_INFO,
) )
from slither.core.solidity_types import ArrayType from slither.core.solidity_types import ArrayType
from slither.core.solidity_types import UserDefinedType from slither.core.solidity_types import UserDefinedType
@ -122,7 +123,13 @@ contract A {
for contract in self.contracts: for contract in self.contracts:
storage_abiencoderv2_arrays = self._detect_storage_abiencoderv2_arrays(contract) storage_abiencoderv2_arrays = self._detect_storage_abiencoderv2_arrays(contract)
for function, node in storage_abiencoderv2_arrays: for function, node in storage_abiencoderv2_arrays:
info = ["Function ", function, " trigger an abi encoding bug:\n\t- ", node, "\n"] info: DETECTOR_INFO = [
"Function ",
function,
" trigger an abi encoding bug:\n\t- ",
node,
"\n",
]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -1,18 +1,21 @@
""" """
Module detecting storage signed integer array bug Module detecting storage signed integer array bug
""" """
from typing import List from typing import List, Tuple, Set
from slither.core.declarations import Function, Contract
from slither.detectors.abstract_detector import ( from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
make_solc_versions, make_solc_versions,
DETECTOR_INFO,
) )
from slither.core.cfg.node import NodeType from slither.core.cfg.node import NodeType, Node
from slither.core.solidity_types import ArrayType from slither.core.solidity_types import ArrayType
from slither.core.solidity_types.elementary_type import Int, ElementaryType from slither.core.solidity_types.elementary_type import Int, ElementaryType
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import Operation, OperationWithLValue
from slither.slithir.operations.assignment import Assignment from slither.slithir.operations.assignment import Assignment
from slither.slithir.operations.init_array import InitArray from slither.slithir.operations.init_array import InitArray
from slither.utils.output import Output from slither.utils.output import Output
@ -60,7 +63,7 @@ contract A {
VULNERABLE_SOLC_VERSIONS = make_solc_versions(4, 7, 25) + make_solc_versions(5, 0, 9) VULNERABLE_SOLC_VERSIONS = make_solc_versions(4, 7, 25) + make_solc_versions(5, 0, 9)
@staticmethod @staticmethod
def _is_vulnerable_type(ir): def _is_vulnerable_type(ir: Operation) -> bool:
""" """
Detect if the IR lvalue is a vulnerable type Detect if the IR lvalue is a vulnerable type
Must be a storage allocation, and an array of Int Must be a storage allocation, and an array of Int
@ -68,23 +71,28 @@ contract A {
""" """
# Storage allocation # Storage allocation
# Base type is signed integer # Base type is signed integer
if not isinstance(ir, OperationWithLValue):
return False
return ( return (
( (
isinstance(ir.lvalue, StateVariable) isinstance(ir.lvalue, StateVariable)
or (isinstance(ir.lvalue, LocalVariable) and ir.lvalue.is_storage) or (isinstance(ir.lvalue, LocalVariable) and ir.lvalue.is_storage)
) )
and isinstance(ir.lvalue.type.type, ElementaryType) and isinstance(ir.lvalue.type.type, ElementaryType) # type: ignore
and ir.lvalue.type.type.type in Int and ir.lvalue.type.type.type in Int # type: ignore
) )
def detect_storage_signed_integer_arrays(self, contract): def detect_storage_signed_integer_arrays(
self, contract: Contract
) -> Set[Tuple[Function, Node]]:
""" """
Detects and returns all nodes with storage-allocated signed integer array init/assignment Detects and returns all nodes with storage-allocated signed integer array init/assignment
:param contract: Contract to detect within :param contract: Contract to detect within
:return: A list of tuples with (function, node) where function node has storage-allocated signed integer array init/assignment :return: A list of tuples with (function, node) where function node has storage-allocated signed integer array init/assignment
""" """
# Create our result set. # Create our result set.
results = set() results: Set[Tuple[Function, Node]] = set()
# Loop for each function and modifier. # Loop for each function and modifier.
for function in contract.functions_and_modifiers_declared: for function in contract.functions_and_modifiers_declared:
@ -118,9 +126,13 @@ contract A {
for contract in self.contracts: for contract in self.contracts:
storage_signed_integer_arrays = self.detect_storage_signed_integer_arrays(contract) storage_signed_integer_arrays = self.detect_storage_signed_integer_arrays(contract)
for function, node in storage_signed_integer_arrays: for function, node in storage_signed_integer_arrays:
contract_info = ["Contract ", contract, " \n"] contract_info: DETECTOR_INFO = ["Contract ", contract, " \n"]
function_info = ["\t- Function ", function, "\n"] function_info: DETECTOR_INFO = ["\t- Function ", function, "\n"]
node_info = ["\t\t- ", node, " has a storage signed integer array assignment\n"] node_info: DETECTOR_INFO = [
"\t\t- ",
node,
" has a storage signed integer array assignment\n",
]
res = self.generate_result(contract_info + function_info + node_info) res = self.generate_result(contract_info + function_info + node_info)
results.append(res) results.append(res)

@ -6,6 +6,7 @@ from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
make_solc_versions, make_solc_versions,
DETECTOR_INFO,
) )
from slither.slithir.operations import InternalDynamicCall, OperationWithLValue from slither.slithir.operations import InternalDynamicCall, OperationWithLValue
from slither.slithir.variables import ReferenceVariable from slither.slithir.variables import ReferenceVariable
@ -115,10 +116,10 @@ The call to `a(10)` will lead to unexpected behavior because function pointer `a
results = [] results = []
for contract in self.compilation_unit.contracts: for contract in self.compilation_unit.contracts:
contract_info = ["Contract ", contract, " \n"] contract_info: DETECTOR_INFO = ["Contract ", contract, " \n"]
nodes = self._detect_uninitialized_function_ptr_in_constructor(contract) nodes = self._detect_uninitialized_function_ptr_in_constructor(contract)
for node in nodes: for node in nodes:
node_info = [ node_info: DETECTOR_INFO = [
"\t ", "\t ",
node, node,
" is an unintialized function pointer call in a constructor\n", " is an unintialized function pointer call in a constructor\n",

@ -61,12 +61,12 @@ class ArbitrarySendErc20:
is_dependent( is_dependent(
ir.arguments[0], ir.arguments[0],
SolidityVariableComposed("msg.sender"), SolidityVariableComposed("msg.sender"),
node.function.contract, node,
) )
or is_dependent( or is_dependent(
ir.arguments[0], ir.arguments[0],
SolidityVariable("this"), SolidityVariable("this"),
node.function.contract, node,
) )
) )
): ):
@ -79,12 +79,12 @@ class ArbitrarySendErc20:
is_dependent( is_dependent(
ir.arguments[1], ir.arguments[1],
SolidityVariableComposed("msg.sender"), SolidityVariableComposed("msg.sender"),
node.function.contract, node,
) )
or is_dependent( or is_dependent(
ir.arguments[1], ir.arguments[1],
SolidityVariable("this"), SolidityVariable("this"),
node.function.contract, node,
) )
) )
): ):

@ -1,5 +1,9 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
from .arbitrary_send_erc20 import ArbitrarySendErc20 from .arbitrary_send_erc20 import ArbitrarySendErc20
@ -38,7 +42,7 @@ Use `msg.sender` as `from` in transferFrom.
arbitrary_sends.detect() arbitrary_sends.detect()
for node in arbitrary_sends.no_permit_results: for node in arbitrary_sends.no_permit_results:
func = node.function func = node.function
info = [func, " uses arbitrary from in transferFrom: ", node, "\n"] info: DETECTOR_INFO = [func, " uses arbitrary from in transferFrom: ", node, "\n"]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -1,5 +1,9 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
from .arbitrary_send_erc20 import ArbitrarySendErc20 from .arbitrary_send_erc20 import ArbitrarySendErc20
@ -41,7 +45,7 @@ Ensure that the underlying ERC20 token correctly implements a permit function.
arbitrary_sends.detect() arbitrary_sends.detect()
for node in arbitrary_sends.permit_results: for node in arbitrary_sends.permit_results:
func = node.function func = node.function
info = [ info: DETECTOR_INFO = [
func, func,
" uses arbitrary from in transferFrom in combination with permit: ", " uses arbitrary from in transferFrom in combination with permit: ",
node, node,

@ -6,7 +6,11 @@ from typing import List, Tuple
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -109,7 +113,7 @@ contract Token{
functions = IncorrectERC20InterfaceDetection.detect_incorrect_erc20_interface(c) functions = IncorrectERC20InterfaceDetection.detect_incorrect_erc20_interface(c)
if functions: if functions:
for function in functions: for function in functions:
info = [ info: DETECTOR_INFO = [
c, c,
" has incorrect ERC20 function interface:", " has incorrect ERC20 function interface:",
function, function,

@ -2,7 +2,11 @@
Detect incorrect erc721 interface. Detect incorrect erc721 interface.
""" """
from typing import Any, List, Tuple, Union from typing import Any, List, Tuple, Union
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.utils.output import Output from slither.utils.output import Output
@ -89,7 +93,9 @@ contract Token{
return False return False
@staticmethod @staticmethod
def detect_incorrect_erc721_interface(contract: Contract) -> List[Union[FunctionContract, Any]]: def detect_incorrect_erc721_interface(
contract: Contract,
) -> List[Union[FunctionContract, Any]]:
"""Detect incorrect ERC721 interface """Detect incorrect ERC721 interface
Returns: Returns:
@ -119,7 +125,7 @@ contract Token{
functions = IncorrectERC721InterfaceDetection.detect_incorrect_erc721_interface(c) functions = IncorrectERC721InterfaceDetection.detect_incorrect_erc721_interface(c)
if functions: if functions:
for function in functions: for function in functions:
info = [ info: DETECTOR_INFO = [
c, c,
" has incorrect ERC721 function interface:", " has incorrect ERC721 function interface:",
function, function,

@ -1,6 +1,10 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -28,7 +32,7 @@ class Backdoor(AbstractDetector):
for f in contract.functions: for f in contract.functions:
if "backdoor" in f.name: if "backdoor" in f.name:
# Info to be printed # Info to be printed
info = ["Backdoor function found in ", f, "\n"] info: DETECTOR_INFO = ["Backdoor function found in ", f, "\n"]
# Add the result in result # Add the result in result
res = self.generate_result(info) res = self.generate_result(info)

@ -18,7 +18,9 @@ from slither.core.declarations.function_contract import FunctionContract
from slither.core.declarations.solidity_variables import ( from slither.core.declarations.solidity_variables import (
SolidityFunction, SolidityFunction,
SolidityVariableComposed, SolidityVariableComposed,
SolidityVariable,
) )
from slither.core.variables import Variable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.slithir.operations import ( from slither.slithir.operations import (
HighLevelCall, HighLevelCall,
@ -39,6 +41,10 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
ret: List[Node] = [] ret: List[Node] = []
for node in func.nodes: for node in func.nodes:
func = node.function
deps_target: Union[Contract, Function] = (
func.contract if isinstance(func, FunctionContract) else func
)
for ir in node.irs: for ir in node.irs:
if isinstance(ir, SolidityCall): if isinstance(ir, SolidityCall):
if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"): if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"):
@ -49,7 +55,7 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
if is_dependent( if is_dependent(
ir.variable_right, ir.variable_right,
SolidityVariableComposed("msg.sender"), SolidityVariableComposed("msg.sender"),
func.contract, deps_target,
): ):
return False return False
if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)):
@ -64,12 +70,13 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
if is_dependent( if is_dependent(
ir.call_value, ir.call_value,
SolidityVariableComposed("msg.value"), SolidityVariableComposed("msg.value"),
func.contract, node,
): ):
continue continue
if is_tainted(ir.destination, func.contract): if isinstance(ir.destination, (Variable, SolidityVariable)):
ret.append(node) if is_tainted(ir.destination, node):
ret.append(node)
return ret return ret

@ -1,7 +1,11 @@
from typing import List, Tuple from typing import List, Tuple
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.code_complexity import compute_cyclomatic_complexity from slither.utils.code_complexity import compute_cyclomatic_complexity
from slither.utils.output import Output from slither.utils.output import Output
@ -44,7 +48,7 @@ class CyclomaticComplexity(AbstractDetector):
_check_for_high_cc(high_cc_functions, f) _check_for_high_cc(high_cc_functions, f)
for f, cc in high_cc_functions: for f, cc in high_cc_functions:
info = [f, f" has a high cyclomatic complexity ({cc}).\n"] info: DETECTOR_INFO = [f, f" has a high cyclomatic complexity ({cc}).\n"]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)
return results return results

@ -4,7 +4,11 @@ Module detecting dead code
from typing import List, Tuple from typing import List, Tuple
from slither.core.declarations import Function, FunctionContract, Contract from slither.core.declarations import Function, FunctionContract, Contract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -72,7 +76,7 @@ contract Contract{
# Continue if the functon is not implemented because it means the contract is abstract # Continue if the functon is not implemented because it means the contract is abstract
if not function.is_implemented: if not function.is_implemented:
continue continue
info = [function, " is never used and should be removed\n"] info: DETECTOR_INFO = [function, " is never used and should be removed\n"]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -6,7 +6,11 @@ are in the outermost scope, they do not guarantee a revert, so a
default value can still be returned. default value can still be returned.
""" """
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.core.cfg.node import Node, NodeType from slither.core.cfg.node import Node, NodeType
from slither.utils.output import Output from slither.utils.output import Output
@ -82,7 +86,11 @@ If the condition in `myModif` is false, the execution of `get()` will return 0."
node = None node = None
else: else:
# Nothing was found in the outer scope # Nothing was found in the outer scope
info = ["Modifier ", mod, " does not always execute _; or revert"] info: DETECTOR_INFO = [
"Modifier ",
mod,
" does not always execute _; or revert",
]
res = self.generate_result(info) res = self.generate_result(info)
results.append(res) results.append(res)

@ -6,7 +6,11 @@ from typing import Union, List
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.function import get_function_id from slither.utils.function import get_function_id
from slither.utils.output import Output from slither.utils.output import Output
@ -63,7 +67,7 @@ contract Contract{
assert isinstance(func_or_var, StateVariable) assert isinstance(func_or_var, StateVariable)
incorrect_return_type = func_or_var.type != ElementaryType("bytes32") incorrect_return_type = func_or_var.type != ElementaryType("bytes32")
if hash_collision or incorrect_return_type: if hash_collision or incorrect_return_type:
info = [ info: DETECTOR_INFO = [
"The function signature of ", "The function signature of ",
func_or_var, func_or_var,
" collides with DOMAIN_SEPARATOR and should be renamed or removed.\n", " collides with DOMAIN_SEPARATOR and should be renamed or removed.\n",

@ -6,7 +6,11 @@ A suicidal contract is an unprotected function that calls selfdestruct
from typing import List from typing import List
from slither.core.declarations import Function, Contract from slither.core.declarations import Function, Contract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -58,7 +62,7 @@ contract Buggy{
self.logger.error(f"{function_sig} not found") self.logger.error(f"{function_sig} not found")
continue continue
if function_protection not in function.all_internal_calls(): if function_protection not in function.all_internal_calls():
info = [ info: DETECTOR_INFO = [
function, function,
" should have ", " should have ",
function_protection, function_protection,

@ -7,7 +7,11 @@ from typing import List
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -78,7 +82,7 @@ Bob calls `kill` and destructs the contract."""
functions = self.detect_suicidal(c) functions = self.detect_suicidal(c)
for func in functions: for func in functions:
info = [func, " allows anyone to destruct the contract\n"] info: DETECTOR_INFO = [func, " allows anyone to destruct the contract\n"]
res = self.generate_result(info) res = self.generate_result(info)

@ -8,7 +8,13 @@ Consider public state variables as implemented functions
Do not consider fallback function or constructor Do not consider fallback function or constructor
""" """
from typing import List, Set from typing import List, Set
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.core.declarations import Function
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.utils.output import Output from slither.utils.output import Output
@ -62,7 +68,7 @@ All unimplemented functions must be implemented on a contract that is meant to b
def _match_state_variable(contract: Contract, f: FunctionContract) -> bool: def _match_state_variable(contract: Contract, f: FunctionContract) -> bool:
return any(s.full_name == f.full_name for s in contract.state_variables) return any(s.full_name == f.full_name for s in contract.state_variables)
def _detect_unimplemented_function(self, contract: Contract) -> Set[FunctionContract]: def _detect_unimplemented_function(self, contract: Contract) -> Set[Function]:
""" """
Detects any function definitions which are not implemented in the given contract. Detects any function definitions which are not implemented in the given contract.
:param contract: The contract to search unimplemented functions for. :param contract: The contract to search unimplemented functions for.
@ -77,6 +83,8 @@ All unimplemented functions must be implemented on a contract that is meant to b
# fallback function and constructor. # fallback function and constructor.
unimplemented = set() unimplemented = set()
for f in contract.all_functions_called: for f in contract.all_functions_called:
if not isinstance(f, Function):
continue
if ( if (
not f.is_implemented not f.is_implemented
and not f.is_constructor and not f.is_constructor
@ -102,7 +110,7 @@ All unimplemented functions must be implemented on a contract that is meant to b
for contract in self.compilation_unit.contracts_derived: for contract in self.compilation_unit.contracts_derived:
functions = self._detect_unimplemented_function(contract) functions = self._detect_unimplemented_function(contract)
if functions: if functions:
info = [contract, " does not implement functions:\n"] info: DETECTOR_INFO = [contract, " does not implement functions:\n"]
for function in sorted(functions, key=lambda x: x.full_name): for function in sorted(functions, key=lambda x: x.full_name):
info += ["\t- ", function, "\n"] info += ["\t- ", function, "\n"]

@ -1,6 +1,10 @@
import re import re
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.formatters.naming_convention.naming_convention import custom_format from slither.formatters.naming_convention.naming_convention import custom_format
from slither.utils.output import Output from slither.utils.output import Output
@ -63,6 +67,7 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2
def _detect(self) -> List[Output]: def _detect(self) -> List[Output]:
results = [] results = []
info: DETECTOR_INFO
for contract in self.contracts: for contract in self.contracts:
if not self.is_cap_words(contract.name): if not self.is_cap_words(contract.name):

@ -50,14 +50,17 @@ def contains_bad_PRNG_sources(func: Function, blockhash_ret_values: List[Variabl
for node in func.nodes: for node in func.nodes:
for ir in node.irs_ssa: for ir in node.irs_ssa:
if isinstance(ir, Binary) and ir.type == BinaryType.MODULO: if isinstance(ir, Binary) and ir.type == BinaryType.MODULO:
var_left = ir.variable_left
if not isinstance(var_left, (Variable, SolidityVariable)):
continue
if is_dependent_ssa( if is_dependent_ssa(
ir.variable_left, SolidityVariableComposed("block.timestamp"), func.contract var_left, SolidityVariableComposed("block.timestamp"), node
) or is_dependent_ssa(ir.variable_left, SolidityVariable("now"), func.contract): ) or is_dependent_ssa(var_left, SolidityVariable("now"), node):
ret.add(node) ret.add(node)
break break
for ret_val in blockhash_ret_values: for ret_val in blockhash_ret_values:
if is_dependent_ssa(ir.variable_left, ret_val, func.contract): if is_dependent_ssa(var_left, ret_val, node):
ret.add(node) ret.add(node)
break break
return list(ret) return list(ret)

@ -6,12 +6,17 @@ from typing import List, Tuple
from slither.analyses.data_dependency.data_dependency import is_dependent from slither.analyses.data_dependency.data_dependency import is_dependent
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.core.declarations import Function, Contract from slither.core.declarations import Function, Contract, FunctionContract
from slither.core.declarations.solidity_variables import ( from slither.core.declarations.solidity_variables import (
SolidityVariableComposed, SolidityVariableComposed,
SolidityVariable, SolidityVariable,
) )
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.core.variables import Variable
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import Binary, BinaryType from slither.slithir.operations import Binary, BinaryType
from slither.utils.output import Output from slither.utils.output import Output
@ -21,25 +26,25 @@ def _timestamp(func: Function) -> List[Node]:
for node in func.nodes: for node in func.nodes:
if node.contains_require_or_assert(): if node.contains_require_or_assert():
for var in node.variables_read: for var in node.variables_read:
if is_dependent(var, SolidityVariableComposed("block.timestamp"), func.contract): if is_dependent(var, SolidityVariableComposed("block.timestamp"), node):
ret.add(node) ret.add(node)
if is_dependent(var, SolidityVariable("now"), func.contract): if is_dependent(var, SolidityVariable("now"), node):
ret.add(node) ret.add(node)
for ir in node.irs: for ir in node.irs:
if isinstance(ir, Binary) and BinaryType.return_bool(ir.type): if isinstance(ir, Binary) and BinaryType.return_bool(ir.type):
for var in ir.read: for var_read in ir.read:
if is_dependent( if not isinstance(var_read, (Variable, SolidityVariable)):
var, SolidityVariableComposed("block.timestamp"), func.contract continue
): if is_dependent(var_read, SolidityVariableComposed("block.timestamp"), node):
ret.add(node) ret.add(node)
if is_dependent(var, SolidityVariable("now"), func.contract): if is_dependent(var_read, SolidityVariable("now"), node):
ret.add(node) ret.add(node)
return sorted(list(ret), key=lambda x: x.node_id) return sorted(list(ret), key=lambda x: x.node_id)
def _detect_dangerous_timestamp( def _detect_dangerous_timestamp(
contract: Contract, contract: Contract,
) -> List[Tuple[Function, List[Node]]]: ) -> List[Tuple[FunctionContract, List[Node]]]:
""" """
Args: Args:
contract (Contract) contract (Contract)
@ -48,7 +53,7 @@ def _detect_dangerous_timestamp(
""" """
ret = [] ret = []
for f in [f for f in contract.functions if f.contract_declarer == contract]: for f in [f for f in contract.functions if f.contract_declarer == contract]:
nodes = _timestamp(f) nodes: List[Node] = _timestamp(f)
if nodes: if nodes:
ret.append((f, nodes)) ret.append((f, nodes))
return ret return ret
@ -78,7 +83,7 @@ class Timestamp(AbstractDetector):
dangerous_timestamp = _detect_dangerous_timestamp(c) dangerous_timestamp = _detect_dangerous_timestamp(c)
for (func, nodes) in dangerous_timestamp: for (func, nodes) in dangerous_timestamp:
info = [func, " uses timestamp for comparisons\n"] info: DETECTOR_INFO = [func, " uses timestamp for comparisons\n"]
info += ["\tDangerous comparisons:\n"] info += ["\tDangerous comparisons:\n"]

@ -2,7 +2,11 @@
Module detecting usage of low level calls Module detecting usage of low level calls
""" """
from typing import List, Tuple from typing import List, Tuple
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import LowLevelCall from slither.slithir.operations import LowLevelCall
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
@ -52,7 +56,7 @@ class LowLevelCalls(AbstractDetector):
for c in self.contracts: for c in self.contracts:
values = self.detect_low_level_calls(c) values = self.detect_low_level_calls(c)
for func, nodes in values: for func, nodes in values:
info = ["Low level call in ", func, ":\n"] info: DETECTOR_INFO = ["Low level call in ", func, ":\n"]
# sort the nodes to get deterministic results # sort the nodes to get deterministic results
nodes.sort(key=lambda x: x.node_id) nodes.sort(key=lambda x: x.node_id)

@ -11,7 +11,11 @@ from slither.core.declarations.function_contract import FunctionContract
from slither.core.declarations.modifier import Modifier from slither.core.declarations.modifier import Modifier
from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations.event_call import EventCall from slither.slithir.operations.event_call import EventCall
from slither.utils.output import Output from slither.utils.output import Output
@ -100,7 +104,7 @@ contract C {
for contract in self.compilation_unit.contracts_derived: for contract in self.compilation_unit.contracts_derived:
missing_events = self._detect_missing_events(contract) missing_events = self._detect_missing_events(contract)
for (function, nodes) in missing_events: for (function, nodes) in missing_events:
info = [function, " should emit an event for: \n"] info: DETECTOR_INFO = [function, " should emit an event for: \n"]
for (node, _sv, _mod) in nodes: for (node, _sv, _mod) in nodes:
info += ["\t- ", node, " \n"] info += ["\t- ", node, " \n"]
res = self.generate_result(info) res = self.generate_result(info)

@ -10,7 +10,11 @@ from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types.elementary_type import ElementaryType, Int, Uint from slither.core.solidity_types.elementary_type import ElementaryType, Int, Uint
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations.event_call import EventCall from slither.slithir.operations.event_call import EventCall
from slither.utils.output import Output from slither.utils.output import Output
@ -122,7 +126,7 @@ contract C {
for contract in self.compilation_unit.contracts_derived: for contract in self.compilation_unit.contracts_derived:
missing_events = self._detect_missing_events(contract) missing_events = self._detect_missing_events(contract)
for (function, nodes) in missing_events: for (function, nodes) in missing_events:
info = [function, " should emit an event for: \n"] info: DETECTOR_INFO = [function, " should emit an event for: \n"]
for (node, _) in nodes: for (node, _) in nodes:
info += ["\t- ", node, " \n"] info += ["\t- ", node, " \n"]
res = self.generate_result(info) res = self.generate_result(info)

@ -12,7 +12,11 @@ from slither.core.declarations.function import ModifierStatements
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import Call from slither.slithir.operations import Call
from slither.slithir.operations import Send, Transfer, LowLevelCall from slither.slithir.operations import Send, Transfer, LowLevelCall
from slither.utils.output import Output from slither.utils.output import Output
@ -155,7 +159,7 @@ Bob calls `updateOwner` without specifying the `newOwner`, so Bob loses ownershi
missing_zero_address_validation = self._detect_missing_zero_address_validation(contract) missing_zero_address_validation = self._detect_missing_zero_address_validation(contract)
for (_, var_nodes) in missing_zero_address_validation: for (_, var_nodes) in missing_zero_address_validation:
for var, nodes in var_nodes.items(): for var, nodes in var_nodes.items():
info = [var, " lacks a zero-check on ", ":\n"] info: DETECTOR_INFO = [var, " lacks a zero-check on ", ":\n"]
for node in nodes: for node in nodes:
info += ["\t\t- ", node, "\n"] info += ["\t\t- ", node, "\n"]
res = self.generate_result(info) res = self.generate_result(info)

@ -7,7 +7,11 @@ from slither.core.cfg.node import Node
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import HighLevelCall from slither.slithir.operations import HighLevelCall
from slither.slithir.operations.operation import Operation from slither.slithir.operations.operation import Operation
from slither.utils.output import Output from slither.utils.output import Output
@ -91,7 +95,7 @@ contract MyConc{
if unused_return: if unused_return:
for node in unused_return: for node in unused_return:
info = [f, " ignores return value by ", node, "\n"] info: DETECTOR_INFO = [f, " ignores return value by ", node, "\n"]
res = self.generate_result(info) res = self.generate_result(info)

@ -1,6 +1,10 @@
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import Nop from slither.slithir.operations import Nop
from slither.utils.output import Output from slither.utils.output import Output
@ -39,7 +43,7 @@ When reading `B`'s constructor definition, we might assume that `A()` initiates
for constructor_call in cst.explicit_base_constructor_calls_statements: for constructor_call in cst.explicit_base_constructor_calls_statements:
for node in constructor_call.nodes: for node in constructor_call.nodes:
if any(isinstance(ir, Nop) for ir in node.irs): if any(isinstance(ir, Nop) for ir in node.irs):
info = ["Void constructor called in ", cst, ":\n"] info: DETECTOR_INFO = ["Void constructor called in ", cst, ":\n"]
info += ["\t- ", node, "\n"] info += ["\t- ", node, "\n"]
res = self.generate_result(info) res = self.generate_result(info)

@ -4,7 +4,11 @@ from typing import Dict, List
from slither.analyses.data_dependency.data_dependency import is_dependent from slither.analyses.data_dependency.data_dependency import is_dependent
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.core.declarations import Function, Contract, SolidityVariableComposed from slither.core.declarations import Function, Contract, SolidityVariableComposed
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import LowLevelCall, HighLevelCall from slither.slithir.operations import LowLevelCall, HighLevelCall
from slither.utils.output import Output from slither.utils.output import Output
@ -88,7 +92,7 @@ If you do, ensure your users are aware of the potential issues."""
for contract in self.compilation_unit.contracts_derived: for contract in self.compilation_unit.contracts_derived:
vulns = _detect_token_reentrant(contract) vulns = _detect_token_reentrant(contract)
for function, nodes in vulns.items(): for function, nodes in vulns.items():
info = [function, " is an reentrancy unsafe token function:\n"] info: DETECTOR_INFO = [function, " is an reentrancy unsafe token function:\n"]
for node in nodes: for node in nodes:
info += ["\t-", node, "\n"] info += ["\t-", node, "\n"]
json = self.generate_result(info) json = self.generate_result(info)

@ -9,7 +9,11 @@ from slither.core.declarations.function_contract import FunctionContract
from slither.core.declarations.modifier import Modifier from slither.core.declarations.modifier import Modifier
from slither.core.variables import Variable from slither.core.variables import Variable
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -194,7 +198,7 @@ contract Bug {
shadow_type = shadow[0] shadow_type = shadow[0]
shadow_object = shadow[1] shadow_object = shadow[1]
info = [ info: DETECTOR_INFO = [
shadow_object, shadow_object,
f' ({shadow_type}) shadows built-in symbol"\n', f' ({shadow_type}) shadows built-in symbol"\n',
] ]

@ -9,7 +9,11 @@ from slither.core.declarations.function_contract import FunctionContract
from slither.core.declarations.modifier import Modifier from slither.core.declarations.modifier import Modifier
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -85,7 +89,7 @@ contract Bug {
] = [] ] = []
# Loop through all functions + modifiers in this contract. # Loop through all functions + modifiers in this contract.
for function in contract.functions + contract.modifiers: for function in contract.functions + list(contract.modifiers):
# We should only look for functions declared directly in this contract (not in a base contract). # We should only look for functions declared directly in this contract (not in a base contract).
if function.contract_declarer != contract: if function.contract_declarer != contract:
continue continue
@ -144,7 +148,7 @@ contract Bug {
for shadow in shadows: for shadow in shadows:
local_variable = shadow[0] local_variable = shadow[0]
overshadowed = shadow[1] overshadowed = shadow[1]
info = [local_variable, " shadows:\n"] info: DETECTOR_INFO = [local_variable, " shadows:\n"]
for overshadowed_entry in overshadowed: for overshadowed_entry in overshadowed:
info += [ info += [
"\t- ", "\t- ",

@ -6,7 +6,11 @@ from typing import List
from slither.core.declarations import Contract from slither.core.declarations import Contract
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.detectors.shadowing.common import is_upgradable_gap_variable from slither.detectors.shadowing.common import is_upgradable_gap_variable
from slither.utils.output import Output from slither.utils.output import Output
@ -89,7 +93,7 @@ contract DerivedContract is BaseContract{
for all_variables in shadowing: for all_variables in shadowing:
shadow = all_variables[0] shadow = all_variables[0]
variables = all_variables[1:] variables = all_variables[1:]
info = [shadow, " shadows:\n"] info: DETECTOR_INFO = [shadow, " shadows:\n"]
for var in variables: for var in variables:
info += ["\t- ", var, "\n"] info += ["\t- ", var, "\n"]

@ -1,12 +1,17 @@
from collections import defaultdict from collections import defaultdict
from typing import Any, List from typing import List
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.core.declarations import Contract
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
def _find_missing_inheritance(compilation_unit: SlitherCompilationUnit) -> List[Any]: def _find_missing_inheritance(compilation_unit: SlitherCompilationUnit) -> List[Contract]:
""" """
Filter contracts with missing inheritance to return only the "most base" contracts Filter contracts with missing inheritance to return only the "most base" contracts
in the inheritance tree. in the inheritance tree.
@ -80,7 +85,7 @@ As a result, the second contract cannot be analyzed.
inheritance_corrupted[father.name].append(contract) inheritance_corrupted[father.name].append(contract)
for contract_name, files in names_reused.items(): for contract_name, files in names_reused.items():
info = [contract_name, " is re-used:\n"] info: DETECTOR_INFO = [contract_name, " is re-used:\n"]
for file in files: for file in files:
if file is None: if file is None:
info += ["\t- In an file not found, most likely in\n"] info += ["\t- In an file not found, most likely in\n"]

@ -1,7 +1,11 @@
import re import re
from typing import List from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -78,7 +82,7 @@ contract Token
idx = start_index + result_index idx = start_index + result_index
relative = self.slither.crytic_compile.filename_lookup(filename).relative relative = self.slither.crytic_compile.filename_lookup(filename).relative
info = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n" info: DETECTOR_INFO = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n"
# We have a patch, so pattern.find will return at least one result # We have a patch, so pattern.find will return at least one result

@ -1,7 +1,9 @@
""" """
Module detecting assignment of array length Module detecting assignment of array length
""" """
from typing import List, Set from typing import List, Set, Union
from slither.core.variables import Variable
from slither.detectors.abstract_detector import ( from slither.detectors.abstract_detector import (
AbstractDetector, AbstractDetector,
DetectorClassification, DetectorClassification,
@ -14,7 +16,7 @@ from slither.slithir.variables.reference import ReferenceVariable
from slither.slithir.operations.binary import Binary from slither.slithir.operations.binary import Binary
from slither.analyses.data_dependency.data_dependency import is_tainted from slither.analyses.data_dependency.data_dependency import is_tainted
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.utils.output import Output from slither.utils.output import Output, SupportedOutput
def detect_array_length_assignment(contract: Contract) -> Set[Node]: def detect_array_length_assignment(contract: Contract) -> Set[Node]:
@ -50,7 +52,7 @@ def detect_array_length_assignment(contract: Contract) -> Set[Node]:
elif isinstance(ir, (Assignment, Binary)): elif isinstance(ir, (Assignment, Binary)):
if isinstance(ir.lvalue, ReferenceVariable): if isinstance(ir.lvalue, ReferenceVariable):
if ir.lvalue in array_length_refs and any( if ir.lvalue in array_length_refs and any(
is_tainted(v, contract) for v in ir.read is_tainted(v, contract) for v in ir.read if isinstance(v, Variable)
): ):
# the taint is not precise enough yet # the taint is not precise enough yet
# as a result, REF_0 = REF_0 + 1 # as a result, REF_0 = REF_0 + 1
@ -120,12 +122,16 @@ Otherwise, thoroughly review the contract to ensure a user-controlled variable c
for contract in self.contracts: for contract in self.contracts:
array_length_assignments = detect_array_length_assignment(contract) array_length_assignments = detect_array_length_assignment(contract)
if array_length_assignments: if array_length_assignments:
contract_info = [ contract_info: List[Union[str, SupportedOutput]] = [
contract, contract,
" contract sets array length with a user-controlled value:\n", " contract sets array length with a user-controlled value:\n",
] ]
for node in array_length_assignments: for node in array_length_assignments:
node_info = contract_info + ["\t- ", node, "\n"] node_info: List[Union[str, SupportedOutput]] = contract_info + [
"\t- ",
node,
"\n",
]
res = self.generate_result(node_info) res = self.generate_result(node_info)
results.append(res) results.append(res)

@ -6,7 +6,11 @@ from typing import List, Tuple
from slither.core.cfg.node import Node, NodeType from slither.core.cfg.node import Node, NodeType
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output from slither.utils.output import Output
@ -52,7 +56,7 @@ class Assembly(AbstractDetector):
for c in self.contracts: for c in self.contracts:
values = self.detect_assembly(c) values = self.detect_assembly(c)
for func, nodes in values: for func, nodes in values:
info = [func, " uses assembly\n"] info: DETECTOR_INFO = [func, " uses assembly\n"]
# sort the nodes to get deterministic results # sort the nodes to get deterministic results
nodes.sort(key=lambda x: x.node_id) nodes.sort(key=lambda x: x.node_id)

@ -6,7 +6,11 @@ from typing import List, Tuple
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations.internal_call import InternalCall from slither.slithir.operations.internal_call import InternalCall
from slither.utils.output import Output from slither.utils.output import Output
@ -25,7 +29,7 @@ def detect_assert_state_change(
results = [] results = []
# Loop for each function and modifier. # Loop for each function and modifier.
for function in contract.functions_declared + contract.modifiers_declared: for function in contract.functions_declared + list(contract.modifiers_declared):
for node in function.nodes: for node in function.nodes:
# Detect assert() calls # Detect assert() calls
if any(c.name == "assert(bool)" for c in node.internal_calls) and ( if any(c.name == "assert(bool)" for c in node.internal_calls) and (
@ -36,7 +40,9 @@ def detect_assert_state_change(
any( any(
ir ir
for ir in node.irs for ir in node.irs
if isinstance(ir, InternalCall) and ir.function.state_variables_written if isinstance(ir, InternalCall)
and ir.function
and ir.function.state_variables_written
) )
): ):
results.append((function, node)) results.append((function, node))
@ -85,7 +91,10 @@ The assert in `bad()` increments the state variable `s_a` while checking for the
for contract in self.contracts: for contract in self.contracts:
assert_state_change = detect_assert_state_change(contract) assert_state_change = detect_assert_state_change(contract)
for (func, node) in assert_state_change: for (func, node) in assert_state_change:
info = [func, " has an assert() call which possibly changes state.\n"] info: DETECTOR_INFO = [
func,
" has an assert() call which possibly changes state.\n",
]
info += ["\t-", node, "\n"] info += ["\t-", node, "\n"]
info += [ info += [
"Consider using require() or change the invariant to not modify the state.\n" "Consider using require() or change the invariant to not modify the state.\n"

@ -6,7 +6,11 @@ from typing import List, Set, Tuple
from slither.core.cfg.node import Node from slither.core.cfg.node import Node
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import ( from slither.slithir.operations import (
Binary, Binary,
BinaryType, BinaryType,
@ -84,7 +88,7 @@ Boolean constants can be used directly and do not need to be compare to `true` o
boolean_constant_misuses = self._detect_boolean_equality(contract) boolean_constant_misuses = self._detect_boolean_equality(contract)
for (func, nodes) in boolean_constant_misuses: for (func, nodes) in boolean_constant_misuses:
for node in nodes: for node in nodes:
info = [ info: DETECTOR_INFO = [
func, func,
" compares to a boolean constant:\n\t-", " compares to a boolean constant:\n\t-",
node, node,

@ -7,7 +7,11 @@ from slither.core.cfg.node import Node, NodeType
from slither.core.declarations import Function from slither.core.declarations import Function
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.solidity_types import ElementaryType from slither.core.solidity_types import ElementaryType
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import ( from slither.slithir.operations import (
Assignment, Assignment,
Call, Call,
@ -120,7 +124,7 @@ Other uses (in complex expressions, as conditionals) indicate either an error or
boolean_constant_misuses = self._detect_boolean_constant_misuses(contract) boolean_constant_misuses = self._detect_boolean_constant_misuses(contract)
for (func, nodes) in boolean_constant_misuses: for (func, nodes) in boolean_constant_misuses:
for node in nodes: for node in nodes:
info = [ info: DETECTOR_INFO = [
func, func,
" uses a Boolean constant improperly:\n\t-", " uses a Boolean constant improperly:\n\t-",
node, node,

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save