Add more types hints

pull/1666/head
Feist Josselin 2 years ago
parent 20a79519f3
commit 371e3cbe47
  1. 63
      slither/__main__.py
  2. 88
      slither/analyses/data_dependency/data_dependency.py
  3. 89
      slither/core/cfg/node.py
  4. 9
      slither/core/children/child_contract.py
  5. 11
      slither/core/children/child_event.py
  6. 11
      slither/core/children/child_expression.py
  7. 9
      slither/core/children/child_function.py
  8. 9
      slither/core/children/child_inheritance.py
  9. 11
      slither/core/children/child_node.py
  10. 9
      slither/core/children/child_structure.py
  11. 51
      slither/core/declarations/contract.py
  12. 4
      slither/core/declarations/custom_error.py
  13. 7
      slither/core/declarations/custom_error_contract.py
  14. 2
      slither/core/declarations/custom_error_top_level.py
  15. 7
      slither/core/declarations/function.py
  16. 12
      slither/core/declarations/solidity_variables.py
  17. 1
      slither/core/dominators/utils.py
  18. 2
      slither/core/expressions/assignment_operation.py
  19. 4
      slither/core/slither_core.py
  20. 4
      slither/core/solidity_types/array_type.py
  21. 6
      slither/core/solidity_types/elementary_type.py
  22. 4
      slither/core/solidity_types/mapping_type.py
  23. 4
      slither/core/solidity_types/type_alias.py
  24. 6
      slither/core/solidity_types/type_information.py
  25. 6
      slither/core/source_mapping/source_mapping.py
  26. 2
      slither/core/variables/event_variable.py
  27. 6
      slither/core/variables/variable.py
  28. 4
      slither/detectors/abstract_detector.py
  29. 13
      slither/detectors/assembly/shift_parameter_mixup.py
  30. 6
      slither/detectors/attributes/const_functions_asm.py
  31. 7
      slither/detectors/compiler_bugs/array_by_reference.py
  32. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20.py
  33. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py
  34. 8
      slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py
  35. 10
      slither/detectors/functions/arbitrary_send_eth.py
  36. 3
      slither/detectors/statements/array_length_assignment.py
  37. 8
      slither/detectors/statements/assembly.py
  38. 7
      slither/slithir/operations/call.py
  39. 20
      slither/slithir/operations/high_level_call.py
  40. 26
      slither/slithir/operations/index.py
  41. 17
      slither/slithir/operations/library_call.py
  42. 6
      slither/slithir/operations/low_level_call.py
  43. 18
      slither/slithir/operations/lvalue.py
  44. 12
      slither/slithir/operations/member.py
  45. 11
      slither/slithir/operations/new_contract.py
  46. 11
      slither/slithir/operations/solidity_call.py
  47. 20
      slither/slithir/utils/utils.py
  48. 7
      slither/tools/mutator/__main__.py
  49. 11
      slither/tools/mutator/mutators/abstract_mutator.py

@ -66,7 +66,7 @@ def process_single(
args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
) -> Tuple[Slither, List[Dict], List[Output], int]:
"""
The core high-level code for running Slither static analysis.
@ -89,7 +89,7 @@ def process_all(
args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[List[Slither], List[Dict], List[Dict], int]:
) -> Tuple[List[Slither], List[Dict], List[Output], int]:
compilations = compile_all(target, **vars(args))
slither_instances = []
results_detectors = []
@ -144,23 +144,6 @@ def _process(
return slither, results_detectors, results_printers, analyzed_contracts_count
# TODO: delete me?
def process_from_asts(
filenames: List[str],
args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
all_contracts: List[str] = []
for filename in filenames:
with open(filename, encoding="utf8") as file_open:
contract_loaded = json.load(file_open)
all_contracts.append(contract_loaded["ast"])
return process_single(all_contracts, args, detector_classes, printer_classes)
# endregion
###################################################################################
###################################################################################
@ -608,9 +591,6 @@ def parse_args(
default=False,
)
# if the json is splitted in different files
parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False)
# Disable the throw/catch on partial analyses
parser.add_argument(
"--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False
@ -626,7 +606,7 @@ def parse_args(
args.filter_paths = parse_filter_paths(args)
# Verify our json-type output is valid
args.json_types = set(args.json_types.split(","))
args.json_types = set(args.json_types.split(",")) # type:ignore
for json_type in args.json_types:
if json_type not in JSON_OUTPUT_TYPES:
raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.')
@ -697,14 +677,14 @@ class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods
class FormatterCryticCompile(logging.Formatter):
def format(self, record):
def format(self, record: logging.LogRecord) -> str:
# for i, msg in enumerate(record.msg):
if record.msg.startswith("Compilation warnings/errors on "):
txt = record.args[1]
txt = txt.split("\n")
txt = record.args[1] # type:ignore
txt = txt.split("\n") # type:ignore
txt = [red(x) if "Error" in x else x for x in txt]
txt = "\n".join(txt)
record.args = (record.args[0], txt)
record.args = (record.args[0], txt) # type:ignore
return super().format(record)
@ -747,7 +727,7 @@ def main_impl(
set_colorization_enabled(False if args.disable_color else sys.stdout.isatty())
# Define some variables for potential JSON output
json_results = {}
json_results: Dict[str, Any] = {}
output_error = None
outputting_json = args.json is not None
outputting_json_stdout = args.json == "-"
@ -796,7 +776,7 @@ def main_impl(
crytic_compile_error.setLevel(logging.INFO)
results_detectors: List[Dict] = []
results_printers: List[Dict] = []
results_printers: List[Output] = []
try:
filename = args.filename
@ -809,26 +789,17 @@ def main_impl(
number_contracts = 0
slither_instances = []
if args.splitted:
for filename in filenames:
(
slither_instance,
results_detectors,
results_printers,
number_contracts,
) = process_from_asts(filenames, args, detector_classes, printer_classes)
results_detectors_tmp,
results_printers_tmp,
number_contracts_tmp,
) = process_single(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp
results_printers += results_printers_tmp
slither_instances.append(slither_instance)
else:
for filename in filenames:
(
slither_instance,
results_detectors_tmp,
results_printers_tmp,
number_contracts_tmp,
) = process_single(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp
results_printers += results_printers_tmp
slither_instances.append(slither_instance)
# Rely on CryticCompile to discern the underlying type of compilations.
else:

@ -4,6 +4,7 @@
from collections import defaultdict
from typing import Union, Set, Dict, TYPE_CHECKING
from slither.core.cfg.node import Node
from slither.core.declarations import (
Contract,
Enum,
@ -12,6 +13,7 @@ from slither.core.declarations import (
SolidityVariable,
SolidityVariableComposed,
Structure,
FunctionContract,
)
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.variables.top_level_variable import TopLevelVariable
@ -40,25 +42,37 @@ if TYPE_CHECKING:
Variable_types = Union[Variable, SolidityVariable]
# TODO refactor the data deps to be better suited for top level function object
# Right now we allow to pass a node to ease the API, but we need something
# better
# The deps propagation for top level elements is also not working as expected
Context_types_API = Union[Contract, Function, Node]
Context_types = Union[Contract, Function]
def is_dependent(
variable: Variable_types,
source: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
) -> bool:
"""
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args:
variable (Variable)
source (Variable)
context (Contract|Function)
context (Contract|Function|Node).
only_unprotected (bool): True only unprotected function are considered
Returns:
bool
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
if isinstance(variable, Constant):
return False
if variable == source:
@ -76,10 +90,13 @@ def is_dependent(
def is_dependent_ssa(
variable: Variable_types,
source: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
) -> bool:
"""
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args:
variable (Variable)
taint (Variable)
@ -88,7 +105,10 @@ def is_dependent_ssa(
Returns:
bool
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
context_dict = context.context
if isinstance(variable, Constant):
return False
@ -112,11 +132,14 @@ GENERIC_TAINT = {
def is_tainted(
variable: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
ignore_generic_taint: bool = False,
) -> bool:
"""
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args:
variable
context (Contract|Function)
@ -124,7 +147,10 @@ def is_tainted(
Returns:
bool
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if isinstance(variable, Constant):
return False
@ -139,11 +165,14 @@ def is_tainted(
def is_tainted_ssa(
variable: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
ignore_generic_taint: bool = False,
):
) -> bool:
"""
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args:
variable
context (Contract|Function)
@ -151,7 +180,10 @@ def is_tainted_ssa(
Returns:
bool
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if isinstance(variable, Constant):
return False
@ -166,18 +198,23 @@ def is_tainted_ssa(
def get_dependencies(
variable: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
) -> Set[Variable]:
"""
Return the variables for which `variable` depends on.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param variable: The target
:param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions
:return: set(Variable)
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if only_unprotected:
return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set())
@ -185,16 +222,21 @@ def get_dependencies(
def get_all_dependencies(
context: Context_types, only_unprotected: bool = False
context: Context_types_API, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]:
"""
Return the dictionary of dependencies.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions
:return: Dict(Variable, set(Variable))
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if only_unprotected:
return context.context[KEY_NON_SSA_UNPROTECTED]
@ -203,18 +245,23 @@ def get_all_dependencies(
def get_dependencies_ssa(
variable: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
) -> Set[Variable]:
"""
Return the variables for which `variable` depends on (SSA version).
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param variable: The target (must be SSA variable)
:param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions
:return: set(Variable)
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if only_unprotected:
return context.context[KEY_SSA_UNPROTECTED].get(variable, set())
@ -222,16 +269,21 @@ def get_dependencies_ssa(
def get_all_dependencies_ssa(
context: Context_types, only_unprotected: bool = False
context: Context_types_API, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]:
"""
Return the dictionary of dependencies.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions
:return: Dict(Variable, set(Variable))
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if only_unprotected:
return context.context[KEY_SSA_UNPROTECTED]

@ -6,7 +6,7 @@ from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING
from slither.all_exceptions import SlitherException
from slither.core.children.child_function import ChildFunction
from slither.core.declarations import Contract, Function
from slither.core.declarations import Contract, Function, FunctionContract
from slither.core.declarations.solidity_variables import (
SolidityVariable,
SolidityFunction,
@ -33,6 +33,7 @@ from slither.slithir.operations import (
Return,
Operation,
)
from slither.slithir.utils.utils import RVALUE
from slither.slithir.variables import (
Constant,
LocalIRVariable,
@ -146,12 +147,12 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._node_id: int = node_id
self._vars_written: List[Variable] = []
self._vars_read: List[Variable] = []
self._vars_read: List[Union[Variable, SolidityVariable]] = []
self._ssa_vars_written: List["SlithIRVariable"] = []
self._ssa_vars_read: List["SlithIRVariable"] = []
self._internal_calls: List["Function"] = []
self._internal_calls: List[Union["Function", "SolidityFunction"]] = []
self._solidity_calls: List[SolidityFunction] = []
self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls
self._library_calls: List["LibraryCallType"] = []
@ -172,7 +173,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._local_vars_read: List[LocalVariable] = []
self._local_vars_written: List[LocalVariable] = []
self._slithir_vars: Set["SlithIRVariable"] = set() # non SSA
self._slithir_vars: Set[
Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]
] = set() # non SSA
self._ssa_local_vars_read: List[LocalIRVariable] = []
self._ssa_local_vars_written: List[LocalIRVariable] = []
@ -213,7 +216,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._node_type
@type.setter
def type(self, new_type: NodeType):
def type(self, new_type: NodeType) -> None:
self._node_type = new_type
@property
@ -232,7 +235,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
###################################################################################
@property
def variables_read(self) -> List[Variable]:
def variables_read(self) -> List[Union[Variable, SolidityVariable]]:
"""
list(Variable): Variables read (local/state/solidity)
"""
@ -285,11 +288,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._expression_vars_read
@variables_read_as_expression.setter
def variables_read_as_expression(self, exprs: List[Expression]):
def variables_read_as_expression(self, exprs: List[Expression]) -> None:
self._expression_vars_read = exprs
@property
def slithir_variables(self) -> List["SlithIRVariable"]:
def slithir_variables(
self,
) -> List[Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]]:
return list(self._slithir_vars)
@property
@ -339,7 +344,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._expression_vars_written
@variables_written_as_expression.setter
def variables_written_as_expression(self, exprs: List[Expression]):
def variables_written_as_expression(self, exprs: List[Expression]) -> None:
self._expression_vars_written = exprs
# endregion
@ -399,7 +404,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._external_calls_as_expressions
@external_calls_as_expressions.setter
def external_calls_as_expressions(self, exprs: List[Expression]):
def external_calls_as_expressions(self, exprs: List[Expression]) -> None:
self._external_calls_as_expressions = exprs
@property
@ -410,7 +415,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._internal_calls_as_expressions
@internal_calls_as_expressions.setter
def internal_calls_as_expressions(self, exprs: List[Expression]):
def internal_calls_as_expressions(self, exprs: List[Expression]) -> None:
self._internal_calls_as_expressions = exprs
@property
@ -418,10 +423,10 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return list(self._expression_calls)
@calls_as_expression.setter
def calls_as_expression(self, exprs: List[Expression]):
def calls_as_expression(self, exprs: List[Expression]) -> None:
self._expression_calls = exprs
def can_reenter(self, callstack=None) -> bool:
def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
"""
Check if the node can re-enter
Do not consider CREATE as potential re-enter, but check if the
@ -567,7 +572,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
"""
self._fathers.append(father)
def set_fathers(self, fathers: List["Node"]):
def set_fathers(self, fathers: List["Node"]) -> None:
"""Set the father nodes
Args:
@ -663,20 +668,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._irs_ssa
@irs_ssa.setter
def irs_ssa(self, irs):
def irs_ssa(self, irs: List[Operation]) -> None:
self._irs_ssa = irs
def add_ssa_ir(self, ir: Operation) -> None:
"""
Use to place phi operation
"""
ir.set_node(self)
ir.set_node(self) # type: ignore
self._irs_ssa.append(ir)
def slithir_generation(self) -> None:
if self.expression:
expression = self.expression
self._irs = convert_expression(expression, self)
self._irs = convert_expression(expression, self) # type:ignore
self._find_read_write_call()
@ -713,7 +718,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._dominators
@dominators.setter
def dominators(self, dom: Set["Node"]):
def dominators(self, dom: Set["Node"]) -> None:
self._dominators = dom
@property
@ -725,7 +730,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._immediate_dominator
@immediate_dominator.setter
def immediate_dominator(self, idom: "Node"):
def immediate_dominator(self, idom: "Node") -> None:
self._immediate_dominator = idom
@property
@ -737,7 +742,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
return self._dominance_frontier
@dominance_frontier.setter
def dominance_frontier(self, doms: Set["Node"]):
def dominance_frontier(self, doms: Set["Node"]) -> None:
"""
Returns:
set(Node)
@ -789,6 +794,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None:
if variable.name not in self._phi_origins_local_variables:
assert variable.name
self._phi_origins_local_variables[variable.name] = (variable, set())
(v, nodes) = self._phi_origins_local_variables[variable.name]
assert v == variable
@ -827,7 +833,8 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
if isinstance(ir, OperationWithLValue):
var = ir.lvalue
if var and self._is_valid_slithir_var(var):
self._slithir_vars.add(var)
# The type is checked by is_valid_slithir_var
self._slithir_vars.add(var) # type: ignore
if not isinstance(ir, (Phi, Index, Member)):
self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)]
@ -835,8 +842,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
if isinstance(var, ReferenceVariable):
self._vars_read.append(var.points_to_origin)
elif isinstance(ir, (Member, Index)):
# TODO investigate types for member variable left
var = ir.variable_left if isinstance(ir, Member) else ir.variable_right
if self._is_non_slithir_var(var):
if var and self._is_non_slithir_var(var):
self._vars_read.append(var)
if isinstance(var, ReferenceVariable):
origin = var.points_to_origin
@ -860,14 +868,21 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._internal_calls.append(ir.function)
if isinstance(ir, LowLevelCall):
assert isinstance(ir.destination, (Variable, SolidityVariable))
self._low_level_calls.append((ir.destination, ir.function_name.value))
self._low_level_calls.append((ir.destination, str(ir.function_name.value)))
elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall):
# Todo investigate this if condition
# It does seem right to compare against a contract
# This might need a refactoring
if isinstance(ir.destination.type, Contract):
self._high_level_calls.append((ir.destination.type, ir.function))
elif ir.destination == SolidityVariable("this"):
self._high_level_calls.append((self.function.contract, ir.function))
func = self.function
# Can't use this in a top level function
assert isinstance(func, FunctionContract)
self._high_level_calls.append((func.contract, ir.function))
else:
try:
# Todo this part needs more tests and documentation
self._high_level_calls.append((ir.destination.type.type, ir.function))
except AttributeError as error:
# pylint: disable=raise-missing-from
@ -883,7 +898,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._vars_read = list(set(self._vars_read))
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)]
self._solidity_vars_read = [
v_ for v_ in self._vars_read if isinstance(v_, SolidityVariable)
]
self._vars_written = list(set(self._vars_written))
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
@ -895,12 +912,15 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
@staticmethod
def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]:
non_ssa_var: Optional[Union[StateVariable, LocalVariable]]
if isinstance(v, StateIRVariable):
contract = v.contract
assert v.name
non_ssa_var = contract.get_state_variable_from_name(v.name)
return non_ssa_var
assert isinstance(v, LocalIRVariable)
function = v.function
assert v.name
non_ssa_var = function.get_local_variable_from_name(v.name)
return non_ssa_var
@ -921,10 +941,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._ssa_vars_read.append(origin)
elif isinstance(ir, (Member, Index)):
if isinstance(ir.variable_right, (StateIRVariable, LocalIRVariable)):
self._ssa_vars_read.append(ir.variable_right)
if isinstance(ir.variable_right, ReferenceVariable):
origin = ir.variable_right.points_to_origin
variable_right: RVALUE = ir.variable_right
if isinstance(variable_right, (StateIRVariable, LocalIRVariable)):
self._ssa_vars_read.append(variable_right)
if isinstance(variable_right, ReferenceVariable):
origin = variable_right.points_to_origin
if isinstance(origin, (StateIRVariable, LocalIRVariable)):
self._ssa_vars_read.append(origin)
@ -944,20 +965,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)]
self._ssa_vars_written = list(set(self._ssa_vars_written))
self._ssa_state_vars_written = [
v for v in self._ssa_vars_written if isinstance(v, StateVariable)
v for v in self._ssa_vars_written if v and isinstance(v, StateIRVariable)
]
self._ssa_local_vars_written = [
v for v in self._ssa_vars_written if isinstance(v, LocalVariable)
v for v in self._ssa_vars_written if v and isinstance(v, LocalIRVariable)
]
vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read]
vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written]
self._vars_read += [v for v in vars_read if v not in self._vars_read]
self._vars_read += [v_ for v_ in vars_read if v_ and v_ not in self._vars_read]
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
self._vars_written += [v for v in vars_written if v not in self._vars_written]
self._vars_written += [v_ for v_ in vars_written if v_ and v_ not in self._vars_written]
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
@ -974,7 +995,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met
additional_info += " " + str(self.expression)
elif self.variable_declaration:
additional_info += " " + str(self.variable_declaration)
txt = self._node_type.value + additional_info
txt = str(self._node_type.value) + additional_info
return txt

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from slither.core.source_mapping.source_mapping import SourceMapping
@ -9,11 +9,14 @@ if TYPE_CHECKING:
class ChildContract(SourceMapping):
def __init__(self) -> None:
super().__init__()
self._contract = None
# TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._contract: Optional["Contract"] = None
def set_contract(self, contract: "Contract") -> None:
self._contract = contract
@property
def contract(self) -> "Contract":
return self._contract
return self._contract # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from slither.core.declarations import Event
@ -7,11 +7,14 @@ if TYPE_CHECKING:
class ChildEvent:
def __init__(self) -> None:
super().__init__()
self._event = None
# TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._event: Optional["Event"] = None
def set_event(self, event: "Event"):
def set_event(self, event: "Event") -> None:
self._event = event
@property
def event(self) -> "Event":
return self._event
return self._event # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Union, Optional
if TYPE_CHECKING:
from slither.core.expressions.expression import Expression
@ -8,11 +8,16 @@ if TYPE_CHECKING:
class ChildExpression:
def __init__(self) -> None:
super().__init__()
self._expression = None
# TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._expression: Optional[Union["Expression", "Operation"]] = None
def set_expression(self, expression: Union["Expression", "Operation"]) -> None:
# TODO investigate when this can be an operation?
# It was auto generated during an AST or detectors tests
self._expression = expression
@property
def expression(self) -> Union["Expression", "Operation"]:
return self._expression
return self._expression # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from slither.core.declarations import Function
@ -7,11 +7,14 @@ if TYPE_CHECKING:
class ChildFunction:
def __init__(self) -> None:
super().__init__()
self._function = None
# TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._function: Optional["Function"] = None
def set_function(self, function: "Function") -> None:
self._function = function
@property
def function(self) -> "Function":
return self._function
return self._function # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from slither.core.declarations import Contract
@ -7,11 +7,14 @@ if TYPE_CHECKING:
class ChildInheritance:
def __init__(self) -> None:
super().__init__()
self._contract_declarer = None
# TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._contract_declarer: Optional["Contract"] = None
def set_contract_declarer(self, contract: "Contract") -> None:
self._contract_declarer = contract
@property
def contract_declarer(self) -> "Contract":
return self._contract_declarer
return self._contract_declarer # type: ignore

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
@ -9,14 +9,17 @@ if TYPE_CHECKING:
class ChildNode:
def __init__(self) -> None:
super().__init__()
self._node = None
# TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._node: Optional["Node"] = None
def set_node(self, node: "Node") -> None:
self._node = node
@property
def node(self) -> "Node":
return self._node
return self._node # type:ignore
@property
def function(self) -> "Function":
@ -24,7 +27,7 @@ class ChildNode:
@property
def contract(self) -> "Contract":
return self.node.function.contract
return self.node.function.contract # type: ignore
@property
def compilation_unit(self) -> "SlitherCompilationUnit":

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from slither.core.declarations import Structure
@ -7,11 +7,14 @@ if TYPE_CHECKING:
class ChildStructure:
def __init__(self) -> None:
super().__init__()
self._structure = None
# TODO remove all the setters for the child objects
# And make it a constructor arguement
# This will remove the optional
self._structure: Optional["Structure"] = None
def set_structure(self, structure: "Structure") -> None:
self._structure = structure
@property
def structure(self) -> "Structure":
return self._structure
return self._structure # type: ignore

@ -81,7 +81,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
# The only str is "*"
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._using_for_complete: Dict[Union[str, Type], List[Type]] = None
self._using_for_complete: Optional[Dict[Union[str, Type], List[Type]]] = None
self._kind: Optional[str] = None
self._is_interface: bool = False
self._is_library: bool = False
@ -275,7 +275,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive
"""
def _merge_using_for(uf1, uf2):
def _merge_using_for(uf1: Dict, uf2: Dict) -> Dict:
result = {**uf1, **uf2}
for key, value in result.items():
if key in uf1 and key in uf2:
@ -524,14 +524,14 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
"""
return list(self._functions.values())
def available_functions_as_dict(self) -> Dict[str, "FunctionContract"]:
def available_functions_as_dict(self) -> Dict[str, "Function"]:
if self._available_functions_as_dict is None:
self._available_functions_as_dict = {
f.full_name: f for f in self._functions.values() if not f.is_shadowed
}
return self._available_functions_as_dict
def add_function(self, func: "FunctionContract"):
def add_function(self, func: "FunctionContract") -> None:
self._functions[func.canonical_name] = func
def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None:
@ -699,7 +699,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
list(Contract): Return the list of contracts derived from self
"""
candidates = self.compilation_unit.contracts
return [c for c in candidates if self in c.inheritance]
return [c for c in candidates if self in c.inheritance] # type: ignore
# endregion
###################################################################################
@ -855,7 +855,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
"""
return next((e for e in self.enums if e.name == enum_name), None)
def get_enum_from_canonical_name(self, enum_name) -> Optional["Enum"]:
def get_enum_from_canonical_name(self, enum_name: str) -> Optional["Enum"]:
"""
Return an enum from a canonical name
Args:
@ -956,7 +956,9 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
###################################################################################
###################################################################################
def get_summary(self, include_shadowed=True) -> Tuple[str, List[str], List[str], List, List]:
def get_summary(
self, include_shadowed: bool = True
) -> Tuple[str, List[str], List[str], List, List]:
"""Return the function summary
:param include_shadowed: boolean to indicate if shadowed functions should be included (default True)
@ -1209,7 +1211,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
@property
def is_test(self) -> bool:
return is_test_contract(self) or self.is_truffle_migration
return is_test_contract(self) or self.is_truffle_migration # type: ignore
# endregion
###################################################################################
@ -1219,7 +1221,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
###################################################################################
def update_read_write_using_ssa(self) -> None:
for function in self.functions + self.modifiers:
for function in self.functions + list(self.modifiers):
function.update_read_write_using_ssa()
# endregion
@ -1254,7 +1256,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_upgradeable
@is_upgradeable.setter
def is_upgradeable(self, upgradeable: bool):
def is_upgradeable(self, upgradeable: bool) -> None:
self._is_upgradeable = upgradeable
@property
@ -1283,7 +1285,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_upgradeable_proxy
@is_upgradeable_proxy.setter
def is_upgradeable_proxy(self, upgradeable_proxy: bool):
def is_upgradeable_proxy(self, upgradeable_proxy: bool) -> None:
self._is_upgradeable_proxy = upgradeable_proxy
@property
@ -1291,7 +1293,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._upgradeable_version
@upgradeable_version.setter
def upgradeable_version(self, version_name: str):
def upgradeable_version(self, version_name: str) -> None:
self._upgradeable_version = version_name
# endregion
@ -1310,7 +1312,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self._is_incorrectly_parsed
@is_incorrectly_constructed.setter
def is_incorrectly_constructed(self, incorrect: bool):
def is_incorrectly_constructed(self, incorrect: bool) -> None:
self._is_incorrectly_parsed = incorrect
def add_constructor_variables(self) -> None:
@ -1322,8 +1324,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
constructor_variable = FunctionContract(self.compilation_unit)
constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES)
constructor_variable.set_contract(self)
constructor_variable.set_contract_declarer(self)
constructor_variable.set_contract(self) # type: ignore
constructor_variable.set_contract_declarer(self) # type: ignore
constructor_variable.set_visibility("internal")
# For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping
@ -1354,8 +1356,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
constructor_variable.set_function_type(
FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES
)
constructor_variable.set_contract(self)
constructor_variable.set_contract_declarer(self)
constructor_variable.set_contract(self) # type: ignore
constructor_variable.set_contract_declarer(self) # type: ignore
constructor_variable.set_visibility("internal")
# For now, source mapping of the constructor variable is the whole contract
# Could be improved with a targeted source mapping
@ -1436,22 +1438,23 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
all_ssa_state_variables_instances[v.canonical_name] = new_var
self._initial_state_variables.append(new_var)
for func in self.functions + self.modifiers:
for func in self.functions + list(self.modifiers):
func.generate_slithir_ssa(all_ssa_state_variables_instances)
def fix_phi(self) -> None:
last_state_variables_instances = {}
initial_state_variables_instances = {}
last_state_variables_instances: Dict[str, List["StateVariable"]] = {}
initial_state_variables_instances: Dict[str, "StateVariable"] = {}
for v in self._initial_state_variables:
last_state_variables_instances[v.canonical_name] = []
initial_state_variables_instances[v.canonical_name] = v
for func in self.functions + self.modifiers:
for func in self.functions + list(self.modifiers):
result = func.get_last_ssa_state_variables_instances()
for variable_name, instances in result.items():
# TODO: investigate the next operation
last_state_variables_instances[variable_name] += instances
for func in self.functions + self.modifiers:
for func in self.functions + list(self.modifiers):
func.fix_phi(last_state_variables_instances, initial_state_variables_instances)
# endregion
@ -1461,7 +1464,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
###################################################################################
###################################################################################
def __eq__(self, other: SourceMapping) -> bool:
def __eq__(self, other: Any) -> bool:
if isinstance(other, str):
return other == self.name
return NotImplemented
@ -1475,6 +1478,6 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
return self.name
def __hash__(self) -> int:
return self._id
return self._id # type:ignore
# endregion

@ -51,7 +51,7 @@ class CustomError(SourceMapping):
return str(t)
@property
def solidity_signature(self) -> Optional[str]:
def solidity_signature(self) -> str:
"""
Return a signature following the Solidity Standard
Contract and converted into address
@ -63,7 +63,7 @@ class CustomError(SourceMapping):
# (set_solidity_sig was not called before find_variable)
if self._solidity_signature is None:
raise ValueError("Custom Error not yet built")
return self._solidity_signature
return self._solidity_signature # type: ignore
def set_solidity_sig(self) -> None:
"""

@ -1,9 +1,14 @@
from typing import TYPE_CHECKING
from slither.core.children.child_contract import ChildContract
from slither.core.declarations.custom_error import CustomError
if TYPE_CHECKING:
from slither.core.declarations import Contract
class CustomErrorContract(CustomError, ChildContract):
def is_declared_by(self, contract):
def is_declared_by(self, contract: "Contract") -> bool:
"""
Check if the element is declared by the contract
:param contract:

@ -9,6 +9,6 @@ if TYPE_CHECKING:
class CustomErrorTopLevel(CustomError, TopLevel):
def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"):
def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None:
super().__init__(compilation_unit)
self.file_scope: "FileScope" = scope

@ -47,7 +47,6 @@ if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
from slither.slithir.variables.state_variable import StateIRVariable
from slither.core.declarations.function_contract import FunctionContract
LOGGER = logging.getLogger("Function")
ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"])
@ -298,7 +297,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def contains_assembly(self, c: bool):
self._contains_assembly = c
def can_reenter(self, callstack: Optional[List["FunctionContract"]] = None) -> bool:
def can_reenter(self, callstack: Optional[List[Union["Function", "Variable"]]] = None) -> bool:
"""
Check if the function can re-enter
Follow internal calls.
@ -1720,8 +1719,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
def fix_phi(
self,
last_state_variables_instances: Dict[str, List["StateIRVariable"]],
initial_state_variables_instances: Dict[str, "StateIRVariable"],
last_state_variables_instances: Dict[str, List["StateVariable"]],
initial_state_variables_instances: Dict[str, "StateVariable"],
) -> None:
from slither.slithir.operations import InternalCall, PhiCallback
from slither.slithir.variables import Constant, StateIRVariable

@ -82,7 +82,7 @@ SOLIDITY_FUNCTIONS: Dict[str, List[str]] = {
}
def solidity_function_signature(name):
def solidity_function_signature(name: str) -> str:
"""
Return the function signature (containing the return value)
It is useful if a solidity function is used as a pointer
@ -106,7 +106,7 @@ class SolidityVariable(SourceMapping):
assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset"))
@property
def state_variable(self):
def state_variable(self) -> str:
if self._name.endswith("_slot"):
return self._name[:-5]
if self._name.endswith("_offset"):
@ -125,7 +125,7 @@ class SolidityVariable(SourceMapping):
def __str__(self) -> str:
return self._name
def __eq__(self, other: SourceMapping) -> bool:
def __eq__(self, other: Any) -> bool:
return self.__class__ == other.__class__ and self.name == other.name
def __hash__(self) -> int:
@ -182,13 +182,13 @@ class SolidityFunction(SourceMapping):
return self._return_type
@return_type.setter
def return_type(self, r: List[Union[TypeInformation, ElementaryType]]):
def return_type(self, r: List[Union[TypeInformation, ElementaryType]]) -> None:
self._return_type = r
def __str__(self) -> str:
return self._name
def __eq__(self, other: "SolidityFunction") -> bool:
def __eq__(self, other: Any) -> bool:
return self.__class__ == other.__class__ and self.name == other.name
def __hash__(self) -> int:
@ -201,7 +201,7 @@ class SolidityCustomRevert(SolidityFunction):
self._custom_error = custom_error
self._return_type: List[Union[TypeInformation, ElementaryType]] = []
def __eq__(self, other: Union["SolidityCustomRevert", SolidityFunction]) -> bool:
def __eq__(self, other: Any) -> bool:
return (
self.__class__ == other.__class__
and self.name == other.name

@ -95,4 +95,5 @@ def compute_dominance_frontier(nodes: List["Node"]) -> None:
runner.dominance_frontier = runner.dominance_frontier.union({node})
while runner != node.immediate_dominator:
runner.dominance_frontier = runner.dominance_frontier.union({node})
assert runner.immediate_dominator
runner = runner.immediate_dominator

@ -91,7 +91,7 @@ class AssignmentOperation(ExpressionTyped):
super().__init__()
left_expression.set_lvalue()
self._expressions = [left_expression, right_expression]
self._type: Optional["AssignmentOperationType"] = expression_type
self._type: AssignmentOperationType = expression_type
self._expression_return_type: Optional["Type"] = expression_return_type
@property

@ -482,8 +482,8 @@ class SlitherCore(Context):
###################################################################################
@property
def crytic_compile(self) -> Optional[CryticCompile]:
return self._crytic_compile
def crytic_compile(self) -> CryticCompile:
return self._crytic_compile # type: ignore
# endregion
###################################################################################

@ -4,11 +4,11 @@ from slither.core.expressions.expression import Expression
from slither.core.solidity_types.type import Type
from slither.visitors.expression.constants_folding import ConstantFolding
from slither.core.expressions.literal import Literal
from slither.core.solidity_types.elementary_type import ElementaryType
if TYPE_CHECKING:
from slither.core.expressions.binary_operation import BinaryOperation
from slither.core.expressions.identifier import Identifier
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.solidity_types.function_type import FunctionType
from slither.core.solidity_types.type_alias import TypeAliasTopLevel
@ -22,7 +22,7 @@ class ArrayType(Type):
assert isinstance(t, Type)
if length:
if isinstance(length, int):
length = Literal(length, "uint256")
length = Literal(length, ElementaryType("uint256"))
assert isinstance(length, Expression)
super().__init__()
self._type: Type = t

@ -1,5 +1,5 @@
import itertools
from typing import Tuple
from typing import Tuple, Optional, Any
from slither.core.solidity_types.type import Type
@ -176,7 +176,7 @@ class ElementaryType(Type):
return self.type
@property
def size(self) -> int:
def size(self) -> Optional[int]:
"""
Return the size in bits
Return None if the size is not known
@ -219,7 +219,7 @@ class ElementaryType(Type):
def __str__(self) -> str:
return self._type
def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
if not isinstance(other, ElementaryType):
return False
return self.type == other.type

@ -1,4 +1,4 @@
from typing import Union, Tuple, TYPE_CHECKING
from typing import Union, Tuple, TYPE_CHECKING, Any
from slither.core.solidity_types.type import Type
@ -38,7 +38,7 @@ class MappingType(Type):
def __str__(self) -> str:
return f"mapping({str(self._from)} => {str(self._to)})"
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, MappingType):
return False
return self.type_from == other.type_from and self.type_to == other.type_to

@ -40,7 +40,7 @@ class TypeAlias(Type):
class TypeAliasTopLevel(TypeAlias, TopLevel):
def __init__(self, underlying_type: Type, name: str, scope: "FileScope") -> None:
def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None:
super().__init__(underlying_type, name)
self.file_scope: "FileScope" = scope
@ -49,7 +49,7 @@ class TypeAliasTopLevel(TypeAlias, TopLevel):
class TypeAliasContract(TypeAlias, ChildContract):
def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None:
def __init__(self, underlying_type: ElementaryType, name: str, contract: "Contract") -> None:
super().__init__(underlying_type, name)
self._contract: "Contract" = contract

@ -1,4 +1,4 @@
from typing import Union, TYPE_CHECKING, Tuple
from typing import Union, TYPE_CHECKING, Tuple, Any
from slither.core.solidity_types import ElementaryType
from slither.core.solidity_types.type import Type
@ -40,10 +40,10 @@ class TypeInformation(Type):
def is_dynamic(self) -> bool:
raise NotImplementedError
def __str__(self):
def __str__(self) -> str:
return f"type({self.type.name})"
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeInformation):
return False
return self.type == other.type

@ -1,6 +1,6 @@
import re
from abc import ABCMeta
from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional
from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional, Any
from Crypto.Hash import SHA1
from crytic_compile.utils.naming import Filename
@ -102,10 +102,10 @@ class Source:
filename_short: str = self.filename.short if self.filename.short else ""
return f"{filename_short}{lines}"
def __hash__(self):
def __hash__(self) -> int:
return hash(str(self))
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return (

@ -16,5 +16,5 @@ class EventVariable(ChildEvent, Variable):
return self._indexed
@indexed.setter
def indexed(self, is_indexed: bool):
def indexed(self, is_indexed: bool) -> bool:
self._indexed = is_indexed

@ -55,7 +55,7 @@ class Variable(SourceMapping):
return self._initialized
@initialized.setter
def initialized(self, is_init: bool):
def initialized(self, is_init: bool) -> None:
self._initialized = is_init
@property
@ -73,7 +73,7 @@ class Variable(SourceMapping):
return self._name
@name.setter
def name(self, name):
def name(self, name: str) -> None:
self._name = name
@property
@ -89,7 +89,7 @@ class Variable(SourceMapping):
return self._is_constant
@is_constant.setter
def is_constant(self, is_cst: bool):
def is_constant(self, is_cst: bool) -> None:
self._is_constant = is_cst
@property

@ -59,6 +59,8 @@ ALL_SOLC_VERSIONS_06 = make_solc_versions(6, 0, 12)
ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6)
# No VERSIONS_08 as it is still in dev
DETECTOR_INFO = Union[str, List[Union[str, SupportedOutput]]]
class AbstractDetector(metaclass=abc.ABCMeta):
ARGUMENT = "" # run the detector with slither.py --ARGUMENT
@ -251,7 +253,7 @@ class AbstractDetector(metaclass=abc.ABCMeta):
def generate_result(
self,
info: Union[str, List[Union[str, SupportedOutput]]],
info: DETECTOR_INFO,
additional_fields: Optional[Dict] = None,
) -> Output:
output = Output(

@ -1,5 +1,9 @@
from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.slithir.operations import Binary, BinaryType
from slither.slithir.variables import Constant
from slither.core.declarations.function_contract import FunctionContract
@ -49,7 +53,12 @@ The shift statement will right-shift the constant 8 by `a` bits"""
BinaryType.RIGHT_SHIFT,
]:
if isinstance(ir.variable_left, Constant):
info = [f, " contains an incorrect shift operation: ", node, "\n"]
info: DETECTOR_INFO = [
f,
" contains an incorrect shift operation: ",
node,
"\n",
]
json = self.generate_result(info)
results.append(json)

@ -7,6 +7,7 @@ from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
ALL_SOLC_VERSIONS_04,
DETECTOR_INFO,
)
from slither.formatters.attributes.const_functions import custom_format
from slither.utils.output import Output
@ -73,7 +74,10 @@ All the calls to `get` revert, breaking Bob's smart contract execution."""
if f.contains_assembly:
attr = "view" if f.view else "pure"
info = [f, f" is declared {attr} but contains assembly code\n"]
info: DETECTOR_INFO = [
f,
f" is declared {attr} but contains assembly code\n",
]
res = self.generate_result(info, {"contains_assembly": True})
results.append(res)

@ -105,7 +105,12 @@ As a result, Bob's usage of the contract is incorrect."""
write to the array unsuccessfully.
"""
# Define our resulting array.
results = []
results: List[
Union[
Tuple[Node, StateVariable, FunctionContract],
Tuple[Node, LocalVariable, FunctionContract],
]
] = []
# Verify we have functions in our list to check for.
if not array_modifying_funcs:

@ -61,12 +61,12 @@ class ArbitrarySendErc20:
is_dependent(
ir.arguments[0],
SolidityVariableComposed("msg.sender"),
node.function.contract,
node,
)
or is_dependent(
ir.arguments[0],
SolidityVariable("this"),
node.function.contract,
node,
)
)
):
@ -79,12 +79,12 @@ class ArbitrarySendErc20:
is_dependent(
ir.arguments[1],
SolidityVariableComposed("msg.sender"),
node.function.contract,
node,
)
or is_dependent(
ir.arguments[1],
SolidityVariable("this"),
node.function.contract,
node,
)
)
):

@ -1,5 +1,9 @@
from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output
from .arbitrary_send_erc20 import ArbitrarySendErc20
@ -38,7 +42,7 @@ Use `msg.sender` as `from` in transferFrom.
arbitrary_sends.detect()
for node in arbitrary_sends.no_permit_results:
func = node.function
info = [func, " uses arbitrary from in transferFrom: ", node, "\n"]
info: DETECTOR_INFO = [func, " uses arbitrary from in transferFrom: ", node, "\n"]
res = self.generate_result(info)
results.append(res)

@ -1,5 +1,9 @@
from typing import List
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output
from .arbitrary_send_erc20 import ArbitrarySendErc20
@ -41,7 +45,7 @@ Ensure that the underlying ERC20 token correctly implements a permit function.
arbitrary_sends.detect()
for node in arbitrary_sends.permit_results:
func = node.function
info = [
info: DETECTOR_INFO = [
func,
" uses arbitrary from in transferFrom in combination with permit: ",
node,

@ -39,6 +39,10 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
ret: List[Node] = []
for node in func.nodes:
func = node.function
deps_target: Union[Contract, Function] = (
func.contract if isinstance(func, FunctionContract) else func
)
for ir in node.irs:
if isinstance(ir, SolidityCall):
if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"):
@ -49,7 +53,7 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
if is_dependent(
ir.variable_right,
SolidityVariableComposed("msg.sender"),
func.contract,
deps_target,
):
return False
if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)):
@ -64,11 +68,11 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]:
if is_dependent(
ir.call_value,
SolidityVariableComposed("msg.value"),
func.contract,
node,
):
continue
if is_tainted(ir.destination, func.contract):
if is_tainted(ir.destination, node):
ret.append(node)
return ret

@ -7,6 +7,7 @@ from slither.detectors.abstract_detector import (
DetectorClassification,
ALL_SOLC_VERSIONS_04,
ALL_SOLC_VERSIONS_05,
DETECTOR_INFO,
)
from slither.core.cfg.node import Node, NodeType
from slither.slithir.operations import Assignment, Length
@ -120,7 +121,7 @@ Otherwise, thoroughly review the contract to ensure a user-controlled variable c
for contract in self.contracts:
array_length_assignments = detect_array_length_assignment(contract)
if array_length_assignments:
contract_info = [
contract_info: DETECTOR_INFO = [
contract,
" contract sets array length with a user-controlled value:\n",
]

@ -6,7 +6,11 @@ from typing import List, Tuple
from slither.core.cfg.node import Node, NodeType
from slither.core.declarations.contract import Contract
from slither.core.declarations.function_contract import FunctionContract
from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification
from slither.detectors.abstract_detector import (
AbstractDetector,
DetectorClassification,
DETECTOR_INFO,
)
from slither.utils.output import Output
@ -52,7 +56,7 @@ class Assembly(AbstractDetector):
for c in self.contracts:
values = self.detect_assembly(c)
for func, nodes in values:
info = [func, " uses assembly\n"]
info: DETECTOR_INFO = [func, " uses assembly\n"]
# sort the nodes to get deterministic results
nodes.sort(key=lambda x: x.node_id)

@ -1,5 +1,7 @@
from typing import Optional, List
from typing import Optional, List, Union
from slither.core.declarations import Function
from slither.core.variables import Variable
from slither.slithir.operations.operation import Operation
@ -16,7 +18,8 @@ class Call(Operation):
def arguments(self, v):
self._arguments = v
def can_reenter(self, _callstack: Optional[List] = None) -> bool: # pylint: disable=no-self-use
# pylint: disable=no-self-use
def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
"""
Must be called after slithIR analysis pass
:return: bool

@ -1,5 +1,6 @@
from typing import List, Optional, Union
from slither.core.declarations import Contract
from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.variables.variable import Variable
@ -32,7 +33,8 @@ class HighLevelCall(Call, OperationWithLValue):
assert is_valid_lvalue(result) or result is None
self._check_destination(destination)
super().__init__()
self._destination = destination
# Contract is only possible for library call, which inherits from highlevelcall
self._destination: Union[Variable, SolidityVariable, Contract] = destination # type: ignore
self._function_name = function_name
self._nbr_arguments = nbr_arguments
self._type_call = type_call
@ -44,8 +46,9 @@ class HighLevelCall(Call, OperationWithLValue):
self._call_gas = None
# Development function, to be removed once the code is stable
# It is ovveride by LbraryCall
def _check_destination(self, destination: SourceMapping) -> None: # pylint: disable=no-self-use
# It is overridden by LibraryCall
# pylint: disable=no-self-use
def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None:
assert isinstance(destination, (Variable, SolidityVariable))
@property
@ -79,7 +82,14 @@ class HighLevelCall(Call, OperationWithLValue):
return [x for x in all_read if x] + [self.destination]
@property
def destination(self) -> SourceMapping:
def destination(self) -> Union[Variable, SolidityVariable, Contract]:
"""
Return a variable or a solidityVariable
Contract is only possible for LibraryCall
Returns:
"""
return self._destination
@property
@ -116,7 +126,7 @@ class HighLevelCall(Call, OperationWithLValue):
return True
return False
def can_reenter(self, callstack: None = None) -> bool:
def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
"""
Must be called after slithIR analysis pass
For Solidity > 0.5, filter access to public variables and constant/pure/view

@ -1,20 +1,20 @@
from typing import List, Union
from slither.core.declarations import SolidityVariableComposed
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue
from slither.slithir.variables.reference import ReferenceVariable
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.variable import Variable
from slither.slithir.variables.reference_ssa import ReferenceVariableSSA
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, RVALUE, LVALUE
from slither.slithir.variables.reference import ReferenceVariable
class Index(OperationWithLValue):
def __init__(
self,
result: Union[ReferenceVariable, ReferenceVariableSSA],
result: ReferenceVariable,
left_variable: Variable,
right_variable: SourceMapping,
right_variable: RVALUE,
index_type: Union[ElementaryType, str],
) -> None:
super().__init__()
@ -25,23 +25,23 @@ class Index(OperationWithLValue):
assert isinstance(result, ReferenceVariable)
self._variables = [left_variable, right_variable]
self._type = index_type
self._lvalue = result
self._lvalue: ReferenceVariable = result
@property
def read(self) -> List[SourceMapping]:
return list(self.variables)
@property
def variables(self) -> List[SourceMapping]:
return self._variables
def variables(self) -> List[Union[LVALUE, RVALUE, SolidityVariableComposed]]:
return self._variables # type: ignore
@property
def variable_left(self) -> Variable:
return self._variables[0]
def variable_left(self) -> Union[LVALUE, SolidityVariableComposed]:
return self._variables[0] # type: ignore
@property
def variable_right(self) -> SourceMapping:
return self._variables[1]
def variable_right(self) -> RVALUE:
return self._variables[1] # type: ignore
@property
def index_type(self) -> Union[ElementaryType, str]:

@ -1,4 +1,7 @@
from slither.core.declarations import Function
from typing import Union, Optional, List
from slither.core.declarations import Function, SolidityVariable
from slither.core.variables import Variable
from slither.slithir.operations.high_level_call import HighLevelCall
from slither.core.declarations.contract import Contract
@ -9,10 +12,10 @@ class LibraryCall(HighLevelCall):
"""
# Development function, to be removed once the code is stable
def _check_destination(self, destination: Contract) -> None:
def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None:
assert isinstance(destination, Contract)
def can_reenter(self, callstack: None = None) -> bool:
def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
"""
Must be called after slithIR analysis pass
:return: bool
@ -20,11 +23,11 @@ class LibraryCall(HighLevelCall):
if self.is_static_call():
return False
# In case of recursion, return False
callstack = [] if callstack is None else callstack
if self.function in callstack:
callstack_local = [] if callstack is None else callstack
if self.function in callstack_local:
return False
callstack = callstack + [self.function]
return self.function.can_reenter(callstack)
callstack_local = callstack_local + [self.function]
return self.function.can_reenter(callstack_local)
def __str__(self):
gas = ""

@ -1,4 +1,6 @@
from typing import List, Union
from typing import List, Union, Optional
from slither.core.declarations import Function
from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.variables.variable import Variable
@ -74,7 +76,7 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta
# remove None
return self._unroll([x for x in all_read if x])
def can_reenter(self, _callstack: None = None) -> bool:
def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
"""
Must be called after slithIR analysis pass
:return: bool

@ -1,4 +1,6 @@
from typing import Any, List
from typing import Any, List, Optional
from slither.core.variables import Variable
from slither.slithir.operations.operation import Operation
@ -10,16 +12,16 @@ class OperationWithLValue(Operation):
def __init__(self) -> None:
super().__init__()
self._lvalue = None
self._lvalue: Optional[Variable] = None
@property
def lvalue(self):
def lvalue(self) -> Optional[Variable]:
return self._lvalue
@property
def used(self) -> List[Any]:
return self.read + [self.lvalue]
@lvalue.setter
def lvalue(self, lvalue):
def lvalue(self, lvalue: Variable) -> None:
self._lvalue = lvalue
@property
def used(self) -> List[Optional[Any]]:
return self.read + [self.lvalue]

@ -5,7 +5,7 @@ from slither.core.declarations.enum import Enum
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.solidity_types import ElementaryType
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.slithir.utils.utils import is_valid_rvalue
from slither.slithir.utils.utils import is_valid_rvalue, RVALUE
from slither.slithir.variables.constant import Constant
from slither.slithir.variables.reference import ReferenceVariable
from slither.core.source_mapping.source_mapping import SourceMapping
@ -39,7 +39,9 @@ class Member(OperationWithLValue):
assert isinstance(variable_right, Constant)
assert isinstance(result, ReferenceVariable)
super().__init__()
self._variable_left = variable_left
self._variable_left: Union[
RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType
] = variable_left
self._variable_right = variable_right
self._lvalue = result
self._gas = None
@ -50,7 +52,11 @@ class Member(OperationWithLValue):
return [self.variable_left, self.variable_right]
@property
def variable_left(self) -> SourceMapping:
def variable_left(
self,
) -> Union[
RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType
]:
return self._variable_left
@property

@ -1,11 +1,13 @@
from typing import Optional, Any, List, Union
from slither.core.declarations import Function
from slither.core.declarations.contract import Contract
from slither.core.variables import Variable
from slither.slithir.operations import Call, OperationWithLValue
from slither.slithir.utils.utils import is_valid_lvalue
from slither.slithir.variables.constant import Constant
from slither.core.declarations.contract import Contract
from slither.slithir.variables.temporary import TemporaryVariable
from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA
from slither.core.declarations.function_contract import FunctionContract
class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes
@ -58,6 +60,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan
def contract_created(self) -> Contract:
contract_name = self.contract_name
contract_instance = self.node.file_scope.get_contract_from_name(contract_name)
assert contract_instance
return contract_instance
###################################################################################
@ -66,7 +69,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan
###################################################################################
###################################################################################
def can_reenter(self, callstack: Optional[List[FunctionContract]] = None) -> bool:
def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool:
"""
Must be called after slithIR analysis pass
For Solidity > 0.5, filter access to public variables and constant/pure/view
@ -92,7 +95,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan
# endregion
def __str__(self):
def __str__(self) -> str:
options = ""
if self.call_value:
options = f"value:{self.call_value} "

@ -1,15 +1,16 @@
from typing import Any, List, Union
from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction
from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue
from slither.core.children.child_node import ChildNode
from slither.core.declarations.solidity_variables import SolidityFunction
from slither.core.solidity_types.elementary_type import ElementaryType
from slither.slithir.operations.call import Call
from slither.slithir.operations.lvalue import OperationWithLValue
class SolidityCall(Call, OperationWithLValue):
def __init__(
self,
function: Union[SolidityCustomRevert, SolidityFunction],
function: SolidityFunction,
nbr_arguments: int,
result: ChildNode,
type_call: Union[str, List[ElementaryType]],
@ -26,7 +27,7 @@ class SolidityCall(Call, OperationWithLValue):
return self._unroll(self.arguments)
@property
def function(self) -> Union[SolidityCustomRevert, SolidityFunction]:
def function(self) -> SolidityFunction:
return self._function
@property

@ -1,3 +1,5 @@
from typing import Union
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.state_variable import StateVariable
@ -10,6 +12,24 @@ from slither.slithir.variables.reference import ReferenceVariable
from slither.slithir.variables.tuple import TupleVariable
from slither.core.source_mapping.source_mapping import SourceMapping
RVALUE = Union[
StateVariable,
LocalVariable,
TopLevelVariable,
TemporaryVariable,
Constant,
SolidityVariable,
ReferenceVariable,
]
LVALUE = Union[
StateVariable,
LocalVariable,
TemporaryVariable,
ReferenceVariable,
TupleVariable,
]
def is_valid_rvalue(v: SourceMapping) -> bool:
return isinstance(

@ -79,9 +79,10 @@ def main() -> None:
print(args.codebase)
sl = Slither(args.codebase, **vars(args))
for M in _get_mutators():
m = M(sl)
m.mutate()
for compilation_unit in sl.compilation_units:
for M in _get_mutators():
m = M(compilation_unit)
m.mutate()
# endregion

@ -3,7 +3,7 @@ import logging
from enum import Enum
from typing import Optional, Dict
from slither import Slither
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.formatters.utils.patches import apply_patch, create_diff
logger = logging.getLogger("Slither")
@ -34,8 +34,11 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public-
FAULTCLASS = FaultClass.Undefined
FAULTNATURE = FaultNature.Undefined
def __init__(self, slither: Slither, rate: int = 10, seed: Optional[int] = None):
self.slither = slither
def __init__(
self, compilation_unit: SlitherCompilationUnit, rate: int = 10, seed: Optional[int] = None
):
self.compilation_unit = compilation_unit
self.slither = compilation_unit.core
self.seed = seed
self.rate = rate
@ -87,7 +90,7 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public-
continue
for patch in patches:
patched_txt, offset = apply_patch(patched_txt, patch, offset)
diff = create_diff(self.slither, original_txt, patched_txt, file)
diff = create_diff(self.compilation_unit, original_txt, patched_txt, file)
if not diff:
logger.info(f"Impossible to generate patch; empty {patches}")
print(diff)

Loading…
Cancel
Save