diff --git a/slither/__main__.py b/slither/__main__.py index 9d611532e..528a93e8f 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -25,7 +25,13 @@ from slither.printers import all_printers from slither.printers.abstract_printer import AbstractPrinter from slither.slither import Slither from slither.utils import codex -from slither.utils.output import output_to_json, output_to_zip, output_to_sarif, ZIP_TYPES_ACCEPTED +from slither.utils.output import ( + output_to_json, + output_to_zip, + output_to_sarif, + ZIP_TYPES_ACCEPTED, + Output, +) from slither.utils.output_capture import StandardOutputCapture from slither.utils.colors import red, set_colorization_enabled from slither.utils.command_line import ( @@ -112,7 +118,7 @@ def _process( slither: Slither, detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[Slither, List[Dict], List[Dict], int]: +) -> Tuple[Slither, List[Dict], List[Output], int]: for detector_cls in detector_classes: slither.register_detector(detector_cls) @@ -125,9 +131,9 @@ def _process( results_printers = [] if not printer_classes: - detector_results = slither.run_detectors() - detector_results = [x for x in detector_results if x] # remove empty results - detector_results = [item for sublist in detector_results for item in sublist] # flatten + detector_resultss = slither.run_detectors() + detector_resultss = [x for x in detector_resultss if x] # remove empty results + detector_results = [item for sublist in detector_resultss for item in sublist] # flatten results_detectors.extend(detector_results) else: diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index f2af8dc57..7643b19b7 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -4,16 +4,19 @@ from enum import Enum from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING +from slither.all_exceptions import SlitherException from slither.core.children.child_function import ChildFunction +from slither.core.declarations import Contract, Function from slither.core.declarations.solidity_variables import ( SolidityVariable, SolidityFunction, ) +from slither.core.expressions.expression import Expression +from slither.core.solidity_types import ElementaryType from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable -from slither.core.solidity_types import ElementaryType from slither.slithir.convert import convert_expression from slither.slithir.operations import ( HighLevelCall, @@ -38,10 +41,6 @@ from slither.slithir.variables import ( TemporaryVariable, TupleVariable, ) -from slither.all_exceptions import SlitherException -from slither.core.declarations import Contract, Function - -from slither.core.expressions.expression import Expression if TYPE_CHECKING: from slither.slithir.variables.variable import SlithIRVariable @@ -119,7 +118,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met node_id: int, scope: Union["Scope", "Function"], file_scope: "FileScope", - ): + ) -> None: super().__init__() self._node_type = node_type @@ -474,11 +473,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ return self._expression - def add_expression(self, expression: Expression, bypass_verif_empty: bool = False): + def add_expression(self, expression: Expression, bypass_verif_empty: bool = False) -> None: assert self._expression is None or bypass_verif_empty self._expression = expression - def add_variable_declaration(self, var: LocalVariable): + def add_variable_declaration(self, var: LocalVariable) -> None: assert self._variable_declaration is None self._variable_declaration = var if var.expression: @@ -511,7 +510,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met for c in self.internal_calls ) - def contains_if(self, include_loop=True) -> bool: + def contains_if(self, include_loop: bool = True) -> bool: """ Check if the node is a IF node Returns: @@ -521,7 +520,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self.type in [NodeType.IF, NodeType.IFLOOP] return self.type == NodeType.IF - def is_conditional(self, include_loop=True) -> bool: + def is_conditional(self, include_loop: bool = True) -> bool: """ Check if the node is a conditional node A conditional node is either a IF or a require/assert or a RETURN bool @@ -550,7 +549,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met def inline_asm(self) -> Optional[Union[str, Dict]]: return self._asm_source_code - def add_inline_asm(self, asm: Union[str, Dict]): + def add_inline_asm(self, asm: Union[str, Dict]) -> None: self._asm_source_code = asm # endregion @@ -560,7 +559,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### ################################################################################### - def add_father(self, father: "Node"): + def add_father(self, father: "Node") -> None: """Add a father node Args: @@ -585,7 +584,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ return list(self._fathers) - def remove_father(self, father: "Node"): + def remove_father(self, father: "Node") -> None: """Remove the father node. Do nothing if the node is not a father Args: @@ -593,7 +592,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ self._fathers = [x for x in self._fathers if x.node_id != father.node_id] - def remove_son(self, son: "Node"): + def remove_son(self, son: "Node") -> None: """Remove the son node. Do nothing if the node is not a son Args: @@ -601,7 +600,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ self._sons = [x for x in self._sons if x.node_id != son.node_id] - def add_son(self, son: "Node"): + def add_son(self, son: "Node") -> None: """Add a son node Args: @@ -609,7 +608,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ self._sons.append(son) - def set_sons(self, sons: List["Node"]): + def set_sons(self, sons: List["Node"]) -> None: """Set the son nodes Args: @@ -667,14 +666,14 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met def irs_ssa(self, irs): self._irs_ssa = irs - def add_ssa_ir(self, ir: Operation): + def add_ssa_ir(self, ir: Operation) -> None: """ Use to place phi operation """ ir.set_node(self) self._irs_ssa.append(ir) - def slithir_generation(self): + def slithir_generation(self) -> None: if self.expression: expression = self.expression self._irs = convert_expression(expression, self) @@ -691,11 +690,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._all_slithir_operations @staticmethod - def _is_non_slithir_var(var: Variable): + def _is_non_slithir_var(var: Variable) -> bool: return not isinstance(var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)) @staticmethod - def _is_valid_slithir_var(var: Variable): + def _is_valid_slithir_var(var: Variable) -> bool: return isinstance(var, (ReferenceVariable, TemporaryVariable, TupleVariable)) # endregion @@ -746,7 +745,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._dominance_frontier = doms @property - def dominator_successors(self): + def dominator_successors(self) -> Set["Node"]: return self._dom_successors @property @@ -788,14 +787,14 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met # def phi_origin_member_variables(self) -> Dict[str, Tuple[MemberVariable, Set["Node"]]]: # return self._phi_origins_member_variables - def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node"): + def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None: if variable.name not in self._phi_origins_local_variables: self._phi_origins_local_variables[variable.name] = (variable, set()) (v, nodes) = self._phi_origins_local_variables[variable.name] assert v == variable nodes.add(node) - def add_phi_origin_state_variable(self, variable: StateVariable, node: "Node"): + def add_phi_origin_state_variable(self, variable: StateVariable, node: "Node") -> None: if variable.canonical_name not in self._phi_origins_state_variables: self._phi_origins_state_variables[variable.canonical_name] = ( variable, @@ -819,7 +818,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### ################################################################################### - def _find_read_write_call(self): # pylint: disable=too-many-statements + def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements for ir in self.irs: @@ -895,7 +894,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._low_level_calls = list(set(self._low_level_calls)) @staticmethod - def _convert_ssa(v: Variable): + def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]: if isinstance(v, StateIRVariable): contract = v.contract non_ssa_var = contract.get_state_variable_from_name(v.name) @@ -905,7 +904,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met non_ssa_var = function.get_local_variable_from_name(v.name) return non_ssa_var - def update_read_write_using_ssa(self): + def update_read_write_using_ssa(self) -> None: if not self.expression: return for ir in self.irs_ssa: @@ -969,7 +968,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### ################################################################################### - def __str__(self): + def __str__(self) -> str: additional_info = "" if self.expression: additional_info += " " + str(self.expression) @@ -987,12 +986,12 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### -def link_nodes(node1: Node, node2: Node): +def link_nodes(node1: Node, node2: Node) -> None: node1.add_son(node2) node2.add_father(node1) -def insert_node(origin: Node, node_inserted: Node): +def insert_node(origin: Node, node_inserted: Node) -> None: sons = origin.sons link_nodes(origin, node_inserted) for son in sons: diff --git a/slither/core/cfg/scope.py b/slither/core/cfg/scope.py index 06e68c3e9..d3ac4e836 100644 --- a/slither/core/cfg/scope.py +++ b/slither/core/cfg/scope.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: # pylint: disable=too-few-public-methods class Scope: - def __init__(self, is_checked: bool, is_yul: bool, scope: Union["Scope", "Function"]): + def __init__(self, is_checked: bool, is_yul: bool, scope: Union["Scope", "Function"]) -> None: self.nodes: List["Node"] = [] self.is_checked = is_checked self.is_yul = is_yul diff --git a/slither/core/children/child_contract.py b/slither/core/children/child_contract.py index 285623b0e..86f9dea53 100644 --- a/slither/core/children/child_contract.py +++ b/slither/core/children/child_contract.py @@ -7,11 +7,11 @@ if TYPE_CHECKING: class ChildContract(SourceMapping): - def __init__(self): + def __init__(self) -> None: super().__init__() self._contract = None - def set_contract(self, contract: "Contract"): + def set_contract(self, contract: "Contract") -> None: self._contract = contract @property diff --git a/slither/core/children/child_event.py b/slither/core/children/child_event.py index 6df697747..df91596e3 100644 --- a/slither/core/children/child_event.py +++ b/slither/core/children/child_event.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: class ChildEvent: - def __init__(self): + def __init__(self) -> None: super().__init__() self._event = None diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py index bc88ea7a4..0064658c0 100644 --- a/slither/core/children/child_expression.py +++ b/slither/core/children/child_expression.py @@ -1,17 +1,18 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from slither.core.expressions.expression import Expression + from slither.slithir.operations import Operation class ChildExpression: - def __init__(self): + def __init__(self) -> None: super().__init__() self._expression = None - def set_expression(self, expression: "Expression"): + def set_expression(self, expression: Union["Expression", "Operation"]) -> None: self._expression = expression @property - def expression(self) -> "Expression": + def expression(self) -> Union["Expression", "Operation"]: return self._expression diff --git a/slither/core/children/child_inheritance.py b/slither/core/children/child_inheritance.py index 5efe53412..30b32f6c1 100644 --- a/slither/core/children/child_inheritance.py +++ b/slither/core/children/child_inheritance.py @@ -5,11 +5,11 @@ if TYPE_CHECKING: class ChildInheritance: - def __init__(self): + def __init__(self) -> None: super().__init__() self._contract_declarer = None - def set_contract_declarer(self, contract: "Contract"): + def set_contract_declarer(self, contract: "Contract") -> None: self._contract_declarer = contract @property diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py index c1fffd49a..8e6e1f0b5 100644 --- a/slither/core/children/child_node.py +++ b/slither/core/children/child_node.py @@ -7,11 +7,11 @@ if TYPE_CHECKING: class ChildNode: - def __init__(self): + def __init__(self) -> None: super().__init__() self._node = None - def set_node(self, node: "Node"): + def set_node(self, node: "Node") -> None: self._node = node @property diff --git a/slither/core/children/child_structure.py b/slither/core/children/child_structure.py index 0f4c7db82..abcb041c2 100644 --- a/slither/core/children/child_structure.py +++ b/slither/core/children/child_structure.py @@ -5,11 +5,11 @@ if TYPE_CHECKING: class ChildStructure: - def __init__(self): + def __init__(self) -> None: super().__init__() self._structure = None - def set_structure(self, structure: "Structure"): + def set_structure(self, structure: "Structure") -> None: self._structure = structure @property diff --git a/slither/core/compilation_unit.py b/slither/core/compilation_unit.py index 2144d4c81..f54f08ab3 100644 --- a/slither/core/compilation_unit.py +++ b/slither/core/compilation_unit.py @@ -16,10 +16,10 @@ from slither.core.declarations import ( from slither.core.declarations.custom_error import CustomError from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel -from slither.core.declarations.using_for_top_level import UsingForTopLevel from slither.core.declarations.structure_top_level import StructureTopLevel -from slither.core.solidity_types.type_alias import TypeAliasTopLevel +from slither.core.declarations.using_for_top_level import UsingForTopLevel from slither.core.scope.scope import FileScope +from slither.core.solidity_types.type_alias import TypeAliasTopLevel from slither.core.variables.state_variable import StateVariable from slither.core.variables.top_level_variable import TopLevelVariable from slither.slithir.operations import InternalCall @@ -31,7 +31,7 @@ if TYPE_CHECKING: # pylint: disable=too-many-instance-attributes,too-many-public-methods class SlitherCompilationUnit(Context): - def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit): + def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit) -> None: super().__init__() self._core = core @@ -150,21 +150,21 @@ class SlitherCompilationUnit(Context): def functions(self) -> List[Function]: return list(self._all_functions) - def add_function(self, func: Function): + def add_function(self, func: Function) -> None: self._all_functions.add(func) @property def modifiers(self) -> List[Modifier]: return list(self._all_modifiers) - def add_modifier(self, modif: Modifier): + def add_modifier(self, modif: Modifier) -> None: self._all_modifiers.add(modif) @property def functions_and_modifiers(self) -> List[Function]: return self.functions + self.modifiers - def propagate_function_calls(self): + def propagate_function_calls(self) -> None: for f in self.functions_and_modifiers: for node in f.nodes: for ir in node.irs_ssa: @@ -256,7 +256,7 @@ class SlitherCompilationUnit(Context): ################################################################################### ################################################################################### - def compute_storage_layout(self): + def compute_storage_layout(self) -> None: for contract in self.contracts_derived: self._storage_layouts[contract.name] = {} diff --git a/slither/core/declarations/__init__.py b/slither/core/declarations/__init__.py index f891ad621..92e0b9eca 100644 --- a/slither/core/declarations/__init__.py +++ b/slither/core/declarations/__init__.py @@ -17,3 +17,4 @@ from .structure_contract import StructureContract from .structure_top_level import StructureTopLevel from .function_contract import FunctionContract from .function_top_level import FunctionTopLevel +from .custom_error_contract import CustomErrorContract diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index d1feebb05..2d2d10b04 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from pathlib import Path -from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union, Set +from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union, Set, Any from crytic_compile.platform import Type as PlatformType @@ -38,13 +38,13 @@ if TYPE_CHECKING: EnumContract, StructureContract, FunctionContract, + CustomErrorContract, ) from slither.slithir.variables.variable import SlithIRVariable - from slither.core.variables.variable import Variable - from slither.core.variables.state_variable import StateVariable + from slither.core.variables import Variable, StateVariable from slither.core.compilation_unit import SlitherCompilationUnit - from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.scope.scope import FileScope + from slither.core.cfg.node import Node LOGGER = logging.getLogger("Contract") @@ -55,7 +55,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods Contract class """ - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__() self._name: Optional[str] = None @@ -366,7 +366,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return list(self._variables_ordered) - def add_variables_ordered(self, new_vars: List["StateVariable"]): + def add_variables_ordered(self, new_vars: List["StateVariable"]) -> None: self._variables_ordered += new_vars @property @@ -534,7 +534,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def add_function(self, func: "FunctionContract"): self._functions[func.canonical_name] = func - def set_functions(self, functions: Dict[str, "FunctionContract"]): + def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None: """ Set the functions @@ -578,7 +578,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def available_modifiers_as_dict(self) -> Dict[str, "Modifier"]: return {m.full_name: m for m in self._modifiers.values() if not m.is_shadowed} - def set_modifiers(self, modifiers: Dict[str, "Modifier"]): + def set_modifiers(self, modifiers: Dict[str, "Modifier"]) -> None: """ Set the modifiers @@ -688,7 +688,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods inheritance: List["Contract"], immediate_inheritance: List["Contract"], called_base_constructor_contracts: List["Contract"], - ): + ) -> None: self._inheritance = inheritance self._immediate_inheritance = immediate_inheritance self._explicit_base_constructor_calls = called_base_constructor_contracts @@ -803,23 +803,25 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return next((v for v in self.state_variables if v.name == canonical_name), None) - def get_structure_from_name(self, structure_name: str) -> Optional["Structure"]: + def get_structure_from_name(self, structure_name: str) -> Optional["StructureContract"]: """ Return a structure from a name Args: structure_name (str): name of the structure Returns: - Structure + StructureContract """ return next((st for st in self.structures if st.name == structure_name), None) - def get_structure_from_canonical_name(self, structure_name: str) -> Optional["Structure"]: + def get_structure_from_canonical_name( + self, structure_name: str + ) -> Optional["StructureContract"]: """ Return a structure from a canonical name Args: structure_name (str): canonical name of the structure Returns: - Structure + StructureContract """ return next((st for st in self.structures if st.canonical_name == structure_name), None) @@ -1216,7 +1218,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def update_read_write_using_ssa(self): + def update_read_write_using_ssa(self) -> None: for function in self.functions + self.modifiers: function.update_read_write_using_ssa() @@ -1311,7 +1313,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def is_incorrectly_constructed(self, incorrect: bool): self._is_incorrectly_parsed = incorrect - def add_constructor_variables(self): + def add_constructor_variables(self) -> None: from slither.core.declarations.function_contract import FunctionContract if self.state_variables: @@ -1380,7 +1382,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def _create_node( self, func: Function, counter: int, variable: "Variable", scope: Union[Scope, Function] - ): + ) -> "Node": from slither.core.cfg.node import Node, NodeType from slither.core.expressions import ( AssignmentOperationType, @@ -1412,7 +1414,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def convert_expression_to_slithir_ssa(self): + def convert_expression_to_slithir_ssa(self) -> None: """ Assume generate_slithir_and_analyze was called on all functions @@ -1437,7 +1439,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods for func in self.functions + self.modifiers: func.generate_slithir_ssa(all_ssa_state_variables_instances) - def fix_phi(self): + def fix_phi(self) -> None: last_state_variables_instances = {} initial_state_variables_instances = {} for v in self._initial_state_variables: @@ -1459,20 +1461,20 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def __eq__(self, other): + def __eq__(self, other: SourceMapping) -> bool: if isinstance(other, str): return other == self.name return NotImplemented - def __neq__(self, other): + def __neq__(self, other: Any) -> bool: if isinstance(other, str): return other != self.name return NotImplemented - def __str__(self): + def __str__(self) -> str: return self.name - def __hash__(self): + def __hash__(self) -> int: return self._id # endregion diff --git a/slither/core/declarations/custom_error.py b/slither/core/declarations/custom_error.py index a1a689fcc..5e851c8da 100644 --- a/slither/core/declarations/custom_error.py +++ b/slither/core/declarations/custom_error.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: class CustomError(SourceMapping): - def __init__(self, compilation_unit: "SlitherCompilationUnit"): + def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: super().__init__() self._name: str = "" self._parameters: List[LocalVariable] = [] @@ -30,7 +30,7 @@ class CustomError(SourceMapping): def parameters(self) -> List[LocalVariable]: return self._parameters - def add_parameters(self, p: "LocalVariable"): + def add_parameters(self, p: "LocalVariable") -> None: self._parameters.append(p) @property @@ -42,7 +42,7 @@ class CustomError(SourceMapping): ################################################################################### @staticmethod - def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]): + def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) -> str: # pylint: disable=import-outside-toplevel from slither.core.declarations import Contract @@ -92,5 +92,5 @@ class CustomError(SourceMapping): ################################################################################### ################################################################################### - def __str__(self): + def __str__(self) -> str: return "revert " + self.solidity_signature diff --git a/slither/core/declarations/enum.py b/slither/core/declarations/enum.py index c53c1c38d..66a02fd11 100644 --- a/slither/core/declarations/enum.py +++ b/slither/core/declarations/enum.py @@ -4,7 +4,7 @@ from slither.core.source_mapping.source_mapping import SourceMapping class Enum(SourceMapping): - def __init__(self, name: str, canonical_name: str, values: List[str]): + def __init__(self, name: str, canonical_name: str, values: List[str]) -> None: super().__init__() self._name = name self._canonical_name = canonical_name @@ -33,5 +33,5 @@ class Enum(SourceMapping): def max(self) -> int: return self._max - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/core/declarations/enum_top_level.py b/slither/core/declarations/enum_top_level.py index 2546b376b..2a94c5bb1 100644 --- a/slither/core/declarations/enum_top_level.py +++ b/slither/core/declarations/enum_top_level.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: class EnumTopLevel(Enum, TopLevel): - def __init__(self, name: str, canonical_name: str, values: List[str], scope: "FileScope"): + def __init__( + self, name: str, canonical_name: str, values: List[str], scope: "FileScope" + ) -> None: super().__init__(name, canonical_name, values) self.file_scope: "FileScope" = scope diff --git a/slither/core/declarations/event.py b/slither/core/declarations/event.py index 9e445ee74..d616679a2 100644 --- a/slither/core/declarations/event.py +++ b/slither/core/declarations/event.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: class Event(ChildContract, SourceMapping): - def __init__(self): + def __init__(self) -> None: super().__init__() self._name = None self._elems: List[EventVariable] = [] @@ -59,5 +59,5 @@ class Event(ChildContract, SourceMapping): """ return self.contract == contract - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index c4f4809f6..c383fc99b 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -6,7 +6,7 @@ from abc import abstractmethod, ABCMeta from collections import namedtuple from enum import Enum from itertools import groupby -from typing import Dict, TYPE_CHECKING, List, Optional, Set, Union, Callable, Tuple +from typing import Any, Dict, TYPE_CHECKING, List, Optional, Set, Union, Callable, Tuple from slither.core.cfg.scope import Scope from slither.core.declarations.solidity_variables import ( @@ -27,6 +27,7 @@ from slither.core.variables.state_variable import StateVariable from slither.utils.type import convert_type_for_solidity_signature_to_string from slither.utils.utils import unroll + # pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines if TYPE_CHECKING: @@ -45,6 +46,8 @@ if TYPE_CHECKING: from slither.slithir.operations import Operation from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope + from slither.slithir.variables.state_variable import StateIRVariable + from slither.core.declarations.function_contract import FunctionContract LOGGER = logging.getLogger("Function") ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) @@ -56,7 +59,7 @@ class ModifierStatements: modifier: Union["Contract", "Function"], entry_point: "Node", nodes: List["Node"], - ): + ) -> None: self._modifier = modifier self._entry_point = entry_point self._nodes = nodes @@ -116,7 +119,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu Function class """ - def __init__(self, compilation_unit: "SlitherCompilationUnit"): + def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: super().__init__() self._internal_scope: List[str] = [] self._name: Optional[str] = None @@ -295,7 +298,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def contains_assembly(self, c: bool): self._contains_assembly = c - def can_reenter(self, callstack=None) -> bool: + def can_reenter(self, callstack: Optional[List["FunctionContract"]] = None) -> bool: """ Check if the function can re-enter Follow internal calls. @@ -370,7 +373,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### ################################################################################### - def set_function_type(self, t: FunctionType): + def set_function_type(self, t: FunctionType) -> None: assert isinstance(t, FunctionType) self._function_type = t @@ -455,7 +458,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def visibility(self, v: str): self._visibility = v - def set_visibility(self, v: str): + def set_visibility(self, v: str) -> None: self._visibility = v @property @@ -554,7 +557,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def entry_point(self, node: "Node"): self._entry_point = node - def add_node(self, node: "Node"): + def add_node(self, node: "Node") -> None: if not self._entry_point: self._entry_point = node self._nodes.append(node) @@ -598,7 +601,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return list(self._parameters) - def add_parameters(self, p: "LocalVariable"): + def add_parameters(self, p: "LocalVariable") -> None: self._parameters.append(p) @property @@ -608,7 +611,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return list(self._parameters_ssa) - def add_parameter_ssa(self, var: "LocalIRVariable"): + def add_parameter_ssa(self, var: "LocalIRVariable") -> None: self._parameters_ssa.append(var) def parameters_src(self) -> SourceMapping: @@ -651,7 +654,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return list(self._returns) - def add_return(self, r: "LocalVariable"): + def add_return(self, r: "LocalVariable") -> None: self._returns.append(r) @property @@ -661,7 +664,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return list(self._returns_ssa) - def add_return_ssa(self, var: "LocalIRVariable"): + def add_return_ssa(self, var: "LocalIRVariable") -> None: self._returns_ssa.append(var) # endregion @@ -680,7 +683,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return [c.modifier for c in self._modifiers] - def add_modifier(self, modif: "ModifierStatements"): + def add_modifier(self, modif: "ModifierStatements") -> None: self._modifiers.append(modif) @property @@ -714,7 +717,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu # This is a list of contracts internally, so we convert it to a list of constructor functions. return list(self._explicit_base_constructor_calls) - def add_explicit_base_constructor_calls_statements(self, modif: ModifierStatements): + def add_explicit_base_constructor_calls_statements(self, modif: ModifierStatements) -> None: self._explicit_base_constructor_calls.append(modif) # endregion @@ -1057,7 +1060,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu self._all_reachable_from_functions = functions return self._all_reachable_from_functions - def add_reachable_from_node(self, n: "Node", ir: "Operation"): + def add_reachable_from_node(self, n: "Node", ir: "Operation") -> None: self._reachable_from_nodes.add(ReacheableNode(n, ir)) self._reachable_from_functions.add(n.function) @@ -1068,7 +1071,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### ################################################################################### - def _explore_functions(self, f_new_values: Callable[["Function"], List]): + def _explore_functions(self, f_new_values: Callable[["Function"], List]) -> List[Any]: values = f_new_values(self) explored = [self] to_explore = [ @@ -1218,11 +1221,13 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu func: "Function", f: Callable[["Node"], List[SolidityVariable]], include_loop: bool, - ): + ) -> List[Any]: ret = [f(n) for n in func.nodes if n.is_conditional(include_loop)] return [item for sublist in ret for item in sublist] - def all_conditional_solidity_variables_read(self, include_loop=True) -> List[SolidityVariable]: + def all_conditional_solidity_variables_read( + self, include_loop: bool = True + ) -> List[SolidityVariable]: """ Return the Soldiity variables directly used in a condtion @@ -1258,7 +1263,9 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu return [var for var in ret if isinstance(var, SolidityVariable)] @staticmethod - def _explore_func_nodes(func: "Function", f: Callable[["Node"], List[SolidityVariable]]): + def _explore_func_nodes( + func: "Function", f: Callable[["Node"], List[SolidityVariable]] + ) -> List[Union[Any, SolidityVariableComposed]]: ret = [f(n) for n in func.nodes] return [item for sublist in ret for item in sublist] @@ -1367,7 +1374,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu with open(filename, "w", encoding="utf8") as f: f.write(content) - def slithir_cfg_to_dot_str(self, skip_expressions=False) -> str: + def slithir_cfg_to_dot_str(self, skip_expressions: bool = False) -> str: """ Export the CFG to a DOT format. The nodes includes the Solidity expressions and the IRs :return: the DOT content @@ -1512,7 +1519,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### ################################################################################### - def _analyze_read_write(self): + def _analyze_read_write(self) -> None: """Compute variables read/written/...""" write_var = [x.variables_written_as_expression for x in self.nodes] write_var = [x for x in write_var if x] @@ -1570,7 +1577,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu slithir_variables = [x for x in slithir_variables if x] self._slithir_variables = [item for sublist in slithir_variables for item in sublist] - def _analyze_calls(self): + def _analyze_calls(self) -> None: calls = [x.calls_as_expression for x in self.nodes] calls = [x for x in calls if x] calls = [item for sublist in calls for item in sublist] @@ -1702,7 +1709,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu return self._get_last_ssa_variable_instances(target_state=False, target_local=True) @staticmethod - def _unchange_phi(ir: "Operation"): + def _unchange_phi(ir: "Operation") -> bool: from slither.slithir.operations import Phi, PhiCallback if not isinstance(ir, (Phi, PhiCallback)) or len(ir.rvalues) > 1: @@ -1711,7 +1718,11 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu return True return ir.rvalues[0] == ir.lvalue - def fix_phi(self, last_state_variables_instances, initial_state_variables_instances): + def fix_phi( + self, + last_state_variables_instances: Dict[str, List["StateIRVariable"]], + initial_state_variables_instances: Dict[str, "StateIRVariable"], + ) -> None: from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.variables import Constant, StateIRVariable @@ -1745,7 +1756,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)] - def generate_slithir_and_analyze(self): + def generate_slithir_and_analyze(self) -> None: for node in self.nodes: node.slithir_generation() @@ -1756,7 +1767,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def generate_slithir_ssa(self, all_ssa_state_variables_instances): pass - def update_read_write_using_ssa(self): + def update_read_write_using_ssa(self) -> None: for node in self.nodes: node.update_read_write_using_ssa() self._analyze_read_write() @@ -1767,7 +1778,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### ################################################################################### - def __str__(self): + def __str__(self) -> str: return self.name # endregion diff --git a/slither/core/declarations/function_contract.py b/slither/core/declarations/function_contract.py index 2235bd227..19456bbea 100644 --- a/slither/core/declarations/function_contract.py +++ b/slither/core/declarations/function_contract.py @@ -1,17 +1,19 @@ """ Function module """ -from typing import TYPE_CHECKING, List, Tuple +from typing import Dict, TYPE_CHECKING, List, Tuple from slither.core.children.child_contract import ChildContract from slither.core.children.child_inheritance import ChildInheritance from slither.core.declarations import Function + # pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines if TYPE_CHECKING: from slither.core.declarations import Contract from slither.core.scope.scope import FileScope + from slither.slithir.variables.state_variable import StateIRVariable class FunctionContract(Function, ChildContract, ChildInheritance): @@ -96,7 +98,9 @@ class FunctionContract(Function, ChildContract, ChildInheritance): ################################################################################### ################################################################################### - def generate_slithir_ssa(self, all_ssa_state_variables_instances): + def generate_slithir_ssa( + self, all_ssa_state_variables_instances: Dict[str, "StateIRVariable"] + ) -> None: from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa from slither.core.dominators.utils import ( compute_dominance_frontier, diff --git a/slither/core/declarations/function_top_level.py b/slither/core/declarations/function_top_level.py index d71033069..407a8d045 100644 --- a/slither/core/declarations/function_top_level.py +++ b/slither/core/declarations/function_top_level.py @@ -1,7 +1,7 @@ """ Function module """ -from typing import List, Tuple, TYPE_CHECKING +from typing import Dict, List, Tuple, TYPE_CHECKING from slither.core.declarations import Function from slither.core.declarations.top_level import TopLevel @@ -9,10 +9,11 @@ from slither.core.declarations.top_level import TopLevel if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope + from slither.slithir.variables.state_variable import StateIRVariable class FunctionTopLevel(Function, TopLevel): - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__(compilation_unit) self._scope: "FileScope" = scope @@ -78,7 +79,9 @@ class FunctionTopLevel(Function, TopLevel): ################################################################################### ################################################################################### - def generate_slithir_ssa(self, all_ssa_state_variables_instances): + def generate_slithir_ssa( + self, all_ssa_state_variables_instances: Dict[str, "StateIRVariable"] + ) -> None: # pylint: disable=import-outside-toplevel from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa from slither.core.dominators.utils import ( diff --git a/slither/core/declarations/import_directive.py b/slither/core/declarations/import_directive.py index 745f8007f..75c0406fe 100644 --- a/slither/core/declarations/import_directive.py +++ b/slither/core/declarations/import_directive.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: class Import(SourceMapping): - def __init__(self, filename: Path, scope: "FileScope"): + def __init__(self, filename: Path, scope: "FileScope") -> None: super().__init__() self._filename: Path = filename self._alias: Optional[str] = None diff --git a/slither/core/declarations/pragma_directive.py b/slither/core/declarations/pragma_directive.py index 602dab6b2..cd790d5a4 100644 --- a/slither/core/declarations/pragma_directive.py +++ b/slither/core/declarations/pragma_directive.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: class Pragma(SourceMapping): - def __init__(self, directive: List[str], scope: "FileScope"): + def __init__(self, directive: List[str], scope: "FileScope") -> None: super().__init__() self._directive = directive self.scope: "FileScope" = scope @@ -39,5 +39,5 @@ class Pragma(SourceMapping): return self._directive[0] == "experimental" and self._directive[1] == "ABIEncoderV2" return False - def __str__(self): + def __str__(self) -> str: return "pragma " + "".join(self.directive) diff --git a/slither/core/declarations/solidity_import_placeholder.py b/slither/core/declarations/solidity_import_placeholder.py index 070d3fff3..8f6360086 100644 --- a/slither/core/declarations/solidity_import_placeholder.py +++ b/slither/core/declarations/solidity_import_placeholder.py @@ -1,7 +1,11 @@ """ Special variable to model import with renaming """ +from typing import Union + from slither.core.declarations import Import +from slither.core.declarations.contract import Contract +from slither.core.declarations.solidity_variables import SolidityVariable from slither.core.solidity_types import ElementaryType from slither.core.variables.variable import Variable @@ -13,7 +17,7 @@ class SolidityImportPlaceHolder(Variable): In the long term we should remove this and better integrate import aliases """ - def __init__(self, import_directive: Import): + def __init__(self, import_directive: Import) -> None: super().__init__() assert import_directive.alias is not None self._import_directive = import_directive @@ -27,7 +31,7 @@ class SolidityImportPlaceHolder(Variable): def type(self) -> ElementaryType: return ElementaryType("string") - def __eq__(self, other): + def __eq__(self, other: Union[Contract, SolidityVariable]) -> bool: return ( self.__class__ == other.__class__ and self._import_directive.filename == self._import_directive.filename diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index 3a5db010c..9569cde93 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -1,13 +1,11 @@ # https://solidity.readthedocs.io/en/v0.4.24/units-and-global-variables.html -from typing import List, Dict, Union, TYPE_CHECKING +from typing import List, Dict, Union, Any from slither.core.declarations.custom_error import CustomError from slither.core.solidity_types import ElementaryType, TypeInformation from slither.core.source_mapping.source_mapping import SourceMapping from slither.exceptions import SlitherException -if TYPE_CHECKING: - pass SOLIDITY_VARIABLES = { "now": "uint256", @@ -98,13 +96,13 @@ def solidity_function_signature(name): class SolidityVariable(SourceMapping): - def __init__(self, name: str): + def __init__(self, name: str) -> None: super().__init__() self._check_name(name) self._name = name # dev function, will be removed once the code is stable - def _check_name(self, name: str): # pylint: disable=no-self-use + def _check_name(self, name: str) -> None: # pylint: disable=no-self-use assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset")) @property @@ -124,18 +122,18 @@ class SolidityVariable(SourceMapping): def type(self) -> ElementaryType: return ElementaryType(SOLIDITY_VARIABLES[self.name]) - def __str__(self): + def __str__(self) -> str: return self._name - def __eq__(self, other): + def __eq__(self, other: SourceMapping) -> bool: return self.__class__ == other.__class__ and self.name == other.name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) class SolidityVariableComposed(SolidityVariable): - def _check_name(self, name: str): + def _check_name(self, name: str) -> None: assert name in SOLIDITY_VARIABLES_COMPOSED @property @@ -146,13 +144,13 @@ class SolidityVariableComposed(SolidityVariable): def type(self) -> ElementaryType: return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name]) - def __str__(self): + def __str__(self) -> str: return self._name - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.__class__ == other.__class__ and self.name == other.name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) @@ -162,7 +160,7 @@ class SolidityFunction(SourceMapping): # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information # As a result, we set return_type during the Ir conversion - def __init__(self, name: str): + def __init__(self, name: str) -> None: super().__init__() assert name in SOLIDITY_FUNCTIONS self._name = name @@ -187,28 +185,28 @@ class SolidityFunction(SourceMapping): def return_type(self, r: List[Union[TypeInformation, ElementaryType]]): self._return_type = r - def __str__(self): + def __str__(self) -> str: return self._name - def __eq__(self, other): + def __eq__(self, other: "SolidityFunction") -> bool: return self.__class__ == other.__class__ and self.name == other.name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) class SolidityCustomRevert(SolidityFunction): - def __init__(self, custom_error: CustomError): # pylint: disable=super-init-not-called + def __init__(self, custom_error: CustomError) -> None: # pylint: disable=super-init-not-called self._name = "revert " + custom_error.solidity_signature self._custom_error = custom_error self._return_type: List[Union[TypeInformation, ElementaryType]] = [] - def __eq__(self, other): + def __eq__(self, other: Union["SolidityCustomRevert", SolidityFunction]) -> bool: return ( self.__class__ == other.__class__ and self.name == other.name and self._custom_error == other._custom_error ) - def __hash__(self): + def __hash__(self) -> int: return hash(hash(self.name) + hash(self._custom_error)) diff --git a/slither/core/declarations/structure_top_level.py b/slither/core/declarations/structure_top_level.py index f06cd2318..f4e2b8a9c 100644 --- a/slither/core/declarations/structure_top_level.py +++ b/slither/core/declarations/structure_top_level.py @@ -9,6 +9,6 @@ if TYPE_CHECKING: class StructureTopLevel(Structure, TopLevel): - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__(compilation_unit) self.file_scope: "FileScope" = scope diff --git a/slither/core/declarations/using_for_top_level.py b/slither/core/declarations/using_for_top_level.py index a1b43e1c1..27d1f90e4 100644 --- a/slither/core/declarations/using_for_top_level.py +++ b/slither/core/declarations/using_for_top_level.py @@ -8,11 +8,11 @@ if TYPE_CHECKING: class UsingForTopLevel(TopLevel): - def __init__(self, scope: "FileScope"): + def __init__(self, scope: "FileScope") -> None: super().__init__() self._using_for: Dict[Union[str, Type], List[Type]] = {} self.file_scope: "FileScope" = scope @property - def using_for(self) -> Dict[Type, List[Type]]: + def using_for(self) -> Dict[Union[str, Type], List[Type]]: return self._using_for diff --git a/slither/core/dominators/utils.py b/slither/core/dominators/utils.py index 837fe46ea..ca5c51282 100644 --- a/slither/core/dominators/utils.py +++ b/slither/core/dominators/utils.py @@ -1,4 +1,4 @@ -from typing import List, TYPE_CHECKING +from typing import Set, List, TYPE_CHECKING from slither.core.cfg.node import NodeType @@ -6,7 +6,7 @@ if TYPE_CHECKING: from slither.core.cfg.node import Node -def intersection_predecessor(node: "Node"): +def intersection_predecessor(node: "Node") -> Set["Node"]: if not node.fathers: return set() ret = node.fathers[0].dominators @@ -15,7 +15,7 @@ def intersection_predecessor(node: "Node"): return ret -def _compute_dominators(nodes: List["Node"]): +def _compute_dominators(nodes: List["Node"]) -> None: changed = True while changed: @@ -28,7 +28,7 @@ def _compute_dominators(nodes: List["Node"]): changed = True -def _compute_immediate_dominators(nodes: List["Node"]): +def _compute_immediate_dominators(nodes: List["Node"]) -> None: for node in nodes: idom_candidates = set(node.dominators) idom_candidates.remove(node) @@ -58,7 +58,7 @@ def _compute_immediate_dominators(nodes: List["Node"]): idom.dominator_successors.add(node) -def compute_dominators(nodes: List["Node"]): +def compute_dominators(nodes: List["Node"]) -> None: """ Naive implementation of Cooper, Harvey, Kennedy algo See 'A Simple,Fast Dominance Algorithm' @@ -74,7 +74,7 @@ def compute_dominators(nodes: List["Node"]): _compute_immediate_dominators(nodes) -def compute_dominance_frontier(nodes: List["Node"]): +def compute_dominance_frontier(nodes: List["Node"]) -> None: """ Naive implementation of Cooper, Harvey, Kennedy algo See 'A Simple,Fast Dominance Algorithm' diff --git a/slither/core/expressions/assignment_operation.py b/slither/core/expressions/assignment_operation.py index b5fd3f4a3..22aba57fb 100644 --- a/slither/core/expressions/assignment_operation.py +++ b/slither/core/expressions/assignment_operation.py @@ -85,7 +85,7 @@ class AssignmentOperation(ExpressionTyped): right_expression: Expression, expression_type: AssignmentOperationType, expression_return_type: Optional["Type"], - ): + ) -> None: assert isinstance(left_expression, Expression) assert isinstance(right_expression, Expression) super().__init__() diff --git a/slither/core/expressions/call_expression.py b/slither/core/expressions/call_expression.py index 8ab6668fe..1dbc4074a 100644 --- a/slither/core/expressions/call_expression.py +++ b/slither/core/expressions/call_expression.py @@ -1,10 +1,10 @@ -from typing import Optional, List +from typing import Any, Optional, List from slither.core.expressions.expression import Expression class CallExpression(Expression): # pylint: disable=too-many-instance-attributes - def __init__(self, called, arguments, type_call): + def __init__(self, called: Expression, arguments: List[Any], type_call: str) -> None: assert isinstance(called, Expression) super().__init__() self._called: Expression = called @@ -53,7 +53,7 @@ class CallExpression(Expression): # pylint: disable=too-many-instance-attribute def type_call(self) -> str: return self._type_call - def __str__(self): + def __str__(self) -> str: txt = str(self._called) if self.call_gas or self.call_value: gas = f"gas: {self.call_gas}" if self.call_gas else "" diff --git a/slither/core/expressions/conditional_expression.py b/slither/core/expressions/conditional_expression.py index adcf8bb1f..818425ba1 100644 --- a/slither/core/expressions/conditional_expression.py +++ b/slither/core/expressions/conditional_expression.py @@ -1,10 +1,23 @@ -from typing import List +from typing import Union, List -from .expression import Expression +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.expression import Expression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.literal import Literal +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation class ConditionalExpression(Expression): - def __init__(self, if_expression, then_expression, else_expression): + def __init__( + self, + if_expression: Union[BinaryOperation, Identifier, Literal], + then_expression: Union[ + "ConditionalExpression", TypeConversion, Literal, TupleExpression, Identifier + ], + else_expression: Union[TupleExpression, UnaryOperation, Identifier, Literal], + ) -> None: assert isinstance(if_expression, Expression) assert isinstance(then_expression, Expression) assert isinstance(else_expression, Expression) diff --git a/slither/core/expressions/elementary_type_name_expression.py b/slither/core/expressions/elementary_type_name_expression.py index 0a310c86a..9a93f0839 100644 --- a/slither/core/expressions/elementary_type_name_expression.py +++ b/slither/core/expressions/elementary_type_name_expression.py @@ -3,10 +3,11 @@ """ from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type +from slither.core.solidity_types.elementary_type import ElementaryType class ElementaryTypeNameExpression(Expression): - def __init__(self, t): + def __init__(self, t: ElementaryType) -> None: assert isinstance(t, Type) super().__init__() self._type = t @@ -20,5 +21,5 @@ class ElementaryTypeNameExpression(Expression): assert isinstance(new_type, Type) self._type = new_type - def __str__(self): + def __str__(self) -> str: return str(self._type) diff --git a/slither/core/expressions/index_access.py b/slither/core/expressions/index_access.py index 49b2a8dd9..4f96a56d6 100644 --- a/slither/core/expressions/index_access.py +++ b/slither/core/expressions/index_access.py @@ -1,6 +1,8 @@ -from typing import List, TYPE_CHECKING +from typing import Union, List, TYPE_CHECKING from slither.core.expressions.expression_typed import ExpressionTyped +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.literal import Literal if TYPE_CHECKING: @@ -9,7 +11,12 @@ if TYPE_CHECKING: class IndexAccess(ExpressionTyped): - def __init__(self, left_expression, right_expression, index_type): + def __init__( + self, + left_expression: Union["IndexAccess", Identifier], + right_expression: Union[Literal, Identifier], + index_type: str, + ) -> None: super().__init__() self._expressions = [left_expression, right_expression] # TODO type of undexAccess is not always a Type @@ -32,5 +39,5 @@ class IndexAccess(ExpressionTyped): def type(self) -> "Type": return self._type - def __str__(self): + def __str__(self) -> str: return str(self.expression_left) + "[" + str(self.expression_right) + "]" diff --git a/slither/core/expressions/literal.py b/slither/core/expressions/literal.py index 2eaeb715d..5dace3c41 100644 --- a/slither/core/expressions/literal.py +++ b/slither/core/expressions/literal.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: class Literal(Expression): def __init__( self, value: Union[int, str], custom_type: "Type", subdenomination: Optional[str] = None - ): + ) -> None: super().__init__() self._value = value self._type = custom_type diff --git a/slither/core/expressions/member_access.py b/slither/core/expressions/member_access.py index 73bd92641..36d6818b2 100644 --- a/slither/core/expressions/member_access.py +++ b/slither/core/expressions/member_access.py @@ -5,7 +5,7 @@ from slither.core.solidity_types.type import Type class MemberAccess(ExpressionTyped): - def __init__(self, member_name, member_type, expression): + def __init__(self, member_name: str, member_type: str, expression: Expression) -> None: # assert isinstance(member_type, Type) # TODO member_type is not always a Type assert isinstance(expression, Expression) @@ -26,5 +26,5 @@ class MemberAccess(ExpressionTyped): def type(self) -> Type: return self._type - def __str__(self): + def __str__(self) -> str: return str(self.expression) + "." + self.member_name diff --git a/slither/core/expressions/new_array.py b/slither/core/expressions/new_array.py index 59c5780d4..162b48f1e 100644 --- a/slither/core/expressions/new_array.py +++ b/slither/core/expressions/new_array.py @@ -1,11 +1,20 @@ +from typing import Union, TYPE_CHECKING + + from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.solidity_types.elementary_type import ElementaryType + from slither.core.solidity_types.type_alias import TypeAliasTopLevel + class NewArray(Expression): # note: dont conserve the size of the array if provided - def __init__(self, depth, array_type): + def __init__( + self, depth: int, array_type: Union["TypeAliasTopLevel", "ElementaryType"] + ) -> None: super().__init__() assert isinstance(array_type, Type) self._depth: int = depth diff --git a/slither/core/expressions/new_contract.py b/slither/core/expressions/new_contract.py index 762a3dcfe..70f930aad 100644 --- a/slither/core/expressions/new_contract.py +++ b/slither/core/expressions/new_contract.py @@ -2,7 +2,7 @@ from slither.core.expressions.expression import Expression class NewContract(Expression): - def __init__(self, contract_name): + def __init__(self, contract_name: str) -> None: super().__init__() self._contract_name: str = contract_name self._gas = None @@ -29,5 +29,5 @@ class NewContract(Expression): def call_salt(self, salt): self._salt = salt - def __str__(self): + def __str__(self) -> str: return "new " + str(self._contract_name) diff --git a/slither/core/expressions/type_conversion.py b/slither/core/expressions/type_conversion.py index 97c70f45b..b9cd6879e 100644 --- a/slither/core/expressions/type_conversion.py +++ b/slither/core/expressions/type_conversion.py @@ -1,10 +1,27 @@ +from typing import Union, TYPE_CHECKING + from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.expressions.call_expression import CallExpression + from slither.core.expressions.identifier import Identifier + from slither.core.expressions.literal import Literal + from slither.core.expressions.member_access import MemberAccess + from slither.core.solidity_types.elementary_type import ElementaryType + from slither.core.solidity_types.type_alias import TypeAliasContract + from slither.core.solidity_types.user_defined_type import UserDefinedType + class TypeConversion(ExpressionTyped): - def __init__(self, expression, expression_type): + def __init__( + self, + expression: Union[ + "MemberAccess", "Literal", "CallExpression", "TypeConversion", "Identifier" + ], + expression_type: Union["ElementaryType", "UserDefinedType", "TypeAliasContract"], + ) -> None: super().__init__() assert isinstance(expression, Expression) assert isinstance(expression_type, Type) @@ -15,5 +32,5 @@ class TypeConversion(ExpressionTyped): def expression(self) -> Expression: return self._expression - def __str__(self): + def __str__(self) -> str: return str(self.type) + "(" + str(self.expression) + ")" diff --git a/slither/core/expressions/unary_operation.py b/slither/core/expressions/unary_operation.py index 596d9dbf0..a04c57591 100644 --- a/slither/core/expressions/unary_operation.py +++ b/slither/core/expressions/unary_operation.py @@ -1,9 +1,15 @@ import logging +from typing import Union from enum import Enum from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.tuple_expression import TupleExpression + logger = logging.getLogger("UnaryOperation") @@ -20,7 +26,7 @@ class UnaryOperationType(Enum): MINUS_PRE = 8 # for stuff like uint(-1) @staticmethod - def get_type(operation_type, isprefix): + def get_type(operation_type: str, isprefix: bool) -> "UnaryOperationType": if isprefix: if operation_type == "!": return UnaryOperationType.BANG @@ -43,7 +49,7 @@ class UnaryOperationType(Enum): return UnaryOperationType.MINUSMINUS_POST raise SlitherCoreError(f"get_type: Unknown operation type {operation_type}") - def __str__(self): + def __str__(self) -> str: if self == UnaryOperationType.BANG: return "!" if self == UnaryOperationType.TILD: @@ -65,7 +71,7 @@ class UnaryOperationType(Enum): raise SlitherCoreError(f"str: Unknown operation type {self}") @staticmethod - def is_prefix(operation_type): + def is_prefix(operation_type: "UnaryOperationType") -> bool: if operation_type in [ UnaryOperationType.BANG, UnaryOperationType.TILD, @@ -86,7 +92,11 @@ class UnaryOperationType(Enum): class UnaryOperation(ExpressionTyped): - def __init__(self, expression, expression_type): + def __init__( + self, + expression: Union[Literal, Identifier, IndexAccess, TupleExpression], + expression_type: UnaryOperationType, + ) -> None: assert isinstance(expression, Expression) super().__init__() self._expression: Expression = expression @@ -114,7 +124,7 @@ class UnaryOperation(ExpressionTyped): def is_prefix(self) -> bool: return UnaryOperationType.is_prefix(self._type) - def __str__(self): + def __str__(self) -> str: if self.is_prefix: return str(self.type) + " " + str(self._expression) return str(self._expression) + " " + str(self.type) diff --git a/slither/core/scope/scope.py b/slither/core/scope/scope.py index 1eb344c2b..cafeb3585 100644 --- a/slither/core/scope/scope.py +++ b/slither/core/scope/scope.py @@ -25,7 +25,7 @@ def _dict_contain(d1: Dict, d2: Dict) -> bool: # pylint: disable=too-many-instance-attributes class FileScope: - def __init__(self, filename: Filename): + def __init__(self, filename: Filename) -> None: self.filename = filename self.accessible_scopes: List[FileScope] = [] diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index 66b8fc430..e5f4e830a 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -8,7 +8,7 @@ import pathlib import posixpath import re from collections import defaultdict -from typing import Optional, Dict, List, Set, Union +from typing import Optional, Dict, List, Set, Union, Tuple from crytic_compile import CryticCompile from crytic_compile.utils.naming import Filename @@ -40,7 +40,7 @@ class SlitherCore(Context): Slither static analyzer """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._filename: Optional[str] = None @@ -73,8 +73,8 @@ class SlitherCore(Context): # Maps from file to detector name to the start/end ranges for that detector. # Infinity is used to signal a detector has no end range. - self._ignore_ranges: defaultdict[str, defaultdict[str, List[(int, int)]]] = defaultdict( - lambda: defaultdict(lambda: []) + self._ignore_ranges: Dict[str, Dict[str, List[Tuple[int, ...]]]] = defaultdict( + lambda: defaultdict(lambda: [(-1, -1)]) ) self._compilation_units: List[SlitherCompilationUnit] = [] @@ -443,7 +443,7 @@ class SlitherCore(Context): return True - def load_previous_results(self): + def load_previous_results(self) -> None: filename = self._previous_results_filename try: if os.path.isfile(filename): @@ -456,7 +456,7 @@ class SlitherCore(Context): except json.decoder.JSONDecodeError: logger.error(red(f"Impossible to decode {filename}. Consider removing the file")) - def write_results_to_hide(self): + def write_results_to_hide(self) -> None: if not self._results_to_hide: return filename = self._previous_results_filename @@ -464,7 +464,7 @@ class SlitherCore(Context): results = self._results_to_hide + self._previous_results json.dump(results, f) - def save_results_to_hide(self, results: List[Dict]): + def save_results_to_hide(self, results: List[Dict]) -> None: self._results_to_hide += results def add_path_to_filter(self, path: str): diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index 59a15dcc6..9a0b12c00 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -1,13 +1,24 @@ -from typing import Optional, Tuple +from typing import Union, Optional, Tuple, Any, TYPE_CHECKING -from slither.core.expressions import Literal from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type from slither.visitors.expression.constants_folding import ConstantFolding +from slither.core.expressions.literal import Literal + +if TYPE_CHECKING: + from slither.core.expressions.binary_operation import BinaryOperation + from slither.core.expressions.identifier import Identifier + from slither.core.solidity_types.elementary_type import ElementaryType + from slither.core.solidity_types.function_type import FunctionType + from slither.core.solidity_types.type_alias import TypeAliasTopLevel class ArrayType(Type): - def __init__(self, t, length): + def __init__( + self, + t: Union["TypeAliasTopLevel", "ArrayType", "FunctionType", "ElementaryType"], + length: Optional[Union["Identifier", Literal, "BinaryOperation", int]], + ) -> None: assert isinstance(t, Type) if length: if isinstance(length, int): @@ -56,15 +67,15 @@ class ArrayType(Type): return elem_size * int(str(self._length_value)), True return 32, True - def __str__(self): + def __str__(self) -> str: if self._length: return str(self._type) + f"[{str(self._length_value)}]" return str(self._type) + "[]" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, ArrayType): return False return self._type == other.type and self.length == other.length - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) diff --git a/slither/core/solidity_types/elementary_type.py b/slither/core/solidity_types/elementary_type.py index fc248e946..ec2b0ef04 100644 --- a/slither/core/solidity_types/elementary_type.py +++ b/slither/core/solidity_types/elementary_type.py @@ -216,13 +216,13 @@ class ElementaryType(Type): return MaxValues[self.name] raise SlitherException(f"{self.name} does not have a max value") - def __str__(self): + def __str__(self) -> str: return self._type - def __eq__(self, other): + def __eq__(self, other) -> bool: if not isinstance(other, ElementaryType): return False return self.type == other.type - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) diff --git a/slither/core/solidity_types/function_type.py b/slither/core/solidity_types/function_type.py index 3146aa0bf..2d644148e 100644 --- a/slither/core/solidity_types/function_type.py +++ b/slither/core/solidity_types/function_type.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Any from slither.core.solidity_types.type import Type from slither.core.variables.function_type_variable import FunctionTypeVariable @@ -9,7 +9,7 @@ class FunctionType(Type): self, params: List[FunctionTypeVariable], return_values: List[FunctionTypeVariable], - ): + ) -> None: assert all(isinstance(x, FunctionTypeVariable) for x in params) assert all(isinstance(x, FunctionTypeVariable) for x in return_values) super().__init__() @@ -68,7 +68,7 @@ class FunctionType(Type): return f"({params}) returns({return_values})" return f"({params})" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, FunctionType): return False return self.params == other.params and self.return_values == other.return_values diff --git a/slither/core/solidity_types/mapping_type.py b/slither/core/solidity_types/mapping_type.py index fea5bdb7b..a8acb4d9c 100644 --- a/slither/core/solidity_types/mapping_type.py +++ b/slither/core/solidity_types/mapping_type.py @@ -1,10 +1,18 @@ -from typing import Tuple +from typing import Union, Tuple, TYPE_CHECKING from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.solidity_types.elementary_type import ElementaryType + from slither.core.solidity_types.type_alias import TypeAliasTopLevel + class MappingType(Type): - def __init__(self, type_from, type_to): + def __init__( + self, + type_from: "ElementaryType", + type_to: Union["MappingType", "TypeAliasTopLevel", "ElementaryType"], + ) -> None: assert isinstance(type_from, Type) assert isinstance(type_to, Type) super().__init__() @@ -27,7 +35,7 @@ class MappingType(Type): def is_dynamic(self) -> bool: return True - def __str__(self): + def __str__(self) -> str: return f"mapping({str(self._from)} => {str(self._to)})" def __eq__(self, other): @@ -35,5 +43,5 @@ class MappingType(Type): return False return self.type_from == other.type_from and self.type_to == other.type_to - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 128d12597..5b9ea0a37 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Tuple from slither.core.children.child_contract import ChildContract from slither.core.declarations.top_level import TopLevel -from slither.core.solidity_types import Type +from slither.core.solidity_types import Type, ElementaryType if TYPE_CHECKING: from slither.core.declarations import Contract @@ -10,13 +10,13 @@ if TYPE_CHECKING: class TypeAlias(Type): - def __init__(self, underlying_type: Type, name: str): + def __init__(self, underlying_type: ElementaryType, name: str) -> None: super().__init__() self.name = name self.underlying_type = underlying_type @property - def type(self) -> Type: + def type(self) -> ElementaryType: """ Return the underlying type. Alias for underlying_type @@ -31,7 +31,7 @@ class TypeAlias(Type): def storage_size(self) -> Tuple[int, bool]: return self.underlying_type.storage_size - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) @property @@ -40,18 +40,18 @@ class TypeAlias(Type): class TypeAliasTopLevel(TypeAlias, TopLevel): - def __init__(self, underlying_type: Type, name: str, scope: "FileScope"): + def __init__(self, underlying_type: Type, name: str, scope: "FileScope") -> None: super().__init__(underlying_type, name) self.file_scope: "FileScope" = scope - def __str__(self): + def __str__(self) -> str: return self.name class TypeAliasContract(TypeAlias, ChildContract): - def __init__(self, underlying_type: Type, name: str, contract: "Contract"): + def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: super().__init__(underlying_type, name) self._contract: "Contract" = contract - def __str__(self): + def __str__(self) -> str: return self.contract.name + "." + self.name diff --git a/slither/core/solidity_types/type_information.py b/slither/core/solidity_types/type_information.py index 0477bb7e6..2af0b097a 100644 --- a/slither/core/solidity_types/type_information.py +++ b/slither/core/solidity_types/type_information.py @@ -1,16 +1,17 @@ -from typing import TYPE_CHECKING, Tuple +from typing import Union, TYPE_CHECKING, Tuple from slither.core.solidity_types import ElementaryType from slither.core.solidity_types.type import Type if TYPE_CHECKING: from slither.core.declarations.contract import Contract + from slither.core.declarations.enum import Enum # Use to model the Type(X) function, which returns an undefined type # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information class TypeInformation(Type): - def __init__(self, c): + def __init__(self, c: Union[ElementaryType, "Contract", "Enum"]) -> None: # pylint: disable=import-outside-toplevel from slither.core.declarations.contract import Contract from slither.core.declarations.enum import Enum @@ -20,7 +21,7 @@ class TypeInformation(Type): self._type = c @property - def type(self) -> "Contract": + def type(self) -> Union["Contract", ElementaryType, "Enum"]: return self._type @property diff --git a/slither/core/solidity_types/user_defined_type.py b/slither/core/solidity_types/user_defined_type.py index a977ab080..a9bbd40a2 100644 --- a/slither/core/solidity_types/user_defined_type.py +++ b/slither/core/solidity_types/user_defined_type.py @@ -1,4 +1,4 @@ -from typing import Union, TYPE_CHECKING, Tuple +from typing import Union, TYPE_CHECKING, Tuple, Any import math from slither.core.solidity_types.type import Type @@ -11,7 +11,7 @@ if TYPE_CHECKING: # pylint: disable=import-outside-toplevel class UserDefinedType(Type): - def __init__(self, t): + def __init__(self, t: Union["Enum", "Contract", "Structure"]) -> None: from slither.core.declarations.structure import Structure from slither.core.declarations.enum import Enum from slither.core.declarations.contract import Contract @@ -62,7 +62,7 @@ class UserDefinedType(Type): to_log = f"{self} does not have storage size" raise SlitherException(to_log) - def __str__(self): + def __str__(self) -> str: from slither.core.declarations.structure_contract import StructureContract from slither.core.declarations.enum_contract import EnumContract @@ -71,7 +71,7 @@ class UserDefinedType(Type): return str(type_used.contract) + "." + str(type_used.name) return str(type_used.name) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: from slither.core.declarations.contract import Contract if not isinstance(other, UserDefinedType): @@ -80,5 +80,5 @@ class UserDefinedType(Type): return self.type == other.type.name return self.type == other.type - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) diff --git a/slither/core/variables/__init__.py b/slither/core/variables/__init__.py index e69de29bb..638f0f3a4 100644 --- a/slither/core/variables/__init__.py +++ b/slither/core/variables/__init__.py @@ -0,0 +1,2 @@ +from .state_variable import StateVariable +from .variable import Variable diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index ca2f40570..f3ad60d0b 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -3,7 +3,7 @@ from slither.core.children.child_event import ChildEvent class EventVariable(ChildEvent, Variable): - def __init__(self): + def __init__(self) -> None: super().__init__() self._indexed = False diff --git a/slither/core/variables/local_variable_init_from_tuple.py b/slither/core/variables/local_variable_init_from_tuple.py index 86a7cbbc2..8d584b373 100644 --- a/slither/core/variables/local_variable_init_from_tuple.py +++ b/slither/core/variables/local_variable_init_from_tuple.py @@ -13,7 +13,7 @@ class LocalVariableInitFromTuple(LocalVariable): """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._tuple_index: Optional[int] = None diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index c9a90f36b..47b7682a4 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: class StateVariable(ChildContract, Variable): - def __init__(self): + def __init__(self) -> None: super().__init__() self._node_initialization: Optional["Node"] = None diff --git a/slither/core/variables/top_level_variable.py b/slither/core/variables/top_level_variable.py index 6d821092e..e6447c1ef 100644 --- a/slither/core/variables/top_level_variable.py +++ b/slither/core/variables/top_level_variable.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: class TopLevelVariable(TopLevel, Variable): - def __init__(self, scope: "FileScope"): + def __init__(self, scope: "FileScope") -> None: super().__init__() self._node_initialization: Optional["Node"] = None self.file_scope = scope diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 5fda02e93..8607a8921 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: # pylint: disable=too-many-instance-attributes class Variable(SourceMapping): - def __init__(self): + def __init__(self) -> None: super().__init__() self._name: Optional[str] = None self._initial_expression: Optional["Expression"] = None diff --git a/slither/detectors/abstract_detector.py b/slither/detectors/abstract_detector.py index 778695d85..8e2dd490d 100644 --- a/slither/detectors/abstract_detector.py +++ b/slither/detectors/abstract_detector.py @@ -5,9 +5,9 @@ from typing import Optional, List, TYPE_CHECKING, Dict, Union, Callable from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.declarations import Contract -from slither.utils.colors import green, yellow, red from slither.formatters.exceptions import FormatImpossible from slither.formatters.utils.patches import apply_patch, create_diff +from slither.utils.colors import green, yellow, red from slither.utils.comparable_enum import ComparableEnum from slither.utils.output import Output, SupportedOutput @@ -81,7 +81,7 @@ class AbstractDetector(metaclass=abc.ABCMeta): def __init__( self, compilation_unit: SlitherCompilationUnit, slither: "Slither", logger: Logger - ): + ) -> None: self.compilation_unit: SlitherCompilationUnit = compilation_unit self.contracts: List[Contract] = compilation_unit.contracts self.slither: "Slither" = slither diff --git a/slither/detectors/assembly/shift_parameter_mixup.py b/slither/detectors/assembly/shift_parameter_mixup.py index 65a35d8c3..31dad2371 100644 --- a/slither/detectors/assembly/shift_parameter_mixup.py +++ b/slither/detectors/assembly/shift_parameter_mixup.py @@ -1,6 +1,9 @@ +from typing import List from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import Binary, BinaryType from slither.slithir.variables import Constant +from slither.core.declarations.function_contract import FunctionContract +from slither.utils.output import Output class ShiftParameterMixup(AbstractDetector): @@ -36,7 +39,7 @@ The shift statement will right-shift the constant 8 by `a` bits""" WIKI_RECOMMENDATION = "Swap the order of parameters." - def _check_function(self, f): + def _check_function(self, f: FunctionContract) -> List[Output]: results = [] for node in f.nodes: @@ -52,7 +55,7 @@ The shift statement will right-shift the constant 8 by `a` bits""" results.append(json) return results - def _detect(self): + def _detect(self) -> List[Output]: results = [] for c in self.contracts: for f in c.functions: diff --git a/slither/detectors/attributes/const_functions_asm.py b/slither/detectors/attributes/const_functions_asm.py index 33853c9f4..e3a938361 100644 --- a/slither/detectors/attributes/const_functions_asm.py +++ b/slither/detectors/attributes/const_functions_asm.py @@ -2,12 +2,14 @@ Module detecting constant functions Recursively check the called functions """ +from typing import List from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, ) from slither.formatters.attributes.const_functions import custom_format +from slither.utils.output import Output class ConstantFunctionsAsm(AbstractDetector): @@ -55,7 +57,7 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" VULNERABLE_SOLC_VERSIONS = ALL_SOLC_VERSIONS_04 - def _detect(self): + def _detect(self) -> List[Output]: """Detect the constant function using assembly code Recursively visit the calls diff --git a/slither/detectors/attributes/const_functions_state.py b/slither/detectors/attributes/const_functions_state.py index a351727cf..36ea8f32d 100644 --- a/slither/detectors/attributes/const_functions_state.py +++ b/slither/detectors/attributes/const_functions_state.py @@ -2,12 +2,14 @@ Module detecting constant functions Recursively check the called functions """ +from typing import List from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, ) from slither.formatters.attributes.const_functions import custom_format +from slither.utils.output import Output class ConstantFunctionsState(AbstractDetector): @@ -55,7 +57,7 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" VULNERABLE_SOLC_VERSIONS = ALL_SOLC_VERSIONS_04 - def _detect(self): + def _detect(self) -> List[Output]: """Detect the constant function changing the state Recursively visit the calls diff --git a/slither/detectors/attributes/constant_pragma.py b/slither/detectors/attributes/constant_pragma.py index 0c77b69ca..2164a78e8 100644 --- a/slither/detectors/attributes/constant_pragma.py +++ b/slither/detectors/attributes/constant_pragma.py @@ -1,9 +1,11 @@ """ Check that the same pragma is used in all the files """ +from typing import List from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.formatters.attributes.constant_pragma import custom_format +from slither.utils.output import Output class ConstantPragma(AbstractDetector): @@ -22,7 +24,7 @@ class ConstantPragma(AbstractDetector): WIKI_DESCRIPTION = "Detect whether different Solidity versions are used." WIKI_RECOMMENDATION = "Use one Solidity version." - def _detect(self): + def _detect(self) -> List[Output]: results = [] pragma = self.compilation_unit.pragma_directives versions = [p.version for p in pragma if p.is_solidity_version] diff --git a/slither/detectors/attributes/incorrect_solc.py b/slither/detectors/attributes/incorrect_solc.py index d6838328e..fa9ffd88d 100644 --- a/slither/detectors/attributes/incorrect_solc.py +++ b/slither/detectors/attributes/incorrect_solc.py @@ -3,8 +3,11 @@ """ import re +from typing import List, Optional, Tuple + from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.formatters.attributes.incorrect_solc import custom_format +from slither.utils.output import Output # group: # 0: ^ > >= < <= (optional) @@ -83,7 +86,7 @@ Consider using the latest version of Solidity for testing.""" "^0.8.8", ] - def _check_version(self, version): + def _check_version(self, version: Tuple[str, str, str, str, str]) -> Optional[str]: op = version[0] if op and op not in [">", ">=", "^"]: return self.LESS_THAN_TXT @@ -96,7 +99,7 @@ Consider using the latest version of Solidity for testing.""" return self.OLD_VERSION_TXT return None - def _check_pragma(self, version): + def _check_pragma(self, version: str) -> Optional[str]: if version in self.BUGGY_VERSIONS: return self.BUGGY_VERSION_TXT versions = PATTERN.findall(version) @@ -117,7 +120,7 @@ Consider using the latest version of Solidity for testing.""" return self._check_version(version_left) return self.COMPLEX_PRAGMA_TXT - def _detect(self): + def _detect(self) -> List[Output]: """ Detects pragma statements that allow for outdated solc versions. :return: Returns the relevant JSON data for the findings. diff --git a/slither/detectors/attributes/locked_ether.py b/slither/detectors/attributes/locked_ether.py index e023f467b..2fdabaea6 100644 --- a/slither/detectors/attributes/locked_ether.py +++ b/slither/detectors/attributes/locked_ether.py @@ -1,7 +1,9 @@ """ Check if ethers are locked in the contract """ +from typing import List +from slither.core.declarations.contract import Contract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import ( HighLevelCall, @@ -12,6 +14,7 @@ from slither.slithir.operations import ( LibraryCall, InternalCall, ) +from slither.utils.output import Output class LockedEther(AbstractDetector): # pylint: disable=too-many-nested-blocks @@ -41,7 +44,7 @@ Every Ether sent to `Locked` will be lost.""" WIKI_RECOMMENDATION = "Remove the payable attribute or add a withdraw function." @staticmethod - def do_no_send_ether(contract): + def do_no_send_ether(contract: Contract) -> bool: functions = contract.all_functions_called to_explore = functions explored = [] @@ -73,7 +76,7 @@ Every Ether sent to `Locked` will be lost.""" return True - def _detect(self): + def _detect(self) -> List[Output]: results = [] for contract in self.compilation_unit.contracts_derived: diff --git a/slither/detectors/attributes/unimplemented_interface.py b/slither/detectors/attributes/unimplemented_interface.py index c5cf5d321..ff0889d11 100644 --- a/slither/detectors/attributes/unimplemented_interface.py +++ b/slither/detectors/attributes/unimplemented_interface.py @@ -4,8 +4,10 @@ Module detecting unimplemented interfaces Collect all the interfaces Check for contracts which implement all interface functions but do not explicitly derive from those interfaces. """ - +from typing import List from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.declarations.contract import Contract +from slither.utils.output import Output class MissingInheritance(AbstractDetector): @@ -42,7 +44,9 @@ contract Something { WIKI_RECOMMENDATION = "Inherit from the missing interface or contract." @staticmethod - def detect_unimplemented_interface(contract, interfaces): + def detect_unimplemented_interface( + contract: Contract, interfaces: List[Contract] + ) -> List[Contract]: """ Detects if contract intends to implement one of the interfaces but does not explicitly do so by deriving from it :param contract: The contract to check @@ -50,7 +54,7 @@ contract Something { :return: Interfaces likely intended to implement by the contract """ - intended_interfaces = [] + intended_interfaces: List[Contract] = [] sigs_contract = {f.full_name for f in contract.functions_entry_points} if not sigs_contract: @@ -111,7 +115,7 @@ contract Something { return intended_interfaces - def _detect(self): + def _detect(self) -> List[Output]: """Detect unimplemented interfaces Returns: list: {'contract'} diff --git a/slither/detectors/compiler_bugs/array_by_reference.py b/slither/detectors/compiler_bugs/array_by_reference.py index 6acc78d17..83ed69b9b 100644 --- a/slither/detectors/compiler_bugs/array_by_reference.py +++ b/slither/detectors/compiler_bugs/array_by_reference.py @@ -1,13 +1,17 @@ """ Detects the passing of arrays located in memory to functions which expect to modify arrays via storage reference. """ - +from typing import List, Set, Tuple, Union from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.core.solidity_types.array_type import ArrayType from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable from slither.slithir.operations.high_level_call import HighLevelCall from slither.slithir.operations.internal_call import InternalCall +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.utils.output import Output class ArrayByReference(AbstractDetector): @@ -55,7 +59,7 @@ As a result, Bob's usage of the contract is incorrect.""" WIKI_RECOMMENDATION = "Ensure the correct usage of `memory` and `storage` in the function parameters. Make all the locations explicit." @staticmethod - def get_funcs_modifying_array_params(contracts): + def get_funcs_modifying_array_params(contracts: List[Contract]) -> Set[FunctionContract]: """ Obtains a set of functions which take arrays not located in storage as parameters, and writes to them. :param contracts: The collection of contracts to check functions in. @@ -83,7 +87,14 @@ As a result, Bob's usage of the contract is incorrect.""" return results @staticmethod - def detect_calls_passing_ref_to_function(contracts, array_modifying_funcs): + def detect_calls_passing_ref_to_function( + contracts: List[Contract], array_modifying_funcs: Set[FunctionContract] + ) -> List[ + Union[ + Tuple[Node, StateVariable, FunctionContract], + Tuple[Node, LocalVariable, FunctionContract], + ] + ]: """ Obtains all calls passing storage arrays by value to a function which cannot write to them successfully. :param contracts: The collection of contracts to check for problematic calls in. @@ -134,7 +145,7 @@ As a result, Bob's usage of the contract is incorrect.""" results.append((node, arg, ir.function)) return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detects passing of arrays located in memory to functions which expect to modify arrays via storage reference. :return: The JSON results of the detector, which contains the calling_node, affected_argument_variable and diff --git a/slither/detectors/compiler_bugs/enum_conversion.py b/slither/detectors/compiler_bugs/enum_conversion.py index 477188fe0..671b8d699 100644 --- a/slither/detectors/compiler_bugs/enum_conversion.py +++ b/slither/detectors/compiler_bugs/enum_conversion.py @@ -1,7 +1,11 @@ """ Module detecting dangerous conversion to enum """ +from typing import List, Tuple +from slither.core.cfg.node import Node +from slither.core.declarations import Contract +from slither.core.source_mapping.source_mapping import SourceMapping from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, @@ -9,9 +13,10 @@ from slither.detectors.abstract_detector import ( ) from slither.slithir.operations import TypeConversion from slither.core.declarations.enum import Enum +from slither.utils.output import Output -def _detect_dangerous_enum_conversions(contract): +def _detect_dangerous_enum_conversions(contract: Contract) -> List[Tuple[Node, SourceMapping]]: """Detect dangerous conversion to enum by checking IR Args: contract (Contract) @@ -61,7 +66,7 @@ Attackers can trigger unexpected behaviour by calling `bug(1)`.""" VULNERABLE_SOLC_VERSIONS = make_solc_versions(4, 0, 4) - def _detect(self): + def _detect(self) -> List[Output]: """Detect dangerous conversion to enum""" results = [] diff --git a/slither/detectors/compiler_bugs/multiple_constructor_schemes.py b/slither/detectors/compiler_bugs/multiple_constructor_schemes.py index 5845cea1c..3486cc41b 100644 --- a/slither/detectors/compiler_bugs/multiple_constructor_schemes.py +++ b/slither/detectors/compiler_bugs/multiple_constructor_schemes.py @@ -1,4 +1,7 @@ +from typing import List + from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class MultipleConstructorSchemes(AbstractDetector): @@ -43,7 +46,7 @@ In Solidity [0.4.22](https://github.com/ethereum/solidity/releases/tag/v0.4.23), WIKI_RECOMMENDATION = "Only declare one constructor, preferably using the new scheme `constructor(...)` instead of `function (...)`." - def _detect(self): + def _detect(self) -> List[Output]: """ Detect multiple constructor schemes in the same contract :return: Returns a list of contract JSON result, where each result contains all constructor definitions. diff --git a/slither/detectors/compiler_bugs/public_mapping_nested.py b/slither/detectors/compiler_bugs/public_mapping_nested.py index 8e6b6f4a8..0ae8c3d50 100644 --- a/slither/detectors/compiler_bugs/public_mapping_nested.py +++ b/slither/detectors/compiler_bugs/public_mapping_nested.py @@ -1,7 +1,7 @@ """ Module detecting public mappings with nested variables (returns incorrect values prior to 0.5.x) """ - +from typing import Any, List, Union from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, @@ -10,9 +10,12 @@ from slither.detectors.abstract_detector import ( from slither.core.solidity_types.mapping_type import MappingType from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.declarations.structure import Structure +from slither.core.declarations.contract import Contract +from slither.core.variables.state_variable import StateVariable +from slither.utils.output import Output -def detect_public_nested_mappings(contract): +def detect_public_nested_mappings(contract: Contract) -> List[Union[StateVariable, Any]]: """ Detect any state variables that are initialized from an immediate function call (prior to constructor run). :param contract: The contract to detect state variable definitions for. @@ -68,7 +71,7 @@ class PublicMappingNested(AbstractDetector): VULNERABLE_SOLC_VERSIONS = ALL_SOLC_VERSIONS_04 - def _detect(self): + def _detect(self) -> List[Output]: """ Detect public mappings with nested variables (returns incorrect values prior to 0.5.x) diff --git a/slither/detectors/compiler_bugs/reused_base_constructor.py b/slither/detectors/compiler_bugs/reused_base_constructor.py index 9d0b91448..73cfac12e 100644 --- a/slither/detectors/compiler_bugs/reused_base_constructor.py +++ b/slither/detectors/compiler_bugs/reused_base_constructor.py @@ -1,18 +1,24 @@ """ Module detecting re-used base constructors in inheritance hierarchy. """ - +from typing import Any, Dict, List, Tuple, Union from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, ) +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.utils.output import Output # Helper: adds explicitly called constructors with arguments to the results lookup. def _add_constructors_with_args( - base_constructors, called_by_constructor, current_contract, results -): + base_constructors: List[Union[Any, FunctionContract]], + called_by_constructor: bool, + current_contract: Contract, + results: Dict[FunctionContract, List[Tuple[Contract, bool]]], +) -> None: for explicit_base_constructor in base_constructors: if len(explicit_base_constructor.parameters) > 0: if explicit_base_constructor not in results: @@ -77,7 +83,9 @@ The constructor of `A` is called multiple times in `D` and `E`: VULNERABLE_SOLC_VERSIONS = ALL_SOLC_VERSIONS_04 - def _detect_explicitly_called_base_constructors(self, contract): + def _detect_explicitly_called_base_constructors( + self, contract: Contract + ) -> Dict[FunctionContract, List[Tuple[Contract, bool]]]: """ Detects explicitly calls to base constructors with arguments in the inheritance hierarchy. :param contract: The contract to detect explicit calls to a base constructor with arguments to. @@ -124,7 +132,7 @@ The constructor of `A` is called multiple times in `D` and `E`: return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect reused base constructors. :return: Returns a list of JSON results. diff --git a/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py b/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py index 59d52760e..aee6361c6 100644 --- a/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py +++ b/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py @@ -1,7 +1,7 @@ """ Module detecting ABIEncoderV2 array bug """ - +from typing import List, Set, Tuple from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, @@ -16,6 +16,10 @@ from slither.core.declarations.solidity_variables import SolidityFunction from slither.slithir.operations import EventCall from slither.slithir.operations import HighLevelCall from slither.utils.utils import unroll +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.utils.output import Output class ABIEncoderV2Array(AbstractDetector): @@ -55,7 +59,9 @@ contract A { VULNERABLE_SOLC_VERSIONS = make_solc_versions(4, 7, 25) + make_solc_versions(5, 0, 9) @staticmethod - def _detect_storage_abiencoderv2_arrays(contract): + def _detect_storage_abiencoderv2_arrays( + contract: Contract, + ) -> Set[Tuple[FunctionContract, Node]]: """ Detects and returns all nodes with storage-allocated abiencoderv2 arrays of arrays/structs in abi.encode, events or external calls :param contract: Contract to detect within @@ -98,7 +104,7 @@ contract A { # Return the resulting set of tuples return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect ABIEncoderV2 array bug """ diff --git a/slither/detectors/compiler_bugs/storage_signed_integer_array.py b/slither/detectors/compiler_bugs/storage_signed_integer_array.py index 419c71c87..736f66789 100644 --- a/slither/detectors/compiler_bugs/storage_signed_integer_array.py +++ b/slither/detectors/compiler_bugs/storage_signed_integer_array.py @@ -1,6 +1,7 @@ """ Module detecting storage signed integer array bug """ +from typing import List from slither.detectors.abstract_detector import ( AbstractDetector, @@ -14,6 +15,7 @@ from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.slithir.operations.assignment import Assignment from slither.slithir.operations.init_array import InitArray +from slither.utils.output import Output class StorageSignedIntegerArray(AbstractDetector): @@ -108,7 +110,7 @@ contract A { # Return the resulting set of tuples return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect storage signed integer array init/assignment """ diff --git a/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py b/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py index a4d3cb8f2..6685948b3 100644 --- a/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py +++ b/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py @@ -1,7 +1,7 @@ """ Module detecting uninitialized function pointer calls in constructors """ - +from typing import Any, List, Union from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, @@ -10,9 +10,14 @@ from slither.detectors.abstract_detector import ( from slither.slithir.operations import InternalDynamicCall, OperationWithLValue from slither.slithir.variables import ReferenceVariable from slither.slithir.variables.variable import SlithIRVariable +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.slithir.variables.state_variable import StateIRVariable +from slither.utils.output import Output -def _get_variables_entrance(function): +def _get_variables_entrance(function: FunctionContract) -> List[Union[Any, StateIRVariable]]: """ Return the first SSA variables of the function Catpure the phi operation at the entry point @@ -25,7 +30,7 @@ def _get_variables_entrance(function): return ret -def _is_vulnerable(node, variables_entrance): +def _is_vulnerable(node: Node, variables_entrance: List[Union[Any, StateIRVariable]]) -> bool: """ Vulnerable if an IR ssa: - It is an internal dynamic call @@ -84,7 +89,9 @@ The call to `a(10)` will lead to unexpected behavior because function pointer `a VULNERABLE_SOLC_VERSIONS = make_solc_versions(4, 5, 25) + make_solc_versions(5, 0, 8) @staticmethod - def _detect_uninitialized_function_ptr_in_constructor(contract): + def _detect_uninitialized_function_ptr_in_constructor( + contract: Contract, + ) -> List[Union[Any, Node]]: """ Detect uninitialized function pointer calls in constructors :param contract: The contract of interest for detection @@ -99,7 +106,7 @@ The call to `a(10)` will lead to unexpected behavior because function pointer `a ] return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect uninitialized function pointer calls in constructors of contracts Returns: diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20.py b/slither/detectors/erc/erc20/arbitrary_send_erc20.py index 7aeaa1139..17b1fba30 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20.py @@ -1,16 +1,17 @@ from typing import List + +from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node +from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.declarations import Contract, Function, SolidityVariableComposed from slither.core.declarations.solidity_variables import SolidityVariable from slither.slithir.operations import HighLevelCall, LibraryCall -from slither.core.declarations import Contract, Function, SolidityVariableComposed -from slither.analyses.data_dependency.data_dependency import is_dependent -from slither.core.compilation_unit import SlitherCompilationUnit class ArbitrarySendErc20: """Detects instances where ERC20 can be sent from an arbitrary from address.""" - def __init__(self, compilation_unit: SlitherCompilationUnit): + def __init__(self, compilation_unit: SlitherCompilationUnit) -> None: self._compilation_unit = compilation_unit self._no_permit_results: List[Node] = [] self._permit_results: List[Node] = [] @@ -27,7 +28,7 @@ class ArbitrarySendErc20: def permit_results(self) -> List[Node]: return self._permit_results - def _detect_arbitrary_from(self, contract: Contract): + def _detect_arbitrary_from(self, contract: Contract) -> None: for f in contract.functions: all_high_level_calls = [ f_called[1].solidity_signature @@ -48,7 +49,7 @@ class ArbitrarySendErc20: ArbitrarySendErc20._arbitrary_from(f.nodes, self._no_permit_results) @staticmethod - def _arbitrary_from(nodes: List[Node], results: List[Node]): + def _arbitrary_from(nodes: List[Node], results: List[Node]) -> None: """Finds instances of (safe)transferFrom that do not use msg.sender or address(this) as from parameter.""" for node in nodes: for ir in node.irs: @@ -89,7 +90,7 @@ class ArbitrarySendErc20: ): results.append(ir.node) - def detect(self): + def detect(self) -> None: """Detect transfers that use arbitrary `from` parameter.""" for c in self.compilation_unit.contracts_derived: self._detect_arbitrary_from(c) diff --git a/slither/detectors/erc/erc20/incorrect_erc20_interface.py b/slither/detectors/erc/erc20/incorrect_erc20_interface.py index aa9f3a916..4da6ab5ae 100644 --- a/slither/detectors/erc/erc20/incorrect_erc20_interface.py +++ b/slither/detectors/erc/erc20/incorrect_erc20_interface.py @@ -2,7 +2,12 @@ Detect incorrect erc20 interface. Some contracts do not return a bool on transfer/transferFrom/approve, which may lead to preventing the contract to be used with contracts compiled with recent solc (>0.4.22) """ +from typing import List, Tuple + +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class IncorrectERC20InterfaceDetection(AbstractDetector): @@ -36,7 +41,7 @@ contract Token{ ) @staticmethod - def incorrect_erc20_interface(signature): + def incorrect_erc20_interface(signature: Tuple[str, List[str], List[str]]) -> bool: (name, parameters, returnVars) = signature if name == "transfer" and parameters == ["address", "uint256"] and returnVars != ["bool"]: @@ -68,7 +73,7 @@ contract Token{ return False @staticmethod - def detect_incorrect_erc20_interface(contract): + def detect_incorrect_erc20_interface(contract: Contract) -> List[FunctionContract]: """Detect incorrect ERC20 interface Returns: @@ -93,7 +98,7 @@ contract Token{ return functions - def _detect(self): + def _detect(self) -> List[Output]: """Detect incorrect erc20 interface Returns: diff --git a/slither/detectors/erc/incorrect_erc721_interface.py b/slither/detectors/erc/incorrect_erc721_interface.py index 2bdc78cd7..8327e8b2e 100644 --- a/slither/detectors/erc/incorrect_erc721_interface.py +++ b/slither/detectors/erc/incorrect_erc721_interface.py @@ -1,7 +1,11 @@ """ Detect incorrect erc721 interface. """ +from typing import Any, List, Tuple, Union from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.utils.output import Output class IncorrectERC721InterfaceDetection(AbstractDetector): @@ -37,7 +41,9 @@ contract Token{ ) @staticmethod - def incorrect_erc721_interface(signature): + def incorrect_erc721_interface( + signature: Union[Tuple[str, List[str], List[str]], Tuple[str, List[str], List[Any]]] + ) -> bool: (name, parameters, returnVars) = signature # ERC721 @@ -83,7 +89,7 @@ contract Token{ return False @staticmethod - def detect_incorrect_erc721_interface(contract): + def detect_incorrect_erc721_interface(contract: Contract) -> List[Union[FunctionContract, Any]]: """Detect incorrect ERC721 interface Returns: @@ -102,7 +108,7 @@ contract Token{ ] return functions - def _detect(self): + def _detect(self) -> List[Output]: """Detect incorrect erc721 interface Returns: diff --git a/slither/detectors/erc/unindexed_event_parameters.py b/slither/detectors/erc/unindexed_event_parameters.py index 3962e2358..6e91b0fb3 100644 --- a/slither/detectors/erc/unindexed_event_parameters.py +++ b/slither/detectors/erc/unindexed_event_parameters.py @@ -1,7 +1,12 @@ """ Detect mistakenly un-indexed ERC20 event parameters """ +from typing import Any, List, Tuple, Union from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.declarations.contract import Contract +from slither.core.declarations.event import Event +from slither.core.variables.event_variable import EventVariable +from slither.utils.output import Output class UnindexedERC20EventParameters(AbstractDetector): @@ -39,7 +44,9 @@ Failure to include these keywords will exclude the parameter data in the transac STANDARD_JSON = False @staticmethod - def detect_erc20_unindexed_event_params(contract): + def detect_erc20_unindexed_event_params( + contract: Contract, + ) -> List[Union[Tuple[Event, EventVariable], Any]]: """ Detect un-indexed ERC20 event parameters in a given contract. :param contract: The contract to check ERC20 events for un-indexed parameters in. @@ -68,7 +75,7 @@ Failure to include these keywords will exclude the parameter data in the transac # Return the results. return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect un-indexed ERC20 event parameters in all contracts. """ diff --git a/slither/detectors/examples/backdoor.py b/slither/detectors/examples/backdoor.py index 1e73fa814..0e8e9ad81 100644 --- a/slither/detectors/examples/backdoor.py +++ b/slither/detectors/examples/backdoor.py @@ -1,4 +1,7 @@ +from typing import List + from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class Backdoor(AbstractDetector): @@ -17,7 +20,7 @@ class Backdoor(AbstractDetector): WIKI_EXPLOIT_SCENARIO = ".." WIKI_RECOMMENDATION = ".." - def _detect(self): + def _detect(self) -> List[Output]: results = [] for contract in self.compilation_unit.contracts_derived: diff --git a/slither/detectors/functions/arbitrary_send_eth.py b/slither/detectors/functions/arbitrary_send_eth.py index e1752bbdb..390b1f2ab 100644 --- a/slither/detectors/functions/arbitrary_send_eth.py +++ b/slither/detectors/functions/arbitrary_send_eth.py @@ -9,11 +9,12 @@ TODO: dont report if the value is tainted by msg.value """ -from typing import List +from typing import Any, Tuple, Union, List +from slither.analyses.data_dependency.data_dependency import is_tainted, is_dependent from slither.core.cfg.node import Node from slither.core.declarations import Function, Contract -from slither.analyses.data_dependency.data_dependency import is_tainted, is_dependent +from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.solidity_variables import ( SolidityFunction, SolidityVariableComposed, @@ -28,12 +29,11 @@ from slither.slithir.operations import ( Transfer, ) - # pylint: disable=too-many-nested-blocks,too-many-branches from slither.utils.output import Output -def arbitrary_send(func: Function): +def arbitrary_send(func: Function) -> Union[bool, List[Node]]: if func.is_protected(): return [] @@ -74,7 +74,9 @@ def arbitrary_send(func: Function): return ret -def detect_arbitrary_send(contract: Contract): +def detect_arbitrary_send( + contract: Contract, +) -> List[Union[Tuple[FunctionContract, List[Node]], Any]]: """ Detect arbitrary send Args: diff --git a/slither/detectors/functions/dead_code.py b/slither/detectors/functions/dead_code.py index 5d632b9f9..1a25c5776 100644 --- a/slither/detectors/functions/dead_code.py +++ b/slither/detectors/functions/dead_code.py @@ -5,6 +5,7 @@ from typing import List, Tuple from slither.core.declarations import Function, FunctionContract, Contract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class DeadCode(AbstractDetector): @@ -34,7 +35,7 @@ contract Contract{ WIKI_RECOMMENDATION = "Remove unused functions." - def _detect(self): + def _detect(self) -> List[Output]: results = [] diff --git a/slither/detectors/functions/modifier.py b/slither/detectors/functions/modifier.py index 23c6fc0fc..271d8e6cb 100644 --- a/slither/detectors/functions/modifier.py +++ b/slither/detectors/functions/modifier.py @@ -5,18 +5,19 @@ Note that require()/assert() are not considered here. Even if they are in the outermost scope, they do not guarantee a revert, so a default value can still be returned. """ - +from typing import List from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.core.cfg.node import NodeType +from slither.core.cfg.node import Node, NodeType +from slither.utils.output import Output -def is_revert(node): +def is_revert(node: Node) -> bool: return node.type == NodeType.THROW or any( c.name in ["revert()", "revert(string"] for c in node.internal_calls ) -def _get_false_son(node): +def _get_false_son(node: Node) -> Node: """Select the son node corresponding to a false branch Following this node stays on the outer scope of the function """ @@ -60,7 +61,7 @@ If the condition in `myModif` is false, the execution of `get()` will return 0." WIKI_RECOMMENDATION = "All the paths in a modifier must execute `_` or revert." - def _detect(self): + def _detect(self) -> List[Output]: results = [] for c in self.contracts: for mod in c.modifiers: diff --git a/slither/detectors/functions/permit_domain_signature_collision.py b/slither/detectors/functions/permit_domain_signature_collision.py index 7142d7cf1..de64ec52e 100644 --- a/slither/detectors/functions/permit_domain_signature_collision.py +++ b/slither/detectors/functions/permit_domain_signature_collision.py @@ -8,6 +8,7 @@ from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.variables.state_variable import StateVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.utils.function import get_function_id +from slither.utils.output import Output class DomainSeparatorCollision(AbstractDetector): @@ -39,7 +40,7 @@ contract Contract{ WIKI_RECOMMENDATION = "Remove or rename the function that collides with DOMAIN_SEPARATOR()." - def _detect(self): + def _detect(self) -> List[Output]: domain_sig = get_function_id("DOMAIN_SEPARATOR()") for contract in self.compilation_unit.contracts_derived: if contract.is_erc20(): diff --git a/slither/detectors/functions/protected_variable.py b/slither/detectors/functions/protected_variable.py index cbd640e18..68ed098c7 100644 --- a/slither/detectors/functions/protected_variable.py +++ b/slither/detectors/functions/protected_variable.py @@ -5,8 +5,8 @@ A suicidal contract is an unprotected function that calls selfdestruct """ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.core.declarations import Function, Contract +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.utils.output import Output @@ -71,7 +71,7 @@ contract Buggy{ results.append(res) return results - def _detect(self): + def _detect(self) -> List[Output]: """Detect the suicidal functions""" results = [] for contract in self.compilation_unit.contracts_derived: diff --git a/slither/detectors/functions/suicidal.py b/slither/detectors/functions/suicidal.py index 906b13902..7741da57d 100644 --- a/slither/detectors/functions/suicidal.py +++ b/slither/detectors/functions/suicidal.py @@ -3,8 +3,12 @@ Module detecting suicidal contract A suicidal contract is an unprotected function that calls selfdestruct """ +from typing import List +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class Suicidal(AbstractDetector): @@ -37,7 +41,7 @@ Bob calls `kill` and destructs the contract.""" WIKI_RECOMMENDATION = "Protect access to all sensitive functions." @staticmethod - def detect_suicidal_func(func): + def detect_suicidal_func(func: FunctionContract) -> bool: """Detect if the function is suicidal Detect the public functions calling suicide/selfdestruct without protection @@ -60,14 +64,14 @@ Bob calls `kill` and destructs the contract.""" return True - def detect_suicidal(self, contract): + def detect_suicidal(self, contract: Contract) -> List[FunctionContract]: ret = [] for f in contract.functions_declared: if self.detect_suicidal_func(f): ret.append(f) return ret - def _detect(self): + def _detect(self) -> List[Output]: """Detect the suicidal functions""" results = [] for c in self.contracts: diff --git a/slither/detectors/functions/unimplemented.py b/slither/detectors/functions/unimplemented.py index 88c71ac05..11a1fad80 100644 --- a/slither/detectors/functions/unimplemented.py +++ b/slither/detectors/functions/unimplemented.py @@ -7,8 +7,12 @@ Check for unimplemented functions that are never implemented Consider public state variables as implemented functions Do not consider fallback function or constructor """ - +from typing import List, Set from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.utils.output import Output + # Since 0.5.1, Solidity allows creating state variable matching a function signature. older_solc_versions = ["0.5.0"] + ["0.4." + str(x) for x in range(0, 27)] @@ -55,10 +59,10 @@ All unimplemented functions must be implemented on a contract that is meant to b WIKI_RECOMMENDATION = "Implement all unimplemented functions in any contract you intend to use directly (not simply inherit from)." @staticmethod - def _match_state_variable(contract, f): + def _match_state_variable(contract: Contract, f: FunctionContract) -> bool: return any(s.full_name == f.full_name for s in contract.state_variables) - def _detect_unimplemented_function(self, contract): + def _detect_unimplemented_function(self, contract: Contract) -> Set[FunctionContract]: """ Detects any function definitions which are not implemented in the given contract. :param contract: The contract to search unimplemented functions for. @@ -87,7 +91,7 @@ All unimplemented functions must be implemented on a contract that is meant to b unimplemented.add(f) return unimplemented - def _detect(self): + def _detect(self) -> List[Output]: """Detect unimplemented functions Recursively visit the calls diff --git a/slither/detectors/naming_convention/naming_convention.py b/slither/detectors/naming_convention/naming_convention.py index 94fd2e8f9..96d3964fa 100644 --- a/slither/detectors/naming_convention/naming_convention.py +++ b/slither/detectors/naming_convention/naming_convention.py @@ -1,6 +1,8 @@ import re +from typing import List from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.formatters.naming_convention.naming_convention import custom_format +from slither.utils.output import Output class NamingConvention(AbstractDetector): @@ -36,28 +38,29 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2 STANDARD_JSON = False @staticmethod - def is_cap_words(name): + def is_cap_words(name: str) -> bool: return re.search("^[A-Z]([A-Za-z0-9]+)?_?$", name) is not None @staticmethod - def is_mixed_case(name): + def is_mixed_case(name: str) -> bool: return re.search("^[a-z]([A-Za-z0-9]+)?_?$", name) is not None @staticmethod - def is_mixed_case_with_underscore(name): + def is_mixed_case_with_underscore(name: str) -> bool: # Allow _ at the beginning to represent private variable # or unused parameters return re.search("^[_]?[a-z]([A-Za-z0-9]+)?_?$", name) is not None @staticmethod - def is_upper_case_with_underscores(name): + def is_upper_case_with_underscores(name: str) -> bool: return re.search("^[A-Z0-9_]+_?$", name) is not None @staticmethod - def should_avoid_name(name): + def should_avoid_name(name: str) -> bool: return re.search("^[lOI]$", name) is not None - def _detect(self): # pylint: disable=too-many-branches,too-many-statements + # pylint: disable=too-many-branches,too-many-statements + def _detect(self) -> List[Output]: results = [] for contract in self.contracts: diff --git a/slither/detectors/operations/block_timestamp.py b/slither/detectors/operations/block_timestamp.py index 01941257d..b80c8c392 100644 --- a/slither/detectors/operations/block_timestamp.py +++ b/slither/detectors/operations/block_timestamp.py @@ -13,6 +13,7 @@ from slither.core.declarations.solidity_variables import ( ) from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import Binary, BinaryType +from slither.utils.output import Output def _timestamp(func: Function) -> List[Node]: @@ -69,7 +70,7 @@ class Timestamp(AbstractDetector): WIKI_EXPLOIT_SCENARIO = """"Bob's contract relies on `block.timestamp` for its randomness. Eve is a miner and manipulates `block.timestamp` to exploit Bob's contract.""" WIKI_RECOMMENDATION = "Avoid relying on `block.timestamp`." - def _detect(self): + def _detect(self) -> List[Output]: """""" results = [] diff --git a/slither/detectors/operations/low_level_calls.py b/slither/detectors/operations/low_level_calls.py index 7e0f45e34..1ea91c37a 100644 --- a/slither/detectors/operations/low_level_calls.py +++ b/slither/detectors/operations/low_level_calls.py @@ -1,9 +1,13 @@ """ Module detecting usage of low level calls """ - +from typing import List, Tuple from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import LowLevelCall +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.utils.output import Output class LowLevelCalls(AbstractDetector): @@ -23,7 +27,7 @@ class LowLevelCalls(AbstractDetector): WIKI_RECOMMENDATION = "Avoid low-level calls. Check the call success. If the call is meant for a contract, check for code existence." @staticmethod - def _contains_low_level_calls(node): + def _contains_low_level_calls(node: Node) -> bool: """ Check if the node contains Low Level Calls Returns: @@ -31,7 +35,9 @@ class LowLevelCalls(AbstractDetector): """ return any(isinstance(ir, LowLevelCall) for ir in node.irs) - def detect_low_level_calls(self, contract): + def detect_low_level_calls( + self, contract: Contract + ) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in [f for f in contract.functions if contract == f.contract_declarer]: nodes = f.nodes @@ -40,7 +46,7 @@ class LowLevelCalls(AbstractDetector): ret.append((f, assembly_nodes)) return ret - def _detect(self): + def _detect(self) -> List[Output]: """Detect the functions that use low level calls""" results = [] for c in self.contracts: diff --git a/slither/detectors/operations/missing_events_access_control.py b/slither/detectors/operations/missing_events_access_control.py index 7437b32ab..20c229759 100644 --- a/slither/detectors/operations/missing_events_access_control.py +++ b/slither/detectors/operations/missing_events_access_control.py @@ -2,11 +2,18 @@ Module detecting missing events for critical contract parameters set by owners and used in access control """ +from typing import List, Tuple -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.analyses.data_dependency.data_dependency import is_tainted -from slither.slithir.operations.event_call import EventCall +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.core.declarations.modifier import Modifier from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.variables.state_variable import StateVariable +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.slithir.operations.event_call import EventCall +from slither.utils.output import Output class MissingEventsAccessControl(AbstractDetector): @@ -45,7 +52,9 @@ contract C { WIKI_RECOMMENDATION = "Emit an event for critical parameter changes." @staticmethod - def _detect_missing_events(contract): + def _detect_missing_events( + contract: Contract, + ) -> List[Tuple[FunctionContract, List[Tuple[Node, StateVariable, Modifier]]]]: """ Detects if critical contract parameters set by owners and used in access control are missing events :param contract: The contract to check @@ -80,7 +89,7 @@ contract C { results.append((function, nodes)) return results - def _detect(self): + def _detect(self) -> List[Output]: """Detect missing events for critical contract parameters set by owners and used in access control Returns: list: {'(function, node)'} diff --git a/slither/detectors/operations/missing_events_arithmetic.py b/slither/detectors/operations/missing_events_arithmetic.py index 340f471a5..6e1d5fbb5 100644 --- a/slither/detectors/operations/missing_events_arithmetic.py +++ b/slither/detectors/operations/missing_events_arithmetic.py @@ -2,11 +2,17 @@ Module detecting missing events for critical contract parameters set by owners and used in arithmetic """ +from typing import List, Tuple -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.analyses.data_dependency.data_dependency import is_tainted -from slither.slithir.operations.event_call import EventCall +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType, Int, Uint +from slither.core.variables.state_variable import StateVariable +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.slithir.operations.event_call import EventCall +from slither.utils.output import Output class MissingEventsArithmetic(AbstractDetector): @@ -49,7 +55,9 @@ contract C { WIKI_RECOMMENDATION = "Emit an event for critical parameter changes." @staticmethod - def _detect_unprotected_use(contract, sv): + def _detect_unprotected_use( + contract: Contract, sv: StateVariable + ) -> List[Tuple[Node, FunctionContract]]: unprotected_functions = [ function for function in contract.functions_declared if not function.is_protected() ] @@ -60,7 +68,9 @@ contract C { if sv in node.state_variables_read ] - def _detect_missing_events(self, contract): + def _detect_missing_events( + self, contract: Contract + ) -> List[Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]]]: """ Detects if critical contract parameters set by owners and used in arithmetic are missing events :param contract: The contract to check @@ -101,7 +111,7 @@ contract C { results.append((function, nodes)) return results - def _detect(self): + def _detect(self) -> List[Output]: """Detect missing events for critical contract parameters set by owners and used in arithmetic Returns: list: {'(function, node)'} diff --git a/slither/detectors/operations/missing_zero_address_validation.py b/slither/detectors/operations/missing_zero_address_validation.py index cb6bf7cdb..a6c8de9ff 100644 --- a/slither/detectors/operations/missing_zero_address_validation.py +++ b/slither/detectors/operations/missing_zero_address_validation.py @@ -3,12 +3,19 @@ Module detecting missing zero address validation """ from collections import defaultdict +from typing import DefaultDict, List, Tuple, Union -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.analyses.data_dependency.data_dependency import is_tainted +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function import ModifierStatements +from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType -from slither.slithir.operations import Send, Transfer, LowLevelCall +from slither.core.variables.local_variable import LocalVariable +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import Call +from slither.slithir.operations import Send, Transfer, LowLevelCall +from slither.utils.output import Output class MissingZeroAddressValidation(AbstractDetector): @@ -46,7 +53,9 @@ Bob calls `updateOwner` without specifying the `newOwner`, so Bob loses ownershi WIKI_RECOMMENDATION = "Check that the address is not zero." - def _zero_address_validation_in_modifier(self, var, modifier_exprs): + def _zero_address_validation_in_modifier( + self, var: LocalVariable, modifier_exprs: List[ModifierStatements] + ) -> bool: for mod in modifier_exprs: for node in mod.nodes: # Skip validation if the modifier's parameters contains more than one variable @@ -62,7 +71,9 @@ Bob calls `updateOwner` without specifying the `newOwner`, so Bob loses ownershi return True return False - def _zero_address_validation(self, var, node, explored): + def _zero_address_validation( + self, var: LocalVariable, node: Node, explored: List[Node] + ) -> bool: """ Detects (recursively) if var is (zero address) checked in the function node """ @@ -83,7 +94,9 @@ Bob calls `updateOwner` without specifying the `newOwner`, so Bob loses ownershi return True return False - def _detect_missing_zero_address_validation(self, contract): + def _detect_missing_zero_address_validation( + self, contract: Contract + ) -> List[Union[Tuple[FunctionContract, DefaultDict[LocalVariable, List[Node]]]]]: """ Detects if addresses are zero address validated before use. :param contract: The contract to check @@ -130,7 +143,7 @@ Bob calls `updateOwner` without specifying the `newOwner`, so Bob loses ownershi results.append((function, var_nodes)) return results - def _detect(self): + def _detect(self) -> List[Output]: """Detect if addresses are zero address validated before use. Returns: list: {'(function, node)'} diff --git a/slither/detectors/operations/unchecked_low_level_return_values.py b/slither/detectors/operations/unchecked_low_level_return_values.py index 5064ebca6..0537ebbf2 100644 --- a/slither/detectors/operations/unchecked_low_level_return_values.py +++ b/slither/detectors/operations/unchecked_low_level_return_values.py @@ -4,6 +4,7 @@ Module detecting unused return values from low level from slither.detectors.abstract_detector import DetectorClassification from slither.detectors.operations.unused_return_values import UnusedReturnValues from slither.slithir.operations import LowLevelCall +from slither.slithir.operations.operation import Operation class UncheckedLowLevel(UnusedReturnValues): @@ -37,5 +38,5 @@ If the low level is used to prevent blocking operations, consider logging failed WIKI_RECOMMENDATION = "Ensure that the return value of a low-level call is checked or logged." - def _is_instance(self, ir): # pylint: disable=no-self-use + def _is_instance(self, ir: Operation) -> bool: # pylint: disable=no-self-use return isinstance(ir, LowLevelCall) diff --git a/slither/detectors/operations/unchecked_send_return_value.py b/slither/detectors/operations/unchecked_send_return_value.py index 0c3ff0d30..e9b2dc322 100644 --- a/slither/detectors/operations/unchecked_send_return_value.py +++ b/slither/detectors/operations/unchecked_send_return_value.py @@ -5,6 +5,7 @@ Module detecting unused return values from send from slither.detectors.abstract_detector import DetectorClassification from slither.detectors.operations.unused_return_values import UnusedReturnValues from slither.slithir.operations import Send +from slither.slithir.operations.operation import Operation class UncheckedSend(UnusedReturnValues): @@ -38,5 +39,5 @@ If `send` is used to prevent blocking operations, consider logging the failed `s WIKI_RECOMMENDATION = "Ensure that the return value of `send` is checked or logged." - def _is_instance(self, ir): # pylint: disable=no-self-use + def _is_instance(self, ir: Operation) -> bool: # pylint: disable=no-self-use return isinstance(ir, Send) diff --git a/slither/detectors/operations/unchecked_transfer.py b/slither/detectors/operations/unchecked_transfer.py index df5e8464c..224a7cda0 100644 --- a/slither/detectors/operations/unchecked_transfer.py +++ b/slither/detectors/operations/unchecked_transfer.py @@ -6,6 +6,7 @@ from slither.core.declarations import Function from slither.detectors.abstract_detector import DetectorClassification from slither.detectors.operations.unused_return_values import UnusedReturnValues from slither.slithir.operations import HighLevelCall +from slither.slithir.operations.operation import Operation class UncheckedTransfer(UnusedReturnValues): @@ -45,7 +46,7 @@ Several tokens do not revert in case of failure and return false. If one of thes "Use `SafeERC20`, or ensure that the transfer/transferFrom return value is checked." ) - def _is_instance(self, ir): # pylint: disable=no-self-use + def _is_instance(self, ir: Operation) -> bool: # pylint: disable=no-self-use return ( isinstance(ir, HighLevelCall) and isinstance(ir.function, Function) diff --git a/slither/detectors/operations/unused_return_values.py b/slither/detectors/operations/unused_return_values.py index ff3e7139d..7edde20fc 100644 --- a/slither/detectors/operations/unused_return_values.py +++ b/slither/detectors/operations/unused_return_values.py @@ -1,11 +1,16 @@ """ Module detecting unused return values from external calls """ +from typing import List +from slither.core.cfg.node import Node +from slither.core.declarations import Function +from slither.core.declarations.function_contract import FunctionContract from slither.core.variables.state_variable import StateVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import HighLevelCall -from slither.core.declarations import Function +from slither.slithir.operations.operation import Operation +from slither.utils.output import Output class UnusedReturnValues(AbstractDetector): @@ -40,7 +45,7 @@ contract MyConc{ WIKI_RECOMMENDATION = "Ensure that all the return values of the function calls are used." - def _is_instance(self, ir): # pylint: disable=no-self-use + def _is_instance(self, ir: Operation) -> bool: # pylint: disable=no-self-use return isinstance(ir, HighLevelCall) and ( ( isinstance(ir.function, Function) @@ -50,7 +55,9 @@ contract MyConc{ or not isinstance(ir.function, Function) ) - def detect_unused_return_values(self, f): # pylint: disable=no-self-use + def detect_unused_return_values( + self, f: FunctionContract + ) -> List[Node]: # pylint: disable=no-self-use """ Return the nodes where the return value of a call is unused Args: @@ -73,7 +80,7 @@ contract MyConc{ return [nodes_origin[value].node for value in values_returned] - def _detect(self): + def _detect(self) -> List[Output]: """Detect high level calls which return a value that are never used""" results = [] for c in self.compilation_unit.contracts: diff --git a/slither/detectors/operations/void_constructor.py b/slither/detectors/operations/void_constructor.py index ca010b24d..fb44ea98c 100644 --- a/slither/detectors/operations/void_constructor.py +++ b/slither/detectors/operations/void_constructor.py @@ -1,5 +1,8 @@ +from typing import List + from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import Nop +from slither.utils.output import Output class VoidConstructor(AbstractDetector): @@ -26,7 +29,7 @@ contract B is A{ When reading `B`'s constructor definition, we might assume that `A()` initiates the contract, but no code is executed.""" # endregion wiki_exploit_scenario - def _detect(self): + def _detect(self) -> List[Output]: """""" results = [] for c in self.contracts: diff --git a/slither/detectors/reentrancy/reentrancy_benign.py b/slither/detectors/reentrancy/reentrancy_benign.py index c3eeb5357..25fe0ff03 100644 --- a/slither/detectors/reentrancy/reentrancy_benign.py +++ b/slither/detectors/reentrancy/reentrancy_benign.py @@ -5,10 +5,11 @@ Iterate over all the nodes of the graph until reaching a fixpoint """ from collections import namedtuple, defaultdict -from typing import List +from typing import DefaultDict, Set, List from slither.detectors.abstract_detector import DetectorClassification -from .reentrancy import Reentrancy, to_hashable +from slither.detectors.reentrancy.reentrancy import Reentrancy, to_hashable +from slither.utils.output import Output FindingKey = namedtuple("FindingKey", ["function", "calls", "send_eth"]) FindingValue = namedtuple("FindingValue", ["variable", "node", "nodes"]) @@ -50,7 +51,7 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr STANDARD_JSON = False - def find_reentrancies(self): + def find_reentrancies(self) -> DefaultDict[FindingKey, Set[FindingValue]]: result = defaultdict(set) for contract in self.contracts: for f in contract.functions_and_modifiers_declared: @@ -87,7 +88,7 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr result[finding_key] |= not_read_then_written return result - def _detect(self): # pylint: disable=too-many-branches + def _detect(self) -> List[Output]: # pylint: disable=too-many-branches """""" super()._detect() diff --git a/slither/detectors/reentrancy/reentrancy_events.py b/slither/detectors/reentrancy/reentrancy_events.py index 78a6058be..2d29442f7 100644 --- a/slither/detectors/reentrancy/reentrancy_events.py +++ b/slither/detectors/reentrancy/reentrancy_events.py @@ -5,9 +5,11 @@ Iterate over all the nodes of the graph until reaching a fixpoint """ from collections import namedtuple, defaultdict +from typing import DefaultDict, List, Set from slither.detectors.abstract_detector import DetectorClassification -from .reentrancy import Reentrancy, to_hashable +from slither.detectors.reentrancy.reentrancy import Reentrancy, to_hashable +from slither.utils.output import Output FindingKey = namedtuple("FindingKey", ["function", "calls", "send_eth"]) FindingValue = namedtuple("FindingValue", ["variable", "node", "nodes"]) @@ -48,7 +50,7 @@ If `d.()` re-enters, the `Counter` events will be shown in an incorrect order, w STANDARD_JSON = False - def find_reentrancies(self): + def find_reentrancies(self) -> DefaultDict[FindingKey, Set[FindingValue]]: result = defaultdict(set) for contract in self.contracts: for f in contract.functions_and_modifiers_declared: @@ -80,7 +82,7 @@ If `d.()` re-enters, the `Counter` events will be shown in an incorrect order, w result[finding_key] |= finding_vars return result - def _detect(self): # pylint: disable=too-many-branches + def _detect(self) -> List[Output]: # pylint: disable=too-many-branches """""" super()._detect() diff --git a/slither/detectors/reentrancy/reentrancy_no_gas.py b/slither/detectors/reentrancy/reentrancy_no_gas.py index 29d3b881a..c559d76df 100644 --- a/slither/detectors/reentrancy/reentrancy_no_gas.py +++ b/slither/detectors/reentrancy/reentrancy_no_gas.py @@ -5,11 +5,16 @@ Iterate over all the nodes of the graph until reaching a fixpoint """ from collections import namedtuple, defaultdict +from typing import DefaultDict, List, Union, Set from slither.core.variables.variable import Variable from slither.detectors.abstract_detector import DetectorClassification +from slither.detectors.reentrancy.reentrancy import Reentrancy, to_hashable from slither.slithir.operations import Send, Transfer, EventCall -from .reentrancy import Reentrancy, to_hashable +from slither.slithir.operations.high_level_call import HighLevelCall +from slither.slithir.operations.member import Member +from slither.slithir.operations.return_operation import Return +from slither.utils.output import Output FindingKey = namedtuple("FindingKey", ["function", "calls", "send_eth"]) FindingValue = namedtuple("FindingValue", ["variable", "node", "nodes"]) @@ -50,7 +55,7 @@ Only report reentrancy that is based on `transfer` or `send`.""" WIKI_RECOMMENDATION = "Apply the [`check-effects-interactions` pattern](http://solidity.readthedocs.io/en/v0.4.21/security-considerations.html#re-entrancy)." @staticmethod - def can_callback(ir): + def can_callback(ir: Union[Member, Return, HighLevelCall]) -> bool: """ Same as Reentrancy, but also consider Send and Transfer @@ -59,8 +64,8 @@ Only report reentrancy that is based on `transfer` or `send`.""" STANDARD_JSON = False - def find_reentrancies(self): - result = defaultdict(set) + def find_reentrancies(self) -> DefaultDict[FindingKey, Set[FindingValue]]: + result: DefaultDict[FindingKey, Set[FindingValue]] = defaultdict(set) for contract in self.contracts: for f in contract.functions_and_modifiers_declared: for node in f.nodes: @@ -97,7 +102,7 @@ Only report reentrancy that is based on `transfer` or `send`.""" result[finding_key] |= finding_vars return result - def _detect(self): # pylint: disable=too-many-branches,too-many-locals + def _detect(self) -> List[Output]: # pylint: disable=too-many-branches,too-many-locals """""" super()._detect() diff --git a/slither/detectors/reentrancy/token.py b/slither/detectors/reentrancy/token.py index 9f9ba97f4..c960bffa7 100644 --- a/slither/detectors/reentrancy/token.py +++ b/slither/detectors/reentrancy/token.py @@ -6,6 +6,7 @@ from slither.core.cfg.node import Node from slither.core.declarations import Function, Contract, SolidityVariableComposed from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import LowLevelCall, HighLevelCall +from slither.utils.output import Output def _detect_token_reentrant(contract: Contract) -> Dict[Function, List[Node]]: @@ -82,7 +83,7 @@ contract MyDefi{ If you do, ensure your users are aware of the potential issues.""" # endregion wiki_recommendation - def _detect(self): + def _detect(self) -> List[Output]: results = [] for contract in self.compilation_unit.contracts_derived: vulns = _detect_token_reentrant(contract) diff --git a/slither/detectors/shadowing/builtin_symbols.py b/slither/detectors/shadowing/builtin_symbols.py index 5001a108f..b0a44c8e2 100644 --- a/slither/detectors/shadowing/builtin_symbols.py +++ b/slither/detectors/shadowing/builtin_symbols.py @@ -1,8 +1,16 @@ """ Module detecting reserved keyword shadowing """ - +from typing import List, Tuple, Union, Optional + +from slither.core.declarations import Function, Event +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.core.declarations.modifier import Modifier +from slither.core.variables import Variable +from slither.core.variables.local_variable import LocalVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class BuiltinSymbolShadowing(AbstractDetector): @@ -114,7 +122,7 @@ contract Bug { "unchecked", ] - def is_builtin_symbol(self, word): + def is_builtin_symbol(self, word: Optional[str]) -> bool: """Detects if a given word is a built-in symbol. Returns: @@ -122,7 +130,9 @@ contract Bug { return word in self.BUILTIN_SYMBOLS or word in self.RESERVED_KEYWORDS - def detect_builtin_shadowing_locals(self, function_or_modifier): + def detect_builtin_shadowing_locals( + self, function_or_modifier: Union[Modifier, FunctionContract] + ) -> List[Tuple[str, LocalVariable]]: """Detects if local variables in a given function/modifier are named after built-in symbols. Any such items are returned in a list. @@ -135,14 +145,16 @@ contract Bug { results.append((self.SHADOWING_LOCAL_VARIABLE, local)) return results - def detect_builtin_shadowing_definitions(self, contract): + def detect_builtin_shadowing_definitions( + self, contract: Contract + ) -> List[Tuple[str, Union[Function, Variable, Event]]]: """Detects if functions, access modifiers, events, state variables, or local variables are named after built-in symbols. Any such definitions are returned in a list. Returns: list of tuple: (type, definition, [local variable parent])""" - result = [] + result: List[Tuple[str, Union[Function, Variable, Event]]] = [] # Loop through all functions, modifiers, variables (state and local) to detect any built-in symbol keywords. for function in contract.functions_declared: @@ -164,7 +176,7 @@ contract Bug { return result - def _detect(self): + def _detect(self) -> List[Output]: """Detect shadowing of built-in symbols Recursively visit the calls diff --git a/slither/detectors/shadowing/local.py b/slither/detectors/shadowing/local.py index 617c24be4..ad65b62d9 100644 --- a/slither/detectors/shadowing/local.py +++ b/slither/detectors/shadowing/local.py @@ -1,8 +1,16 @@ """ Module detecting local variable shadowing """ - +from typing import List, Tuple, Union + +from slither.core.declarations.contract import Contract +from slither.core.declarations.event import Event +from slither.core.declarations.function_contract import FunctionContract +from slither.core.declarations.modifier import Modifier +from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.state_variable import StateVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class LocalShadowing(AbstractDetector): @@ -50,13 +58,30 @@ contract Bug { OVERSHADOWED_STATE_VARIABLE = "state variable" OVERSHADOWED_EVENT = "event" - def detect_shadowing_definitions(self, contract): # pylint: disable=too-many-branches + # pylint: disable=too-many-branches + def detect_shadowing_definitions( + self, contract: Contract + ) -> List[ + Union[ + Tuple[LocalVariable, List[Tuple[str, StateVariable]]], + Tuple[LocalVariable, List[Tuple[str, FunctionContract]]], + Tuple[LocalVariable, List[Tuple[str, Modifier]]], + Tuple[LocalVariable, List[Tuple[str, Event]]], + ] + ]: """Detects if functions, access modifiers, events, state variables, and local variables are named after reserved keywords. Any such definitions are returned in a list. Returns: list of tuple: (type, contract name, definition)""" - result = [] + result: List[ + Union[ + Tuple[LocalVariable, List[Tuple[str, StateVariable]]], + Tuple[LocalVariable, List[Tuple[str, FunctionContract]]], + Tuple[LocalVariable, List[Tuple[str, Modifier]]], + Tuple[LocalVariable, List[Tuple[str, Event]]], + ] + ] = [] # Loop through all functions + modifiers in this contract. for function in contract.functions + contract.modifiers: @@ -93,7 +118,7 @@ contract Bug { return result - def _detect(self): + def _detect(self) -> List[Output]: """Detect shadowing local variables Recursively visit the calls diff --git a/slither/detectors/shadowing/state.py b/slither/detectors/shadowing/state.py index 766c2437d..801c370a5 100644 --- a/slither/detectors/shadowing/state.py +++ b/slither/detectors/shadowing/state.py @@ -2,12 +2,16 @@ Module detecting shadowing of state variables """ -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from typing import List + from slither.core.declarations import Contract -from .common import is_upgradable_gap_variable +from slither.core.variables.state_variable import StateVariable +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.shadowing.common import is_upgradable_gap_variable +from slither.utils.output import Output -def detect_shadowing(contract: Contract): +def detect_shadowing(contract: Contract) -> List[List[StateVariable]]: ret = [] variables_fathers = [] for father in contract.inheritance: @@ -70,7 +74,7 @@ contract DerivedContract is BaseContract{ WIKI_RECOMMENDATION = "Remove the state variable shadowing." - def _detect(self): + def _detect(self) -> List[Output]: """Detect shadowing Recursively visit the calls diff --git a/slither/detectors/slither/name_reused.py b/slither/detectors/slither/name_reused.py index 2cd10ed31..f6f2820fa 100644 --- a/slither/detectors/slither/name_reused.py +++ b/slither/detectors/slither/name_reused.py @@ -1,10 +1,12 @@ from collections import defaultdict +from typing import Any, List from slither.core.compilation_unit import SlitherCompilationUnit from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output -def _find_missing_inheritance(compilation_unit: SlitherCompilationUnit): +def _find_missing_inheritance(compilation_unit: SlitherCompilationUnit) -> List[Any]: """ Filter contracts with missing inheritance to return only the "most base" contracts in the inheritance tree. @@ -50,7 +52,8 @@ As a result, the second contract cannot be analyzed. WIKI_RECOMMENDATION = "Rename the contract." - def _detect(self): # pylint: disable=too-many-locals,too-many-branches + # pylint: disable=too-many-locals,too-many-branches + def _detect(self) -> List[Output]: results = [] compilation_unit = self.compilation_unit diff --git a/slither/detectors/source/rtlo.py b/slither/detectors/source/rtlo.py index ed73fdd4d..f89eb70eb 100644 --- a/slither/detectors/source/rtlo.py +++ b/slither/detectors/source/rtlo.py @@ -1,5 +1,9 @@ import re +from typing import List + from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output + # pylint: disable=bidirectional-unicode class RightToLeftOverride(AbstractDetector): @@ -52,7 +56,7 @@ contract Token RTLO_CHARACTER_ENCODED = "\u202e".encode("utf-8") STANDARD_JSON = False - def _detect(self): + def _detect(self) -> List[Output]: results = [] pattern = re.compile(".*\u202e.*".encode("utf-8")) diff --git a/slither/detectors/statements/array_length_assignment.py b/slither/detectors/statements/array_length_assignment.py index 7f875fa9e..51302a2c9 100644 --- a/slither/detectors/statements/array_length_assignment.py +++ b/slither/detectors/statements/array_length_assignment.py @@ -1,21 +1,23 @@ """ Module detecting assignment of array length """ - +from typing import List, Set from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_05, ) -from slither.core.cfg.node import NodeType +from slither.core.cfg.node import Node, NodeType from slither.slithir.operations import Assignment, Length from slither.slithir.variables.reference import ReferenceVariable from slither.slithir.operations.binary import Binary from slither.analyses.data_dependency.data_dependency import is_tainted +from slither.core.declarations.contract import Contract +from slither.utils.output import Output -def detect_array_length_assignment(contract): +def detect_array_length_assignment(contract: Contract) -> Set[Node]: """ Detects and returns all nodes which assign array length. :param contract: Contract to detect assignment within. @@ -110,7 +112,7 @@ Otherwise, thoroughly review the contract to ensure a user-controlled variable c VULNERABLE_SOLC_VERSIONS = ALL_SOLC_VERSIONS_04 + ALL_SOLC_VERSIONS_05 - def _detect(self): + def _detect(self) -> List[Output]: """ Detect array length assignments """ diff --git a/slither/detectors/statements/assembly.py b/slither/detectors/statements/assembly.py index 3a554e380..2c0c49f09 100644 --- a/slither/detectors/statements/assembly.py +++ b/slither/detectors/statements/assembly.py @@ -1,9 +1,13 @@ """ Module detecting usage of inline assembly """ +from typing import List, Tuple +from slither.core.cfg.node import Node, NodeType +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.core.cfg.node import NodeType +from slither.utils.output import Output class Assembly(AbstractDetector): @@ -23,7 +27,7 @@ class Assembly(AbstractDetector): WIKI_RECOMMENDATION = "Do not use `evm` assembly." @staticmethod - def _contains_inline_assembly_use(node): + def _contains_inline_assembly_use(node: Node) -> bool: """ Check if the node contains ASSEMBLY type Returns: @@ -31,7 +35,7 @@ class Assembly(AbstractDetector): """ return node.type == NodeType.ASSEMBLY - def detect_assembly(self, contract): + def detect_assembly(self, contract: Contract) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in contract.functions: if f.contract_declarer != contract: @@ -42,7 +46,7 @@ class Assembly(AbstractDetector): ret.append((f, assembly_nodes)) return ret - def _detect(self): + def _detect(self) -> List[Output]: """Detect the functions that use inline assembly""" results = [] for c in self.contracts: diff --git a/slither/detectors/statements/assert_state_change.py b/slither/detectors/statements/assert_state_change.py index 126f4b64d..c82919de6 100644 --- a/slither/detectors/statements/assert_state_change.py +++ b/slither/detectors/statements/assert_state_change.py @@ -1,11 +1,19 @@ """ Module detecting state changes in assert calls """ +from typing import List, Tuple + +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations.internal_call import InternalCall +from slither.utils.output import Output -def detect_assert_state_change(contract): +def detect_assert_state_change( + contract: Contract, +) -> List[Tuple[FunctionContract, Node]]: """ Detects and returns all nodes with assert calls that change contract state from within the invariant :param contract: Contract to detect @@ -69,7 +77,7 @@ The assert in `bad()` increments the state variable `s_a` while checking for the WIKI_RECOMMENDATION = """Use `require` for invariants modifying the state.""" - def _detect(self): + def _detect(self) -> List[Output]: """ Detect assert calls that change state from within the invariant """ diff --git a/slither/detectors/statements/boolean_constant_equality.py b/slither/detectors/statements/boolean_constant_equality.py index eddea6236..5b91f364f 100644 --- a/slither/detectors/statements/boolean_constant_equality.py +++ b/slither/detectors/statements/boolean_constant_equality.py @@ -1,13 +1,18 @@ """ Module detecting misuse of Boolean constants """ +from typing import List, Set, Tuple +from slither.core.cfg.node import Node +from slither.core.declarations import Function +from slither.core.declarations.contract import Contract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import ( Binary, BinaryType, ) from slither.slithir.variables import Constant +from slither.utils.output import Output class BooleanEquality(AbstractDetector): @@ -44,10 +49,12 @@ Boolean constants can be used directly and do not need to be compare to `true` o WIKI_RECOMMENDATION = """Remove the equality to the boolean constant.""" @staticmethod - def _detect_boolean_equality(contract): + def _detect_boolean_equality( + contract: Contract, + ) -> List[Tuple[Function, Set[Node]]]: # Create our result set. - results = [] + results: List[Tuple[Function, Set[Node]]] = [] # Loop for each function and modifier. # pylint: disable=too-many-nested-blocks @@ -68,7 +75,7 @@ Boolean constants can be used directly and do not need to be compare to `true` o # Return the resulting set of nodes with improper uses of Boolean constants return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect Boolean constant misuses """ diff --git a/slither/detectors/statements/boolean_constant_misuse.py b/slither/detectors/statements/boolean_constant_misuse.py index 4e7a0d69d..96dd2012f 100644 --- a/slither/detectors/statements/boolean_constant_misuse.py +++ b/slither/detectors/statements/boolean_constant_misuse.py @@ -1,7 +1,11 @@ """ Module detecting misuse of Boolean constants """ -from slither.core.cfg.node import NodeType +from typing import List, Set, Tuple + +from slither.core.cfg.node import Node, NodeType +from slither.core.declarations import Function +from slither.core.declarations.contract import Contract from slither.core.solidity_types import ElementaryType from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import ( @@ -14,6 +18,7 @@ from slither.slithir.operations import ( Condition, ) from slither.slithir.variables import Constant +from slither.utils.output import Output class BooleanConstantMisuse(AbstractDetector): @@ -59,7 +64,9 @@ Other uses (in complex expressions, as conditionals) indicate either an error or WIKI_RECOMMENDATION = """Verify and simplify the condition.""" @staticmethod - def _detect_boolean_constant_misuses(contract): # pylint: disable=too-many-branches + def _detect_boolean_constant_misuses( + contract: Contract, + ) -> List[Tuple[Function, Set[Node]]]: # pylint: disable=too-many-branches """ Detects and returns all nodes which misuse a Boolean constant. :param contract: Contract to detect assignment within. @@ -67,7 +74,7 @@ Other uses (in complex expressions, as conditionals) indicate either an error or """ # Create our result set. - results = [] + results: List[Tuple[Function, Set[Node]]] = [] # Loop for each function and modifier. for function in contract.functions_declared: @@ -104,7 +111,7 @@ Other uses (in complex expressions, as conditionals) indicate either an error or # Return the resulting set of nodes with improper uses of Boolean constants return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect Boolean constant misuses """ diff --git a/slither/detectors/statements/controlled_delegatecall.py b/slither/detectors/statements/controlled_delegatecall.py index eeac55925..08280940d 100644 --- a/slither/detectors/statements/controlled_delegatecall.py +++ b/slither/detectors/statements/controlled_delegatecall.py @@ -1,9 +1,14 @@ +from typing import List + +from slither.analyses.data_dependency.data_dependency import is_tainted +from slither.core.cfg.node import Node +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import LowLevelCall -from slither.analyses.data_dependency.data_dependency import is_tainted +from slither.utils.output import Output -def controlled_delegatecall(function): +def controlled_delegatecall(function: FunctionContract) -> List[Node]: ret = [] for node in function.nodes: for ir in node.irs: @@ -42,7 +47,7 @@ Bob calls `delegate` and delegates the execution to his malicious contract. As a WIKI_RECOMMENDATION = "Avoid using `delegatecall`. Use only trusted destinations." - def _detect(self): + def _detect(self) -> List[Output]: results = [] for contract in self.compilation_unit.contracts_derived: diff --git a/slither/detectors/statements/deprecated_calls.py b/slither/detectors/statements/deprecated_calls.py index f8fc50d32..3d0ca4ba9 100644 --- a/slither/detectors/statements/deprecated_calls.py +++ b/slither/detectors/statements/deprecated_calls.py @@ -1,14 +1,19 @@ """ Module detecting deprecated standards. """ +from typing import List, Tuple, Union -from slither.core.cfg.node import NodeType +from slither.core.cfg.node import Node, NodeType +from slither.core.declarations.contract import Contract from slither.core.declarations.solidity_variables import ( SolidityVariableComposed, SolidityFunction, ) +from slither.core.expressions.expression import Expression +from slither.core.variables import StateVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import LowLevelCall +from slither.utils.output import Output from slither.visitors.expression.export_values import ExportValues @@ -73,7 +78,9 @@ contract ContractWithDeprecatedReferences { DEPRECATED_NODE_TYPES = [(NodeType.THROW, "throw", "revert()")] DEPRECATED_LOW_LEVEL_CALLS = [("callcode", "callcode", "delegatecall")] - def detect_deprecation_in_expression(self, expression): + def detect_deprecation_in_expression( + self, expression: Expression + ) -> List[Tuple[str, str, str]]: """Detects if an expression makes use of any deprecated standards. Returns: @@ -95,13 +102,15 @@ contract ContractWithDeprecatedReferences { return results - def detect_deprecated_references_in_node(self, node): + def detect_deprecated_references_in_node( + self, node: Node + ) -> List[Tuple[Union[str, NodeType], str, str]]: """Detects if a node makes use of any deprecated standards. Returns: list of tuple: (detecting_signature, original_text, recommended_text)""" # Define our results list - results = [] + results: List[Tuple[Union[str, NodeType], str, str]] = [] # If this node has an expression, we check the underlying expression. if node.expression: @@ -114,12 +123,24 @@ contract ContractWithDeprecatedReferences { return results - def detect_deprecated_references_in_contract(self, contract): + def detect_deprecated_references_in_contract( + self, contract: Contract + ) -> List[ + Union[ + Tuple[StateVariable, List[Tuple[str, str, str]]], + Tuple[Node, List[Tuple[Union[str, NodeType], str, str]]], + ] + ]: """Detects the usage of any deprecated built-in symbols. Returns: list of tuple: (state_variable | node, (detecting_signature, original_text, recommended_text))""" - results = [] + results: List[ + Union[ + Tuple[StateVariable, List[Tuple[str, str, str]]], + Tuple[Node, List[Tuple[Union[str, NodeType], str, str]]], + ] + ] = [] for state_variable in contract.state_variables_declared: if state_variable.expression: @@ -135,22 +156,22 @@ contract ContractWithDeprecatedReferences { # Loop through each node in this function. for node in function.nodes: # Detect deprecated references in the node. - deprecated_results = self.detect_deprecated_references_in_node(node) + deprecated_results_node = self.detect_deprecated_references_in_node(node) # Detect additional deprecated low-level-calls. for ir in node.irs: if isinstance(ir, LowLevelCall): for dep_llc in self.DEPRECATED_LOW_LEVEL_CALLS: if ir.function_name == dep_llc[0]: - deprecated_results.append(dep_llc) + deprecated_results_node.append(dep_llc) # If we have any results from this iteration, add them to our results list. - if deprecated_results: - results.append((node, deprecated_results)) + if deprecated_results_node: + results.append((node, deprecated_results_node)) return results - def _detect(self): + def _detect(self) -> List[Output]: """Detects if an expression makes use of any deprecated standards. Recursively visit the calls diff --git a/slither/detectors/statements/divide_before_multiply.py b/slither/detectors/statements/divide_before_multiply.py index 1b7c72197..a9de76b40 100644 --- a/slither/detectors/statements/divide_before_multiply.py +++ b/slither/detectors/statements/divide_before_multiply.py @@ -2,12 +2,18 @@ Module detecting possible loss of precision due to divide before multiple """ from collections import defaultdict +from typing import Any, DefaultDict, List, Set, Tuple + +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.slithir.operations import Binary, Assignment, BinaryType, LibraryCall +from slither.slithir.operations import Binary, Assignment, BinaryType, LibraryCall, Operation from slither.slithir.variables import Constant +from slither.utils.output import Output -def is_division(ir): +def is_division(ir: Operation) -> bool: if isinstance(ir, Binary): if ir.type == BinaryType.DIVISION: return True @@ -23,7 +29,7 @@ def is_division(ir): return False -def is_multiplication(ir): +def is_multiplication(ir: Operation) -> bool: if isinstance(ir, Binary): if ir.type == BinaryType.MULTIPLICATION: return True @@ -39,7 +45,7 @@ def is_multiplication(ir): return False -def is_assert(node): +def is_assert(node: Node) -> bool: if node.contains_require_or_assert(): return True # Old Solidity code where using an internal 'assert(bool)' function @@ -50,7 +56,10 @@ def is_assert(node): return False -def _explore(to_explore, f_results, divisions): # pylint: disable=too-many-branches +# pylint: disable=too-many-branches +def _explore( + to_explore: Set[Node], f_results: List[Node], divisions: DefaultDict[Any, Any] +) -> None: explored = set() while to_explore: # pylint: disable=too-many-nested-blocks node = to_explore.pop() @@ -103,7 +112,9 @@ def _explore(to_explore, f_results, divisions): # pylint: disable=too-many-bran to_explore.add(son) -def detect_divide_before_multiply(contract): +def detect_divide_before_multiply( + contract: Contract, +) -> List[Tuple[FunctionContract, List[Node]]]: """ Detects and returns all nodes with multiplications of division results. :param contract: Contract to detect assignment within. @@ -169,7 +180,7 @@ In general, it's usually a good idea to re-arrange arithmetic to perform multipl WIKI_RECOMMENDATION = """Consider ordering multiplication before division.""" - def _detect(self): + def _detect(self) -> List[Output]: """ Detect divisions before multiplications """ diff --git a/slither/detectors/statements/incorrect_strict_equality.py b/slither/detectors/statements/incorrect_strict_equality.py index 9a2234f4f..bc7b0cebe 100644 --- a/slither/detectors/statements/incorrect_strict_equality.py +++ b/slither/detectors/statements/incorrect_strict_equality.py @@ -2,7 +2,7 @@ Module detecting dangerous strict equality """ - +from typing import Any, Dict, List, Union from slither.analyses.data_dependency.data_dependency import is_dependent_ssa from slither.core.declarations import Function from slither.core.declarations.function_top_level import FunctionTopLevel @@ -23,6 +23,14 @@ from slither.core.declarations.solidity_variables import ( SolidityVariableComposed, SolidityFunction, ) +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract +from slither.slithir.operations.operation import Operation +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.utils.output import Output class IncorrectStrictEquality(AbstractDetector): @@ -61,11 +69,25 @@ contract Crowdsale{ ] @staticmethod - def is_direct_comparison(ir): + def is_direct_comparison(ir: Operation) -> bool: return isinstance(ir, Binary) and ir.type == BinaryType.EQUAL @staticmethod - def is_any_tainted(variables, taints, function) -> bool: + def is_any_tainted( + variables: List[ + Union[ + Constant, + LocalIRVariable, + TemporaryVariableSSA, + SolidityVariableComposed, + SolidityVariable, + ] + ], + taints: List[ + Union[LocalIRVariable, SolidityVariable, SolidityVariableComposed, TemporaryVariableSSA] + ], + function: FunctionContract, + ) -> bool: return any( ( is_dependent_ssa(var, taint, function.contract) @@ -74,7 +96,9 @@ contract Crowdsale{ ) ) - def taint_balance_equalities(self, functions): + def taint_balance_equalities( + self, functions: List[Union[FunctionContract, Any]] + ) -> List[Union[LocalIRVariable, TemporaryVariableSSA, Any]]: taints = [] for func in functions: for node in func.nodes: @@ -105,7 +129,11 @@ contract Crowdsale{ return taints # Retrieve all tainted (node, function) pairs - def tainted_equality_nodes(self, funcs, taints): + def tainted_equality_nodes( + self, + funcs: List[Union[FunctionContract, Any]], + taints: List[Union[LocalIRVariable, TemporaryVariableSSA, Any]], + ) -> Dict[FunctionContract, List[Node]]: results = {} taints += self.sources_taint @@ -124,7 +152,7 @@ contract Crowdsale{ return results - def detect_strict_equality(self, contract): + def detect_strict_equality(self, contract: Contract) -> Dict[FunctionContract, List[Node]]: funcs = contract.all_functions_called + contract.modifiers # Taint all BALANCE accesses @@ -135,7 +163,7 @@ contract Crowdsale{ return results - def _detect(self): + def _detect(self) -> List[Output]: results = [] for c in self.compilation_unit.contracts_derived: diff --git a/slither/detectors/statements/mapping_deletion.py b/slither/detectors/statements/mapping_deletion.py index 2515eacf6..59882cc96 100644 --- a/slither/detectors/statements/mapping_deletion.py +++ b/slither/detectors/statements/mapping_deletion.py @@ -1,11 +1,16 @@ """ Detect deletion on structure containing a mapping """ +from typing import List, Tuple +from slither.core.cfg.node import Node from slither.core.declarations import Structure +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types import MappingType, UserDefinedType from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import Delete +from slither.utils.output import Output class MappingDeletionDetection(AbstractDetector): @@ -45,13 +50,15 @@ The mapping `balances` is never deleted, so `remove` does not work as intended." ) @staticmethod - def detect_mapping_deletion(contract): + def detect_mapping_deletion( + contract: Contract, + ) -> List[Tuple[FunctionContract, Structure, Node]]: """Detect deletion on structure containing a mapping Returns: list (function, structure, node) """ - ret = [] + ret: List[Tuple[FunctionContract, Structure, Node]] = [] # pylint: disable=too-many-nested-blocks for f in contract.functions: for node in f.nodes: @@ -66,7 +73,7 @@ The mapping `balances` is never deleted, so `remove` does not work as intended." ret.append((f, st, node)) return ret - def _detect(self): + def _detect(self) -> List[Output]: """Detect mapping deletion Returns: diff --git a/slither/detectors/statements/redundant_statements.py b/slither/detectors/statements/redundant_statements.py index 023d326ec..7e7223134 100644 --- a/slither/detectors/statements/redundant_statements.py +++ b/slither/detectors/statements/redundant_statements.py @@ -1,11 +1,14 @@ """ Module detecting redundant statements. """ +from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.core.cfg.node import NodeType +from slither.core.cfg.node import Node, NodeType +from slither.core.declarations.contract import Contract from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.expressions.identifier import Identifier +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class RedundantStatements(AbstractDetector): @@ -50,7 +53,7 @@ Each commented line references types/identifiers, but performs no action with th # This is a disallowed list of tuple (node.type, type(node.expression)) REDUNDANT_TOP_LEVEL_EXPRESSIONS = (ElementaryTypeNameExpression, Identifier) - def detect_redundant_statements_contract(self, contract): + def detect_redundant_statements_contract(self, contract: Contract) -> List[Node]: """Detects the usage of redundant statements in a contract. Returns: @@ -70,7 +73,7 @@ Each commented line references types/identifiers, but performs no action with th return results - def _detect(self): + def _detect(self) -> List[Output]: """Detect redundant statements Recursively visit the calls diff --git a/slither/detectors/statements/too_many_digits.py b/slither/detectors/statements/too_many_digits.py index 6e4cc4cd9..239efa4be 100644 --- a/slither/detectors/statements/too_many_digits.py +++ b/slither/detectors/statements/too_many_digits.py @@ -3,13 +3,18 @@ Module detecting numbers with too many digits. """ import re +from typing import List + +from slither.core.cfg.node import Node +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.variables import Constant +from slither.utils.output import Output _HEX_ADDRESS_REGEXP = re.compile("(0[xX])?[0-9a-fA-F]{40}") -def is_hex_address(value) -> bool: +def is_hex_address(value: str) -> bool: """ Checks if the given string of text type is an address in hexadecimal encoded form. """ @@ -57,7 +62,7 @@ Use: # endregion wiki_recommendation @staticmethod - def _detect_too_many_digits(f): + def _detect_too_many_digits(f: FunctionContract) -> List[Node]: ret = [] for node in f.nodes: # each node contains a list of IR instruction @@ -73,7 +78,7 @@ Use: ret.append(node) return ret - def _detect(self): + def _detect(self) -> List[Output]: results = [] # iterate over all contracts diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py index 485eeaf13..34f8173d5 100644 --- a/slither/detectors/statements/tx_origin.py +++ b/slither/detectors/statements/tx_origin.py @@ -1,8 +1,13 @@ """ Module detecting usage of `tx.origin` in a conditional node """ +from typing import List, Tuple +from slither.core.cfg.node import Node +from slither.core.declarations.contract import Contract +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class TxOrigin(AbstractDetector): @@ -38,7 +43,7 @@ Bob is the owner of `TxOrigin`. Bob calls Eve's contract. Eve's contract calls ` WIKI_RECOMMENDATION = "Do not use `tx.origin` for authorization." @staticmethod - def _contains_incorrect_tx_origin_use(node): + def _contains_incorrect_tx_origin_use(node: Node) -> bool: """ Check if the node reads tx.origin and doesn't read msg.sender Avoid the FP due to (msg.sender == tx.origin) @@ -52,7 +57,7 @@ Bob is the owner of `TxOrigin`. Bob calls Eve's contract. Eve's contract calls ` ) return False - def detect_tx_origin(self, contract): + def detect_tx_origin(self, contract: Contract) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in contract.functions: @@ -67,7 +72,7 @@ Bob is the owner of `TxOrigin`. Bob calls Eve's contract. Eve's contract calls ` ret.append((f, bad_tx_nodes)) return ret - def _detect(self): + def _detect(self) -> List[Output]: """Detect the functions that use tx.origin in a conditional node""" results = [] for c in self.contracts: diff --git a/slither/detectors/statements/type_based_tautology.py b/slither/detectors/statements/type_based_tautology.py index 0129ad03f..9edb1f53e 100644 --- a/slither/detectors/statements/type_based_tautology.py +++ b/slither/detectors/statements/type_based_tautology.py @@ -1,14 +1,19 @@ """ Module detecting tautologies and contradictions based on types in comparison operations over integers """ +from typing import List, Set, Tuple +from slither.core.cfg.node import Node +from slither.core.declarations import Function +from slither.core.declarations.contract import Contract +from slither.core.solidity_types.elementary_type import Int, Uint from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import Binary, BinaryType from slither.slithir.variables import Constant -from slither.core.solidity_types.elementary_type import Int, Uint +from slither.utils.output import Output -def typeRange(t): +def typeRange(t: str) -> Tuple[int, int]: bits = int(t.split("int")[1]) if t in Uint: return 0, (2**bits) - 1 @@ -18,7 +23,7 @@ def typeRange(t): return None -def _detect_tautology_or_contradiction(low, high, cval, op): +def _detect_tautology_or_contradiction(low: int, high: int, cval: int, op: BinaryType) -> bool: """ Return true if "[low high] op cval " is always true or always false :param low: @@ -110,7 +115,7 @@ contract A { BinaryType.LESS_EQUAL: BinaryType.GREATER_EQUAL, } - def detect_type_based_tautologies(self, contract): + def detect_type_based_tautologies(self, contract: Contract) -> List[Tuple[Function, Set[Node]]]: """ Detects and returns all nodes with tautology/contradiction comparisons (based on type alone). :param contract: Contract to detect assignment within. @@ -118,7 +123,7 @@ contract A { """ # Create our result set. - results = [] + results: List[Tuple[Function, Set[Node]]] = [] allInts = Int + Uint # Loop for each function and modifier. @@ -151,7 +156,7 @@ contract A { # Return the resulting set of nodes with tautologies and contradictions return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect tautological (or contradictory) comparisons """ diff --git a/slither/detectors/statements/unary.py b/slither/detectors/statements/unary.py index 019c80e29..5bb8d9c3c 100644 --- a/slither/detectors/statements/unary.py +++ b/slither/detectors/statements/unary.py @@ -1,14 +1,17 @@ """ Module detecting the incorrect use of unary expressions """ +from typing import List +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.unary_operation import UnaryOperationType, UnaryOperation from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.unary_operation import UnaryOperationType, UnaryOperation class InvalidUnaryExpressionDetector(ExpressionVisitor): - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: if isinstance(expression.expression_right, UnaryOperation): if expression.expression_right.type == UnaryOperationType.PLUS_PRE: # This is defined in ExpressionVisitor but pylint @@ -18,7 +21,7 @@ class InvalidUnaryExpressionDetector(ExpressionVisitor): class InvalidUnaryStateVariableDetector(ExpressionVisitor): - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: if expression.type == UnaryOperationType.PLUS_PRE: # This is defined in ExpressionVisitor but pylint # Seems to think its not @@ -60,7 +63,7 @@ contract Bug{ WIKI_RECOMMENDATION = "Remove the unary expression." - def _detect(self): + def _detect(self) -> List[Output]: """ Detect the incorrect use of unary expressions """ diff --git a/slither/detectors/statements/unprotected_upgradeable.py b/slither/detectors/statements/unprotected_upgradeable.py index 25be6a5ae..1adf49540 100644 --- a/slither/detectors/statements/unprotected_upgradeable.py +++ b/slither/detectors/statements/unprotected_upgradeable.py @@ -4,6 +4,7 @@ from slither.core.declarations import SolidityFunction, Function from slither.core.declarations.contract import Contract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import LowLevelCall, SolidityCall +from slither.utils.output import Output def _can_be_destroyed(contract: Contract) -> List[Function]: @@ -87,7 +88,7 @@ Buggy is an upgradeable contract. Anyone can call initialize on the logic contra """Add a constructor to ensure `initialize` cannot be called on the logic contract.""" ) - def _detect(self): + def _detect(self) -> List[Output]: results = [] for contract in self.compilation_unit.contracts_derived: diff --git a/slither/detectors/statements/write_after_write.py b/slither/detectors/statements/write_after_write.py index ea654f79d..5b2e29925 100644 --- a/slither/detectors/statements/write_after_write.py +++ b/slither/detectors/statements/write_after_write.py @@ -15,9 +15,10 @@ from slither.slithir.operations import ( ) from slither.slithir.variables import ReferenceVariable, TemporaryVariable, TupleVariable from slither.slithir.variables.variable import SlithIRVariable +from slither.utils.output import Output -def _remove_states(written: Dict[Variable, Node]): +def _remove_states(written: Dict[Variable, Node]) -> None: for key in list(written.keys()): if isinstance(key, StateVariable): del written[key] @@ -27,7 +28,7 @@ def _handle_ir( ir: Operation, written: Dict[Variable, Node], ret: List[Tuple[Variable, Node, Node]], -): +) -> None: if isinstance(ir, (HighLevelCall, InternalDynamicCall, LowLevelCall)): _remove_states(written) @@ -73,7 +74,7 @@ def _detect_write_after_write( explored: Set[Node], written: Dict[Variable, Node], ret: List[Tuple[Variable, Node, Node]], -): +) -> None: if node in explored: return @@ -121,7 +122,7 @@ class WriteAfterWrite(AbstractDetector): WIKI_RECOMMENDATION = """Fix or remove the writes.""" - def _detect(self): + def _detect(self) -> List[Output]: results = [] for contract in self.compilation_unit.contracts_derived: diff --git a/slither/detectors/variables/function_init_state_variables.py b/slither/detectors/variables/function_init_state_variables.py index 081eb9f8b..e35cfe351 100644 --- a/slither/detectors/variables/function_init_state_variables.py +++ b/slither/detectors/variables/function_init_state_variables.py @@ -1,14 +1,17 @@ """ Module detecting state variables initializing from an immediate function call (prior to constructor run). """ +from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.visitors.expression.export_values import ExportValues +from slither.core.declarations.contract import Contract from slither.core.declarations.function import Function from slither.core.variables.state_variable import StateVariable +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output +from slither.visitors.expression.export_values import ExportValues -def detect_function_init_state_vars(contract): +def detect_function_init_state_vars(contract: Contract) -> List[StateVariable]: """ Detect any state variables that are initialized from an immediate function call (prior to constructor run). :param contract: The contract to detect state variable definitions for. @@ -87,7 +90,7 @@ Special care must be taken when initializing state variables from an immediate f WIKI_RECOMMENDATION = "Remove any initialization of state variables via non-constant state variables or function calls. If variables must be set upon contract deployment, locate initialization in the constructor instead." - def _detect(self): + def _detect(self) -> List[Output]: """ Detect state variables defined from an immediate function call (pre-contract deployment). diff --git a/slither/detectors/variables/predeclaration_usage_local.py b/slither/detectors/variables/predeclaration_usage_local.py index 8e36b19a3..2ba539a91 100644 --- a/slither/detectors/variables/predeclaration_usage_local.py +++ b/slither/detectors/variables/predeclaration_usage_local.py @@ -1,8 +1,14 @@ """ Module detecting any path leading to usage of a local variable before it is declared. """ +from typing import List, Set, Tuple +from slither.core.cfg.node import Node +from slither.core.declarations import Function +from slither.core.declarations.contract import Contract +from slither.core.variables.local_variable import LocalVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class PredeclarationUsageLocal(AbstractDetector): @@ -48,7 +54,13 @@ Additionally, the for-loop uses the variable `max`, which is declared in a previ WIKI_RECOMMENDATION = "Move all variable declarations prior to any usage of the variable, and ensure that reaching a variable declaration does not depend on some conditional if it is used unconditionally." - def detect_predeclared_local_usage(self, node, results, already_declared, visited): + def detect_predeclared_local_usage( + self, + node: Node, + results: List[Tuple[Node, LocalVariable]], + already_declared: Set[LocalVariable], + visited: Set[Node], + ) -> None: """ Detects if a given node uses a variable prior to declaration in any code path. :param node: The node to initiate the scan from (searches recursively through all sons) @@ -87,7 +99,9 @@ Additionally, the for-loop uses the variable `max`, which is declared in a previ for son in node.sons: self.detect_predeclared_local_usage(son, results, already_declared, visited) - def detect_predeclared_in_contract(self, contract): + def detect_predeclared_in_contract( + self, contract: Contract + ) -> List[Tuple[Function, List[Tuple[Node, LocalVariable]]]]: """ Detects and returns all nodes in a contract which use a variable before it is declared. :param contract: Contract to detect pre-declaration usage of locals within. @@ -95,11 +109,11 @@ Additionally, the for-loop uses the variable `max`, which is declared in a previ """ # Create our result set. - results = [] + results: List[Tuple[Function, List[Tuple[Node, LocalVariable]]]] = [] # Loop for each function and modifier's nodes and analyze for predeclared local variable usage. for function in contract.functions_and_modifiers_declared: - predeclared_usage = [] + predeclared_usage: List[Tuple[Node, LocalVariable]] = [] if function.nodes: self.detect_predeclared_local_usage( function.nodes[0], @@ -113,7 +127,7 @@ Additionally, the for-loop uses the variable `max`, which is declared in a previ # Return the resulting set of nodes which set array length. return results - def _detect(self): + def _detect(self) -> List[Output]: """ Detect usage of a local variable before it is declared. """ diff --git a/slither/detectors/variables/similar_variables.py b/slither/detectors/variables/similar_variables.py index bab2d0acc..d0a15aaab 100644 --- a/slither/detectors/variables/similar_variables.py +++ b/slither/detectors/variables/similar_variables.py @@ -3,8 +3,12 @@ Check for state variables too similar Do not check contract inheritance """ import difflib +from typing import List, Set, Tuple +from slither.core.declarations.contract import Contract +from slither.core.variables.local_variable import LocalVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class SimilarVarsDetection(AbstractDetector): @@ -27,7 +31,7 @@ class SimilarVarsDetection(AbstractDetector): WIKI_RECOMMENDATION = "Prevent variables from having similar names." @staticmethod - def similar(seq1, seq2): + def similar(seq1: str, seq2: str) -> bool: """Test the name similarity Two name are similar if difflib.SequenceMatcher on the lowercase @@ -46,7 +50,7 @@ class SimilarVarsDetection(AbstractDetector): return ret @staticmethod - def detect_sim(contract): + def detect_sim(contract: Contract) -> Set[Tuple[LocalVariable, LocalVariable]]: """Detect variables with similar name Returns: @@ -69,7 +73,7 @@ class SimilarVarsDetection(AbstractDetector): return set(ret) - def _detect(self): + def _detect(self) -> List[Output]: """Detect similar variables name Returns: diff --git a/slither/detectors/variables/unchanged_state_variables.py b/slither/detectors/variables/unchanged_state_variables.py index 0c6af80f7..f12cc5784 100644 --- a/slither/detectors/variables/unchanged_state_variables.py +++ b/slither/detectors/variables/unchanged_state_variables.py @@ -69,7 +69,7 @@ class UnchangedStateVariables: Find state variables that could be declared as constant or immutable (not written after deployment). """ - def __init__(self, compilation_unit: SlitherCompilationUnit): + def __init__(self, compilation_unit: SlitherCompilationUnit) -> None: self.compilation_unit = compilation_unit self._constant_candidates: List[StateVariable] = [] self._immutable_candidates: List[StateVariable] = [] @@ -84,7 +84,7 @@ class UnchangedStateVariables: """Return the constant candidates""" return self._constant_candidates - def detect(self): + def detect(self) -> None: """Detect state variables that could be constant or immutable""" for c in self.compilation_unit.contracts_derived: variables = [] diff --git a/slither/detectors/variables/uninitialized_local_variables.py b/slither/detectors/variables/uninitialized_local_variables.py index 7f7cb76e0..759691d50 100644 --- a/slither/detectors/variables/uninitialized_local_variables.py +++ b/slither/detectors/variables/uninitialized_local_variables.py @@ -4,8 +4,12 @@ Recursively explore the CFG to only report uninitialized local variables that are read before being written """ +from typing import List +from slither.core.cfg.node import Node +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class UninitializedLocalVars(AbstractDetector): @@ -37,7 +41,9 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is key = "UNINITIALIZEDLOCAL" - def _detect_uninitialized(self, function, node, visited): + def _detect_uninitialized( + self, function: FunctionContract, node: Node, visited: List[Node] + ) -> None: if node in visited: return @@ -73,7 +79,7 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is for son in node.sons: self._detect_uninitialized(function, son, visited) - def _detect(self): + def _detect(self) -> List[Output]: """Detect uninitialized local variables Recursively visit the calls diff --git a/slither/detectors/variables/uninitialized_state_variables.py b/slither/detectors/variables/uninitialized_state_variables.py index baf1b2218..0fbb73b5d 100644 --- a/slither/detectors/variables/uninitialized_state_variables.py +++ b/slither/detectors/variables/uninitialized_state_variables.py @@ -8,10 +8,16 @@ Only analyze "leaf" contracts (contracts that are not inherited by another contract) """ +from typing import List, Tuple +from slither.core.declarations import Function +from slither.core.declarations.contract import Contract +from slither.core.variables import Variable +from slither.core.variables.state_variable import StateVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import InternalCall, LibraryCall from slither.slithir.variables import ReferenceVariable +from slither.utils.output import Output class UninitializedStateVarsDetection(AbstractDetector): @@ -51,7 +57,7 @@ Initialize all the variables. If a variable is meant to be initialized to zero, # endregion wiki_recommendation @staticmethod - def _written_variables(contract): + def _written_variables(contract: Contract) -> List[StateVariable]: ret = [] # pylint: disable=too-many-nested-blocks for f in contract.all_functions_called + contract.modifiers: @@ -88,8 +94,8 @@ Initialize all the variables. If a variable is meant to be initialized to zero, self.__variables_written_in_proxy = list({v.name for v in variables_written_in_proxy}) return self.__variables_written_in_proxy - def _written_variables_in_proxy(self, contract): - variables = [] + def _written_variables_in_proxy(self, contract: Contract) -> List[StateVariable]: + variables: List[StateVariable] = [] if contract.is_upgradeable: variables_name_written_in_proxy = self._variable_written_in_proxy() if variables_name_written_in_proxy: @@ -97,18 +103,20 @@ Initialize all the variables. If a variable is meant to be initialized to zero, contract.get_state_variable_from_name(v) for v in variables_name_written_in_proxy ] - variables_in_contract = [v for v in variables_in_contract if v] - variables += variables_in_contract + variables += [v for v in variables_in_contract if v] return list(set(variables)) @staticmethod - def _read_variables(contract): + def _read_variables(contract: Contract) -> List[StateVariable]: ret = [] - for f in contract.all_functions_called + contract.modifiers: - ret += f.state_variables_read + for f in contract.all_functions_called: + if isinstance(f, Function): + ret += f.state_variables_read + for m in contract.modifiers: + ret += m.state_variables_read return ret - def _detect_uninitialized(self, contract): + def _detect_uninitialized(self, contract: Contract) -> List[Tuple[Variable, List[Function]]]: written_variables = self._written_variables(contract) written_variables += self._written_variables_in_proxy(contract) read_variables = self._read_variables(contract) @@ -120,7 +128,7 @@ Initialize all the variables. If a variable is meant to be initialized to zero, and variable in read_variables ] - def _detect(self): + def _detect(self) -> List[Output]: """Detect uninitialized state variables Recursively visit the calls diff --git a/slither/detectors/variables/uninitialized_storage_variables.py b/slither/detectors/variables/uninitialized_storage_variables.py index a0c35d80d..422996646 100644 --- a/slither/detectors/variables/uninitialized_storage_variables.py +++ b/slither/detectors/variables/uninitialized_storage_variables.py @@ -4,8 +4,12 @@ Recursively explore the CFG to only report uninitialized storage variables that are written before being read """ +from typing import List +from slither.core.cfg.node import Node +from slither.core.declarations.function_contract import FunctionContract from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.utils.output import Output class UninitializedStorageVars(AbstractDetector): @@ -45,7 +49,9 @@ Bob calls `func`. As a result, `owner` is overridden to `0`. # node.context[self.key] contains the uninitialized storage variables key = "UNINITIALIZEDSTORAGE" - def _detect_uninitialized(self, function, node, visited): + def _detect_uninitialized( + self, function: FunctionContract, node: Node, visited: List[Node] + ) -> None: if node in visited: return @@ -81,7 +87,7 @@ Bob calls `func`. As a result, `owner` is overridden to `0`. for son in node.sons: self._detect_uninitialized(function, son, visited) - def _detect(self): + def _detect(self) -> List[Output]: """Detect uninitialized storage variables Recursively visit the calls diff --git a/slither/detectors/variables/unused_state_variables.py b/slither/detectors/variables/unused_state_variables.py index 71cecbfbd..d542f67d3 100644 --- a/slither/detectors/variables/unused_state_variables.py +++ b/slither/detectors/variables/unused_state_variables.py @@ -1,15 +1,19 @@ """ Module detecting unused state variables """ +from typing import List, Optional + from slither.core.compilation_unit import SlitherCompilationUnit -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.declarations.contract import Contract from slither.core.solidity_types import ArrayType -from slither.visitors.expression.export_values import ExportValues from slither.core.variables.state_variable import StateVariable +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.formatters.variables.unused_state_variables import custom_format +from slither.utils.output import Output +from slither.visitors.expression.export_values import ExportValues -def detect_unused(contract): +def detect_unused(contract: Contract) -> Optional[List[StateVariable]]: if contract.is_signature_only(): return None # Get all the variables read in all the functions and modifiers @@ -54,7 +58,7 @@ class UnusedStateVars(AbstractDetector): WIKI_EXPLOIT_SCENARIO = "" WIKI_RECOMMENDATION = "Remove unused state variables." - def _detect(self): + def _detect(self) -> List[Output]: """Detect unused state variables""" results = [] for c in self.compilation_unit.contracts_derived: diff --git a/slither/detectors/variables/var_read_using_this.py b/slither/detectors/variables/var_read_using_this.py index 3d9f204c2..b224f8c17 100644 --- a/slither/detectors/variables/var_read_using_this.py +++ b/slither/detectors/variables/var_read_using_this.py @@ -1,8 +1,10 @@ from typing import List + from slither.core.cfg.node import Node from slither.core.declarations import Function, SolidityVariable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations.high_level_call import HighLevelCall +from slither.utils.output import Output class VarReadUsingThis(AbstractDetector): @@ -28,7 +30,7 @@ contract C { WIKI_RECOMMENDATION = "Read the variable directly from storage instead of calling the contract." - def _detect(self): + def _detect(self) -> List[Output]: results = [] for c in self.contracts: for func in c.functions: diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 95d113a84..166fa48f5 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -31,6 +31,7 @@ from slither.slithir.operations import ( from slither.slithir.operations.binary import Binary from slither.slithir.variables import Constant from slither.visitors.expression.constants_folding import ConstantFolding +from slither.utils.output import Output def _get_name(f: Union[Function, Variable]) -> str: @@ -168,7 +169,7 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n all_cst_used: List[ConstantValue], all_cst_used_in_binary: Dict[str, List[ConstantValue]], context_explored: Set[Node], -): +) -> None: for ir in irs: if isinstance(ir, Binary): for r in ir.read: @@ -364,7 +365,7 @@ class Echidna(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#echidna" - def output(self, filename): # pylint: disable=too-many-locals + def output(self, filename: str) -> Output: # pylint: disable=too-many-locals """ Output the inheritance relation diff --git a/slither/slither.py b/slither/slither.py index 227a37365..3e44944b3 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -11,6 +11,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi from slither.exceptions import SlitherError from slither.printers.abstract_printer import AbstractPrinter from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc +from slither.utils.output import Output logger = logging.getLogger("Slither") logging.basicConfig() @@ -48,7 +49,7 @@ def _update_file_scopes(candidates: ValuesView[FileScope]): class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes - def __init__(self, target: Union[str, CryticCompile], **kwargs): + def __init__(self, target: Union[str, CryticCompile], **kwargs) -> None: """ Args: target (str | CryticCompile) @@ -207,7 +208,7 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes self.write_results_to_hide() return results - def run_printers(self): + def run_printers(self) -> List[Output]: """ :return: List of registered printers outputs. """ @@ -215,5 +216,5 @@ class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes return [p.output(self._crytic_compile.target).data for p in self._printers] @property - def triage_mode(self): + def triage_mode(self) -> bool: return self._triage_mode diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 89f85499c..87a6b075b 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import List, TYPE_CHECKING, Union, Optional +from typing import Any, List, TYPE_CHECKING, Union, Optional # pylint: disable= too-many-lines,import-outside-toplevel,too-many-branches,too-many-statements,too-many-nested-blocks from slither.core.declarations import ( @@ -34,7 +34,7 @@ from slither.core.solidity_types.elementary_type import ( MaxValues, ) from slither.core.solidity_types.type import Type -from slither.core.solidity_types.type_alias import TypeAlias +from slither.core.solidity_types.type_alias import TypeAliasTopLevel, TypeAlias from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable @@ -83,15 +83,36 @@ from slither.slithir.variables import TupleVariable from slither.utils.function import get_function_id from slither.utils.type import export_nested_types_from_variable from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR +import slither.core.declarations.contract +import slither.core.declarations.function +import slither.core.solidity_types.elementary_type +import slither.core.solidity_types.function_type +import slither.core.solidity_types.user_defined_type +import slither.slithir.operations.assignment +import slither.slithir.operations.binary +import slither.slithir.operations.call +import slither.slithir.operations.high_level_call +import slither.slithir.operations.index +import slither.slithir.operations.init_array +import slither.slithir.operations.internal_call +import slither.slithir.operations.length +import slither.slithir.operations.library_call +import slither.slithir.operations.low_level_call +import slither.slithir.operations.member +import slither.slithir.operations.operation +import slither.slithir.operations.send +import slither.slithir.operations.solidity_call +import slither.slithir.operations.transfer +import slither.slithir.variables.temporary +from slither.core.expressions.expression import Expression if TYPE_CHECKING: from slither.core.cfg.node import Node - from slither.core.compilation_unit import SlitherCompilationUnit logger = logging.getLogger("ConvertToIR") -def convert_expression(expression, node): +def convert_expression(expression: Expression, node: "Node") -> List[Any]: # handle standlone expression # such as return true; from slither.core.cfg.node import NodeType @@ -143,7 +164,7 @@ def convert_expression(expression, node): ################################################################################### -def is_value(ins): +def is_value(ins: Operation) -> bool: if isinstance(ins, TmpCall): if isinstance(ins.ori, Member): if ins.ori.variable_right == "value": @@ -151,7 +172,7 @@ def is_value(ins): return False -def is_gas(ins): +def is_gas(ins: Operation) -> bool: if isinstance(ins, TmpCall): if isinstance(ins.ori, Member): if ins.ori.variable_right == "gas": @@ -159,7 +180,7 @@ def is_gas(ins): return False -def _fits_under_integer(val: int, can_be_int: bool, can_be_uint) -> List[str]: +def _fits_under_integer(val: int, can_be_int: bool, can_be_uint: bool) -> List[str]: """ Return the list of uint/int that can contain val @@ -271,7 +292,7 @@ def _find_function_from_parameter( return None -def is_temporary(ins): +def is_temporary(ins: Operation) -> bool: return isinstance( ins, (Argument, TmpNewElementaryType, TmpNewContract, TmpNewArray, TmpNewStructure), @@ -300,7 +321,7 @@ def _make_function_type(func: Function) -> FunctionType: ################################################################################### -def integrate_value_gas(result): +def integrate_value_gas(result: List[Operation]) -> List[Operation]: """ Integrate value and gas temporary arguments to call instruction """ @@ -504,7 +525,9 @@ def _convert_type_contract(ir: Member) -> Assignment: raise SlithIRError(f"type({contract.name}).{ir.variable_right} is unknown") -def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals +def propagate_types( + ir: slither.slithir.operations.operation.Operation, node: "Node" +): # pylint: disable=too-many-locals # propagate the type node_function = node.function using_for = ( @@ -813,7 +836,10 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals return None -def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: disable=too-many-locals +# pylint: disable=too-many-locals +def extract_tmp_call( + ins: TmpCall, contract: Optional[Contract] +) -> slither.slithir.operations.call.Call: assert isinstance(ins, TmpCall) if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): # If the call is made to a variable member, where the member is this @@ -1114,7 +1140,7 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis ################################################################################### -def can_be_low_level(ir): +def can_be_low_level(ir: HighLevelCall) -> bool: return ir.function_name in [ "transfer", "send", @@ -1125,7 +1151,9 @@ def can_be_low_level(ir): ] -def convert_to_low_level(ir): +def convert_to_low_level( + ir: HighLevelCall, +) -> Union[Send, LowLevelCall, Transfer,]: """ Convert to a transfer/send/or low level call The funciton assume to receive a correct IR @@ -1165,7 +1193,7 @@ def convert_to_low_level(ir): raise SlithIRError(f"Incorrect conversion to low level {ir}") -def can_be_solidity_func(ir) -> bool: +def can_be_solidity_func(ir: HighLevelCall) -> bool: if not isinstance(ir, HighLevelCall): return False return ir.destination.name == "abi" and ir.function_name in [ @@ -1178,7 +1206,9 @@ def can_be_solidity_func(ir) -> bool: ] -def convert_to_solidity_func(ir): +def convert_to_solidity_func( + ir: HighLevelCall, +) -> SolidityCall: """ Must be called after can_be_solidity_func :param ir: @@ -1214,7 +1244,9 @@ def convert_to_solidity_func(ir): return new_ir -def convert_to_push_expand_arr(ir, node, ret): +def convert_to_push_expand_arr( + ir: HighLevelCall, node: "Node", ret: List[Any] +) -> TemporaryVariable: arr = ir.destination length = ReferenceVariable(node) @@ -1249,7 +1281,18 @@ def convert_to_push_expand_arr(ir, node, ret): return length_val -def convert_to_push_set_val(ir, node, length_val, ret): +def convert_to_push_set_val( + ir: HighLevelCall, + node: "Node", + length_val: TemporaryVariable, + ret: List[ + Union[ + Length, + Assignment, + Binary, + ] + ], +) -> None: arr = ir.destination new_type = ir.destination.type.type @@ -1284,7 +1327,17 @@ def convert_to_push_set_val(ir, node, length_val, ret): ret.append(ir_assign_value) -def convert_to_push(ir, node): +def convert_to_push( + ir: slither.slithir.operations.high_level_call.HighLevelCall, node: "Node" +) -> List[ + Union[ + slither.slithir.operations.length.Length, + slither.slithir.operations.assignment.Assignment, + slither.slithir.operations.binary.Binary, + slither.slithir.operations.index.Index, + slither.slithir.operations.init_array.InitArray, + ] +]: """ Convert a call to a series of operations to push a new value onto the array @@ -1358,7 +1411,17 @@ def convert_to_pop(ir, node): return ret -def look_for_library_or_top_level(contract, ir, using_for, t): +def look_for_library_or_top_level( + contract: Contract, + ir: HighLevelCall, + using_for, + t: Union[ + UserDefinedType, + ElementaryType, + str, + TypeAliasTopLevel, + ], +) -> Optional[Union[LibraryCall, InternalCall,]]: for destination in using_for[t]: if isinstance(destination, FunctionTopLevel) and destination.name == ir.function_name: arguments = [ir.destination] + ir.arguments @@ -1403,7 +1466,9 @@ def look_for_library_or_top_level(contract, ir, using_for, t): return None -def convert_to_library_or_top_level(ir, node, using_for): +def convert_to_library_or_top_level( + ir: HighLevelCall, node: "Node", using_for +) -> Optional[Union[LibraryCall, InternalCall,]]: # We use contract_declarer, because Solidity resolve the library # before resolving the inheritance. # Though we could use .contract as libraries cannot be shadowed @@ -1422,7 +1487,12 @@ def convert_to_library_or_top_level(ir, node, using_for): return None -def get_type(t): +def get_type( + t: Union[ + UserDefinedType, + ElementaryType, + ] +) -> str: """ Convert a type to a str If the instance is a Contract, return 'address' instead @@ -1441,7 +1511,7 @@ def _can_be_implicitly_converted(source: str, target: str) -> bool: return source == target -def convert_type_library_call(ir: HighLevelCall, lib_contract: Contract): +def convert_type_library_call(ir: HighLevelCall, lib_contract: Contract) -> Optional[LibraryCall]: func = None candidates = [ f @@ -1652,7 +1722,7 @@ def convert_type_of_high_and_internal_level_call( ################################################################################### -def find_references_origin(irs): +def find_references_origin(irs: List[Operation]) -> None: """ Make lvalue of each Index, Member operation points to the left variable @@ -1689,7 +1759,7 @@ def remove_temporary(result): return result -def remove_unused(result): +def remove_unused(result: List[Operation]) -> List[Operation]: removed = True if not result: @@ -1736,7 +1806,7 @@ def remove_unused(result): ################################################################################### -def convert_constant_types(irs): +def convert_constant_types(irs: List[Operation]) -> None: """ late conversion of uint -> type for constant (Literal) :param irs: @@ -1812,7 +1882,7 @@ def convert_constant_types(irs): ################################################################################### -def convert_delete(irs): +def convert_delete(irs: List[Operation]) -> None: """ Convert the lvalue of the Delete to point to the variable removed This can only be done after find_references_origin is called @@ -1833,7 +1903,7 @@ def convert_delete(irs): ################################################################################### -def _find_source_mapping_references(irs: List[Operation]): +def _find_source_mapping_references(irs: List[Operation]) -> None: for ir in irs: if isinstance(ir, NewContract): @@ -1848,7 +1918,7 @@ def _find_source_mapping_references(irs: List[Operation]): ################################################################################### -def apply_ir_heuristics(irs: List[Operation], node: "Node"): +def apply_ir_heuristics(irs: List[Operation], node: "Node") -> List[Operation]: """ Apply a set of heuristic to improve slithIR """ diff --git a/slither/slithir/operations/assignment.py b/slither/slithir/operations/assignment.py index 3d05c3040..0ed5f70a4 100644 --- a/slither/slithir/operations/assignment.py +++ b/slither/slithir/operations/assignment.py @@ -1,15 +1,21 @@ import logging +from typing import List from slither.core.declarations.function import Function from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.variables import TupleVariable, ReferenceVariable +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.variable import Variable + logger = logging.getLogger("AssignmentOperationIR") class Assignment(OperationWithLValue): - def __init__(self, left_variable, right_variable, variable_return_type): + def __init__( + self, left_variable: Variable, right_variable: SourceMapping, variable_return_type + ) -> None: assert is_valid_lvalue(left_variable) assert is_valid_rvalue(right_variable) or isinstance( right_variable, (Function, TupleVariable) @@ -25,7 +31,7 @@ class Assignment(OperationWithLValue): return list(self._variables) @property - def read(self): + def read(self) -> List[SourceMapping]: return [self.rvalue] @property @@ -33,7 +39,7 @@ class Assignment(OperationWithLValue): return self._variable_return_type @property - def rvalue(self): + def rvalue(self) -> SourceMapping: return self._rvalue def __str__(self): diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index d416f3f90..ad65e3e75 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -1,4 +1,6 @@ import logging +from typing import List + from enum import Enum from slither.core.declarations import Function @@ -7,6 +9,9 @@ from slither.slithir.exceptions import SlithIRError from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.variables import ReferenceVariable +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.variable import Variable + logger = logging.getLogger("BinaryOperationIR") @@ -33,7 +38,7 @@ class BinaryType(Enum): OROR = "||" @staticmethod - def return_bool(operation_type): + def return_bool(operation_type: "BinaryType") -> bool: return operation_type in [ BinaryType.OROR, BinaryType.ANDAND, @@ -100,7 +105,13 @@ class BinaryType(Enum): class Binary(OperationWithLValue): - def __init__(self, result, left_variable, right_variable, operation_type: BinaryType): + def __init__( + self, + result: Variable, + left_variable: SourceMapping, + right_variable: Variable, + operation_type: BinaryType, + ) -> None: assert is_valid_rvalue(left_variable) or isinstance(left_variable, Function) assert is_valid_rvalue(right_variable) or isinstance(right_variable, Function) assert is_valid_lvalue(result) @@ -115,7 +126,7 @@ class Binary(OperationWithLValue): result.set_type(left_variable.type) @property - def read(self): + def read(self) -> List[SourceMapping]: return [self.variable_left, self.variable_right] @property @@ -123,15 +134,15 @@ class Binary(OperationWithLValue): return self._variables @property - def variable_left(self): + def variable_left(self) -> SourceMapping: return self._variables[0] @property - def variable_right(self): + def variable_right(self) -> Variable: return self._variables[1] @property - def type(self): + def type(self) -> BinaryType: return self._type @property diff --git a/slither/slithir/operations/codesize.py b/slither/slithir/operations/codesize.py index e7a910806..6640f4fd8 100644 --- a/slither/slithir/operations/codesize.py +++ b/slither/slithir/operations/codesize.py @@ -1,10 +1,19 @@ +from typing import List, Union from slither.core.solidity_types import ElementaryType from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA class CodeSize(OperationWithLValue): - def __init__(self, value, lvalue): + def __init__( + self, + value: Union[LocalVariable, LocalIRVariable], + lvalue: Union[ReferenceVariableSSA, ReferenceVariable], + ) -> None: super().__init__() assert is_valid_rvalue(value) assert is_valid_lvalue(lvalue) @@ -13,11 +22,11 @@ class CodeSize(OperationWithLValue): lvalue.set_type(ElementaryType("uint256")) @property - def read(self): + def read(self) -> List[Union[LocalIRVariable, LocalVariable]]: return [self._value] @property - def value(self): + def value(self) -> LocalVariable: return self._value def __str__(self): diff --git a/slither/slithir/operations/condition.py b/slither/slithir/operations/condition.py index 5ba959a73..41fb3d933 100644 --- a/slither/slithir/operations/condition.py +++ b/slither/slithir/operations/condition.py @@ -1,6 +1,13 @@ +from typing import List, Union from slither.slithir.operations.operation import Operation from slither.slithir.utils.utils import is_valid_rvalue +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.core.variables.variable import Variable class Condition(Operation): @@ -9,17 +16,26 @@ class Condition(Operation): Only present as last operation in conditional node """ - def __init__(self, value): + def __init__( + self, + value: Union[ + LocalVariable, TemporaryVariableSSA, TemporaryVariable, Constant, LocalIRVariable + ], + ) -> None: assert is_valid_rvalue(value) super().__init__() self._value = value @property - def read(self): + def read( + self, + ) -> List[ + Union[LocalIRVariable, Constant, LocalVariable, TemporaryVariableSSA, TemporaryVariable] + ]: return [self.value] @property - def value(self): + def value(self) -> Variable: return self._value def __str__(self): diff --git a/slither/slithir/operations/delete.py b/slither/slithir/operations/delete.py index 4fb05b8f5..496d170ad 100644 --- a/slither/slithir/operations/delete.py +++ b/slither/slithir/operations/delete.py @@ -1,6 +1,11 @@ +from typing import List, Union from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue +from slither.core.variables.state_variable import StateVariable +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from slither.slithir.variables.state_variable import StateIRVariable class Delete(OperationWithLValue): @@ -9,18 +14,26 @@ class Delete(OperationWithLValue): of its operand """ - def __init__(self, lvalue, variable): + def __init__( + self, + lvalue: Union[StateIRVariable, StateVariable, ReferenceVariable], + variable: Union[StateIRVariable, StateVariable, ReferenceVariable, ReferenceVariableSSA], + ) -> None: assert is_valid_lvalue(variable) super().__init__() self._variable = variable self._lvalue = lvalue @property - def read(self): + def read( + self, + ) -> List[Union[StateIRVariable, ReferenceVariable, ReferenceVariableSSA, StateVariable]]: return [self.variable] @property - def variable(self): + def variable( + self, + ) -> Union[StateIRVariable, StateVariable, ReferenceVariable, ReferenceVariableSSA]: return self._variable def __str__(self): diff --git a/slither/slithir/operations/event_call.py b/slither/slithir/operations/event_call.py index 6ef846d4b..8c23a8715 100644 --- a/slither/slithir/operations/event_call.py +++ b/slither/slithir/operations/event_call.py @@ -1,18 +1,20 @@ +from typing import Any, List, Union from slither.slithir.operations.call import Call +from slither.slithir.variables.constant import Constant class EventCall(Call): - def __init__(self, name): + def __init__(self, name: Union[str, Constant]) -> None: super().__init__() self._name = name # todo add instance of the Event @property - def name(self): + def name(self) -> Union[str, Constant]: return self._name @property - def read(self): + def read(self) -> List[Any]: return self._unroll(self.arguments) def __str__(self): diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index ff72a0899..93fb73bd4 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Optional, Union from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue @@ -8,6 +8,10 @@ from slither.core.declarations.function import Function from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.slithir.variables.tuple import TupleVariable class HighLevelCall(Call, OperationWithLValue): @@ -16,7 +20,14 @@ class HighLevelCall(Call, OperationWithLValue): """ # pylint: disable=too-many-arguments,too-many-instance-attributes - def __init__(self, destination, function_name, nbr_arguments, result, type_call): + def __init__( + self, + destination: SourceMapping, + function_name: Constant, + nbr_arguments: int, + result: Optional[Union[TemporaryVariable, TupleVariable, TemporaryVariableSSA]], + type_call: str, + ) -> None: assert isinstance(function_name, Constant) assert is_valid_lvalue(result) or result is None self._check_destination(destination) @@ -34,7 +45,7 @@ class HighLevelCall(Call, OperationWithLValue): # Development function, to be removed once the code is stable # It is ovveride by LbraryCall - def _check_destination(self, destination): # pylint: disable=no-self-use + def _check_destination(self, destination: SourceMapping) -> None: # pylint: disable=no-self-use assert isinstance(destination, (Variable, SolidityVariable)) @property @@ -62,17 +73,17 @@ class HighLevelCall(Call, OperationWithLValue): self._call_gas = v @property - def read(self): + def read(self) -> List[SourceMapping]: all_read = [self.destination, self.call_gas, self.call_value] + self._unroll(self.arguments) # remove None return [x for x in all_read if x] + [self.destination] @property - def destination(self): + def destination(self) -> SourceMapping: return self._destination @property - def function_name(self): + def function_name(self) -> Constant: return self._function_name @property @@ -84,11 +95,11 @@ class HighLevelCall(Call, OperationWithLValue): self._function_instance = function @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call ################################################################################### @@ -96,7 +107,7 @@ class HighLevelCall(Call, OperationWithLValue): # region Analyses ################################################################################### ################################################################################### - def is_static_call(self): + def is_static_call(self) -> bool: # If solidity >0.5, STATICCALL is used if self.compilation_unit.solc_version and self.compilation_unit.solc_version >= "0.5.0": if isinstance(self.function, Function) and (self.function.view or self.function.pure): @@ -105,7 +116,7 @@ class HighLevelCall(Call, OperationWithLValue): return True return False - def can_reenter(self, callstack=None): + def can_reenter(self, callstack: None = None) -> bool: """ Must be called after slithIR analysis pass For Solidity > 0.5, filter access to public variables and constant/pure/view @@ -134,7 +145,7 @@ class HighLevelCall(Call, OperationWithLValue): return True - def can_send_eth(self): + def can_send_eth(self) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py index 096cc7268..ade84fe1d 100644 --- a/slither/slithir/operations/index.py +++ b/slither/slithir/operations/index.py @@ -1,11 +1,22 @@ +from typing import List, Union from slither.core.declarations import SolidityVariableComposed from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.variables.reference import ReferenceVariable +from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.variable import Variable +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA class Index(OperationWithLValue): - def __init__(self, result, left_variable, right_variable, index_type): + def __init__( + self, + result: Union[ReferenceVariable, ReferenceVariableSSA], + left_variable: Variable, + right_variable: SourceMapping, + index_type: Union[ElementaryType, str], + ) -> None: super().__init__() assert is_valid_lvalue(left_variable) or left_variable == SolidityVariableComposed( "msg.data" @@ -17,23 +28,23 @@ class Index(OperationWithLValue): self._lvalue = result @property - def read(self): + def read(self) -> List[SourceMapping]: return list(self.variables) @property - def variables(self): + def variables(self) -> List[SourceMapping]: return self._variables @property - def variable_left(self): + def variable_left(self) -> Variable: return self._variables[0] @property - def variable_right(self): + def variable_right(self) -> SourceMapping: return self._variables[1] @property - def index_type(self): + def index_type(self) -> Union[ElementaryType, str]: return self._type def __str__(self): diff --git a/slither/slithir/operations/init_array.py b/slither/slithir/operations/init_array.py index 23a95ebe7..4f6b2f9fa 100644 --- a/slither/slithir/operations/init_array.py +++ b/slither/slithir/operations/init_array.py @@ -1,9 +1,15 @@ +from typing import List, Union from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_rvalue +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA class InitArray(OperationWithLValue): - def __init__(self, init_values, lvalue): + def __init__( + self, init_values: List[Constant], lvalue: Union[TemporaryVariableSSA, TemporaryVariable] + ) -> None: # init_values can be an array of n dimension # reduce was removed in py3 super().__init__() @@ -24,11 +30,11 @@ class InitArray(OperationWithLValue): self._lvalue = lvalue @property - def read(self): + def read(self) -> List[Constant]: return self._unroll(self.init_values) @property - def init_values(self): + def init_values(self) -> List[Constant]: return list(self._init_values) def __str__(self): diff --git a/slither/slithir/operations/internal_call.py b/slither/slithir/operations/internal_call.py index c096bdceb..395c68846 100644 --- a/slither/slithir/operations/internal_call.py +++ b/slither/slithir/operations/internal_call.py @@ -1,15 +1,26 @@ -from typing import Union, Tuple, List, Optional +from typing import Any, Union, Tuple, List, Optional from slither.core.declarations import Modifier from slither.core.declarations.function import Function from slither.core.declarations.function_contract import FunctionContract from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.slithir.variables.tuple import TupleVariable +from slither.slithir.variables.tuple_ssa import TupleVariableSSA class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes def __init__( - self, function: Union[Function, Tuple[str, str]], nbr_arguments, result, type_call - ): + self, + function: Union[Function, Tuple[str, str]], + nbr_arguments: int, + result: Optional[ + Union[TupleVariableSSA, TemporaryVariableSSA, TupleVariable, TemporaryVariable] + ], + type_call: str, + ) -> None: super().__init__() self._contract_name = "" if isinstance(function, Function): @@ -30,7 +41,7 @@ class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-insta self.function_candidates: Optional[List[Function]] = None @property - def read(self): + def read(self) -> List[Any]: return list(self._unroll(self.arguments)) @property @@ -42,19 +53,19 @@ class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-insta self._function = f @property - def function_name(self): + def function_name(self) -> Constant: return self._function_name @property - def contract_name(self): + def contract_name(self) -> str: return self._contract_name @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call @property diff --git a/slither/slithir/operations/internal_dynamic_call.py b/slither/slithir/operations/internal_dynamic_call.py index 39a8ae6a1..a1ad1aa15 100644 --- a/slither/slithir/operations/internal_dynamic_call.py +++ b/slither/slithir/operations/internal_dynamic_call.py @@ -1,14 +1,25 @@ +from typing import List, Optional, Union from slither.core.solidity_types import FunctionType from slither.core.variables.variable import Variable from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA class InternalDynamicCall( Call, OperationWithLValue ): # pylint: disable=too-many-instance-attributes - def __init__(self, lvalue, function, function_type): + def __init__( + self, + lvalue: Optional[Union[TemporaryVariableSSA, TemporaryVariable]], + function: Union[LocalVariable, LocalIRVariable], + function_type: FunctionType, + ) -> None: assert isinstance(function_type, FunctionType) assert isinstance(function, Variable) assert is_valid_lvalue(lvalue) or lvalue is None @@ -22,15 +33,15 @@ class InternalDynamicCall( self._call_gas = None @property - def read(self): + def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return self._unroll(self.arguments) + [self.function] @property - def function(self): + def function(self) -> Union[LocalVariable, LocalIRVariable]: return self._function @property - def function_type(self): + def function_type(self) -> FunctionType: return self._function_type @property diff --git a/slither/slithir/operations/length.py b/slither/slithir/operations/length.py index 9ba33e655..46637dcc8 100644 --- a/slither/slithir/operations/length.py +++ b/slither/slithir/operations/length.py @@ -1,10 +1,21 @@ +from typing import List, Union from slither.core.solidity_types import ElementaryType from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.state_variable import StateVariable +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from slither.slithir.variables.state_variable import StateIRVariable class Length(OperationWithLValue): - def __init__(self, value, lvalue): + def __init__( + self, + value: Union[StateVariable, LocalIRVariable, LocalVariable, StateIRVariable], + lvalue: Union[ReferenceVariable, ReferenceVariableSSA], + ) -> None: super().__init__() assert is_valid_rvalue(value) assert is_valid_lvalue(lvalue) @@ -13,11 +24,11 @@ class Length(OperationWithLValue): lvalue.set_type(ElementaryType("uint256")) @property - def read(self): + def read(self) -> List[Union[LocalVariable, StateVariable, LocalIRVariable, StateIRVariable]]: return [self._value] @property - def value(self): + def value(self) -> Union[StateVariable, LocalVariable]: return self._value def __str__(self): diff --git a/slither/slithir/operations/library_call.py b/slither/slithir/operations/library_call.py index 1f1f507a6..ebe9bf5ef 100644 --- a/slither/slithir/operations/library_call.py +++ b/slither/slithir/operations/library_call.py @@ -9,10 +9,10 @@ class LibraryCall(HighLevelCall): """ # Development function, to be removed once the code is stable - def _check_destination(self, destination): + def _check_destination(self, destination: Contract) -> None: assert isinstance(destination, Contract) - def can_reenter(self, callstack=None): + def can_reenter(self, callstack: None = None) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/low_level_call.py b/slither/slithir/operations/low_level_call.py index 83bbbb336..7e8c278b8 100644 --- a/slither/slithir/operations/low_level_call.py +++ b/slither/slithir/operations/low_level_call.py @@ -1,9 +1,16 @@ +from typing import List, Union from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.variables.variable import Variable from slither.core.declarations.solidity_variables import SolidityVariable from slither.slithir.variables.constant import Constant +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.slithir.variables.tuple import TupleVariable +from slither.slithir.variables.tuple_ssa import TupleVariableSSA class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes @@ -11,7 +18,14 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta High level message call """ - def __init__(self, destination, function_name, nbr_arguments, result, type_call): + def __init__( + self, + destination: Union[LocalVariable, LocalIRVariable, TemporaryVariableSSA, TemporaryVariable], + function_name: Constant, + nbr_arguments: int, + result: Union[TupleVariable, TupleVariableSSA], + type_call: str, + ) -> None: # pylint: disable=too-many-arguments assert isinstance(destination, (Variable, SolidityVariable)) assert isinstance(function_name, Constant) @@ -51,12 +65,16 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta self._call_gas = v @property - def read(self): + def read( + self, + ) -> List[ + Union[LocalIRVariable, Constant, LocalVariable, TemporaryVariableSSA, TemporaryVariable] + ]: all_read = [self.destination, self.call_gas, self.call_value] + self.arguments # remove None return self._unroll([x for x in all_read if x]) - def can_reenter(self, _callstack=None): + def can_reenter(self, _callstack: None = None) -> bool: """ Must be called after slithIR analysis pass :return: bool @@ -65,7 +83,7 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta return False return True - def can_send_eth(self): + def can_send_eth(self) -> bool: """ Must be called after slithIR analysis pass :return: bool @@ -73,19 +91,21 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta return self._call_value is not None @property - def destination(self): + def destination( + self, + ) -> Union[LocalVariable, LocalIRVariable, TemporaryVariableSSA, TemporaryVariable]: return self._destination @property - def function_name(self): + def function_name(self) -> Constant: return self._function_name @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call def __str__(self): diff --git a/slither/slithir/operations/lvalue.py b/slither/slithir/operations/lvalue.py index 4571b9f1f..d9b800c92 100644 --- a/slither/slithir/operations/lvalue.py +++ b/slither/slithir/operations/lvalue.py @@ -1,3 +1,4 @@ +from typing import Any, List from slither.slithir.operations.operation import Operation @@ -6,7 +7,7 @@ class OperationWithLValue(Operation): Operation with a lvalue """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._lvalue = None @@ -16,7 +17,7 @@ class OperationWithLValue(Operation): return self._lvalue @property - def used(self): + def used(self) -> List[Any]: return self.read + [self.lvalue] @lvalue.setter diff --git a/slither/slithir/operations/member.py b/slither/slithir/operations/member.py index f0c6ae523..9a561ea87 100644 --- a/slither/slithir/operations/member.py +++ b/slither/slithir/operations/member.py @@ -1,3 +1,4 @@ +from typing import List, Union from slither.core.declarations import Contract, Function from slither.core.declarations.custom_error import CustomError from slither.core.declarations.enum import Enum @@ -7,10 +8,17 @@ from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_rvalue from slither.slithir.variables.constant import Constant from slither.slithir.variables.reference import ReferenceVariable +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA class Member(OperationWithLValue): - def __init__(self, variable_left, variable_right, result): + def __init__( + self, + variable_left: SourceMapping, + variable_right: Constant, + result: Union[ReferenceVariable, ReferenceVariableSSA], + ) -> None: # Function can happen for something like # library FunctionExtensions { # function h(function() internal _t, uint8) internal { } @@ -38,15 +46,15 @@ class Member(OperationWithLValue): self._value = None @property - def read(self): + def read(self) -> List[SourceMapping]: return [self.variable_left, self.variable_right] @property - def variable_left(self): + def variable_left(self) -> SourceMapping: return self._variable_left @property - def variable_right(self): + def variable_right(self) -> Constant: return self._variable_right @property diff --git a/slither/slithir/operations/new_array.py b/slither/slithir/operations/new_array.py index 57ee6dcf0..8dad8532f 100644 --- a/slither/slithir/operations/new_array.py +++ b/slither/slithir/operations/new_array.py @@ -1,10 +1,22 @@ +from typing import List, Union, TYPE_CHECKING from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.call import Call from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.solidity_types.type_alias import TypeAliasTopLevel + from slither.slithir.variables.constant import Constant + from slither.slithir.variables.temporary import TemporaryVariable + from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA + class NewArray(Call, OperationWithLValue): - def __init__(self, depth, array_type, lvalue): + def __init__( + self, + depth: int, + array_type: "TypeAliasTopLevel", + lvalue: Union["TemporaryVariableSSA", "TemporaryVariable"], + ) -> None: super().__init__() assert isinstance(array_type, Type) self._depth = depth @@ -13,15 +25,15 @@ class NewArray(Call, OperationWithLValue): self._lvalue = lvalue @property - def array_type(self): + def array_type(self) -> "TypeAliasTopLevel": return self._array_type @property - def read(self): + def read(self) -> List["Constant"]: return self._unroll(self.arguments) @property - def depth(self): + def depth(self) -> int: return self._depth def __str__(self): diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py index fb0014c76..879d12df6 100644 --- a/slither/slithir/operations/new_contract.py +++ b/slither/slithir/operations/new_contract.py @@ -1,10 +1,17 @@ +from typing import Optional, Any, List, Union from slither.slithir.operations import Call, OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant +from slither.core.declarations.contract import Contract +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.core.declarations.function_contract import FunctionContract class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes - def __init__(self, contract_name, lvalue): + def __init__( + self, contract_name: Constant, lvalue: Union[TemporaryVariableSSA, TemporaryVariable] + ) -> None: assert isinstance(contract_name, Constant) assert is_valid_lvalue(lvalue) super().__init__() @@ -40,15 +47,15 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan self._call_salt = s @property - def contract_name(self): + def contract_name(self) -> Constant: return self._contract_name @property - def read(self): + def read(self) -> List[Any]: return self._unroll(self.arguments) @property - def contract_created(self): + def contract_created(self) -> Contract: contract_name = self.contract_name contract_instance = self.node.file_scope.get_contract_from_name(contract_name) return contract_instance @@ -59,7 +66,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan ################################################################################### ################################################################################### - def can_reenter(self, callstack=None): + def can_reenter(self, callstack: Optional[List[FunctionContract]] = None) -> bool: """ Must be called after slithIR analysis pass For Solidity > 0.5, filter access to public variables and constant/pure/view @@ -76,7 +83,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan callstack = callstack + [constructor] return constructor.can_reenter(callstack) - def can_send_eth(self): + def can_send_eth(self) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/new_structure.py b/slither/slithir/operations/new_structure.py index 16a8af785..752de6a3d 100644 --- a/slither/slithir/operations/new_structure.py +++ b/slither/slithir/operations/new_structure.py @@ -1,13 +1,21 @@ +from typing import List, Union + from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.core.declarations.structure import Structure +from slither.core.declarations.structure_contract import StructureContract +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA class NewStructure(Call, OperationWithLValue): - def __init__(self, structure, lvalue): + def __init__( + self, structure: StructureContract, lvalue: Union[TemporaryVariableSSA, TemporaryVariable] + ) -> None: super().__init__() assert isinstance(structure, Structure) assert is_valid_lvalue(lvalue) @@ -16,11 +24,11 @@ class NewStructure(Call, OperationWithLValue): self._lvalue = lvalue @property - def read(self): + def read(self) -> List[Union[TemporaryVariableSSA, TemporaryVariable, Constant]]: return self._unroll(self.arguments) @property - def structure(self): + def structure(self) -> StructureContract: return self._structure @property diff --git a/slither/slithir/operations/nop.py b/slither/slithir/operations/nop.py index 2d2d360dd..387ca3cad 100644 --- a/slither/slithir/operations/nop.py +++ b/slither/slithir/operations/nop.py @@ -1,9 +1,13 @@ -from .operation import Operation +from typing import List + + +from slither.core.variables.variable import Variable +from slither.slithir.operations import Operation class Nop(Operation): @property - def read(self): + def read(self) -> List[Variable]: return [] @property diff --git a/slither/slithir/operations/operation.py b/slither/slithir/operations/operation.py index fa1db89c2..fcf5f4868 100644 --- a/slither/slithir/operations/operation.py +++ b/slither/slithir/operations/operation.py @@ -1,7 +1,9 @@ import abc +from typing import Any, List from slither.core.context.context import Context from slither.core.children.child_expression import ChildExpression from slither.core.children.child_node import ChildNode +from slither.core.variables.variable import Variable from slither.utils.utils import unroll @@ -25,7 +27,7 @@ class AbstractOperation(abc.ABC): class Operation(Context, ChildExpression, ChildNode, AbstractOperation): @property - def used(self): + def used(self) -> List[Variable]: """ By default used is all the variables read """ @@ -33,5 +35,5 @@ class Operation(Context, ChildExpression, ChildNode, AbstractOperation): # if array inside the parameters @staticmethod - def _unroll(l): + def _unroll(l: List[Any]) -> List[Any]: return unroll(l) diff --git a/slither/slithir/operations/phi.py b/slither/slithir/operations/phi.py index a4fa0217e..dd416e8f2 100644 --- a/slither/slithir/operations/phi.py +++ b/slither/slithir/operations/phi.py @@ -1,9 +1,19 @@ +from typing import List, Set, Union, TYPE_CHECKING from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue +from slither.core.declarations.solidity_variables import SolidityVariableComposed +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.state_variable import StateIRVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA + +if TYPE_CHECKING: + from slither.core.cfg.node import Node class Phi(OperationWithLValue): - def __init__(self, left_variable, nodes): + def __init__( + self, left_variable: Union[LocalIRVariable, StateIRVariable], nodes: Set["Node"] + ) -> None: # When Phi operations are created the # correct indexes of the variables are not yet computed # We store the nodes where the variables are written @@ -17,7 +27,11 @@ class Phi(OperationWithLValue): self._nodes = nodes @property - def read(self): + def read( + self, + ) -> List[ + Union[SolidityVariableComposed, LocalIRVariable, TemporaryVariableSSA, StateIRVariable] + ]: return self.rvalues @property @@ -29,7 +43,7 @@ class Phi(OperationWithLValue): self._rvalues = vals @property - def nodes(self): + def nodes(self) -> Set["Node"]: return self._nodes def __str__(self): diff --git a/slither/slithir/operations/phi_callback.py b/slither/slithir/operations/phi_callback.py index 486015b0c..0c8994056 100644 --- a/slither/slithir/operations/phi_callback.py +++ b/slither/slithir/operations/phi_callback.py @@ -1,9 +1,24 @@ +from typing import List, Set, Union, TYPE_CHECKING + from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.operations.phi import Phi +from slither.slithir.operations.high_level_call import HighLevelCall +from slither.slithir.operations.internal_call import InternalCall +from slither.slithir.variables.state_variable import StateIRVariable + +if TYPE_CHECKING: + from slither.core.cfg.node import Node + class PhiCallback(Phi): - def __init__(self, left_variable, nodes, call_ir, rvalue): + def __init__( + self, + left_variable: StateIRVariable, + nodes: Set["Node"], + call_ir: Union[InternalCall, HighLevelCall], + rvalue: StateIRVariable, + ) -> None: assert is_valid_lvalue(left_variable) assert isinstance(nodes, set) super().__init__(left_variable, nodes) @@ -12,17 +27,21 @@ class PhiCallback(Phi): self._rvalue_no_callback = rvalue @property - def callee_ir(self): + def callee_ir(self) -> Union[InternalCall, HighLevelCall]: return self._call_ir @property - def read(self): + def read(self) -> List[StateIRVariable]: return self.rvalues @property def rvalues(self): return self._rvalues + @rvalues.setter + def rvalues(self, vals): + self._rvalues = vals + @property def rvalue_no_callback(self): """ @@ -30,10 +49,6 @@ class PhiCallback(Phi): """ return self._rvalue_no_callback - @rvalues.setter - def rvalues(self, vals): - self._rvalues = vals - @property def nodes(self): return self._nodes diff --git a/slither/slithir/operations/return_operation.py b/slither/slithir/operations/return_operation.py index c1ccf47d1..c21579763 100644 --- a/slither/slithir/operations/return_operation.py +++ b/slither/slithir/operations/return_operation.py @@ -1,8 +1,11 @@ +from typing import List + from slither.core.declarations import Function from slither.slithir.operations.operation import Operation from slither.slithir.variables.tuple import TupleVariable from slither.slithir.utils.utils import is_valid_rvalue +from slither.core.variables.variable import Variable class Return(Operation): @@ -11,7 +14,7 @@ class Return(Operation): Only present as last operation in RETURN node """ - def __init__(self, values): + def __init__(self, values) -> None: # Note: Can return None # ex: return call() # where call() dont return @@ -35,7 +38,7 @@ class Return(Operation): super().__init__() self._values = values - def _valid_value(self, value): + def _valid_value(self, value) -> bool: if isinstance(value, list): assert all(self._valid_value(v) for v in value) else: @@ -43,11 +46,11 @@ class Return(Operation): return True @property - def read(self): + def read(self) -> List[Variable]: return self._unroll(self.values) @property - def values(self): + def values(self) -> List[Variable]: return self._unroll(self._values) def __str__(self): diff --git a/slither/slithir/operations/send.py b/slither/slithir/operations/send.py index 27b0c7021..344043419 100644 --- a/slither/slithir/operations/send.py +++ b/slither/slithir/operations/send.py @@ -1,12 +1,24 @@ +from typing import List, Union + from slither.core.declarations.solidity_variables import SolidityVariable from slither.core.variables.variable import Variable from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA class Send(Call, OperationWithLValue): - def __init__(self, destination, value, result): + def __init__( + self, + destination: Union[LocalVariable, LocalIRVariable], + value: Constant, + result: Union[TemporaryVariable, TemporaryVariableSSA], + ) -> None: assert is_valid_lvalue(result) assert isinstance(destination, (Variable, SolidityVariable)) super().__init__() @@ -15,19 +27,19 @@ class Send(Call, OperationWithLValue): self._call_value = value - def can_send_eth(self): + def can_send_eth(self) -> bool: return True @property - def call_value(self): + def call_value(self) -> Constant: return self._call_value @property - def read(self): + def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return [self.destination, self.call_value] @property - def destination(self): + def destination(self) -> Union[LocalVariable, LocalIRVariable]: return self._destination def __str__(self): diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py index 628a527fe..b059c55a6 100644 --- a/slither/slithir/operations/solidity_call.py +++ b/slither/slithir/operations/solidity_call.py @@ -1,10 +1,19 @@ -from slither.core.declarations.solidity_variables import SolidityFunction +from typing import Any, List, Union +from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.children.child_node import ChildNode +from slither.core.solidity_types.elementary_type import ElementaryType class SolidityCall(Call, OperationWithLValue): - def __init__(self, function, nbr_arguments, result, type_call): + def __init__( + self, + function: Union[SolidityCustomRevert, SolidityFunction], + nbr_arguments: int, + result: ChildNode, + type_call: Union[str, List[ElementaryType]], + ) -> None: assert isinstance(function, SolidityFunction) super().__init__() self._function = function @@ -13,19 +22,19 @@ class SolidityCall(Call, OperationWithLValue): self._lvalue = result @property - def read(self): + def read(self) -> List[Any]: return self._unroll(self.arguments) @property - def function(self): + def function(self) -> Union[SolidityCustomRevert, SolidityFunction]: return self._function @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> Union[str, List[ElementaryType]]: return self._type_call def __str__(self): diff --git a/slither/slithir/operations/transfer.py b/slither/slithir/operations/transfer.py index 6cfc58f07..40f1dab3d 100644 --- a/slither/slithir/operations/transfer.py +++ b/slither/slithir/operations/transfer.py @@ -1,29 +1,33 @@ +from typing import List, Union from slither.slithir.operations.call import Call from slither.core.variables.variable import Variable from slither.core.declarations.solidity_variables import SolidityVariable +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable class Transfer(Call): - def __init__(self, destination, value): + def __init__(self, destination: Union[LocalVariable, LocalIRVariable], value: Constant) -> None: assert isinstance(destination, (Variable, SolidityVariable)) self._destination = destination super().__init__() self._call_value = value - def can_send_eth(self): + def can_send_eth(self) -> bool: return True @property - def call_value(self): + def call_value(self) -> Constant: return self._call_value @property - def read(self): + def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return [self.destination, self.call_value] @property - def destination(self): + def destination(self) -> Union[LocalVariable, LocalIRVariable]: return self._destination def __str__(self): diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py index feee46a2c..f351f1fdd 100644 --- a/slither/slithir/operations/type_conversion.py +++ b/slither/slithir/operations/type_conversion.py @@ -1,11 +1,24 @@ +from typing import List, Union from slither.core.declarations import Contract from slither.core.solidity_types.type import Type from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +import slither.core.declarations.contract +from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.solidity_types.type_alias import TypeAliasContract, TypeAliasTopLevel +from slither.core.solidity_types.user_defined_type import UserDefinedType +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA class TypeConversion(OperationWithLValue): - def __init__(self, result, variable, variable_type): + def __init__( + self, + result: Union[TemporaryVariableSSA, TemporaryVariable], + variable: SourceMapping, + variable_type: Union[TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel], + ) -> None: super().__init__() assert is_valid_rvalue(variable) or isinstance(variable, Contract) assert is_valid_lvalue(result) @@ -16,15 +29,23 @@ class TypeConversion(OperationWithLValue): self._lvalue = result @property - def variable(self): + def variable(self) -> SourceMapping: return self._variable @property - def type(self): + def type( + self, + ) -> Union[ + TypeAliasContract, + TypeAliasTopLevel, + slither.core.declarations.contract.Contract, + UserDefinedType, + ElementaryType, + ]: return self._type @property - def read(self): + def read(self) -> List[SourceMapping]: return [self.variable] def __str__(self): diff --git a/slither/slithir/operations/unary.py b/slither/slithir/operations/unary.py index 59b5bcbb0..a6529d726 100644 --- a/slither/slithir/operations/unary.py +++ b/slither/slithir/operations/unary.py @@ -1,9 +1,17 @@ import logging +from typing import List, Union from enum import Enum from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.exceptions import SlithIRError +from slither.core.expressions.unary_operation import UnaryOperationType +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA + logger = logging.getLogger("BinaryOperationIR") @@ -23,7 +31,12 @@ class UnaryType(Enum): class Unary(OperationWithLValue): - def __init__(self, result, variable, operation_type): + def __init__( + self, + result: Union[TemporaryVariableSSA, TemporaryVariable], + variable: Union[Constant, LocalIRVariable, LocalVariable], + operation_type: UnaryOperationType, + ) -> None: assert is_valid_rvalue(variable) assert is_valid_lvalue(result) super().__init__() @@ -32,15 +45,15 @@ class Unary(OperationWithLValue): self._lvalue = result @property - def read(self): + def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return [self._variable] @property - def rvalue(self): + def rvalue(self) -> Union[Constant, LocalVariable]: return self._variable @property - def type(self): + def type(self) -> UnaryOperationType: return self._type @property diff --git a/slither/slithir/operations/unpack.py b/slither/slithir/operations/unpack.py index c5a45cb9e..a183463d4 100644 --- a/slither/slithir/operations/unpack.py +++ b/slither/slithir/operations/unpack.py @@ -1,10 +1,20 @@ +from typing import List, Union + from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.tuple import TupleVariable +from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.tuple_ssa import TupleVariableSSA class Unpack(OperationWithLValue): - def __init__(self, result, tuple_var, idx): + def __init__( + self, + result: Union[LocalVariableInitFromTuple, LocalIRVariable], + tuple_var: Union[TupleVariable, TupleVariableSSA], + idx: int, + ) -> None: assert is_valid_lvalue(result) assert isinstance(tuple_var, TupleVariable) assert isinstance(idx, int) @@ -14,15 +24,15 @@ class Unpack(OperationWithLValue): self._lvalue = result @property - def read(self): + def read(self) -> List[Union[TupleVariableSSA, TupleVariable]]: return [self.tuple] @property - def tuple(self): + def tuple(self) -> Union[TupleVariable, TupleVariableSSA]: return self._tuple @property - def index(self): + def index(self) -> int: return self._idx def __str__(self): diff --git a/slither/slithir/tmp_operations/argument.py b/slither/slithir/tmp_operations/argument.py index 746bc13f2..25ea5d019 100644 --- a/slither/slithir/tmp_operations/argument.py +++ b/slither/slithir/tmp_operations/argument.py @@ -10,7 +10,7 @@ class ArgumentType(Enum): class Argument(Operation): - def __init__(self, argument): + def __init__(self, argument) -> None: super().__init__() self._argument = argument self._type = ArgumentType.CALL @@ -32,11 +32,11 @@ class Argument(Operation): def read(self): return [self.argument] - def set_type(self, t): + def set_type(self, t: ArgumentType) -> None: assert isinstance(t, ArgumentType) self._type = t - def get_type(self): + def get_type(self) -> ArgumentType: return self._type def __str__(self): diff --git a/slither/slithir/tmp_operations/tmp_call.py b/slither/slithir/tmp_operations/tmp_call.py index fb6641139..2137ebd81 100644 --- a/slither/slithir/tmp_operations/tmp_call.py +++ b/slither/slithir/tmp_operations/tmp_call.py @@ -1,3 +1,5 @@ +from typing import Optional, Union + from slither.core.declarations import ( Event, Contract, @@ -8,10 +10,22 @@ from slither.core.declarations import ( from slither.core.declarations.custom_error import CustomError from slither.core.variables.variable import Variable from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.operations.member import Member +from slither.slithir.tmp_operations.tmp_new_array import TmpNewArray +from slither.slithir.tmp_operations.tmp_new_contract import TmpNewContract +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.tuple import TupleVariable class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attributes - def __init__(self, called, nbr_arguments, result, type_call): + def __init__( + self, + called: SourceMapping, + nbr_arguments: int, + result: Union[TupleVariable, TemporaryVariable], + type_call: str, + ) -> None: assert isinstance( called, ( @@ -80,18 +94,18 @@ class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attribu self._called = c @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call @property - def ori(self): + def ori(self) -> Optional[Union[TmpNewContract, TmpNewArray, "TmpCall", Member]]: return self._ori - def set_ori(self, ori): + def set_ori(self, ori: Union["TmpCall", TmpNewContract, TmpNewArray, Member]) -> None: self._ori = ori def __str__(self): diff --git a/slither/slithir/tmp_operations/tmp_new_array.py b/slither/slithir/tmp_operations/tmp_new_array.py index 0da9c54eb..efbdb6242 100644 --- a/slither/slithir/tmp_operations/tmp_new_array.py +++ b/slither/slithir/tmp_operations/tmp_new_array.py @@ -1,9 +1,15 @@ -from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.solidity_types.type import Type +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.variables.temporary import TemporaryVariable class TmpNewArray(OperationWithLValue): - def __init__(self, depth, array_type, lvalue): + def __init__( + self, + depth: int, + array_type: Type, + lvalue: TemporaryVariable, + ) -> None: super().__init__() assert isinstance(array_type, Type) self._depth = depth @@ -11,7 +17,7 @@ class TmpNewArray(OperationWithLValue): self._lvalue = lvalue @property - def array_type(self): + def array_type(self) -> Type: return self._array_type @property @@ -19,7 +25,7 @@ class TmpNewArray(OperationWithLValue): return [] @property - def depth(self): + def depth(self) -> int: return self._depth def __str__(self): diff --git a/slither/slithir/tmp_operations/tmp_new_contract.py b/slither/slithir/tmp_operations/tmp_new_contract.py index 337a40bd1..5a987a7c7 100644 --- a/slither/slithir/tmp_operations/tmp_new_contract.py +++ b/slither/slithir/tmp_operations/tmp_new_contract.py @@ -1,8 +1,9 @@ from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.variables.temporary import TemporaryVariable class TmpNewContract(OperationWithLValue): - def __init__(self, contract_name, lvalue): + def __init__(self, contract_name: str, lvalue: TemporaryVariable) -> None: super().__init__() self._contract_name = contract_name self._lvalue = lvalue @@ -10,7 +11,7 @@ class TmpNewContract(OperationWithLValue): self._call_salt = None @property - def contract_name(self): + def contract_name(self) -> str: return self._contract_name @property diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 322583381..156914b61 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -1,6 +1,8 @@ import logging +from typing import Any, Callable, Dict, List, Tuple, Union -from slither.core.cfg.node import NodeType +import slither.slithir.variables.tuple_ssa +from slither.core.cfg.node import Node, NodeType from slither.core.declarations import ( Contract, Enum, @@ -9,11 +11,16 @@ from slither.core.declarations import ( SolidityVariable, Structure, ) +from slither.core.declarations.function_contract import FunctionContract +from slither.core.declarations.function_top_level import FunctionTopLevel +from slither.core.declarations.modifier import Modifier from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.solidity_types.type import Type from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.core.variables.top_level_variable import TopLevelVariable +from slither.core.variables.variable import Variable +from slither.slithir.exceptions import SlithIRError from slither.slithir.operations import ( Assignment, Binary, @@ -45,7 +52,9 @@ from slither.slithir.operations import ( Unpack, Nop, ) +from slither.slithir.operations.call import Call from slither.slithir.operations.codesize import CodeSize +from slither.slithir.operations.operation import Operation from slither.slithir.variables import ( Constant, LocalIRVariable, @@ -57,7 +66,6 @@ from slither.slithir.variables import ( TupleVariable, TupleVariableSSA, ) -from slither.slithir.exceptions import SlithIRError logger = logging.getLogger("SSA_Conversion") @@ -68,7 +76,9 @@ logger = logging.getLogger("SSA_Conversion") ################################################################################### -def transform_slithir_vars_to_ssa(function): +def transform_slithir_vars_to_ssa( + function: Union[FunctionContract, Modifier, FunctionTopLevel] +) -> None: """ Transform slithIR vars to SSA (TemporaryVariable, ReferenceVariable, TupleVariable) """ @@ -98,7 +108,12 @@ def transform_slithir_vars_to_ssa(function): # pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks,too-many-statements,too-many-branches -def add_ssa_ir(function, all_state_variables_instances): +def add_ssa_ir( + function: Union[FunctionContract, Modifier, FunctionTopLevel], + all_state_variables_instances: Dict[ + str, slither.slithir.variables.state_variable.StateIRVariable + ], +) -> None: """ Add SSA version of the IR Args: @@ -199,14 +214,14 @@ def add_ssa_ir(function, all_state_variables_instances): def generate_ssa_irs( - node, - local_variables_instances, - all_local_variables_instances, - state_variables_instances, - all_state_variables_instances, - init_local_variables_instances, - visited, -): + node: Node, + local_variables_instances: Dict[str, LocalIRVariable], + all_local_variables_instances: Dict[str, LocalIRVariable], + state_variables_instances: Dict[str, StateIRVariable], + all_state_variables_instances: Dict[str, StateIRVariable], + init_local_variables_instances: Dict[str, LocalIRVariable], + visited: List[Node], +) -> None: if node in visited: return @@ -323,7 +338,14 @@ def generate_ssa_irs( ################################################################################### -def last_name(n, var, init_vars): +def last_name( + n: Node, + var: Union[ + StateIRVariable, + LocalIRVariable, + ], + init_vars: Dict[str, LocalIRVariable], +) -> Union[StateIRVariable, LocalIRVariable,]: candidates = [] # Todo optimize by creating a variables_ssa_written attribute for ir_ssa in n.irs_ssa: @@ -342,7 +364,10 @@ def last_name(n, var, init_vars): return max(candidates, key=lambda v: v.index) -def is_used_later(initial_node, variable): +def is_used_later( + initial_node: Node, + variable: Union[StateIRVariable, LocalVariable], +) -> bool: # TODO: does not handle the case where its read and written in the declaration node # It can be problematic if this happens in a loop/if structure # Ex: @@ -389,13 +414,13 @@ def is_used_later(initial_node, variable): def update_lvalue( - new_ir, - node, - local_variables_instances, - all_local_variables_instances, - state_variables_instances, - all_state_variables_instances, -): + new_ir: Operation, + node: Node, + local_variables_instances: Dict[str, LocalIRVariable], + all_local_variables_instances: Dict[str, LocalIRVariable], + state_variables_instances: Dict[str, StateIRVariable], + all_state_variables_instances: Dict[str, StateIRVariable], +) -> None: if isinstance(new_ir, OperationWithLValue): lvalue = new_ir.lvalue update_through_ref = False @@ -437,8 +462,10 @@ def update_lvalue( def initiate_all_local_variables_instances( - nodes, local_variables_instances, all_local_variables_instances -): + nodes: List[Node], + local_variables_instances: Dict[str, LocalIRVariable], + all_local_variables_instances: Dict[str, LocalIRVariable], +) -> None: for node in nodes: if node.variable_declaration: new_var = LocalIRVariable(node.variable_declaration) @@ -457,13 +484,13 @@ def initiate_all_local_variables_instances( def fix_phi_rvalues_and_storage_ref( - node, - local_variables_instances, - all_local_variables_instances, - state_variables_instances, - all_state_variables_instances, - init_local_variables_instances, -): + node: Node, + local_variables_instances: Dict[str, LocalIRVariable], + all_local_variables_instances: Dict[str, LocalIRVariable], + state_variables_instances: Dict[str, StateIRVariable], + all_state_variables_instances: Dict[str, StateIRVariable], + init_local_variables_instances: Dict[str, LocalIRVariable], +) -> None: for ir in node.irs_ssa: if isinstance(ir, (Phi)) and not ir.rvalues: variables = [ @@ -506,7 +533,11 @@ def fix_phi_rvalues_and_storage_ref( ) -def add_phi_origins(node, local_variables_definition, state_variables_definition): +def add_phi_origins( + node: Node, + local_variables_definition: Dict[str, Tuple[LocalVariable, Node]], + state_variables_definition: Dict[str, Tuple[StateVariable, Node]], +) -> None: # Add new key to local_variables_definition # The key is the variable_name @@ -557,12 +588,12 @@ def add_phi_origins(node, local_variables_definition, state_variables_definition def get( variable, - local_variables_instances, - state_variables_instances, - temporary_variables_instances, - reference_variables_instances, - tuple_variables_instances, - all_local_variables_instances, + local_variables_instances: Dict[str, LocalIRVariable], + state_variables_instances: Dict[str, StateIRVariable], + temporary_variables_instances: Dict[int, TemporaryVariableSSA], + reference_variables_instances: Dict[int, ReferenceVariableSSA], + tuple_variables_instances: Dict[int, TupleVariableSSA], + all_local_variables_instances: Dict[str, LocalIRVariable], ): # variable can be None # for example, on LowLevelCall, ir.lvalue can be none @@ -623,14 +654,14 @@ def get( return variable -def get_variable(ir, f, *instances): +def get_variable(ir: Operation, f: Callable, *instances): # pylint: disable=no-value-for-parameter variable = f(ir) variable = get(variable, *instances) return variable -def _get_traversal(values, *instances): +def _get_traversal(values: List[Any], *instances) -> List[Any]: ret = [] # pylint: disable=no-value-for-parameter for v in values: @@ -642,11 +673,19 @@ def _get_traversal(values, *instances): return ret -def get_arguments(ir, *instances): +def get_arguments(ir: Call, *instances) -> List[Any]: return _get_traversal(ir.arguments, *instances) -def get_rec_values(ir, f, *instances): +def get_rec_values( + ir: Union[ + InitArray, + Return, + NewArray, + ], + f: Callable, + *instances, +) -> List[Variable]: # Use by InitArray and NewArray # Potential recursive array(s) ori_init_values = f(ir) @@ -654,7 +693,7 @@ def get_rec_values(ir, f, *instances): return _get_traversal(ori_init_values, *instances) -def copy_ir(ir, *instances): +def copy_ir(ir: Operation, *instances) -> Operation: """ Args: ir (Operation) diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index 7bebc0a80..0a50f8e50 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -8,9 +8,10 @@ from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.constant import Constant from slither.slithir.variables.reference import ReferenceVariable from slither.slithir.variables.tuple import TupleVariable +from slither.core.source_mapping.source_mapping import SourceMapping -def is_valid_rvalue(v): +def is_valid_rvalue(v: SourceMapping) -> bool: return isinstance( v, ( @@ -25,7 +26,7 @@ def is_valid_rvalue(v): ) -def is_valid_lvalue(v): +def is_valid_lvalue(v) -> bool: return isinstance( v, ( diff --git a/slither/slithir/variables/constant.py b/slither/slithir/variables/constant.py index 5da2d9cc0..ddfc9e054 100644 --- a/slither/slithir/variables/constant.py +++ b/slither/slithir/variables/constant.py @@ -1,4 +1,5 @@ from functools import total_ordering +from typing import Optional, Union from slither.core.solidity_types.elementary_type import ElementaryType, Int, Uint from slither.slithir.variables.variable import SlithIRVariable @@ -9,8 +10,11 @@ from slither.utils.integer_conversion import convert_string_to_int @total_ordering class Constant(SlithIRVariable): def __init__( - self, val, constant_type=None, subdenomination=None - ): # pylint: disable=too-many-branches + self, + val: Union[int, str], + constant_type: Optional[ElementaryType] = None, + subdenomination: Optional[str] = None, + ) -> None: # pylint: disable=too-many-branches super().__init__() assert isinstance(val, str) @@ -38,7 +42,7 @@ class Constant(SlithIRVariable): self._val = val @property - def value(self): + def value(self) -> Union[bool, int, str]: """ Return the value. If the expression was an hexadecimal delcared as hex'...' @@ -49,21 +53,21 @@ class Constant(SlithIRVariable): return self._val @property - def original_value(self): + def original_value(self) -> str: """ Return the string representation of the value :return: str """ return self._original_value - def __str__(self): + def __str__(self) -> str: return str(self.value) @property - def name(self): + def name(self) -> str: return str(self) - def __eq__(self, other): + def __eq__(self, other: Union["Constant", str]) -> bool: return self.value == other def __ne__(self, other): diff --git a/slither/slithir/variables/local_variable.py b/slither/slithir/variables/local_variable.py index b78500f14..eb32d4024 100644 --- a/slither/slithir/variables/local_variable.py +++ b/slither/slithir/variables/local_variable.py @@ -1,12 +1,14 @@ +from typing import Set from slither.core.variables.local_variable import LocalVariable from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.variable import SlithIRVariable +from slither.slithir.variables.state_variable import StateIRVariable class LocalIRVariable( LocalVariable, SlithIRVariable ): # pylint: disable=too-many-instance-attributes - def __init__(self, local_variable): + def __init__(self, local_variable: LocalVariable) -> None: assert isinstance(local_variable, LocalVariable) super().__init__() @@ -30,7 +32,7 @@ class LocalIRVariable( # Additional field # points to state variables - self._refers_to = set() + self._refers_to: Set[StateIRVariable] = set() # keep un-ssa version if isinstance(local_variable, LocalIRVariable): @@ -57,10 +59,10 @@ class LocalIRVariable( self._refers_to = variables @property - def non_ssa_version(self): + def non_ssa_version(self) -> LocalVariable: return self._non_ssa_version - def add_refers_to(self, variable): + def add_refers_to(self, variable: StateIRVariable) -> None: # It is a temporaryVariable if its the return of a new .. # ex: string[] memory dynargs = new string[](1); assert isinstance(variable, (SlithIRVariable, TemporaryVariable)) diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py index 100886f2f..95802b7e2 100644 --- a/slither/slithir/variables/reference.py +++ b/slither/slithir/variables/reference.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from slither.core.children.child_node import ChildNode from slither.core.declarations import Contract, Enum, SolidityVariable, Function @@ -9,7 +9,7 @@ if TYPE_CHECKING: class ReferenceVariable(ChildNode, Variable): - def __init__(self, node: "Node", index=None): + def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: self._index = node.compilation_unit.counter_slithir_reference @@ -35,13 +35,6 @@ class ReferenceVariable(ChildNode, Variable): """ return self._points_to - @property - def points_to_origin(self): - points = self.points_to - while isinstance(points, ReferenceVariable): - points = points.points_to - return points - @points_to.setter def points_to(self, points_to): # Can only be a rvalue of @@ -56,17 +49,24 @@ class ReferenceVariable(ChildNode, Variable): self._points_to = points_to @property - def name(self): + def points_to_origin(self): + points = self.points_to + while isinstance(points, ReferenceVariable): + points = points.points_to + return points + + @property + def name(self) -> str: return f"REF_{self.index}" # overide of core.variables.variables # reference can have Function has a type # to handle the function selector - def set_type(self, t): + def set_type(self, t) -> None: if not isinstance(t, Function): super().set_type(t) else: self._type = t - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/slithir/variables/reference_ssa.py b/slither/slithir/variables/reference_ssa.py index 3e591555d..6359b5722 100644 --- a/slither/slithir/variables/reference_ssa.py +++ b/slither/slithir/variables/reference_ssa.py @@ -3,15 +3,17 @@ It is similar to the non-SSA version of slithIR as the ReferenceVariable are in SSA form in both version """ +from typing import Union from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.tuple import TupleVariable class ReferenceVariableSSA(ReferenceVariable): # pylint: disable=too-few-public-methods - def __init__(self, reference): + def __init__(self, reference: ReferenceVariable) -> None: super().__init__(reference.node, reference.index) self._non_ssa_version = reference @property - def non_ssa_version(self): + def non_ssa_version(self) -> Union[ReferenceVariable, TupleVariable]: return self._non_ssa_version diff --git a/slither/slithir/variables/state_variable.py b/slither/slithir/variables/state_variable.py index 0f92d8687..7bb3a4077 100644 --- a/slither/slithir/variables/state_variable.py +++ b/slither/slithir/variables/state_variable.py @@ -5,7 +5,7 @@ from slither.slithir.variables.variable import SlithIRVariable class StateIRVariable( StateVariable, SlithIRVariable ): # pylint: disable=too-many-instance-attributes - def __init__(self, state_variable): + def __init__(self, state_variable: StateVariable) -> None: assert isinstance(state_variable, StateVariable) super().__init__() @@ -38,7 +38,7 @@ class StateIRVariable( self._index = idx @property - def non_ssa_version(self): + def non_ssa_version(self) -> StateVariable: return self._non_ssa_version @property diff --git a/slither/slithir/variables/temporary.py b/slither/slithir/variables/temporary.py index e0a3adb26..8cb1cf350 100644 --- a/slither/slithir/variables/temporary.py +++ b/slither/slithir/variables/temporary.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from slither.core.children.child_node import ChildNode from slither.core.variables.variable import Variable @@ -8,7 +8,7 @@ if TYPE_CHECKING: class TemporaryVariable(ChildNode, Variable): - def __init__(self, node: "Node", index=None): + def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: self._index = node.compilation_unit.counter_slithir_temporary @@ -26,8 +26,8 @@ class TemporaryVariable(ChildNode, Variable): self._index = idx @property - def name(self): + def name(self) -> str: return f"TMP_{self.index}" - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/slithir/variables/temporary_ssa.py b/slither/slithir/variables/temporary_ssa.py index 3dc772b75..0d8fb8e3c 100644 --- a/slither/slithir/variables/temporary_ssa.py +++ b/slither/slithir/variables/temporary_ssa.py @@ -3,15 +3,18 @@ It is similar to the non-SSA version of slithIR as the TemporaryVariable are in SSA form in both version """ +from typing import Union from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.tuple import TupleVariable class TemporaryVariableSSA(TemporaryVariable): # pylint: disable=too-few-public-methods - def __init__(self, temporary): + def __init__(self, temporary: TemporaryVariable) -> None: super().__init__(temporary.node, temporary.index) self._non_ssa_version = temporary @property - def non_ssa_version(self): + def non_ssa_version(self) -> Union[TemporaryVariable, TupleVariable, ReferenceVariable]: return self._non_ssa_version diff --git a/slither/slithir/variables/tuple.py b/slither/slithir/variables/tuple.py index 537f91d42..dc085347e 100644 --- a/slither/slithir/variables/tuple.py +++ b/slither/slithir/variables/tuple.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from slither.core.children.child_node import ChildNode from slither.slithir.variables.variable import SlithIRVariable @@ -8,7 +8,7 @@ if TYPE_CHECKING: class TupleVariable(ChildNode, SlithIRVariable): - def __init__(self, node: "Node", index=None): + def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: self._index = node.compilation_unit.counter_slithir_tuple @@ -27,8 +27,8 @@ class TupleVariable(ChildNode, SlithIRVariable): self._index = idx @property - def name(self): + def name(self) -> str: return f"TUPLE_{self.index}" - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/slithir/variables/tuple_ssa.py b/slither/slithir/variables/tuple_ssa.py index 50a4b672b..881feb1d6 100644 --- a/slither/slithir/variables/tuple_ssa.py +++ b/slither/slithir/variables/tuple_ssa.py @@ -7,7 +7,7 @@ from slither.slithir.variables.tuple import TupleVariable class TupleVariableSSA(TupleVariable): # pylint: disable=too-few-public-methods - def __init__(self, t): + def __init__(self, t: TupleVariable) -> None: super().__init__(t.node, t.index) self._non_ssa_version = t diff --git a/slither/slithir/variables/variable.py b/slither/slithir/variables/variable.py index 8e2cb145c..a1a1a6df9 100644 --- a/slither/slithir/variables/variable.py +++ b/slither/slithir/variables/variable.py @@ -2,7 +2,7 @@ from slither.core.variables.variable import Variable class SlithIRVariable(Variable): - def __init__(self): + def __init__(self) -> None: super().__init__() self._index = 0 diff --git a/slither/solc_parsing/cfg/node.py b/slither/solc_parsing/cfg/node.py index d7009fdda..b1380553b 100644 --- a/slither/solc_parsing/cfg/node.py +++ b/slither/solc_parsing/cfg/node.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Union, Optional, Dict, TYPE_CHECKING from slither.core.cfg.node import Node from slither.core.cfg.node import NodeType @@ -12,9 +12,13 @@ from slither.visitors.expression.find_calls import FindCalls from slither.visitors.expression.read_var import ReadVar from slither.visitors.expression.write_var import WriteVar +if TYPE_CHECKING: + from slither.solc_parsing.declarations.function import FunctionSolc + from slither.solc_parsing.declarations.modifier import ModifierSolc + class NodeSolc: - def __init__(self, node: Node): + def __init__(self, node: Node) -> None: self._unparsed_expression: Optional[Dict] = None self._node = node @@ -22,11 +26,11 @@ class NodeSolc: def underlying_node(self) -> Node: return self._node - def add_unparsed_expression(self, expression: Dict): + def add_unparsed_expression(self, expression: Dict) -> None: assert self._unparsed_expression is None self._unparsed_expression = expression - def analyze_expressions(self, caller_context): + def analyze_expressions(self, caller_context: Union["FunctionSolc", "ModifierSolc"]) -> None: if self._node.type == NodeType.VARIABLE and not self._node.expression: self._node.add_expression(self._node.variable_declaration.expression) if self._unparsed_expression: diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 0ffd24131..475c3fab2 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -1,6 +1,6 @@ import logging import re -from typing import List, Dict, Callable, TYPE_CHECKING, Union, Set +from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set from slither.core.declarations import Modifier, Event, EnumContract, StructureContract, Function from slither.core.declarations.contract import Contract @@ -22,14 +22,15 @@ LOGGER = logging.getLogger("ContractSolcParsing") if TYPE_CHECKING: from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc - from slither.core.slither_core import SlitherCore from slither.core.compilation_unit import SlitherCompilationUnit # pylint: disable=too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks,too-many-public-methods class ContractSolc(CallerContextExpression): - def __init__(self, slither_parser: "SlitherCompilationUnitSolc", contract: Contract, data): + def __init__( + self, slither_parser: "SlitherCompilationUnitSolc", contract: Contract, data: Dict[str, Any] + ) -> None: # assert slitherSolc.solc_version.startswith('0.4') self._contract = contract @@ -86,7 +87,7 @@ class ContractSolc(CallerContextExpression): def is_analyzed(self) -> bool: return self._is_analyzed - def set_is_analyzed(self, is_analyzed: bool): + def set_is_analyzed(self, is_analyzed: bool) -> None: self._is_analyzed = is_analyzed @property @@ -130,7 +131,7 @@ class ContractSolc(CallerContextExpression): def get_key(self) -> str: return self._slither_parser.get_key() - def get_children(self, key="nodes") -> str: + def get_children(self, key: str = "nodes") -> str: if self.is_compact_ast: return key return "children" @@ -150,7 +151,7 @@ class ContractSolc(CallerContextExpression): ################################################################################### ################################################################################### - def _parse_contract_info(self): + def _parse_contract_info(self) -> None: if self.is_compact_ast: attributes = self._data else: @@ -178,7 +179,7 @@ class ContractSolc(CallerContextExpression): "name" ] - def _parse_base_contract_info(self): # pylint: disable=too-many-branches + def _parse_base_contract_info(self) -> None: # pylint: disable=too-many-branches # Parse base contracts (immediate, non-linearized) if self.is_compact_ast: # Parse base contracts + constructors in compact-ast @@ -236,7 +237,7 @@ class ContractSolc(CallerContextExpression): ): self.baseConstructorContractsCalled.append(referencedDeclaration) - def _parse_contract_items(self): + def _parse_contract_items(self) -> None: # pylint: disable=too-many-branches if not self.get_children() in self._data: # empty contract return @@ -289,7 +290,7 @@ class ContractSolc(CallerContextExpression): self._contract.file_scope.user_defined_types[alias] = user_defined_type self._contract.file_scope.user_defined_types[alias_canonical] = user_defined_type - def _parse_struct(self, struct: Dict): + def _parse_struct(self, struct: Dict) -> None: st = StructureContract(self._contract.compilation_unit) st.set_contract(self._contract) @@ -299,7 +300,7 @@ class ContractSolc(CallerContextExpression): self._contract.structures_as_dict[st.name] = st self._structures_parser.append(st_parser) - def parse_structs(self): + def parse_structs(self) -> None: for father in self._contract.inheritance_reverse: self._contract.structures_as_dict.update(father.structures_as_dict) @@ -307,7 +308,7 @@ class ContractSolc(CallerContextExpression): self._parse_struct(struct) self._structuresNotParsed = None - def _parse_custom_error(self, custom_error: Dict): + def _parse_custom_error(self, custom_error: Dict) -> None: ce = CustomErrorContract(self.compilation_unit) ce.set_contract(self._contract) ce.set_offset(custom_error["src"], self.compilation_unit) @@ -316,7 +317,7 @@ class ContractSolc(CallerContextExpression): self._contract.custom_errors_as_dict[ce.name] = ce self._custom_errors_parser.append(ce_parser) - def parse_custom_errors(self): + def parse_custom_errors(self) -> None: for father in self._contract.inheritance_reverse: self._contract.custom_errors_as_dict.update(father.custom_errors_as_dict) @@ -324,7 +325,7 @@ class ContractSolc(CallerContextExpression): self._parse_custom_error(custom_error) self._customErrorParsed = None - def parse_state_variables(self): + def parse_state_variables(self) -> None: for father in self._contract.inheritance_reverse: self._contract.variables_as_dict.update( { @@ -352,7 +353,7 @@ class ContractSolc(CallerContextExpression): self._contract.variables_as_dict[var.name] = var self._contract.add_variables_ordered([var]) - def _parse_modifier(self, modifier_data: Dict): + def _parse_modifier(self, modifier_data: Dict) -> None: modif = Modifier(self._contract.compilation_unit) modif.set_offset(modifier_data["src"], self._contract.compilation_unit) modif.set_contract(self._contract) @@ -365,12 +366,12 @@ class ContractSolc(CallerContextExpression): self._slither_parser.add_function_or_modifier_parser(modif_parser) - def parse_modifiers(self): + def parse_modifiers(self) -> None: for modifier in self._modifiersNotParsed: self._parse_modifier(modifier) self._modifiersNotParsed = None - def _parse_function(self, function_data: Dict): + def _parse_function(self, function_data: Dict) -> None: func = FunctionContract(self._contract.compilation_unit) func.set_offset(function_data["src"], self._contract.compilation_unit) func.set_contract(self._contract) @@ -383,7 +384,7 @@ class ContractSolc(CallerContextExpression): self._slither_parser.add_function_or_modifier_parser(func_parser) - def parse_functions(self): + def parse_functions(self) -> None: for function in self._functionsNotParsed: self._parse_function(function) @@ -403,21 +404,21 @@ class ContractSolc(CallerContextExpression): LOGGER.error(error) self._contract.is_incorrectly_constructed = True - def analyze_content_modifiers(self): + def analyze_content_modifiers(self) -> None: try: for modifier_parser in self._modifiers_parser: modifier_parser.analyze_content() except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing modifier {e}") - def analyze_content_functions(self): + def analyze_content_functions(self) -> None: try: for function_parser in self._functions_parser: function_parser.analyze_content() except (VariableNotFound, KeyError, ParsingError) as e: self.log_incorrect_parsing(f"Missing function {e}") - def analyze_params_modifiers(self): + def analyze_params_modifiers(self) -> None: try: elements_no_params = self._modifiers_no_params getter = lambda c: c.modifiers_parser @@ -437,7 +438,7 @@ class ContractSolc(CallerContextExpression): self.log_incorrect_parsing(f"Missing params {e}") self._modifiers_no_params = [] - def analyze_params_functions(self): + def analyze_params_functions(self) -> None: try: elements_no_params = self._functions_no_params getter = lambda c: c.functions_parser @@ -465,7 +466,7 @@ class ContractSolc(CallerContextExpression): explored_reference_id: Set[str], parser: List[FunctionSolc], all_elements: Dict[str, Function], - ): + ) -> None: elem = Cls(self._contract.compilation_unit) elem.set_contract(self._contract) underlying_function = element_parser.underlying_function @@ -566,7 +567,7 @@ class ContractSolc(CallerContextExpression): self.log_incorrect_parsing(f"Missing params {e}") return all_elements - def analyze_constant_state_variables(self): + def analyze_constant_state_variables(self) -> None: for var_parser in self._variables_parser: if var_parser.underlying_variable.is_constant: # cant parse constant expression based on function calls @@ -575,7 +576,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: LOGGER.error(e) - def analyze_state_variables(self): + def analyze_state_variables(self) -> None: try: for var_parser in self._variables_parser: var_parser.analyze(self) @@ -583,7 +584,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing state variable {e}") - def analyze_using_for(self): # pylint: disable=too-many-branches + def analyze_using_for(self) -> None: # pylint: disable=too-many-branches try: for father in self._contract.inheritance: self._contract.using_for.update(father.using_for) @@ -620,7 +621,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing using for {e}") - def _analyze_function_list(self, function_list: List, type_name: Type): + def _analyze_function_list(self, function_list: List, type_name: Type) -> None: for f in function_list: full_name_split = f["function"]["name"].split(".") if len(full_name_split) == 1: @@ -639,7 +640,7 @@ class ContractSolc(CallerContextExpression): function_name = full_name_split[2] self._analyze_library_function(library_name, function_name, type_name) - def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type): + def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type) -> None: # We check if the first part appear as alias for an import # if it is then function_name must be a top level function # otherwise it's a library function @@ -649,7 +650,7 @@ class ContractSolc(CallerContextExpression): return self._analyze_library_function(first_part, function_name, type_name) - def _analyze_top_level_function(self, function_name: str, type_name: Type): + def _analyze_top_level_function(self, function_name: str, type_name: Type) -> None: for tl_function in self.compilation_unit.functions_top_level: if tl_function.name == function_name: self._contract.using_for[type_name].append(tl_function) @@ -673,7 +674,7 @@ class ContractSolc(CallerContextExpression): f"Contract level using for: Library {library_name} - function {function_name} not found" ) - def analyze_enums(self): + def analyze_enums(self) -> None: try: for father in self._contract.inheritance: self._contract.enums_as_dict.update(father.enums_as_dict) @@ -686,7 +687,19 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing enum {e}") - def _analyze_enum(self, enum): + def _analyze_enum( + self, + enum: Dict[ + str, + Union[ + str, + int, + List[Dict[str, Union[int, str]]], + Dict[str, str], + List[Dict[str, Union[Dict[str, str], int, str]]], + ], + ], + ) -> None: # Enum can be parsed in one pass if self.is_compact_ast: name = enum["name"] @@ -710,21 +723,21 @@ class ContractSolc(CallerContextExpression): new_enum.set_offset(enum["src"], self._contract.compilation_unit) self._contract.enums_as_dict[canonicalName] = new_enum - def _analyze_struct(self, struct: StructureContractSolc): # pylint: disable=no-self-use + def _analyze_struct(self, struct: StructureContractSolc) -> None: # pylint: disable=no-self-use struct.analyze() - def analyze_structs(self): + def analyze_structs(self) -> None: try: for struct in self._structures_parser: self._analyze_struct(struct) except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing struct {e}") - def analyze_custom_errors(self): + def analyze_custom_errors(self) -> None: for custom_error in self._custom_errors_parser: custom_error.analyze_params() - def analyze_events(self): + def analyze_events(self) -> None: try: for father in self._contract.inheritance_reverse: self._contract.events_as_dict.update(father.events_as_dict) diff --git a/slither/solc_parsing/declarations/custom_error.py b/slither/solc_parsing/declarations/custom_error.py index 83f6f0e5c..8cd459262 100644 --- a/slither/solc_parsing/declarations/custom_error.py +++ b/slither/solc_parsing/declarations/custom_error.py @@ -22,7 +22,7 @@ class CustomErrorSolc(CallerContextExpression): custom_error: CustomError, custom_error_data: dict, slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: self._slither_parser: "SlitherCompilationUnitSolc" = slither_parser self._custom_error = custom_error custom_error.name = custom_error_data["name"] @@ -32,7 +32,7 @@ class CustomErrorSolc(CallerContextExpression): custom_error_data = custom_error_data["attributes"] self._custom_error_data = custom_error_data - def analyze_params(self): + def analyze_params(self) -> None: # Can be re-analyzed due to inheritance if self._params_was_analyzed: return @@ -68,7 +68,7 @@ class CustomErrorSolc(CallerContextExpression): return key return "children" - def _parse_params(self, params: Dict): + def _parse_params(self, params: Dict) -> None: assert params[self.get_key()] == "ParameterList" if self._slither_parser.is_compact_ast: diff --git a/slither/solc_parsing/declarations/event.py b/slither/solc_parsing/declarations/event.py index 1f8904fc1..6531e6536 100644 --- a/slither/solc_parsing/declarations/event.py +++ b/slither/solc_parsing/declarations/event.py @@ -16,7 +16,7 @@ class EventSolc: Event class """ - def __init__(self, event: Event, event_data: Dict, contract_parser: "ContractSolc"): + def __init__(self, event: Event, event_data: Dict, contract_parser: "ContractSolc") -> None: self._event = event event.set_contract(contract_parser.underlying_contract) @@ -43,7 +43,7 @@ class EventSolc: def is_compact_ast(self) -> bool: return self._parser_contract.is_compact_ast - def analyze(self, contract: "ContractSolc"): + def analyze(self, contract: "ContractSolc") -> None: for elem_to_parse in self._elemsNotParsed: elem = EventVariable() # Todo: check if the source offset is always here diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 6b8aca51e..9671d9bbe 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional, Union, List, TYPE_CHECKING +from typing import Dict, Optional, Union, List, TYPE_CHECKING, Tuple from slither.core.cfg.node import NodeType, link_nodes, insert_node, Node from slither.core.cfg.scope import Scope @@ -23,10 +23,10 @@ from slither.solc_parsing.variables.local_variable_init_from_tuple import ( LocalVariableInitFromTupleSolc, ) from slither.solc_parsing.variables.variable_declaration import MultipleVariablesDeclaration -from slither.solc_parsing.yul.parse_yul import YulBlock from slither.utils.expression_manipulations import SplitTernaryExpression from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.has_conditional import HasConditional +from slither.solc_parsing.yul.parse_yul import YulBlock if TYPE_CHECKING: from slither.core.expressions.expression import Expression @@ -55,7 +55,7 @@ class FunctionSolc(CallerContextExpression): function_data: Dict, contract_parser: Optional["ContractSolc"], slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: self._slither_parser: "SlitherCompilationUnitSolc" = slither_parser self._contract_parser = contract_parser self._function = function @@ -143,7 +143,7 @@ class FunctionSolc(CallerContextExpression): def _add_local_variable( self, local_var_parser: Union[LocalVariableSolc, LocalVariableInitFromTupleSolc] - ): + ) -> None: # If two local variables have the same name # We add a suffix to the new variable # This is done to prevent collision during SSA translation @@ -175,7 +175,7 @@ class FunctionSolc(CallerContextExpression): def function_not_parsed(self) -> Dict: return self._functionNotParsed - def _analyze_type(self): + def _analyze_type(self) -> None: """ Analyz the type of the function Myst be called in the constructor as the name might change according to the function's type @@ -201,7 +201,7 @@ class FunctionSolc(CallerContextExpression): if self._function.name == self._function.contract_declarer.name: self._function.function_type = FunctionType.CONSTRUCTOR - def _analyze_attributes(self): + def _analyze_attributes(self) -> None: if self.is_compact_ast: attributes = self._functionNotParsed else: @@ -445,7 +445,7 @@ class FunctionSolc(CallerContextExpression): def _parse_for_compact_ast( # pylint: disable=no-self-use self, statement: Dict - ) -> (Optional[Dict], Optional[Dict], Optional[Dict], Dict): + ) -> Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Dict]: body = statement["body"] init_expression = statement.get("initializationExpression", None) condition = statement.get("condition", None) @@ -455,7 +455,7 @@ class FunctionSolc(CallerContextExpression): def _parse_for_legacy_ast( self, statement: Dict - ) -> (Optional[Dict], Optional[Dict], Optional[Dict], Dict): + ) -> Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Dict]: # if we're using an old version of solc (anything below and including 0.4.11) or if the user # explicitly enabled compact ast, we might need to make some best-effort guesses children = statement[self.get_children("children")] @@ -1018,7 +1018,7 @@ class FunctionSolc(CallerContextExpression): return node - def _parse_block(self, block: Dict, node: NodeSolc, check_arithmetic: bool = False): + def _parse_block(self, block: Dict, node: NodeSolc, check_arithmetic: bool = False) -> NodeSolc: """ Return: Node @@ -1053,7 +1053,7 @@ class FunctionSolc(CallerContextExpression): node = self._parse_statement(statement, node, new_scope) return node - def _parse_cfg(self, cfg: Dict): + def _parse_cfg(self, cfg: Dict) -> None: assert cfg[self.get_key()] == "Block" @@ -1118,7 +1118,7 @@ class FunctionSolc(CallerContextExpression): return None - def _fix_break_node(self, node: Node): + def _fix_break_node(self, node: Node) -> None: end_node = self._find_end_loop(node, [], 0) if not end_node: @@ -1134,7 +1134,7 @@ class FunctionSolc(CallerContextExpression): node.set_sons([end_node]) end_node.add_father(node) - def _fix_continue_node(self, node: Node): + def _fix_continue_node(self, node: Node) -> None: start_node = self._find_start_loop(node, []) if not start_node: @@ -1145,14 +1145,14 @@ class FunctionSolc(CallerContextExpression): node.set_sons([start_node]) start_node.add_father(node) - def _fix_try(self, node: Node): + def _fix_try(self, node: Node) -> None: end_node = next((son for son in node.sons if son.type != NodeType.CATCH), None) if end_node: for son in node.sons: if son.type == NodeType.CATCH: self._fix_catch(son, end_node) - def _fix_catch(self, node: Node, end_node: Node): + def _fix_catch(self, node: Node, end_node: Node) -> None: if not node.sons: link_nodes(node, end_node) else: diff --git a/slither/solc_parsing/declarations/modifier.py b/slither/solc_parsing/declarations/modifier.py index e55487612..ea7af00b3 100644 --- a/slither/solc_parsing/declarations/modifier.py +++ b/slither/solc_parsing/declarations/modifier.py @@ -23,7 +23,7 @@ class ModifierSolc(FunctionSolc): function_data: Dict, contract_parser: "ContractSolc", slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: super().__init__(modifier, function_data, contract_parser, slither_parser) # _modifier is equal to _function, but keep it here to prevent # confusion for mypy in underlying_function @@ -33,7 +33,7 @@ class ModifierSolc(FunctionSolc): def underlying_function(self) -> Modifier: return self._modifier - def analyze_params(self): + def analyze_params(self) -> None: # Can be re-analyzed due to inheritance if self._params_was_analyzed: return @@ -55,7 +55,7 @@ class ModifierSolc(FunctionSolc): if params: self._parse_params(params) - def analyze_content(self): + def analyze_content(self) -> None: if self._content_was_analyzed: return diff --git a/slither/solc_parsing/declarations/structure_contract.py b/slither/solc_parsing/declarations/structure_contract.py index 9c3784ea9..c48c73c4f 100644 --- a/slither/solc_parsing/declarations/structure_contract.py +++ b/slither/solc_parsing/declarations/structure_contract.py @@ -23,7 +23,7 @@ class StructureContractSolc: # pylint: disable=too-few-public-methods st: Structure, struct: Dict, contract_parser: "ContractSolc", - ): + ) -> None: if contract_parser.is_compact_ast: name = struct["name"] @@ -45,7 +45,7 @@ class StructureContractSolc: # pylint: disable=too-few-public-methods self._elemsNotParsed = children - def analyze(self): + def analyze(self) -> None: for elem_to_parse in self._elemsNotParsed: elem = StructureVariable() elem.set_structure(self._structure) diff --git a/slither/solc_parsing/declarations/structure_top_level.py b/slither/solc_parsing/declarations/structure_top_level.py index 1597ad44e..6dcca19d4 100644 --- a/slither/solc_parsing/declarations/structure_top_level.py +++ b/slither/solc_parsing/declarations/structure_top_level.py @@ -25,7 +25,7 @@ class StructureTopLevelSolc(CallerContextExpression): # pylint: disable=too-few st: StructureTopLevel, struct: Dict, slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: if slither_parser.is_compact_ast: name = struct["name"] @@ -47,7 +47,7 @@ class StructureTopLevelSolc(CallerContextExpression): # pylint: disable=too-few self._elemsNotParsed = children - def analyze(self): + def analyze(self) -> None: for elem_to_parse in self._elemsNotParsed: elem = StructureVariable() elem.set_structure(self._structure) diff --git a/slither/solc_parsing/declarations/using_for_top_level.py b/slither/solc_parsing/declarations/using_for_top_level.py index 16e3666b0..b16fadc40 100644 --- a/slither/solc_parsing/declarations/using_for_top_level.py +++ b/slither/solc_parsing/declarations/using_for_top_level.py @@ -77,7 +77,7 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- first_part: str, function_name: str, type_name: Union[TypeAliasTopLevel, UserDefinedType], - ): + ) -> None: # We check if the first part appear as alias for an import # if it is then function_name must be a top level function # otherwise it's a library function @@ -132,7 +132,9 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- f"Error when propagating global using for {type_name} {type(type_name)}" ) - def _propagate_global_UserDefinedType(self, scope: FileScope, type_name: UserDefinedType): + def _propagate_global_UserDefinedType( + self, scope: FileScope, type_name: UserDefinedType + ) -> None: underlying = type_name.type if isinstance(underlying, StructureTopLevel): for struct in scope.structures.values(): diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index 4739c60b3..ea433a921 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -1,19 +1,12 @@ import logging import re -from typing import Dict, TYPE_CHECKING +from typing import Union, Dict, TYPE_CHECKING +import slither.core.expressions.type_conversion from slither.core.declarations.solidity_variables import ( SOLIDITY_VARIABLES_COMPOSED, SolidityVariableComposed, ) -from slither.core.expressions.assignment_operation import ( - AssignmentOperation, - AssignmentOperationType, -) -from slither.core.expressions.binary_operation import ( - BinaryOperation, - BinaryOperationType, -) from slither.core.expressions import ( CallExpression, ConditionalExpression, @@ -32,6 +25,14 @@ from slither.core.expressions import ( UnaryOperation, UnaryOperationType, ) +from slither.core.expressions.assignment_operation import ( + AssignmentOperation, + AssignmentOperationType, +) +from slither.core.expressions.binary_operation import ( + BinaryOperation, + BinaryOperationType, +) from slither.core.solidity_types import ( ArrayType, ElementaryType, @@ -42,8 +43,12 @@ from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.expressions.find_variable import find_variable from slither.solc_parsing.solidity_types.type_parsing import UnknownType, parse_type + if TYPE_CHECKING: from slither.core.expressions.expression import Expression + from slither.solc_parsing.declarations.contract import ContractSolc + from slither.solc_parsing.declarations.function import FunctionSolc + from slither.solc_parsing.variables.top_level_variable import TopLevelVariableSolc logger = logging.getLogger("ExpressionParsing") @@ -98,8 +103,13 @@ def filter_name(value: str) -> str: ################################################################################### ################################################################################### - -def parse_call(expression: Dict, caller_context): # pylint: disable=too-many-statements +# pylint: disable=too-many-statements +def parse_call( + expression: Dict, caller_context: Union["FunctionSolc", "ContractSolc", "TopLevelVariableSolc"] +) -> Union[ + slither.core.expressions.call_expression.CallExpression, + slither.core.expressions.type_conversion.TypeConversion, +]: src = expression["src"] if caller_context.is_compact_ast: attributes = expression @@ -223,8 +233,7 @@ def _parse_elementary_type_name_expression( if TYPE_CHECKING: - - from slither.core.scope.scope import FileScope + pass def parse_expression(expression: Dict, caller_context: CallerContextExpression) -> "Expression": diff --git a/slither/solc_parsing/expressions/find_variable.py b/slither/solc_parsing/expressions/find_variable.py index 471f1a750..32f5afc58 100644 --- a/slither/solc_parsing/expressions/find_variable.py +++ b/slither/solc_parsing/expressions/find_variable.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: # CallerContext =Union["ContractSolc", "FunctionSolc", "CustomErrorSolc", "StructureTopLevelSolc"] -def _get_pointer_name(variable: Variable): +def _get_pointer_name(variable: Variable) -> Optional[str]: curr_type = variable.type while isinstance(curr_type, (ArrayType, MappingType)): if isinstance(curr_type, ArrayType): diff --git a/slither/solc_parsing/slither_compilation_unit_solc.py b/slither/solc_parsing/slither_compilation_unit_solc.py index 6e7de2db9..b1c2387f0 100644 --- a/slither/solc_parsing/slither_compilation_unit_solc.py +++ b/slither/solc_parsing/slither_compilation_unit_solc.py @@ -19,6 +19,7 @@ from slither.core.scope.scope import FileScope from slither.core.solidity_types import ElementaryType, TypeAliasTopLevel from slither.core.variables.top_level_variable import TopLevelVariable from slither.exceptions import SlitherException +from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.custom_error import CustomErrorSolc from slither.solc_parsing.declarations.function import FunctionSolc @@ -26,7 +27,6 @@ from slither.solc_parsing.declarations.structure_top_level import StructureTopLe from slither.solc_parsing.declarations.using_for_top_level import UsingForTopLevelSolc from slither.solc_parsing.exceptions import VariableNotFound from slither.solc_parsing.variables.top_level_variable import TopLevelVariableSolc -from slither.solc_parsing.declarations.caller_context import CallerContextExpression logging.basicConfig() logger = logging.getLogger("SlitherSolcParsing") @@ -68,7 +68,7 @@ def _handle_import_aliases( class SlitherCompilationUnitSolc(CallerContextExpression): # pylint: disable=no-self-use,too-many-instance-attributes - def __init__(self, compilation_unit: SlitherCompilationUnit): + def __init__(self, compilation_unit: SlitherCompilationUnit) -> None: super().__init__() self._contracts_by_id: Dict[int, ContractSolc] = {} @@ -98,7 +98,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): def all_functions_and_modifiers_parser(self) -> List[FunctionSolc]: return self._all_functions_and_modifier_parser - def add_function_or_modifier_parser(self, f: FunctionSolc): + def add_function_or_modifier_parser(self, f: FunctionSolc) -> None: self._all_functions_and_modifier_parser.append(f) @property @@ -163,7 +163,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): return True return False - def _parse_enum(self, top_level_data: Dict, filename: str): + def _parse_enum(self, top_level_data: Dict, filename: str) -> None: if self.is_compact_ast: name = top_level_data["name"] canonicalName = top_level_data["canonicalName"] @@ -192,9 +192,8 @@ class SlitherCompilationUnitSolc(CallerContextExpression): enum.set_offset(top_level_data["src"], self._compilation_unit) self._compilation_unit.enums_top_level.append(enum) - def parse_top_level_from_loaded_json( - self, data_loaded: Dict, filename: str - ): # pylint: disable=too-many-branches,too-many-statements,too-many-locals + # pylint: disable=too-many-branches,too-many-statements,too-many-locals + def parse_top_level_from_loaded_json(self, data_loaded: Dict, filename: str) -> None: if "nodeType" in data_loaded: self._is_compact_ast = True @@ -342,7 +341,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): else: raise SlitherException(f"Top level {top_level_data[self.get_key()]} not supported") - def _parse_source_unit(self, data: Dict, filename: str): + def _parse_source_unit(self, data: Dict, filename: str) -> None: if data[self.get_key()] != "SourceUnit": return # handle solc prior 0.3.6 @@ -392,7 +391,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): def analyzed(self) -> bool: return self._analyzed - def parse_contracts(self): # pylint: disable=too-many-statements,too-many-branches + def parse_contracts(self) -> None: # pylint: disable=too-many-statements,too-many-branches if not self._underlying_contract_to_parser: logger.info( f"No contract were found in {self._compilation_unit.core.filename}, check the correct compilation" @@ -523,7 +522,7 @@ Please rename it, this name is reserved for Slither's internals""" self._parsed = True - def analyze_contracts(self): # pylint: disable=too-many-statements,too-many-branches + def analyze_contracts(self) -> None: # pylint: disable=too-many-statements,too-many-branches if not self._parsed: raise SlitherException("Parse the contract before running analyses") self._convert_to_slithir() @@ -532,7 +531,7 @@ Please rename it, this name is reserved for Slither's internals""" self._compilation_unit.compute_storage_layout() self._analyzed = True - def _analyze_all_enums(self, contracts_to_be_analyzed: List[ContractSolc]): + def _analyze_all_enums(self, contracts_to_be_analyzed: List[ContractSolc]) -> None: while contracts_to_be_analyzed: contract = contracts_to_be_analyzed[0] @@ -551,7 +550,7 @@ Please rename it, this name is reserved for Slither's internals""" self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc], - ): + ) -> None: for lib in libraries: self._parse_struct_var_modifiers_functions(lib) @@ -578,7 +577,7 @@ Please rename it, this name is reserved for Slither's internals""" self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc], - ): + ) -> None: for lib in libraries: self._analyze_struct_events(lib) @@ -608,7 +607,7 @@ Please rename it, this name is reserved for Slither's internals""" self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc], - ): + ) -> None: for lib in libraries: self._analyze_variables_modifiers_functions(lib) @@ -633,7 +632,7 @@ Please rename it, this name is reserved for Slither's internals""" def _analyze_using_for( self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc] - ): + ) -> None: self._analyze_top_level_using_for() for lib in libraries: @@ -654,12 +653,12 @@ Please rename it, this name is reserved for Slither's internals""" else: contracts_to_be_analyzed += [contract] - def _analyze_enums(self, contract: ContractSolc): + def _analyze_enums(self, contract: ContractSolc) -> None: # Enum must be analyzed first contract.analyze_enums() contract.set_is_analyzed(True) - def _parse_struct_var_modifiers_functions(self, contract: ContractSolc): + def _parse_struct_var_modifiers_functions(self, contract: ContractSolc) -> None: contract.parse_structs() # struct can refer another struct contract.parse_state_variables() contract.parse_modifiers() @@ -667,7 +666,7 @@ Please rename it, this name is reserved for Slither's internals""" contract.parse_custom_errors() contract.set_is_analyzed(True) - def _analyze_struct_events(self, contract: ContractSolc): + def _analyze_struct_events(self, contract: ContractSolc) -> None: contract.analyze_constant_state_variables() @@ -680,41 +679,41 @@ Please rename it, this name is reserved for Slither's internals""" contract.set_is_analyzed(True) - def _analyze_top_level_structures(self): + def _analyze_top_level_structures(self) -> None: try: for struct in self._structures_top_level_parser: struct.analyze() except (VariableNotFound, KeyError) as e: raise SlitherException(f"Missing struct {e} during top level structure analyze") from e - def _analyze_top_level_variables(self): + def _analyze_top_level_variables(self) -> None: try: for var in self._variables_top_level_parser: var.analyze(var) except (VariableNotFound, KeyError) as e: raise SlitherException(f"Missing {e} during variable analyze") from e - def _analyze_params_top_level_function(self): + def _analyze_params_top_level_function(self) -> None: for func_parser in self._functions_top_level_parser: func_parser.analyze_params() self._compilation_unit.add_function(func_parser.underlying_function) - def _analyze_top_level_using_for(self): + def _analyze_top_level_using_for(self) -> None: for using_for in self._using_for_top_level_parser: using_for.analyze() - def _analyze_params_custom_error(self): + def _analyze_params_custom_error(self) -> None: for custom_error_parser in self._custom_error_parser: custom_error_parser.analyze_params() - def _analyze_content_top_level_function(self): + def _analyze_content_top_level_function(self) -> None: try: for func_parser in self._functions_top_level_parser: func_parser.analyze_content() except (VariableNotFound, KeyError) as e: raise SlitherException(f"Missing {e} during top level function analyze") from e - def _analyze_variables_modifiers_functions(self, contract: ContractSolc): + def _analyze_variables_modifiers_functions(self, contract: ContractSolc) -> None: # State variables, modifiers and functions can refer to anything contract.analyze_params_modifiers() @@ -730,7 +729,7 @@ Please rename it, this name is reserved for Slither's internals""" contract.set_is_analyzed(True) - def _convert_to_slithir(self): + def _convert_to_slithir(self) -> None: for contract in self._compilation_unit.contracts: contract.add_constructor_variables() diff --git a/slither/solc_parsing/solidity_types/type_parsing.py b/slither/solc_parsing/solidity_types/type_parsing.py index 21ce8e02a..e12290722 100644 --- a/slither/solc_parsing/solidity_types/type_parsing.py +++ b/slither/solc_parsing/solidity_types/type_parsing.py @@ -22,7 +22,7 @@ from slither.solc_parsing.exceptions import ParsingError from slither.solc_parsing.expressions.expression_parsing import CallerContextExpression if TYPE_CHECKING: - from slither.core.declarations import Structure, Enum + from slither.core.declarations import Structure, Enum, Function from slither.core.declarations.contract import Contract from slither.core.compilation_unit import SlitherCompilationUnit from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc @@ -33,11 +33,11 @@ logger = logging.getLogger("TypeParsing") class UnknownType: # pylint: disable=too-few-public-methods - def __init__(self, name): + def __init__(self, name: str) -> None: self._name = name @property - def name(self): + def name(self) -> str: return self._name @@ -195,7 +195,7 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t return UserDefinedType(var_type) -def _add_type_references(type_found: Type, src: str, sl: "SlitherCompilationUnit"): +def _add_type_references(type_found: Type, src: str, sl: "SlitherCompilationUnit") -> None: if isinstance(type_found, UserDefinedType): type_found.type.add_reference_from_raw_source(src, sl) @@ -236,6 +236,7 @@ def parse_type( sl: "SlitherCompilationUnit" renaming: Dict[str, str] user_defined_types: Dict[str, TypeAlias] + enums_direct_access: List["Enum"] = [] # Note: for convenicence top level functions use the same parser than function in contract # but contract_parser is set to None if isinstance(caller_context, SlitherCompilationUnitSolc) or ( @@ -257,7 +258,7 @@ def parse_type( all_structuress = [c.structures for c in sl.contracts] all_structures = [item for sublist in all_structuress for item in sublist] all_structures += structures_direct_access - enums_direct_access = sl.enums_top_level + enums_direct_access += sl.enums_top_level all_enumss = [c.enums for c in sl.contracts] all_enums = [item for sublist in all_enumss for item in sublist] all_enums += enums_direct_access @@ -319,7 +320,7 @@ def parse_type( all_structuress = [c.structures for c in contract.file_scope.contracts.values()] all_structures = [item for sublist in all_structuress for item in sublist] all_structures += contract.file_scope.structures.values() - enums_direct_access: List["Enum"] = contract.enums + enums_direct_access += contract.enums enums_direct_access += contract.file_scope.enums.values() all_enumss = [c.enums for c in contract.file_scope.contracts.values()] all_enums = [item for sublist in all_enumss for item in sublist] diff --git a/slither/solc_parsing/variables/event_variable.py b/slither/solc_parsing/variables/event_variable.py index 664b7d057..fe30b8a3a 100644 --- a/slither/solc_parsing/variables/event_variable.py +++ b/slither/solc_parsing/variables/event_variable.py @@ -14,7 +14,7 @@ class EventVariableSolc(VariableDeclarationSolc): assert isinstance(self._variable, EventVariable) return self._variable - def _analyze_variable_attributes(self, attributes: Dict): + def _analyze_variable_attributes(self, attributes: Dict) -> None: """ Analyze event variable attributes :param attributes: The event variable attributes to parse. diff --git a/slither/solc_parsing/variables/local_variable.py b/slither/solc_parsing/variables/local_variable.py index b9617a59c..cd9030d58 100644 --- a/slither/solc_parsing/variables/local_variable.py +++ b/slither/solc_parsing/variables/local_variable.py @@ -5,7 +5,7 @@ from slither.core.variables.local_variable import LocalVariable class LocalVariableSolc(VariableDeclarationSolc): - def __init__(self, variable: LocalVariable, variable_data: Dict): + def __init__(self, variable: LocalVariable, variable_data: Dict) -> None: super().__init__(variable, variable_data) @property @@ -14,7 +14,7 @@ class LocalVariableSolc(VariableDeclarationSolc): assert isinstance(self._variable, LocalVariable) return self._variable - def _analyze_variable_attributes(self, attributes: Dict): + def _analyze_variable_attributes(self, attributes: Dict) -> None: """' Variable Location Can be storage/memory or default diff --git a/slither/solc_parsing/variables/local_variable_init_from_tuple.py b/slither/solc_parsing/variables/local_variable_init_from_tuple.py index 72c57281e..1a551c695 100644 --- a/slither/solc_parsing/variables/local_variable_init_from_tuple.py +++ b/slither/solc_parsing/variables/local_variable_init_from_tuple.py @@ -5,7 +5,9 @@ from slither.core.variables.local_variable_init_from_tuple import LocalVariableI class LocalVariableInitFromTupleSolc(VariableDeclarationSolc): - def __init__(self, variable: LocalVariableInitFromTuple, variable_data: Dict, index: int): + def __init__( + self, variable: LocalVariableInitFromTuple, variable_data: Dict, index: int + ) -> None: super().__init__(variable, variable_data) variable.tuple_index = index diff --git a/slither/solc_parsing/variables/state_variable.py b/slither/solc_parsing/variables/state_variable.py index 3fa132077..a9c0ff730 100644 --- a/slither/solc_parsing/variables/state_variable.py +++ b/slither/solc_parsing/variables/state_variable.py @@ -5,7 +5,7 @@ from slither.core.variables.state_variable import StateVariable class StateVariableSolc(VariableDeclarationSolc): - def __init__(self, variable: StateVariable, variable_data: Dict): + def __init__(self, variable: StateVariable, variable_data: Dict) -> None: super().__init__(variable, variable_data) @property diff --git a/slither/solc_parsing/variables/top_level_variable.py b/slither/solc_parsing/variables/top_level_variable.py index 6c24c3bdf..56eb79c46 100644 --- a/slither/solc_parsing/variables/top_level_variable.py +++ b/slither/solc_parsing/variables/top_level_variable.py @@ -15,7 +15,7 @@ class TopLevelVariableSolc(VariableDeclarationSolc, CallerContextExpression): variable: TopLevelVariable, variable_data: Dict, slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: super().__init__(variable, variable_data) self._slither_parser = slither_parser diff --git a/slither/solc_parsing/variables/variable_declaration.py b/slither/solc_parsing/variables/variable_declaration.py index 119604ca4..d21d89875 100644 --- a/slither/solc_parsing/variables/variable_declaration.py +++ b/slither/solc_parsing/variables/variable_declaration.py @@ -1,6 +1,6 @@ import logging import re -from typing import Dict +from typing import Dict, Optional from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.expressions.expression_parsing import parse_expression @@ -30,9 +30,8 @@ class MultipleVariablesDeclaration(Exception): class VariableDeclarationSolc: - def __init__( - self, variable: Variable, variable_data: Dict - ): # pylint: disable=too-many-branches + # pylint: disable=too-many-branches + def __init__(self, variable: Variable, variable_data: Dict) -> None: """ A variable can be declared through a statement, or directly. If it is through a statement, the following children may contain @@ -104,7 +103,7 @@ class VariableDeclarationSolc: """ return self._reference_id - def _handle_comment(self, attributes: Dict): + def _handle_comment(self, attributes: Dict) -> None: if "documentation" in attributes and "text" in attributes["documentation"]: candidates = attributes["documentation"]["text"].split(",") @@ -121,13 +120,15 @@ class VariableDeclarationSolc: self._variable.write_protection = [] self._variable.write_protection.append(write_protection.group(1)) - def _analyze_variable_attributes(self, attributes: Dict): + def _analyze_variable_attributes(self, attributes: Dict) -> None: if "visibility" in attributes: self._variable.visibility = attributes["visibility"] else: self._variable.visibility = "internal" - def _init_from_declaration(self, var: Dict, init: bool): # pylint: disable=too-many-branches + def _init_from_declaration( + self, var: Dict, init: Optional[bool] + ) -> None: # pylint: disable=too-many-branches if self._is_compact_ast: attributes = var self._typeName = attributes["typeDescriptions"]["typeString"] @@ -200,7 +201,7 @@ class VariableDeclarationSolc: self._variable.initialized = True self._initializedNotParsed = var["children"][1] - def analyze(self, caller_context: CallerContextExpression): + def analyze(self, caller_context: CallerContextExpression) -> None: # Can be re-analyzed due to inheritance if self._was_analyzed: return diff --git a/slither/solc_parsing/yul/evm_functions.py b/slither/solc_parsing/yul/evm_functions.py index 41c150765..dfb52a244 100644 --- a/slither/solc_parsing/yul/evm_functions.py +++ b/slither/solc_parsing/yul/evm_functions.py @@ -225,7 +225,7 @@ function_args = { } -def format_function_descriptor(name): +def format_function_descriptor(name: str) -> str: if name not in function_args: return name + "()" diff --git a/slither/solc_parsing/yul/parse_yul.py b/slither/solc_parsing/yul/parse_yul.py index 6be4803ca..35d5cdd9d 100644 --- a/slither/solc_parsing/yul/parse_yul.py +++ b/slither/solc_parsing/yul/parse_yul.py @@ -43,7 +43,7 @@ from slither.visitors.expression.write_var import WriteVar class YulNode: - def __init__(self, node: Node, scope: "YulScope"): + def __init__(self, node: Node, scope: "YulScope") -> None: self._node = node self._scope = scope self._unparsed_expression: Optional[Dict] = None @@ -99,7 +99,7 @@ class YulNode: ] -def link_underlying_nodes(node1: YulNode, node2: YulNode): +def link_underlying_nodes(node1: YulNode, node2: YulNode) -> None: link_nodes(node1.underlying_node, node2.underlying_node) @@ -191,7 +191,7 @@ class YulScope(metaclass=abc.ABCMeta): class YulLocalVariable: # pylint: disable=too-few-public-methods __slots__ = ["_variable", "_root"] - def __init__(self, var: LocalVariable, root: YulScope, ast: Dict): + def __init__(self, var: LocalVariable, root: YulScope, ast: Dict) -> None: assert ast["nodeType"] == "YulTypedName" self._variable = var @@ -215,7 +215,7 @@ class YulFunction(YulScope): def __init__( self, func: Function, root: YulScope, ast: Dict, node_scope: Union[Function, Scope] - ): + ) -> None: super().__init__(root.contract, root.id + [ast["name"]], parent_func=root.parent_func) assert ast["nodeType"] == "YulFunctionDefinition" @@ -272,7 +272,7 @@ class YulFunction(YulScope): for node in self._nodes: node.analyze_expressions() - def new_node(self, node_type, src) -> YulNode: + def new_node(self, node_type: NodeType, src: str) -> YulNode: if self._function: node = self._function.new_node(node_type, src, self.node_scope) else: @@ -299,7 +299,7 @@ class YulBlock(YulScope): entrypoint: Node, yul_id: List[str], node_scope: Union[Scope, Function], - ): + ) -> None: super().__init__(contract, yul_id, entrypoint.function) self._entrypoint: YulNode = YulNode(entrypoint, self) @@ -884,7 +884,9 @@ def vars_to_typestr(rets: List[Expression]) -> str: return f"tuple({','.join(str(ret.type) for ret in rets)})" -def vars_to_val(vars_to_convert): +def vars_to_val( + vars_to_convert: List[Identifier], +) -> Identifier: if len(vars_to_convert) == 1: return vars_to_convert[0] return TupleExpression(vars_to_convert) diff --git a/slither/tools/doctor/__main__.py b/slither/tools/doctor/__main__.py index b9b4c5497..f401781a7 100644 --- a/slither/tools/doctor/__main__.py +++ b/slither/tools/doctor/__main__.py @@ -26,7 +26,7 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def main(): +def main() -> None: # log on stdout to keep output in order logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) diff --git a/slither/tools/erc_conformance/__main__.py b/slither/tools/erc_conformance/__main__.py index ef594a7c6..1c9224eac 100644 --- a/slither/tools/erc_conformance/__main__.py +++ b/slither/tools/erc_conformance/__main__.py @@ -1,15 +1,17 @@ import argparse import logging from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Dict, List, Callable from crytic_compile import cryticparser + from slither import Slither +from slither.core.declarations import Contract from slither.utils.erc import ERCS from slither.utils.output import output_to_json -from .erc.ercs import generic_erc_checks -from .erc.erc20 import check_erc20 from .erc.erc1155 import check_erc1155 +from .erc.erc20 import check_erc20 +from .erc.ercs import generic_erc_checks logging.basicConfig() logging.getLogger("Slither").setLevel(logging.INFO) @@ -24,7 +26,10 @@ logger.addHandler(ch) logger.handlers[0].setFormatter(formatter) logger.propagate = False -ADDITIONAL_CHECKS = {"ERC20": check_erc20, "ERC1155": check_erc1155} +ADDITIONAL_CHECKS: Dict[str, Callable[[Contract, Dict[str, List]], Dict[str, List]]] = { + "ERC20": check_erc20, + "ERC1155": check_erc1155, +} def parse_args() -> argparse.Namespace: diff --git a/slither/tools/erc_conformance/erc/erc1155.py b/slither/tools/erc_conformance/erc/erc1155.py index fceb4e242..34bb18bfb 100644 --- a/slither/tools/erc_conformance/erc/erc1155.py +++ b/slither/tools/erc_conformance/erc/erc1155.py @@ -1,12 +1,14 @@ import logging +from typing import Dict, List, Optional +from slither.core.declarations import Contract from slither.slithir.operations import EventCall from slither.utils import output logger = logging.getLogger("Slither-conformance") -def events_safeBatchTransferFrom(contract, ret): +def events_safeBatchTransferFrom(contract: Contract, ret: Dict[str, List]) -> None: function = contract.get_function_from_signature( "safeBatchTransferFrom(address,address,uint256[],uint256[],bytes)" ) @@ -44,7 +46,9 @@ def events_safeBatchTransferFrom(contract, ret): ) -def check_erc1155(contract, ret, explored=None): +def check_erc1155( + contract: Contract, ret: Dict[str, List], explored: Optional[bool] = None +) -> Dict[str, List]: if explored is None: explored = set() diff --git a/slither/tools/erc_conformance/erc/erc20.py b/slither/tools/erc_conformance/erc/erc20.py index 720b08322..6ee243515 100644 --- a/slither/tools/erc_conformance/erc/erc20.py +++ b/slither/tools/erc_conformance/erc/erc20.py @@ -1,11 +1,13 @@ import logging +from typing import Dict, List, Optional +from slither.core.declarations import Contract from slither.utils import output logger = logging.getLogger("Slither-conformance") -def approval_race_condition(contract, ret): +def approval_race_condition(contract: Contract, ret: Dict[str, List]) -> None: increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)") if not increaseAllowance: @@ -27,7 +29,9 @@ def approval_race_condition(contract, ret): ) -def check_erc20(contract, ret, explored=None): +def check_erc20( + contract: Contract, ret: Dict[str, List], explored: Optional[bool] = None +) -> Dict[str, List]: if explored is None: explored = set() diff --git a/slither/tools/kspec_coverage/analysis.py b/slither/tools/kspec_coverage/analysis.py index 3d513d22f..763939e25 100755 --- a/slither/tools/kspec_coverage/analysis.py +++ b/slither/tools/kspec_coverage/analysis.py @@ -1,11 +1,13 @@ -import re import logging -from typing import Set, Tuple +import re +from argparse import Namespace +from typing import Set, Tuple, List, Dict, Union, Optional, Callable -from slither.core.declarations import Function -from slither.core.variables.variable import Variable -from slither.utils.colors import yellow, green, red +from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.declarations import FunctionContract +from slither.core.variables.state_variable import StateVariable from slither.utils import output +from slither.utils.colors import yellow, green, red logging.basicConfig(level=logging.WARNING) logger = logging.getLogger("Slither.kspec") @@ -54,13 +56,15 @@ def _get_all_covered_kspec_functions(target: str) -> Set[Tuple[str, str]]: return covered_functions -def _get_slither_functions(slither): +def _get_slither_functions( + slither: SlitherCompilationUnit, +) -> Dict[Tuple[str, str], Union[FunctionContract, StateVariable]]: # Use contract == contract_declarer to avoid dupplicate - all_functions_declared = [ + all_functions_declared: List[Union[FunctionContract, StateVariable]] = [ f for f in slither.functions if ( - f.contract == f.contract_declarer + (isinstance(f, FunctionContract) and f.contract == f.contract_declarer) and f.is_implemented and not f.is_constructor and not f.is_constructor_variables @@ -79,7 +83,12 @@ def _get_slither_functions(slither): return slither_functions -def _generate_output(kspec, message, color, generate_json): +def _generate_output( + kspec: List[Union[FunctionContract, StateVariable]], + message: str, + color: Callable[[str], str], + generate_json: bool, +) -> Optional[Dict]: info = "" for function in kspec: info += f"{message} {function.contract.name}.{function.full_name}\n" @@ -94,7 +103,9 @@ def _generate_output(kspec, message, color, generate_json): return None -def _generate_output_unresolved(kspec, message, color, generate_json): +def _generate_output_unresolved( + kspec: Set[Tuple[str, str]], message: str, color: Callable[[str], str], generate_json: bool +) -> Optional[Dict]: info = "" for contract, function in kspec: info += f"{message} {contract}.{function}\n" @@ -107,17 +118,19 @@ def _generate_output_unresolved(kspec, message, color, generate_json): return None -def _run_coverage_analysis(args, slither, kspec_functions): +def _run_coverage_analysis( + args: Namespace, slither: SlitherCompilationUnit, kspec_functions: Set[Tuple[str, str]] +) -> None: # Collect all slither functions slither_functions = _get_slither_functions(slither) # Determine which klab specs were not resolved. slither_functions_set = set(slither_functions) kspec_functions_resolved = kspec_functions & slither_functions_set - kspec_functions_unresolved = kspec_functions - kspec_functions_resolved + kspec_functions_unresolved: Set[Tuple[str, str]] = kspec_functions - kspec_functions_resolved - kspec_missing = [] - kspec_present = [] + kspec_missing: List[Union[FunctionContract, StateVariable]] = [] + kspec_present: List[Union[FunctionContract, StateVariable]] = [] for slither_func_desc in sorted(slither_functions_set): slither_func = slither_functions[slither_func_desc] @@ -130,13 +143,13 @@ def _run_coverage_analysis(args, slither, kspec_functions): logger.info("## Check for functions coverage") json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json) json_kspec_missing_functions = _generate_output( - [f for f in kspec_missing if isinstance(f, Function)], + [f for f in kspec_missing if isinstance(f, FunctionContract)], "[ ] (Missing function)", red, args.json, ) json_kspec_missing_variables = _generate_output( - [f for f in kspec_missing if isinstance(f, Variable)], + [f for f in kspec_missing if isinstance(f, StateVariable)], "[ ] (Missing variable)", yellow, args.json, @@ -159,11 +172,11 @@ def _run_coverage_analysis(args, slither, kspec_functions): ) -def run_analysis(args, slither, kspec_arg): +def run_analysis(args: Namespace, slither: SlitherCompilationUnit, kspec_arg: str) -> None: # Get all of our kspec'd functions (tuple(contract_name, function_name)). if "," in kspec_arg: kspecs = kspec_arg.split(",") - kspec_functions = set() + kspec_functions: Set[Tuple[str, str]] = set() for kspec in kspecs: kspec_functions |= _get_all_covered_kspec_functions(kspec) else: diff --git a/slither/tools/mutator/__main__.py b/slither/tools/mutator/__main__.py index 78b86d681..27e396d0b 100644 --- a/slither/tools/mutator/__main__.py +++ b/slither/tools/mutator/__main__.py @@ -72,7 +72,7 @@ class ListMutators(argparse.Action): # pylint: disable=too-few-public-methods ################################################################################### -def main(): +def main() -> None: args = parse_args() diff --git a/slither/tools/mutator/mutators/MIA.py b/slither/tools/mutator/mutators/MIA.py index 54ca0ec1c..405888f8b 100644 --- a/slither/tools/mutator/mutators/MIA.py +++ b/slither/tools/mutator/mutators/MIA.py @@ -1,3 +1,5 @@ +from typing import Dict + from slither.core.cfg.node import NodeType from slither.formatters.utils.patches import create_patch from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass @@ -9,13 +11,13 @@ class MIA(AbstractMutator): # pylint: disable=too-few-public-methods FAULTCLASS = FaultClass.Checking FAULTNATURE = FaultNature.Missing - def _mutate(self): + def _mutate(self) -> Dict: - result = {} + result: Dict = {} for contract in self.slither.contracts: - for function in contract.functions_declared + contract.modifiers_declared: + for function in contract.functions_declared + list(contract.modifiers_declared): for node in function.nodes: if node.type == NodeType.IF: diff --git a/slither/tools/mutator/mutators/MVIE.py b/slither/tools/mutator/mutators/MVIE.py index 8f8cc11bf..a16a8252e 100644 --- a/slither/tools/mutator/mutators/MVIE.py +++ b/slither/tools/mutator/mutators/MVIE.py @@ -1,4 +1,7 @@ +from typing import Dict + from slither.core.expressions import Literal +from slither.core.variables.variable import Variable from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass from slither.tools.mutator.utils.generic_patching import remove_assignement @@ -9,10 +12,10 @@ class MVIE(AbstractMutator): # pylint: disable=too-few-public-methods FAULTCLASS = FaultClass.Assignement FAULTNATURE = FaultNature.Missing - def _mutate(self): - - result = {} + def _mutate(self) -> Dict: + result: Dict = {} + variable: Variable for contract in self.slither.contracts: # Create fault for state variables declaration @@ -25,7 +28,7 @@ class MVIE(AbstractMutator): # pylint: disable=too-few-public-methods if not isinstance(variable.expression, Literal): remove_assignement(variable, contract, result) - for function in contract.functions_declared + contract.modifiers_declared: + for function in contract.functions_declared + list(contract.modifiers_declared): for variable in function.local_variables: if variable.initialized and not isinstance(variable.expression, Literal): remove_assignement(variable, contract, result) diff --git a/slither/tools/mutator/mutators/MVIV.py b/slither/tools/mutator/mutators/MVIV.py index dac34da28..d4a7c5486 100644 --- a/slither/tools/mutator/mutators/MVIV.py +++ b/slither/tools/mutator/mutators/MVIV.py @@ -1,4 +1,7 @@ +from typing import Dict + from slither.core.expressions import Literal +from slither.core.variables.variable import Variable from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator, FaultNature, FaultClass from slither.tools.mutator.utils.generic_patching import remove_assignement @@ -9,9 +12,10 @@ class MVIV(AbstractMutator): # pylint: disable=too-few-public-methods FAULTCLASS = FaultClass.Assignement FAULTNATURE = FaultNature.Missing - def _mutate(self): + def _mutate(self) -> Dict: - result = {} + result: Dict = {} + variable: Variable for contract in self.slither.contracts: @@ -25,7 +29,7 @@ class MVIV(AbstractMutator): # pylint: disable=too-few-public-methods if isinstance(variable.expression, Literal): remove_assignement(variable, contract, result) - for function in contract.functions_declared + contract.modifiers_declared: + for function in contract.functions_declared + list(contract.modifiers_declared): for variable in function.local_variables: if variable.initialized and isinstance(variable.expression, Literal): remove_assignement(variable, contract, result) diff --git a/slither/tools/mutator/utils/command_line.py b/slither/tools/mutator/utils/command_line.py index 9799fd488..840976ccf 100644 --- a/slither/tools/mutator/utils/command_line.py +++ b/slither/tools/mutator/utils/command_line.py @@ -1,7 +1,10 @@ +from typing import List, Type + +from slither.tools.mutator.mutators.abstract_mutator import AbstractMutator from slither.utils.myprettytable import MyPrettyTable -def output_mutators(mutators_classes): +def output_mutators(mutators_classes: List[Type[AbstractMutator]]) -> None: mutators_list = [] for detector in mutators_classes: argument = detector.NAME diff --git a/slither/tools/similarity/encode.py b/slither/tools/similarity/encode.py index d08086282..48700ec4a 100644 --- a/slither/tools/similarity/encode.py +++ b/slither/tools/similarity/encode.py @@ -1,5 +1,6 @@ import logging import os +from typing import Optional, Tuple, List from slither import Slither from slither.core.declarations import ( @@ -60,7 +61,7 @@ slither_logger = logging.getLogger("Slither") slither_logger.setLevel(logging.CRITICAL) -def parse_target(target): +def parse_target(target: Optional[str]) -> Tuple[Optional[str], Optional[str]]: if target is None: return None, None @@ -68,9 +69,9 @@ def parse_target(target): if len(parts) == 1: return None, parts[0] if len(parts) == 2: - return parts + return parts[0], parts[1] simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") - return None + return None, None def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs): @@ -88,7 +89,9 @@ def load_and_encode(infile: str, vmodel, ext=None, nsamples=None, **kwargs): return r -def load_contracts(dirname, ext=None, nsamples=None): +def load_contracts( + dirname: str, ext: Optional[str] = None, nsamples: Optional[int] = None +) -> List[str]: r = [] walk = list(os.walk(dirname)) for x, y, files in walk: diff --git a/slither/tools/similarity/test.py b/slither/tools/similarity/test.py index 76229d5bf..7d42c4a63 100755 --- a/slither/tools/similarity/test.py +++ b/slither/tools/similarity/test.py @@ -2,6 +2,7 @@ import logging import operator import sys import traceback +from argparse import Namespace from slither.tools.similarity.encode import encode_contract, load_and_encode, parse_target from slither.tools.similarity.model import load_model @@ -10,7 +11,7 @@ from slither.tools.similarity.similarity import similarity logger = logging.getLogger("Slither-simil") -def test(args): +def test(args: Namespace) -> None: try: model = args.model diff --git a/slither/utils/comparable_enum.py b/slither/utils/comparable_enum.py index 63e476d6a..d6f06bbbd 100644 --- a/slither/utils/comparable_enum.py +++ b/slither/utils/comparable_enum.py @@ -1,26 +1,27 @@ from enum import Enum # pylint: disable=comparison-with-callable +from typing import Any class ComparableEnum(Enum): - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, ComparableEnum): return self.value == other.value return False - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: if isinstance(other, ComparableEnum): return self.value != other.value return False - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: if isinstance(other, ComparableEnum): return self.value < other.value return False - def __repr__(self): + def __repr__(self) -> str: return f"{str(self.value)}" - def __hash__(self): + def __hash__(self) -> int: return hash(self.value) diff --git a/slither/utils/expression_manipulations.py b/slither/utils/expression_manipulations.py index 0f3750600..753778be9 100644 --- a/slither/utils/expression_manipulations.py +++ b/slither/utils/expression_manipulations.py @@ -4,6 +4,8 @@ """ import copy from typing import Union, Callable + +from slither.all_exceptions import SlitherException from slither.core.expressions import UnaryOperation from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation @@ -19,7 +21,7 @@ from slither.core.expressions.new_array import NewArray from slither.core.expressions.new_contract import NewContract from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.type_conversion import TypeConversion -from slither.all_exceptions import SlitherException + # pylint: disable=protected-access def f_expressions( @@ -29,7 +31,7 @@ def f_expressions( e._expressions.append(x) -def f_call(e: CallExpression, x): +def f_call(e: CallExpression, x: ElementaryTypeNameExpression) -> None: e._arguments.append(x) @@ -41,11 +43,11 @@ def f_call_gas(e: CallExpression, x): e._gas = x -def f_expression(e: Union[TypeConversion, UnaryOperation, MemberAccess], x): +def f_expression(e: Union[TypeConversion, UnaryOperation, MemberAccess], x: CallExpression) -> None: e._expression = x -def f_called(e: CallExpression, x): +def f_called(e: CallExpression, x: Identifier) -> None: e._called = x diff --git a/slither/utils/output.py b/slither/utils/output.py index 5db6492db..9dba15e31 100644 --- a/slither/utils/output.py +++ b/slither/utils/output.py @@ -1,11 +1,12 @@ import hashlib -import os import json import logging +import os import zipfile from collections import OrderedDict -from typing import Optional, Dict, List, Union, Any, TYPE_CHECKING, Type +from typing import Tuple, Optional, Dict, List, Union, Any, TYPE_CHECKING, Type from zipfile import ZipFile + from pkg_resources import require from slither.core.cfg.node import Node @@ -218,7 +219,7 @@ def output_to_zip(filename: str, error: Optional[str], results: Dict, zip_type: ################################################################################### -def _convert_to_description(d): +def _convert_to_description(d: str) -> str: if isinstance(d, str): return d @@ -239,7 +240,7 @@ def _convert_to_description(d): raise SlitherError(f"{type(d)} cannot be converted (no name, or canonical_name") -def _convert_to_markdown(d, markdown_root): +def _convert_to_markdown(d: str, markdown_root: str) -> str: if isinstance(d, str): return d @@ -260,7 +261,7 @@ def _convert_to_markdown(d, markdown_root): raise SlitherError(f"{type(d)} cannot be converted (no name, or canonical_name") -def _convert_to_id(d): +def _convert_to_id(d: str) -> str: """ Id keeps the source mapping of the node, otherwise we risk to consider two different node as the same :param d: @@ -298,8 +299,35 @@ def _convert_to_id(d): def _create_base_element( - custom_type, name, source_mapping: Dict, type_specific_fields=None, additional_fields=None -): + custom_type: str, + name: str, + source_mapping: Dict, + type_specific_fields: Optional[ + Dict[ + str, + Union[ + Dict[ + str, + Union[ + str, + Dict[str, Union[int, str, bool, List[int]]], + Dict[ + str, + Union[ + Dict[str, Union[str, Dict[str, Union[int, str, bool, List[int]]]]], + str, + ], + ], + ], + ], + Dict[str, Union[str, Dict[str, Union[int, str, bool, List[int]]]]], + str, + List[str], + ], + ] + ] = None, + additional_fields: Optional[Dict[str, str]] = None, +) -> Dict[str, Any]: if additional_fields is None: additional_fields = {} if type_specific_fields is None: @@ -312,7 +340,16 @@ def _create_base_element( return element -def _create_parent_element(element): +def _create_parent_element( + element: SourceMapping, +) -> Dict[ + str, + Union[ + str, + Dict[str, Union[int, str, bool, List[int]]], + Dict[str, Union[Dict[str, Union[str, Dict[str, Union[int, str, bool, List[int]]]]], str]], + ], +]: # pylint: disable=import-outside-toplevel from slither.core.children.child_contract import ChildContract from slither.core.children.child_function import ChildFunction @@ -345,9 +382,9 @@ class Output: self, info_: Union[str, List[Union[str, SupportedOutput]]], additional_fields: Optional[Dict] = None, - markdown_root="", - standard_format=True, - ): + markdown_root: str = "", + standard_format: bool = True, + ) -> None: if additional_fields is None: additional_fields = {} @@ -377,7 +414,7 @@ class Output: if additional_fields: self._data["additional_fields"] = additional_fields - def add(self, add: SupportedOutput, additional_fields: Optional[Dict] = None): + def add(self, add: SupportedOutput, additional_fields: Optional[Dict] = None) -> None: if not self._data["first_markdown_element"]: self._data["first_markdown_element"] = add.source_mapping.to_markdown( self._markdown_root @@ -416,7 +453,7 @@ class Output: ################################################################################### ################################################################################### - def add_variable(self, variable: Variable, additional_fields: Optional[Dict] = None): + def add_variable(self, variable: Variable, additional_fields: Optional[Dict] = None) -> None: if additional_fields is None: additional_fields = {} type_specific_fields = {"parent": _create_parent_element(variable)} @@ -440,7 +477,7 @@ class Output: ################################################################################### ################################################################################### - def add_contract(self, contract: Contract, additional_fields: Optional[Dict] = None): + def add_contract(self, contract: Contract, additional_fields: Optional[Dict] = None) -> None: if additional_fields is None: additional_fields = {} element = _create_base_element( @@ -455,7 +492,7 @@ class Output: ################################################################################### ################################################################################### - def add_function(self, function: Function, additional_fields: Optional[Dict] = None): + def add_function(self, function: Function, additional_fields: Optional[Dict] = None) -> None: if additional_fields is None: additional_fields = {} type_specific_fields = { @@ -484,7 +521,7 @@ class Output: ################################################################################### ################################################################################### - def add_enum(self, enum: Enum, additional_fields: Optional[Dict] = None): + def add_enum(self, enum: Enum, additional_fields: Optional[Dict] = None) -> None: if additional_fields is None: additional_fields = {} type_specific_fields = {"parent": _create_parent_element(enum)} @@ -504,7 +541,7 @@ class Output: ################################################################################### ################################################################################### - def add_struct(self, struct: Structure, additional_fields: Optional[Dict] = None): + def add_struct(self, struct: Structure, additional_fields: Optional[Dict] = None) -> None: if additional_fields is None: additional_fields = {} type_specific_fields = {"parent": _create_parent_element(struct)} @@ -524,7 +561,7 @@ class Output: ################################################################################### ################################################################################### - def add_event(self, event: Event, additional_fields: Optional[Dict] = None): + def add_event(self, event: Event, additional_fields: Optional[Dict] = None) -> None: if additional_fields is None: additional_fields = {} type_specific_fields = { @@ -548,7 +585,7 @@ class Output: ################################################################################### ################################################################################### - def add_node(self, node: Node, additional_fields: Optional[Dict] = None): + def add_node(self, node: Node, additional_fields: Optional[Dict] = None) -> None: if additional_fields is None: additional_fields = {} type_specific_fields = { @@ -575,7 +612,7 @@ class Output: ################################################################################### ################################################################################### - def add_pragma(self, pragma: Pragma, additional_fields: Optional[Dict] = None): + def add_pragma(self, pragma: Pragma, additional_fields: Optional[Dict] = None) -> None: if additional_fields is None: additional_fields = {} type_specific_fields = {"directive": pragma.directive} @@ -633,10 +670,10 @@ class Output: def add_other( self, name: str, - source_mapping, + source_mapping: Tuple[str, int, int], compilation_unit: "SlitherCompilationUnit", additional_fields: Optional[Dict] = None, - ): + ) -> None: # If this a tuple with (filename, start, end), convert it to a source mapping. if additional_fields is None: additional_fields = {} diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 797d1f46e..5f419ef99 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -1,4 +1,6 @@ from fractions import Fraction +from typing import Union, TYPE_CHECKING + from slither.core.expressions import ( BinaryOperationType, Literal, @@ -6,10 +8,16 @@ from slither.core.expressions import ( Identifier, BinaryOperation, UnaryOperation, + TupleExpression, + TypeConversion, ) + from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor +if TYPE_CHECKING: + from slither.core.solidity_types.elementary_type import ElementaryType + class NotConstant(Exception): pass @@ -17,24 +25,30 @@ class NotConstant(Exception): KEY = "ConstantFolding" +CONSTANT_TYPES_OPERATIONS = Union[ + Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion +] + -def get_val(expression): +def get_val(expression: CONSTANT_TYPES_OPERATIONS) -> Union[bool, int, Fraction, str]: val = expression.context[KEY] # we delete the item to reduce memory use del expression.context[KEY] return val -def set_val(expression, val): +def set_val(expression: CONSTANT_TYPES_OPERATIONS, val: Union[bool, int, Fraction, str]) -> None: expression.context[KEY] = val class ConstantFolding(ExpressionVisitor): - def __init__(self, expression, custom_type): + def __init__( + self, expression: CONSTANT_TYPES_OPERATIONS, custom_type: Union[str, "ElementaryType"] + ) -> None: self._type = custom_type super().__init__(expression) - def result(self): + def result(self) -> "Literal": value = get_val(self._expression) if isinstance(value, Fraction): value = int(value) @@ -43,7 +57,7 @@ class ConstantFolding(ExpressionVisitor): value = value & (2**256 - 1) return Literal(value, self._type) - def _post_identifier(self, expression: Identifier): + def _post_identifier(self, expression: Identifier) -> None: if not expression.value.is_constant: raise NotConstant expr = expression.value.expression @@ -54,7 +68,7 @@ class ConstantFolding(ExpressionVisitor): set_val(expression, convert_string_to_int(expr.converted_value)) # pylint: disable=too-many-branches - def _post_binary_operation(self, expression: BinaryOperation): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get_val(expression.expression_left) right = get_val(expression.expression_right) if expression.type == BinaryOperationType.POWER: @@ -100,7 +114,7 @@ class ConstantFolding(ExpressionVisitor): else: raise NotConstant - def _post_unary_operation(self, expression: UnaryOperation): + def _post_unary_operation(self, expression: UnaryOperation) -> None: # Case of uint a = -7; uint[-a] arr; if expression.type == UnaryOperationType.MINUS_PRE: expr = expression.expression @@ -112,7 +126,7 @@ class ConstantFolding(ExpressionVisitor): else: raise NotConstant - def _post_literal(self, expression: Literal): + def _post_literal(self, expression: Literal) -> None: if expression.converted_value in ["true", "false"]: set_val(expression, expression.converted_value) else: diff --git a/slither/visitors/expression/export_values.py b/slither/visitors/expression/export_values.py index 246cc8384..f5ca39a96 100644 --- a/slither/visitors/expression/export_values.py +++ b/slither/visitors/expression/export_values.py @@ -1,21 +1,31 @@ +from typing import Any, List from slither.visitors.expression.expression import ExpressionVisitor +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.literal import Literal +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.expression import Expression +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion + key = "ExportValues" -def get(expression): +def get(expression: Expression) -> List[Any]: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def set_val(expression, val): +def set_val(expression: Expression, val: List[Any]) -> None: expression.context[key] = val class ExportValues(ExpressionVisitor): - def result(self): + def result(self) -> List[Any]: if self._result is None: self._result = list(set(get(self.expression))) return self._result @@ -26,13 +36,13 @@ class ExportValues(ExpressionVisitor): val = left + right set_val(expression, val) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) args = [get(a) for a in expression.arguments if a] args = [item for sublist in args for item in sublist] @@ -49,7 +59,7 @@ class ExportValues(ExpressionVisitor): def _post_elementary_type_name_expression(self, expression): set_val(expression, []) - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: set_val(expression, [expression.value]) def _post_index_access(self, expression): @@ -58,10 +68,10 @@ class ExportValues(ExpressionVisitor): val = left + right set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: set_val(expression, []) - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: expr = get(expression.expression) val = expr set_val(expression, val) @@ -75,12 +85,12 @@ class ExportValues(ExpressionVisitor): def _post_new_elementary_type(self, expression): set_val(expression, []) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) for e in expression.expressions if e] val = [item for sublist in expressions for item in sublist] set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: expr = get(expression.expression) val = expr set_val(expression, val) diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index 17020aaba..464ea1285 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Optional, Any from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation @@ -23,13 +23,13 @@ logger = logging.getLogger("ExpressionVisitor") class ExpressionVisitor: - def __init__(self, expression: Expression): + def __init__(self, expression: Expression) -> None: # Inherited class must declared their variables prior calling super().__init__ self._expression = expression self._result: Any = None self._visit_expression(self.expression) - def result(self): + def result(self) -> Optional[bool]: return self._result @property @@ -38,7 +38,8 @@ class ExpressionVisitor: # visit an expression # call pre_visit, visit_expression_name, post_visit - def _visit_expression(self, expression: Expression): # pylint: disable=too-many-branches + # pylint: disable=too-many-branches + def _visit_expression(self, expression: Expression) -> None: self._pre_visit(expression) if isinstance(expression, AssignmentOperation): @@ -96,15 +97,15 @@ class ExpressionVisitor: # visit_expression_name - def _visit_assignement_operation(self, expression): + def _visit_assignement_operation(self, expression: AssignmentOperation) -> None: self._visit_expression(expression.expression_left) self._visit_expression(expression.expression_right) - def _visit_binary_operation(self, expression): + def _visit_binary_operation(self, expression: BinaryOperation) -> None: self._visit_expression(expression.expression_left) self._visit_expression(expression.expression_right) - def _visit_call_expression(self, expression): + def _visit_call_expression(self, expression: CallExpression) -> None: self._visit_expression(expression.called) for arg in expression.arguments: if arg: @@ -116,50 +117,52 @@ class ExpressionVisitor: if expression.call_salt: self._visit_expression(expression.call_salt) - def _visit_conditional_expression(self, expression): + def _visit_conditional_expression(self, expression: ConditionalExpression) -> None: self._visit_expression(expression.if_expression) self._visit_expression(expression.else_expression) self._visit_expression(expression.then_expression) - def _visit_elementary_type_name_expression(self, expression): + def _visit_elementary_type_name_expression( + self, expression: ElementaryTypeNameExpression + ) -> None: pass - def _visit_identifier(self, expression): + def _visit_identifier(self, expression: Identifier) -> None: pass - def _visit_index_access(self, expression): + def _visit_index_access(self, expression: IndexAccess) -> None: self._visit_expression(expression.expression_left) self._visit_expression(expression.expression_right) - def _visit_literal(self, expression): + def _visit_literal(self, expression: Literal) -> None: pass - def _visit_member_access(self, expression): + def _visit_member_access(self, expression: MemberAccess) -> None: self._visit_expression(expression.expression) - def _visit_new_array(self, expression): + def _visit_new_array(self, expression: NewArray) -> None: pass - def _visit_new_contract(self, expression): + def _visit_new_contract(self, expression: NewContract) -> None: pass def _visit_new_elementary_type(self, expression): pass - def _visit_tuple_expression(self, expression): + def _visit_tuple_expression(self, expression: TupleExpression) -> None: for e in expression.expressions: if e: self._visit_expression(e) - def _visit_type_conversion(self, expression): + def _visit_type_conversion(self, expression: TypeConversion) -> None: self._visit_expression(expression.expression) - def _visit_unary_operation(self, expression): + def _visit_unary_operation(self, expression: UnaryOperation) -> None: self._visit_expression(expression.expression) # pre visit - def _pre_visit(self, expression): # pylint: disable=too-many-branches + def _pre_visit(self, expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._pre_assignement_operation(expression) @@ -213,54 +216,56 @@ class ExpressionVisitor: # pre_expression_name - def _pre_assignement_operation(self, expression): + def _pre_assignement_operation(self, expression: AssignmentOperation) -> None: pass - def _pre_binary_operation(self, expression): + def _pre_binary_operation(self, expression: BinaryOperation) -> None: pass - def _pre_call_expression(self, expression): + def _pre_call_expression(self, expression: CallExpression) -> None: pass - def _pre_conditional_expression(self, expression): + def _pre_conditional_expression(self, expression: ConditionalExpression) -> None: pass - def _pre_elementary_type_name_expression(self, expression): + def _pre_elementary_type_name_expression( + self, expression: ElementaryTypeNameExpression + ) -> None: pass - def _pre_identifier(self, expression): + def _pre_identifier(self, expression: Identifier) -> None: pass - def _pre_index_access(self, expression): + def _pre_index_access(self, expression: IndexAccess) -> None: pass - def _pre_literal(self, expression): + def _pre_literal(self, expression: Literal) -> None: pass - def _pre_member_access(self, expression): + def _pre_member_access(self, expression: MemberAccess) -> None: pass - def _pre_new_array(self, expression): + def _pre_new_array(self, expression: NewArray) -> None: pass - def _pre_new_contract(self, expression): + def _pre_new_contract(self, expression: NewContract) -> None: pass def _pre_new_elementary_type(self, expression): pass - def _pre_tuple_expression(self, expression): + def _pre_tuple_expression(self, expression: TupleExpression) -> None: pass - def _pre_type_conversion(self, expression): + def _pre_type_conversion(self, expression: TypeConversion) -> None: pass - def _pre_unary_operation(self, expression): + def _pre_unary_operation(self, expression: UnaryOperation) -> None: pass # post visit - def _post_visit(self, expression): # pylint: disable=too-many-branches + def _post_visit(self, expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._post_assignement_operation(expression) @@ -314,47 +319,49 @@ class ExpressionVisitor: # post_expression_name - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: pass - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: pass - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: pass def _post_conditional_expression(self, expression): pass - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: ElementaryTypeNameExpression + ) -> None: pass - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: pass - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: pass - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: pass - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: pass - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: pass - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: pass def _post_new_elementary_type(self, expression): pass - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: pass - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: pass - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: pass diff --git a/slither/visitors/expression/find_calls.py b/slither/visitors/expression/find_calls.py index 9b9141c76..6653a9759 100644 --- a/slither/visitors/expression/find_calls.py +++ b/slither/visitors/expression/find_calls.py @@ -1,19 +1,33 @@ -from typing import List +from typing import Any, Union, List from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation key = "FindCall" -def get(expression): +def get(expression: Expression) -> List[Union[Any, CallExpression]]: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def set_val(expression, val): +def set_val(expression: Expression, val: List[Union[Any, CallExpression]]) -> None: expression.context[key] = val @@ -23,19 +37,19 @@ class FindCalls(ExpressionVisitor): self._result = list(set(get(self.expression))) return self._result - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) args = [get(a) for a in expression.arguments if a] args = [item for sublist in args for item in sublist] @@ -43,54 +57,56 @@ class FindCalls(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) val = if_expr + else_expr + then_expr set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: ElementaryTypeNameExpression + ) -> None: set_val(expression, []) # save only identifier expression - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: set_val(expression, []) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: set_val(expression, []) - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: expr = get(expression.expression) val = expr set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: set_val(expression, []) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) def _post_new_elementary_type(self, expression): set_val(expression, []) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) for e in expression.expressions if e] val = [item for sublist in expressions for item in sublist] set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: expr = get(expression.expression) val = expr set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: expr = get(expression.expression) val = expr set_val(expression, val) diff --git a/slither/visitors/expression/has_conditional.py b/slither/visitors/expression/has_conditional.py index 906f522ff..b866a696b 100644 --- a/slither/visitors/expression/has_conditional.py +++ b/slither/visitors/expression/has_conditional.py @@ -1,12 +1,13 @@ from slither.visitors.expression.expression import ExpressionVisitor +from slither.core.expressions.conditional_expression import ConditionalExpression class HasConditional(ExpressionVisitor): - def result(self): + def result(self) -> bool: # == True, to convert None to false return self._result is True - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: # if self._result is True: # raise('Slither does not support nested ternary operator') self._result = True diff --git a/slither/visitors/expression/read_var.py b/slither/visitors/expression/read_var.py index 8fe063a2f..e8f5aae67 100644 --- a/slither/visitors/expression/read_var.py +++ b/slither/visitors/expression/read_var.py @@ -1,38 +1,58 @@ +from typing import Any, List, Union + from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignment_operation import AssignmentOperationType +from slither.core.expressions.assignment_operation import ( + AssignmentOperation, + AssignmentOperationType, +) from slither.core.variables.variable import Variable from slither.core.declarations.solidity_variables import SolidityVariable +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.expression import Expression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation + key = "ReadVar" -def get(expression): +def get(expression: Expression) -> List[Union[Identifier, IndexAccess, Any]]: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def set_val(expression, val): +def set_val(expression: Expression, val: List[Union[Identifier, IndexAccess, Any]]) -> None: expression.context[key] = val class ReadVar(ExpressionVisitor): - def result(self): + def result(self) -> List[Union[Identifier, IndexAccess, Any]]: if self._result is None: self._result = list(set(get(self.expression))) return self._result # overide assignement # dont explore if its direct assignement (we explore if its +=, -=, ...) - def _visit_assignement_operation(self, expression): + def _visit_assignement_operation(self, expression: AssignmentOperation) -> None: if expression.type != AssignmentOperationType.ASSIGN: self._visit_expression(expression.expression_left) self._visit_expression(expression.expression_right) - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: if expression.type != AssignmentOperationType.ASSIGN: left = get(expression.expression_left) else: @@ -41,31 +61,33 @@ class ReadVar(ExpressionVisitor): val = left + right set_val(expression, val) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) args = [get(a) for a in expression.arguments if a] args = [item for sublist in args for item in sublist] val = called + args set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) val = if_expr + else_expr + then_expr set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: ElementaryTypeNameExpression + ) -> None: set_val(expression, []) # save only identifier expression - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: if isinstance(expression.value, Variable): set_val(expression, [expression]) elif isinstance(expression.value, SolidityVariable): @@ -73,40 +95,40 @@ class ReadVar(ExpressionVisitor): else: set_val(expression, []) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right + [expression] set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: set_val(expression, []) - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: expr = get(expression.expression) val = expr set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: set_val(expression, []) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) def _post_new_elementary_type(self, expression): set_val(expression, []) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) for e in expression.expressions if e] val = [item for sublist in expressions for item in sublist] set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: expr = get(expression.expression) val = expr set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: expr = get(expression.expression) val = expr set_val(expression, val) diff --git a/slither/visitors/expression/write_var.py b/slither/visitors/expression/write_var.py index 0509a2eb7..97d3858e7 100644 --- a/slither/visitors/expression/write_var.py +++ b/slither/visitors/expression/write_var.py @@ -1,26 +1,43 @@ +from typing import Any, List from slither.visitors.expression.expression import ExpressionVisitor +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.expression import Expression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation + key = "WriteVar" -def get(expression): +def get(expression: Expression) -> List[Any]: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def set_val(expression, val): +def set_val(expression: Expression, val: List[Any]) -> None: expression.context[key] = val class WriteVar(ExpressionVisitor): - def result(self): + def result(self) -> List[Any]: if self._result is None: self._result = list(set(get(self.expression))) return self._result - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -28,7 +45,7 @@ class WriteVar(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) args = [get(a) for a in expression.arguments if a] args = [item for sublist in args for item in sublist] @@ -37,7 +54,7 @@ class WriteVar(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) @@ -46,7 +63,7 @@ class WriteVar(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -54,11 +71,13 @@ class WriteVar(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: ElementaryTypeNameExpression + ) -> None: set_val(expression, []) # save only identifier expression - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: if expression.is_lvalue: set_val(expression, [expression]) else: @@ -69,7 +88,7 @@ class WriteVar(ExpressionVisitor): # else: # set_val(expression, []) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -87,10 +106,10 @@ class WriteVar(ExpressionVisitor): # n = n.expression set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: set_val(expression, []) - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: expr = get(expression.expression) val = expr if expression.is_lvalue: @@ -98,30 +117,30 @@ class WriteVar(ExpressionVisitor): val += [expression.expression] set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: set_val(expression, []) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) def _post_new_elementary_type(self, expression): set_val(expression, []) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) for e in expression.expressions if e] val = [item for sublist in expressions for item in sublist] if expression.is_lvalue: val += [expression] set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: expr = get(expression.expression) val = expr if expression.is_lvalue: val += [expression] set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: expr = get(expression.expression) val = expr if expression.is_lvalue: diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index 6b7b4c264..c150ee20b 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -1,6 +1,5 @@ import logging - -from typing import List +from typing import Union, List, TYPE_CHECKING from slither.core.declarations import ( Function, @@ -11,6 +10,7 @@ from slither.core.declarations import ( ) from slither.core.declarations.enum import Enum from slither.core.expressions import ( + AssignmentOperation, AssignmentOperationType, UnaryOperationType, BinaryOperationType, @@ -19,9 +19,19 @@ from slither.core.expressions import ( Identifier, MemberAccess, ) +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.expression import Expression +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.unary_operation import UnaryOperation from slither.core.solidity_types import ArrayType, ElementaryType, TypeAlias from slither.core.solidity_types.type import Type +from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.slithir.exceptions import SlithIRError from slither.slithir.operations import ( @@ -53,12 +63,15 @@ from slither.slithir.variables import ( ) from slither.visitors.expression.expression import ExpressionVisitor +if TYPE_CHECKING: + from slither.core.cfg.node import Node + logger = logging.getLogger("VISTIOR:ExpressionToSlithIR") key = "expressionToSlithIR" -def get(expression): +def get(expression: Union[Expression, Operation]): val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] @@ -69,7 +82,7 @@ def get_without_removing(expression): return expression.context[key] -def set_val(expression, val): +def set_val(expression: Union[Expression, Operation], val) -> None: expression.context[key] = val @@ -104,7 +117,12 @@ _signed_to_unsigned = { } -def convert_assignment(left, right, t, return_type): +def convert_assignment( + left: Union[LocalVariable, StateVariable, ReferenceVariable], + right: Union[LocalVariable, StateVariable, ReferenceVariable], + t: AssignmentOperationType, + return_type, +) -> Union[Binary, Assignment]: if t == AssignmentOperationType.ASSIGN: return Assignment(left, right, return_type) if t == AssignmentOperationType.ASSIGN_OR: @@ -132,7 +150,8 @@ def convert_assignment(left, right, t, return_type): class ExpressionToSlithIR(ExpressionVisitor): - def __init__(self, expression, node): # pylint: disable=super-init-not-called + # pylint: disable=super-init-not-called + def __init__(self, expression: Expression, node: "Node") -> None: from slither.core.cfg.node import NodeType # pylint: disable=import-outside-toplevel self._expression = expression @@ -146,10 +165,10 @@ class ExpressionToSlithIR(ExpressionVisitor): for ir in self._result: ir.set_node(node) - def result(self): + def result(self) -> List[Operation]: return self._result - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) if isinstance(left, list): # tuple expression: @@ -211,7 +230,7 @@ class ExpressionToSlithIR(ExpressionVisitor): # a = b = 1; set_val(expression, left) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = TemporaryVariable(self._node) @@ -248,9 +267,8 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) - def _post_call_expression( - self, expression - ): # pylint: disable=too-many-branches,too-many-statements,too-many-locals + # pylint: disable=too-many-branches,too-many-statements,too-many-locals + def _post_call_expression(self, expression: CallExpression) -> None: assert isinstance(expression, CallExpression) @@ -358,13 +376,16 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_conditional_expression(self, expression): raise Exception(f"Ternary operator are not convertible to SlithIR {expression}") - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, + expression: ElementaryTypeNameExpression, + ) -> None: set_val(expression, expression.type) - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: set_val(expression, expression.value) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) # Left can be a type for abi.decode(var, uint[2]) @@ -390,11 +411,11 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: cst = Constant(expression.value, expression.type, expression.subdenomination) set_val(expression, cst) - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: expr = get(expression.expression) # Look for type(X).max / min @@ -479,14 +500,14 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(member) set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: val = TemporaryVariable(self._node) operation = TmpNewArray(expression.depth, expression.array_type, val) operation.set_expression(expression) self._result.append(operation) set_val(expression, val) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: val = TemporaryVariable(self._node) operation = TmpNewContract(expression.contract_name, val) operation.set_expression(expression) @@ -508,7 +529,7 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, val) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) if e else None for e in expression.expressions] if len(expressions) == 1: val = expressions[0] @@ -516,7 +537,7 @@ class ExpressionToSlithIR(ExpressionVisitor): val = expressions set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: expr = get(expression.expression) val = TemporaryVariable(self._node) operation = TypeConversion(val, expr, expression.type) @@ -525,9 +546,8 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, val) - def _post_unary_operation( - self, expression - ): # pylint: disable=too-many-branches,too-many-statements + # pylint: disable=too-many-statements + def _post_unary_operation(self, expression: UnaryOperation) -> None: value = get(expression.expression) if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: lvalue = TemporaryVariable(self._node)