diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index eec599bc4..f2a431571 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -42,6 +42,7 @@ from slither.all_exceptions import SlitherException from slither.core.declarations import Contract, Function from slither.core.expressions.expression import Expression +import slither.slithir.operations.operation if TYPE_CHECKING: from slither.slithir.variables.variable import SlithIRVariable @@ -104,7 +105,7 @@ class NodeType(Enum): OTHER_ENTRYPOINT = 0x60 # @staticmethod - def __str__(self): + def __str__(self) -> str: if self == NodeType.ENTRYPOINT: return "ENTRY_POINT" if self == NodeType.EXPRESSION: @@ -158,7 +159,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met node_id: int, scope: Union["Scope", "Function"], file_scope: "FileScope", - ): + ) -> None: super().__init__() self._node_type = node_type @@ -513,11 +514,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ return self._expression - def add_expression(self, expression: Expression, bypass_verif_empty: bool = False): + def add_expression(self, expression: Expression, bypass_verif_empty: bool = False) -> None: assert self._expression is None or bypass_verif_empty self._expression = expression - def add_variable_declaration(self, var: LocalVariable): + def add_variable_declaration(self, var: LocalVariable) -> None: assert self._variable_declaration is None self._variable_declaration = var if var.expression: @@ -550,7 +551,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met for c in self.internal_calls ) - def contains_if(self, include_loop=True) -> bool: + def contains_if(self, include_loop: bool=True) -> bool: """ Check if the node is a IF node Returns: @@ -560,7 +561,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self.type in [NodeType.IF, NodeType.IFLOOP] return self.type == NodeType.IF - def is_conditional(self, include_loop=True) -> bool: + def is_conditional(self, include_loop: bool=True) -> bool: """ Check if the node is a conditional node A conditional node is either a IF or a require/assert or a RETURN bool @@ -589,7 +590,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met def inline_asm(self) -> Optional[Union[str, Dict]]: return self._asm_source_code - def add_inline_asm(self, asm: Union[str, Dict]): + def add_inline_asm(self, asm: Union[str, Dict]) -> None: self._asm_source_code = asm # endregion @@ -599,7 +600,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### ################################################################################### - def add_father(self, father: "Node"): + def add_father(self, father: "Node") -> None: """Add a father node Args: @@ -624,7 +625,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ return list(self._fathers) - def remove_father(self, father: "Node"): + def remove_father(self, father: "Node") -> None: """Remove the father node. Do nothing if the node is not a father Args: @@ -632,7 +633,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ self._fathers = [x for x in self._fathers if x.node_id != father.node_id] - def remove_son(self, son: "Node"): + def remove_son(self, son: "Node") -> None: """Remove the son node. Do nothing if the node is not a son Args: @@ -640,7 +641,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ self._sons = [x for x in self._sons if x.node_id != son.node_id] - def add_son(self, son: "Node"): + def add_son(self, son: "Node") -> None: """Add a son node Args: @@ -648,7 +649,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ self._sons.append(son) - def set_sons(self, sons: List["Node"]): + def set_sons(self, sons: List["Node"]) -> None: """Set the son nodes Args: @@ -706,14 +707,14 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met def irs_ssa(self, irs): self._irs_ssa = irs - def add_ssa_ir(self, ir: Operation): + def add_ssa_ir(self, ir: Operation) -> None: """ Use to place phi operation """ ir.set_node(self) self._irs_ssa.append(ir) - def slithir_generation(self): + def slithir_generation(self) -> None: if self.expression: expression = self.expression self._irs = convert_expression(expression, self) @@ -730,11 +731,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._all_slithir_operations @staticmethod - def _is_non_slithir_var(var: Variable): + def _is_non_slithir_var(var: Variable) -> bool: return not isinstance(var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)) @staticmethod - def _is_valid_slithir_var(var: Variable): + def _is_valid_slithir_var(var: Variable) -> bool: return isinstance(var, (ReferenceVariable, TemporaryVariable, TupleVariable)) # endregion @@ -785,7 +786,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._dominance_frontier = doms @property - def dominator_successors(self): + def dominator_successors(self) -> Set["Node"]: return self._dom_successors @property @@ -827,14 +828,14 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met # def phi_origin_member_variables(self) -> Dict[str, Tuple[MemberVariable, Set["Node"]]]: # return self._phi_origins_member_variables - def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node"): + def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None: if variable.name not in self._phi_origins_local_variables: self._phi_origins_local_variables[variable.name] = (variable, set()) (v, nodes) = self._phi_origins_local_variables[variable.name] assert v == variable nodes.add(node) - def add_phi_origin_state_variable(self, variable: StateVariable, node: "Node"): + def add_phi_origin_state_variable(self, variable: StateVariable, node: "Node") -> None: if variable.canonical_name not in self._phi_origins_state_variables: self._phi_origins_state_variables[variable.canonical_name] = ( variable, @@ -858,7 +859,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### ################################################################################### - def _find_read_write_call(self): # pylint: disable=too-many-statements + def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements for ir in self.irs: @@ -934,7 +935,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._low_level_calls = list(set(self._low_level_calls)) @staticmethod - def _convert_ssa(v: Variable): + def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]: if isinstance(v, StateIRVariable): contract = v.contract non_ssa_var = contract.get_state_variable_from_name(v.name) @@ -944,7 +945,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met non_ssa_var = function.get_local_variable_from_name(v.name) return non_ssa_var - def update_read_write_using_ssa(self): + def update_read_write_using_ssa(self) -> None: if not self.expression: return for ir in self.irs_ssa: @@ -1026,12 +1027,12 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### -def link_nodes(node1: Node, node2: Node): +def link_nodes(node1: Node, node2: Node) -> None: node1.add_son(node2) node2.add_father(node1) -def insert_node(origin: Node, node_inserted: Node): +def insert_node(origin: Node, node_inserted: Node) -> None: sons = origin.sons link_nodes(origin, node_inserted) for son in sons: diff --git a/slither/core/cfg/scope.py b/slither/core/cfg/scope.py index 06e68c3e9..d3ac4e836 100644 --- a/slither/core/cfg/scope.py +++ b/slither/core/cfg/scope.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: # pylint: disable=too-few-public-methods class Scope: - def __init__(self, is_checked: bool, is_yul: bool, scope: Union["Scope", "Function"]): + def __init__(self, is_checked: bool, is_yul: bool, scope: Union["Scope", "Function"]) -> None: self.nodes: List["Node"] = [] self.is_checked = is_checked self.is_yul = is_yul diff --git a/slither/core/children/child_contract.py b/slither/core/children/child_contract.py index 285623b0e..86f9dea53 100644 --- a/slither/core/children/child_contract.py +++ b/slither/core/children/child_contract.py @@ -7,11 +7,11 @@ if TYPE_CHECKING: class ChildContract(SourceMapping): - def __init__(self): + def __init__(self) -> None: super().__init__() self._contract = None - def set_contract(self, contract: "Contract"): + def set_contract(self, contract: "Contract") -> None: self._contract = contract @property diff --git a/slither/core/children/child_event.py b/slither/core/children/child_event.py index 6df697747..df91596e3 100644 --- a/slither/core/children/child_event.py +++ b/slither/core/children/child_event.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: class ChildEvent: - def __init__(self): + def __init__(self) -> None: super().__init__() self._event = None diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py index bc88ea7a4..102e267b3 100644 --- a/slither/core/children/child_expression.py +++ b/slither/core/children/child_expression.py @@ -5,11 +5,11 @@ if TYPE_CHECKING: class ChildExpression: - def __init__(self): + def __init__(self) -> None: super().__init__() self._expression = None - def set_expression(self, expression: "Expression"): + def set_expression(self, expression: "Expression") -> None: self._expression = expression @property diff --git a/slither/core/children/child_inheritance.py b/slither/core/children/child_inheritance.py index 5efe53412..30b32f6c1 100644 --- a/slither/core/children/child_inheritance.py +++ b/slither/core/children/child_inheritance.py @@ -5,11 +5,11 @@ if TYPE_CHECKING: class ChildInheritance: - def __init__(self): + def __init__(self) -> None: super().__init__() self._contract_declarer = None - def set_contract_declarer(self, contract: "Contract"): + def set_contract_declarer(self, contract: "Contract") -> None: self._contract_declarer = contract @property diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py index c1fffd49a..8e6e1f0b5 100644 --- a/slither/core/children/child_node.py +++ b/slither/core/children/child_node.py @@ -7,11 +7,11 @@ if TYPE_CHECKING: class ChildNode: - def __init__(self): + def __init__(self) -> None: super().__init__() self._node = None - def set_node(self, node: "Node"): + def set_node(self, node: "Node") -> None: self._node = node @property diff --git a/slither/core/children/child_structure.py b/slither/core/children/child_structure.py index 0f4c7db82..abcb041c2 100644 --- a/slither/core/children/child_structure.py +++ b/slither/core/children/child_structure.py @@ -5,11 +5,11 @@ if TYPE_CHECKING: class ChildStructure: - def __init__(self): + def __init__(self) -> None: super().__init__() self._structure = None - def set_structure(self, structure: "Structure"): + def set_structure(self, structure: "Structure") -> None: self._structure = structure @property diff --git a/slither/core/compilation_unit.py b/slither/core/compilation_unit.py index 2144d4c81..71c7f367f 100644 --- a/slither/core/compilation_unit.py +++ b/slither/core/compilation_unit.py @@ -24,6 +24,12 @@ from slither.core.variables.state_variable import StateVariable from slither.core.variables.top_level_variable import TopLevelVariable from slither.slithir.operations import InternalCall from slither.slithir.variables import Constant +import crytic_compile.compilation_unit +import slither.core.declarations.contract +import slither.core.declarations.function +import slither.core.declarations.import_directive +import slither.core.declarations.modifier +import slither.core.declarations.pragma_directive if TYPE_CHECKING: from slither.core.slither_core import SlitherCore @@ -31,7 +37,7 @@ if TYPE_CHECKING: # pylint: disable=too-many-instance-attributes,too-many-public-methods class SlitherCompilationUnit(Context): - def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit): + def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit) -> None: super().__init__() self._core = core @@ -150,21 +156,21 @@ class SlitherCompilationUnit(Context): def functions(self) -> List[Function]: return list(self._all_functions) - def add_function(self, func: Function): + def add_function(self, func: Function) -> None: self._all_functions.add(func) @property def modifiers(self) -> List[Modifier]: return list(self._all_modifiers) - def add_modifier(self, modif: Modifier): + def add_modifier(self, modif: Modifier) -> None: self._all_modifiers.add(modif) @property def functions_and_modifiers(self) -> List[Function]: return self.functions + self.modifiers - def propagate_function_calls(self): + def propagate_function_calls(self) -> None: for f in self.functions_and_modifiers: for node in f.nodes: for ir in node.irs_ssa: @@ -256,7 +262,7 @@ class SlitherCompilationUnit(Context): ################################################################################### ################################################################################### - def compute_storage_layout(self): + def compute_storage_layout(self) -> None: for contract in self.contracts_derived: self._storage_layouts[contract.name] = {} diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index d1feebb05..5c958e7fd 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -55,7 +55,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods Contract class """ - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__() self._name: Optional[str] = None @@ -366,7 +366,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return list(self._variables_ordered) - def add_variables_ordered(self, new_vars: List["StateVariable"]): + def add_variables_ordered(self, new_vars: List["StateVariable"]) -> None: self._variables_ordered += new_vars @property @@ -534,7 +534,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def add_function(self, func: "FunctionContract"): self._functions[func.canonical_name] = func - def set_functions(self, functions: Dict[str, "FunctionContract"]): + def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None: """ Set the functions @@ -578,7 +578,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def available_modifiers_as_dict(self) -> Dict[str, "Modifier"]: return {m.full_name: m for m in self._modifiers.values() if not m.is_shadowed} - def set_modifiers(self, modifiers: Dict[str, "Modifier"]): + def set_modifiers(self, modifiers: Dict[str, "Modifier"]) -> None: """ Set the modifiers @@ -688,7 +688,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods inheritance: List["Contract"], immediate_inheritance: List["Contract"], called_base_constructor_contracts: List["Contract"], - ): + ) -> None: self._inheritance = inheritance self._immediate_inheritance = immediate_inheritance self._explicit_base_constructor_calls = called_base_constructor_contracts @@ -1216,7 +1216,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def update_read_write_using_ssa(self): + def update_read_write_using_ssa(self) -> None: for function in self.functions + self.modifiers: function.update_read_write_using_ssa() @@ -1311,7 +1311,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def is_incorrectly_constructed(self, incorrect: bool): self._is_incorrectly_parsed = incorrect - def add_constructor_variables(self): + def add_constructor_variables(self) -> None: from slither.core.declarations.function_contract import FunctionContract if self.state_variables: @@ -1380,7 +1380,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def _create_node( self, func: Function, counter: int, variable: "Variable", scope: Union[Scope, Function] - ): + ) -> "Node": from slither.core.cfg.node import Node, NodeType from slither.core.expressions import ( AssignmentOperationType, @@ -1412,7 +1412,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def convert_expression_to_slithir_ssa(self): + def convert_expression_to_slithir_ssa(self) -> None: """ Assume generate_slithir_and_analyze was called on all functions @@ -1437,7 +1437,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods for func in self.functions + self.modifiers: func.generate_slithir_ssa(all_ssa_state_variables_instances) - def fix_phi(self): + def fix_phi(self) -> None: last_state_variables_instances = {} initial_state_variables_instances = {} for v in self._initial_state_variables: @@ -1459,7 +1459,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def __eq__(self, other): + def __eq__(self, other: SourceMapping): if isinstance(other, str): return other == self.name return NotImplemented @@ -1469,10 +1469,10 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return other != self.name return NotImplemented - def __str__(self): + def __str__(self) -> str: return self.name - def __hash__(self): + def __hash__(self) -> int: return self._id # endregion diff --git a/slither/core/declarations/custom_error.py b/slither/core/declarations/custom_error.py index a1a689fcc..5e851c8da 100644 --- a/slither/core/declarations/custom_error.py +++ b/slither/core/declarations/custom_error.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: class CustomError(SourceMapping): - def __init__(self, compilation_unit: "SlitherCompilationUnit"): + def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: super().__init__() self._name: str = "" self._parameters: List[LocalVariable] = [] @@ -30,7 +30,7 @@ class CustomError(SourceMapping): def parameters(self) -> List[LocalVariable]: return self._parameters - def add_parameters(self, p: "LocalVariable"): + def add_parameters(self, p: "LocalVariable") -> None: self._parameters.append(p) @property @@ -42,7 +42,7 @@ class CustomError(SourceMapping): ################################################################################### @staticmethod - def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]): + def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) -> str: # pylint: disable=import-outside-toplevel from slither.core.declarations import Contract @@ -92,5 +92,5 @@ class CustomError(SourceMapping): ################################################################################### ################################################################################### - def __str__(self): + def __str__(self) -> str: return "revert " + self.solidity_signature diff --git a/slither/core/declarations/enum.py b/slither/core/declarations/enum.py index c53c1c38d..66a02fd11 100644 --- a/slither/core/declarations/enum.py +++ b/slither/core/declarations/enum.py @@ -4,7 +4,7 @@ from slither.core.source_mapping.source_mapping import SourceMapping class Enum(SourceMapping): - def __init__(self, name: str, canonical_name: str, values: List[str]): + def __init__(self, name: str, canonical_name: str, values: List[str]) -> None: super().__init__() self._name = name self._canonical_name = canonical_name @@ -33,5 +33,5 @@ class Enum(SourceMapping): def max(self) -> int: return self._max - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/core/declarations/enum_top_level.py b/slither/core/declarations/enum_top_level.py index 2546b376b..2d3fab52f 100644 --- a/slither/core/declarations/enum_top_level.py +++ b/slither/core/declarations/enum_top_level.py @@ -8,6 +8,6 @@ if TYPE_CHECKING: class EnumTopLevel(Enum, TopLevel): - def __init__(self, name: str, canonical_name: str, values: List[str], scope: "FileScope"): + def __init__(self, name: str, canonical_name: str, values: List[str], scope: "FileScope") -> None: super().__init__(name, canonical_name, values) self.file_scope: "FileScope" = scope diff --git a/slither/core/declarations/event.py b/slither/core/declarations/event.py index 9e445ee74..e166008de 100644 --- a/slither/core/declarations/event.py +++ b/slither/core/declarations/event.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: class Event(ChildContract, SourceMapping): - def __init__(self): + def __init__(self) -> None: super().__init__() self._name = None self._elems: List[EventVariable] = [] diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 2fdea7210..9bd2181b8 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -6,7 +6,7 @@ from abc import abstractmethod, ABCMeta from collections import namedtuple from enum import Enum from itertools import groupby -from typing import Dict, TYPE_CHECKING, List, Optional, Set, Union, Callable, Tuple +from typing import Any, Dict, TYPE_CHECKING, List, Optional, Set, Union, Callable, Tuple from slither.core.cfg.scope import Scope from slither.core.declarations.solidity_variables import ( @@ -27,6 +27,7 @@ from slither.core.variables.state_variable import StateVariable from slither.utils.type import convert_type_for_solidity_signature_to_string from slither.utils.utils import unroll + # pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines if TYPE_CHECKING: @@ -45,6 +46,7 @@ if TYPE_CHECKING: from slither.slithir.operations import Operation from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope + from slither.slithir.variables.state_variable import StateIRVariable LOGGER = logging.getLogger("Function") ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) @@ -56,7 +58,7 @@ class ModifierStatements: modifier: Union["Contract", "Function"], entry_point: "Node", nodes: List["Node"], - ): + ) -> None: self._modifier = modifier self._entry_point = entry_point self._nodes = nodes @@ -116,7 +118,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu Function class """ - def __init__(self, compilation_unit: "SlitherCompilationUnit"): + def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: super().__init__() self._internal_scope: List[str] = [] self._name: Optional[str] = None @@ -370,7 +372,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### ################################################################################### - def set_function_type(self, t: FunctionType): + def set_function_type(self, t: FunctionType) -> None: assert isinstance(t, FunctionType) self._function_type = t @@ -455,7 +457,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def visibility(self, v: str): self._visibility = v - def set_visibility(self, v: str): + def set_visibility(self, v: str) -> None: self._visibility = v @property @@ -554,7 +556,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def entry_point(self, node: "Node"): self._entry_point = node - def add_node(self, node: "Node"): + def add_node(self, node: "Node") -> None: if not self._entry_point: self._entry_point = node self._nodes.append(node) @@ -598,7 +600,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return list(self._parameters) - def add_parameters(self, p: "LocalVariable"): + def add_parameters(self, p: "LocalVariable") -> None: self._parameters.append(p) @property @@ -608,7 +610,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return list(self._parameters_ssa) - def add_parameter_ssa(self, var: "LocalIRVariable"): + def add_parameter_ssa(self, var: "LocalIRVariable") -> None: self._parameters_ssa.append(var) def parameters_src(self) -> SourceMapping: @@ -651,7 +653,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return list(self._returns) - def add_return(self, r: "LocalVariable"): + def add_return(self, r: "LocalVariable") -> None: self._returns.append(r) @property @@ -661,7 +663,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return list(self._returns_ssa) - def add_return_ssa(self, var: "LocalIRVariable"): + def add_return_ssa(self, var: "LocalIRVariable") -> None: self._returns_ssa.append(var) # endregion @@ -680,7 +682,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu """ return [c.modifier for c in self._modifiers] - def add_modifier(self, modif: "ModifierStatements"): + def add_modifier(self, modif: "ModifierStatements") -> None: self._modifiers.append(modif) @property @@ -714,7 +716,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu # This is a list of contracts internally, so we convert it to a list of constructor functions. return list(self._explicit_base_constructor_calls) - def add_explicit_base_constructor_calls_statements(self, modif: ModifierStatements): + def add_explicit_base_constructor_calls_statements(self, modif: ModifierStatements) -> None: self._explicit_base_constructor_calls.append(modif) # endregion @@ -1057,7 +1059,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu self._all_reachable_from_functions = functions return self._all_reachable_from_functions - def add_reachable_from_node(self, n: "Node", ir: "Operation"): + def add_reachable_from_node(self, n: "Node", ir: "Operation") -> None: self._reachable_from_nodes.add(ReacheableNode(n, ir)) self._reachable_from_functions.add(n.function) @@ -1068,7 +1070,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### ################################################################################### - def _explore_functions(self, f_new_values: Callable[["Function"], List]): + def _explore_functions(self, f_new_values: Callable[["Function"], List]) -> List[Any]: values = f_new_values(self) explored = [self] to_explore = [ @@ -1218,11 +1220,11 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu func: "Function", f: Callable[["Node"], List[SolidityVariable]], include_loop: bool, - ): + ) -> List[Any]: ret = [f(n) for n in func.nodes if n.is_conditional(include_loop)] return [item for sublist in ret for item in sublist] - def all_conditional_solidity_variables_read(self, include_loop=True) -> List[SolidityVariable]: + def all_conditional_solidity_variables_read(self, include_loop: bool=True) -> List[SolidityVariable]: """ Return the Soldiity variables directly used in a condtion @@ -1258,7 +1260,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu return [var for var in ret if isinstance(var, SolidityVariable)] @staticmethod - def _explore_func_nodes(func: "Function", f: Callable[["Node"], List[SolidityVariable]]): + def _explore_func_nodes(func: "Function", f: Callable[["Node"], List[SolidityVariable]]) -> List[Union[Any, SolidityVariableComposed]]: ret = [f(n) for n in func.nodes] return [item for sublist in ret for item in sublist] @@ -1367,7 +1369,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu with open(filename, "w", encoding="utf8") as f: f.write(content) - def slithir_cfg_to_dot_str(self, skip_expressions=False) -> str: + def slithir_cfg_to_dot_str(self, skip_expressions: bool=False) -> str: """ Export the CFG to a DOT format. The nodes includes the Solidity expressions and the IRs :return: the DOT content @@ -1512,7 +1514,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### ################################################################################### - def _analyze_read_write(self): + def _analyze_read_write(self) -> None: """Compute variables read/written/...""" write_var = [x.variables_written_as_expression for x in self.nodes] write_var = [x for x in write_var if x] @@ -1570,7 +1572,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu slithir_variables = [x for x in slithir_variables if x] self._slithir_variables = [item for sublist in slithir_variables for item in sublist] - def _analyze_calls(self): + def _analyze_calls(self) -> None: calls = [x.calls_as_expression for x in self.nodes] calls = [x for x in calls if x] calls = [item for sublist in calls for item in sublist] @@ -1702,7 +1704,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu return self._get_last_ssa_variable_instances(target_state=False, target_local=True) @staticmethod - def _unchange_phi(ir: "Operation"): + def _unchange_phi(ir: "Operation") -> bool: from slither.slithir.operations import Phi, PhiCallback if not isinstance(ir, (Phi, PhiCallback)) or len(ir.rvalues) > 1: @@ -1711,7 +1713,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu return True return ir.rvalues[0] == ir.lvalue - def fix_phi(self, last_state_variables_instances, initial_state_variables_instances): + def fix_phi(self, last_state_variables_instances: Dict[str, Union[List[Any], List[Union[Any, "StateIRVariable"]], List[ "StateIRVariable"]]], initial_state_variables_instances: Dict[str, "StateIRVariable"]) -> None: from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.variables import Constant, StateIRVariable @@ -1745,7 +1747,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)] - def generate_slithir_and_analyze(self): + def generate_slithir_and_analyze(self) -> None: for node in self.nodes: node.slithir_generation() @@ -1756,7 +1758,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def generate_slithir_ssa(self, all_ssa_state_variables_instances): pass - def update_read_write_using_ssa(self): + def update_read_write_using_ssa(self) -> None: for node in self.nodes: node.update_read_write_using_ssa() self._analyze_read_write() @@ -1767,7 +1769,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ################################################################################### ################################################################################### - def __str__(self): + def __str__(self) -> str: return self.name # endregion diff --git a/slither/core/declarations/function_contract.py b/slither/core/declarations/function_contract.py index 2235bd227..9a607e291 100644 --- a/slither/core/declarations/function_contract.py +++ b/slither/core/declarations/function_contract.py @@ -1,17 +1,19 @@ """ Function module """ -from typing import TYPE_CHECKING, List, Tuple +from typing import Dict, TYPE_CHECKING, List, Tuple from slither.core.children.child_contract import ChildContract from slither.core.children.child_inheritance import ChildInheritance from slither.core.declarations import Function + # pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines if TYPE_CHECKING: from slither.core.declarations import Contract from slither.core.scope.scope import FileScope + from slither.slithir.variables.state_variable import StateIRVariable class FunctionContract(Function, ChildContract, ChildInheritance): @@ -96,7 +98,7 @@ class FunctionContract(Function, ChildContract, ChildInheritance): ################################################################################### ################################################################################### - def generate_slithir_ssa(self, all_ssa_state_variables_instances): + def generate_slithir_ssa(self, all_ssa_state_variables_instances: Dict[str, "StateIRVariable"]) -> None: from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa from slither.core.dominators.utils import ( compute_dominance_frontier, diff --git a/slither/core/declarations/function_top_level.py b/slither/core/declarations/function_top_level.py index d71033069..93d7a983d 100644 --- a/slither/core/declarations/function_top_level.py +++ b/slither/core/declarations/function_top_level.py @@ -1,7 +1,7 @@ """ Function module """ -from typing import List, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Tuple, TYPE_CHECKING from slither.core.declarations import Function from slither.core.declarations.top_level import TopLevel @@ -12,7 +12,7 @@ if TYPE_CHECKING: class FunctionTopLevel(Function, TopLevel): - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__(compilation_unit) self._scope: "FileScope" = scope @@ -78,7 +78,7 @@ class FunctionTopLevel(Function, TopLevel): ################################################################################### ################################################################################### - def generate_slithir_ssa(self, all_ssa_state_variables_instances): + def generate_slithir_ssa(self, all_ssa_state_variables_instances: Dict[Any, Any]) -> None: # pylint: disable=import-outside-toplevel from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa from slither.core.dominators.utils import ( diff --git a/slither/core/declarations/import_directive.py b/slither/core/declarations/import_directive.py index 745f8007f..75c0406fe 100644 --- a/slither/core/declarations/import_directive.py +++ b/slither/core/declarations/import_directive.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: class Import(SourceMapping): - def __init__(self, filename: Path, scope: "FileScope"): + def __init__(self, filename: Path, scope: "FileScope") -> None: super().__init__() self._filename: Path = filename self._alias: Optional[str] = None diff --git a/slither/core/declarations/pragma_directive.py b/slither/core/declarations/pragma_directive.py index 602dab6b2..f6ecef490 100644 --- a/slither/core/declarations/pragma_directive.py +++ b/slither/core/declarations/pragma_directive.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: class Pragma(SourceMapping): - def __init__(self, directive: List[str], scope: "FileScope"): + def __init__(self, directive: List[str], scope: "FileScope") -> None: super().__init__() self._directive = directive self.scope: "FileScope" = scope diff --git a/slither/core/declarations/solidity_import_placeholder.py b/slither/core/declarations/solidity_import_placeholder.py index 070d3fff3..d5ee1dbae 100644 --- a/slither/core/declarations/solidity_import_placeholder.py +++ b/slither/core/declarations/solidity_import_placeholder.py @@ -4,6 +4,11 @@ Special variable to model import with renaming from slither.core.declarations import Import from slither.core.solidity_types import ElementaryType from slither.core.variables.variable import Variable +import slither.core.declarations.import_directive +import slither.core.solidity_types.elementary_type +from slither.core.declarations.contract import Contract +from slither.core.declarations.solidity_variables import SolidityVariable +from typing import Union class SolidityImportPlaceHolder(Variable): @@ -13,7 +18,7 @@ class SolidityImportPlaceHolder(Variable): In the long term we should remove this and better integrate import aliases """ - def __init__(self, import_directive: Import): + def __init__(self, import_directive: Import) -> None: super().__init__() assert import_directive.alias is not None self._import_directive = import_directive @@ -27,7 +32,7 @@ class SolidityImportPlaceHolder(Variable): def type(self) -> ElementaryType: return ElementaryType("string") - def __eq__(self, other): + def __eq__(self, other: Union[Contract, SolidityVariable]) -> bool: return ( self.__class__ == other.__class__ and self._import_directive.filename == self._import_directive.filename diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index 3a5db010c..9569cde93 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -1,13 +1,11 @@ # https://solidity.readthedocs.io/en/v0.4.24/units-and-global-variables.html -from typing import List, Dict, Union, TYPE_CHECKING +from typing import List, Dict, Union, Any from slither.core.declarations.custom_error import CustomError from slither.core.solidity_types import ElementaryType, TypeInformation from slither.core.source_mapping.source_mapping import SourceMapping from slither.exceptions import SlitherException -if TYPE_CHECKING: - pass SOLIDITY_VARIABLES = { "now": "uint256", @@ -98,13 +96,13 @@ def solidity_function_signature(name): class SolidityVariable(SourceMapping): - def __init__(self, name: str): + def __init__(self, name: str) -> None: super().__init__() self._check_name(name) self._name = name # dev function, will be removed once the code is stable - def _check_name(self, name: str): # pylint: disable=no-self-use + def _check_name(self, name: str) -> None: # pylint: disable=no-self-use assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset")) @property @@ -124,18 +122,18 @@ class SolidityVariable(SourceMapping): def type(self) -> ElementaryType: return ElementaryType(SOLIDITY_VARIABLES[self.name]) - def __str__(self): + def __str__(self) -> str: return self._name - def __eq__(self, other): + def __eq__(self, other: SourceMapping) -> bool: return self.__class__ == other.__class__ and self.name == other.name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) class SolidityVariableComposed(SolidityVariable): - def _check_name(self, name: str): + def _check_name(self, name: str) -> None: assert name in SOLIDITY_VARIABLES_COMPOSED @property @@ -146,13 +144,13 @@ class SolidityVariableComposed(SolidityVariable): def type(self) -> ElementaryType: return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name]) - def __str__(self): + def __str__(self) -> str: return self._name - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.__class__ == other.__class__ and self.name == other.name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) @@ -162,7 +160,7 @@ class SolidityFunction(SourceMapping): # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information # As a result, we set return_type during the Ir conversion - def __init__(self, name: str): + def __init__(self, name: str) -> None: super().__init__() assert name in SOLIDITY_FUNCTIONS self._name = name @@ -187,28 +185,28 @@ class SolidityFunction(SourceMapping): def return_type(self, r: List[Union[TypeInformation, ElementaryType]]): self._return_type = r - def __str__(self): + def __str__(self) -> str: return self._name - def __eq__(self, other): + def __eq__(self, other: "SolidityFunction") -> bool: return self.__class__ == other.__class__ and self.name == other.name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) class SolidityCustomRevert(SolidityFunction): - def __init__(self, custom_error: CustomError): # pylint: disable=super-init-not-called + def __init__(self, custom_error: CustomError) -> None: # pylint: disable=super-init-not-called self._name = "revert " + custom_error.solidity_signature self._custom_error = custom_error self._return_type: List[Union[TypeInformation, ElementaryType]] = [] - def __eq__(self, other): + def __eq__(self, other: Union["SolidityCustomRevert", SolidityFunction]) -> bool: return ( self.__class__ == other.__class__ and self.name == other.name and self._custom_error == other._custom_error ) - def __hash__(self): + def __hash__(self) -> int: return hash(hash(self.name) + hash(self._custom_error)) diff --git a/slither/core/declarations/structure_top_level.py b/slither/core/declarations/structure_top_level.py index f06cd2318..f4e2b8a9c 100644 --- a/slither/core/declarations/structure_top_level.py +++ b/slither/core/declarations/structure_top_level.py @@ -9,6 +9,6 @@ if TYPE_CHECKING: class StructureTopLevel(Structure, TopLevel): - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__(compilation_unit) self.file_scope: "FileScope" = scope diff --git a/slither/core/declarations/using_for_top_level.py b/slither/core/declarations/using_for_top_level.py index a1b43e1c1..2c8dfc9b8 100644 --- a/slither/core/declarations/using_for_top_level.py +++ b/slither/core/declarations/using_for_top_level.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: class UsingForTopLevel(TopLevel): - def __init__(self, scope: "FileScope"): + def __init__(self, scope: "FileScope") -> None: super().__init__() self._using_for: Dict[Union[str, Type], List[Type]] = {} self.file_scope: "FileScope" = scope diff --git a/slither/core/dominators/utils.py b/slither/core/dominators/utils.py index 837fe46ea..5ad480f07 100644 --- a/slither/core/dominators/utils.py +++ b/slither/core/dominators/utils.py @@ -1,4 +1,4 @@ -from typing import List, TYPE_CHECKING +from typing import Set, List, TYPE_CHECKING from slither.core.cfg.node import NodeType @@ -6,7 +6,7 @@ if TYPE_CHECKING: from slither.core.cfg.node import Node -def intersection_predecessor(node: "Node"): +def intersection_predecessor(node: "Node") -> Set[Node]: if not node.fathers: return set() ret = node.fathers[0].dominators @@ -15,7 +15,7 @@ def intersection_predecessor(node: "Node"): return ret -def _compute_dominators(nodes: List["Node"]): +def _compute_dominators(nodes: List["Node"]) -> None: changed = True while changed: @@ -28,7 +28,7 @@ def _compute_dominators(nodes: List["Node"]): changed = True -def _compute_immediate_dominators(nodes: List["Node"]): +def _compute_immediate_dominators(nodes: List["Node"]) -> None: for node in nodes: idom_candidates = set(node.dominators) idom_candidates.remove(node) @@ -58,7 +58,7 @@ def _compute_immediate_dominators(nodes: List["Node"]): idom.dominator_successors.add(node) -def compute_dominators(nodes: List["Node"]): +def compute_dominators(nodes: List["Node"]) -> None: """ Naive implementation of Cooper, Harvey, Kennedy algo See 'A Simple,Fast Dominance Algorithm' @@ -74,7 +74,7 @@ def compute_dominators(nodes: List["Node"]): _compute_immediate_dominators(nodes) -def compute_dominance_frontier(nodes: List["Node"]): +def compute_dominance_frontier(nodes: List["Node"]) -> None: """ Naive implementation of Cooper, Harvey, Kennedy algo See 'A Simple,Fast Dominance Algorithm' diff --git a/slither/core/expressions/assignment_operation.py b/slither/core/expressions/assignment_operation.py index b5fd3f4a3..22aba57fb 100644 --- a/slither/core/expressions/assignment_operation.py +++ b/slither/core/expressions/assignment_operation.py @@ -85,7 +85,7 @@ class AssignmentOperation(ExpressionTyped): right_expression: Expression, expression_type: AssignmentOperationType, expression_return_type: Optional["Type"], - ): + ) -> None: assert isinstance(left_expression, Expression) assert isinstance(right_expression, Expression) super().__init__() diff --git a/slither/core/expressions/call_expression.py b/slither/core/expressions/call_expression.py index 8ab6668fe..1dbc4074a 100644 --- a/slither/core/expressions/call_expression.py +++ b/slither/core/expressions/call_expression.py @@ -1,10 +1,10 @@ -from typing import Optional, List +from typing import Any, Optional, List from slither.core.expressions.expression import Expression class CallExpression(Expression): # pylint: disable=too-many-instance-attributes - def __init__(self, called, arguments, type_call): + def __init__(self, called: Expression, arguments: List[Any], type_call: str) -> None: assert isinstance(called, Expression) super().__init__() self._called: Expression = called @@ -53,7 +53,7 @@ class CallExpression(Expression): # pylint: disable=too-many-instance-attribute def type_call(self) -> str: return self._type_call - def __str__(self): + def __str__(self) -> str: txt = str(self._called) if self.call_gas or self.call_value: gas = f"gas: {self.call_gas}" if self.call_gas else "" diff --git a/slither/core/expressions/conditional_expression.py b/slither/core/expressions/conditional_expression.py index adcf8bb1f..6ad668160 100644 --- a/slither/core/expressions/conditional_expression.py +++ b/slither/core/expressions/conditional_expression.py @@ -1,10 +1,17 @@ -from typing import List +from typing import Union, List from .expression import Expression +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.expression import Expression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.literal import Literal +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation class ConditionalExpression(Expression): - def __init__(self, if_expression, then_expression, else_expression): + def __init__(self, if_expression: Union[BinaryOperation, Identifier, Literal], then_expression: Union["ConditionalExpression", TypeConversion, Literal, TupleExpression, Identifier], else_expression: Union[TupleExpression, UnaryOperation, Identifier, Literal]) -> None: assert isinstance(if_expression, Expression) assert isinstance(then_expression, Expression) assert isinstance(else_expression, Expression) diff --git a/slither/core/expressions/elementary_type_name_expression.py b/slither/core/expressions/elementary_type_name_expression.py index 0a310c86a..9a93f0839 100644 --- a/slither/core/expressions/elementary_type_name_expression.py +++ b/slither/core/expressions/elementary_type_name_expression.py @@ -3,10 +3,11 @@ """ from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type +from slither.core.solidity_types.elementary_type import ElementaryType class ElementaryTypeNameExpression(Expression): - def __init__(self, t): + def __init__(self, t: ElementaryType) -> None: assert isinstance(t, Type) super().__init__() self._type = t @@ -20,5 +21,5 @@ class ElementaryTypeNameExpression(Expression): assert isinstance(new_type, Type) self._type = new_type - def __str__(self): + def __str__(self) -> str: return str(self._type) diff --git a/slither/core/expressions/index_access.py b/slither/core/expressions/index_access.py index 49b2a8dd9..f2933cb37 100644 --- a/slither/core/expressions/index_access.py +++ b/slither/core/expressions/index_access.py @@ -1,6 +1,8 @@ -from typing import List, TYPE_CHECKING +from typing import Union, List, TYPE_CHECKING from slither.core.expressions.expression_typed import ExpressionTyped +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.literal import Literal if TYPE_CHECKING: @@ -9,7 +11,7 @@ if TYPE_CHECKING: class IndexAccess(ExpressionTyped): - def __init__(self, left_expression, right_expression, index_type): + def __init__(self, left_expression: Union["IndexAccess", Identifier], right_expression: Union[Literal, Identifier], index_type: str) -> None: super().__init__() self._expressions = [left_expression, right_expression] # TODO type of undexAccess is not always a Type @@ -32,5 +34,5 @@ class IndexAccess(ExpressionTyped): def type(self) -> "Type": return self._type - def __str__(self): + def __str__(self) -> str: return str(self.expression_left) + "[" + str(self.expression_right) + "]" diff --git a/slither/core/expressions/literal.py b/slither/core/expressions/literal.py index 2eaeb715d..5dace3c41 100644 --- a/slither/core/expressions/literal.py +++ b/slither/core/expressions/literal.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: class Literal(Expression): def __init__( self, value: Union[int, str], custom_type: "Type", subdenomination: Optional[str] = None - ): + ) -> None: super().__init__() self._value = value self._type = custom_type diff --git a/slither/core/expressions/member_access.py b/slither/core/expressions/member_access.py index 73bd92641..36d6818b2 100644 --- a/slither/core/expressions/member_access.py +++ b/slither/core/expressions/member_access.py @@ -5,7 +5,7 @@ from slither.core.solidity_types.type import Type class MemberAccess(ExpressionTyped): - def __init__(self, member_name, member_type, expression): + def __init__(self, member_name: str, member_type: str, expression: Expression) -> None: # assert isinstance(member_type, Type) # TODO member_type is not always a Type assert isinstance(expression, Expression) @@ -26,5 +26,5 @@ class MemberAccess(ExpressionTyped): def type(self) -> Type: return self._type - def __str__(self): + def __str__(self) -> str: return str(self.expression) + "." + self.member_name diff --git a/slither/core/expressions/new_array.py b/slither/core/expressions/new_array.py index 59c5780d4..9bb8893ff 100644 --- a/slither/core/expressions/new_array.py +++ b/slither/core/expressions/new_array.py @@ -1,11 +1,18 @@ +from typing import Union, TYPE_CHECKING + + from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.solidity_types.elementary_type import ElementaryType + from slither.core.solidity_types.type_alias import TypeAliasTopLevel + class NewArray(Expression): # note: dont conserve the size of the array if provided - def __init__(self, depth, array_type): + def __init__(self, depth: int, array_type: Union["TypeAliasTopLevel", "ElementaryType"]) -> None: super().__init__() assert isinstance(array_type, Type) self._depth: int = depth diff --git a/slither/core/expressions/new_contract.py b/slither/core/expressions/new_contract.py index 762a3dcfe..70f930aad 100644 --- a/slither/core/expressions/new_contract.py +++ b/slither/core/expressions/new_contract.py @@ -2,7 +2,7 @@ from slither.core.expressions.expression import Expression class NewContract(Expression): - def __init__(self, contract_name): + def __init__(self, contract_name: str) -> None: super().__init__() self._contract_name: str = contract_name self._gas = None @@ -29,5 +29,5 @@ class NewContract(Expression): def call_salt(self, salt): self._salt = salt - def __str__(self): + def __str__(self) -> str: return "new " + str(self._contract_name) diff --git a/slither/core/expressions/type_conversion.py b/slither/core/expressions/type_conversion.py index 97c70f45b..e52b03d95 100644 --- a/slither/core/expressions/type_conversion.py +++ b/slither/core/expressions/type_conversion.py @@ -1,10 +1,20 @@ +from typing import Union, TYPE_CHECKING + from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.expressions.call_expression import CallExpression + from slither.core.expressions.identifier import Identifier + from slither.core.expressions.literal import Literal + from slither.core.expressions.member_access import MemberAccess + from slither.core.solidity_types.elementary_type import ElementaryType + from slither.core.solidity_types.type_alias import TypeAliasContract + from slither.core.solidity_types.user_defined_type import UserDefinedType class TypeConversion(ExpressionTyped): - def __init__(self, expression, expression_type): + def __init__(self, expression: Union["MemberAccess", "Literal", "CallExpression", "TypeConversion", "Identifier"], expression_type: Union["ElementaryType", "UserDefinedType", "TypeAliasContract"]) -> None: super().__init__() assert isinstance(expression, Expression) assert isinstance(expression_type, Type) @@ -15,5 +25,5 @@ class TypeConversion(ExpressionTyped): def expression(self) -> Expression: return self._expression - def __str__(self): + def __str__(self) -> str: return str(self.type) + "(" + str(self.expression) + ")" diff --git a/slither/core/expressions/unary_operation.py b/slither/core/expressions/unary_operation.py index 596d9dbf0..99d23d23f 100644 --- a/slither/core/expressions/unary_operation.py +++ b/slither/core/expressions/unary_operation.py @@ -4,6 +4,11 @@ from enum import Enum from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.tuple_expression import TupleExpression +from typing import Union logger = logging.getLogger("UnaryOperation") @@ -20,7 +25,7 @@ class UnaryOperationType(Enum): MINUS_PRE = 8 # for stuff like uint(-1) @staticmethod - def get_type(operation_type, isprefix): + def get_type(operation_type: str, isprefix: bool) -> "UnaryOperationType": if isprefix: if operation_type == "!": return UnaryOperationType.BANG @@ -86,7 +91,7 @@ class UnaryOperationType(Enum): class UnaryOperation(ExpressionTyped): - def __init__(self, expression, expression_type): + def __init__(self, expression: Union[Literal, Identifier, IndexAccess, TupleExpression], expression_type: UnaryOperationType) -> None: assert isinstance(expression, Expression) super().__init__() self._expression: Expression = expression diff --git a/slither/core/scope/scope.py b/slither/core/scope/scope.py index 1eb344c2b..cafeb3585 100644 --- a/slither/core/scope/scope.py +++ b/slither/core/scope/scope.py @@ -25,7 +25,7 @@ def _dict_contain(d1: Dict, d2: Dict) -> bool: # pylint: disable=too-many-instance-attributes class FileScope: - def __init__(self, filename: Filename): + def __init__(self, filename: Filename) -> None: self.filename = filename self.accessible_scopes: List[FileScope] = [] diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index cba86e56e..ff6df04f0 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -40,7 +40,7 @@ class SlitherCore(Context): Slither static analyzer """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._filename: Optional[str] = None diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index 59a15dcc6..26dfa9d31 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -1,13 +1,20 @@ -from typing import Optional, Tuple +from typing import Union, Optional, Tuple, Any, TYPE_CHECKING -from slither.core.expressions import Literal from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type from slither.visitors.expression.constants_folding import ConstantFolding +if TYPE_CHECKING: + from slither.core.expressions.literal import Literal + from slither.core.expressions.binary_operation import BinaryOperation + from slither.core.expressions.identifier import Identifier + from slither.core.solidity_types.elementary_type import ElementaryType + from slither.core.solidity_types.function_type import FunctionType + from slither.core.solidity_types.type_alias import TypeAliasTopLevel + class ArrayType(Type): - def __init__(self, t, length): + def __init__(self, t: Union["TypeAliasTopLevel", "ArrayType", "FunctionType", "ElementaryType"], length: Optional[Union["Identifier", "Literal", "BinaryOperation"]]) -> None: assert isinstance(t, Type) if length: if isinstance(length, int): @@ -38,7 +45,7 @@ class ArrayType(Type): return self._length @property - def length_value(self) -> Optional[Literal]: + def length_value(self) -> Optional["Literal"]: return self._length_value @property @@ -56,15 +63,15 @@ class ArrayType(Type): return elem_size * int(str(self._length_value)), True return 32, True - def __str__(self): + def __str__(self) -> str: if self._length: return str(self._type) + f"[{str(self._length_value)}]" return str(self._type) + "[]" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, ArrayType): return False return self._type == other.type and self.length == other.length - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) diff --git a/slither/core/solidity_types/elementary_type.py b/slither/core/solidity_types/elementary_type.py index fc248e946..ec2b0ef04 100644 --- a/slither/core/solidity_types/elementary_type.py +++ b/slither/core/solidity_types/elementary_type.py @@ -216,13 +216,13 @@ class ElementaryType(Type): return MaxValues[self.name] raise SlitherException(f"{self.name} does not have a max value") - def __str__(self): + def __str__(self) -> str: return self._type - def __eq__(self, other): + def __eq__(self, other) -> bool: if not isinstance(other, ElementaryType): return False return self.type == other.type - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) diff --git a/slither/core/solidity_types/function_type.py b/slither/core/solidity_types/function_type.py index 3146aa0bf..4682bcf19 100644 --- a/slither/core/solidity_types/function_type.py +++ b/slither/core/solidity_types/function_type.py @@ -2,6 +2,7 @@ from typing import List, Tuple from slither.core.solidity_types.type import Type from slither.core.variables.function_type_variable import FunctionTypeVariable +from slither.core.solidity_types.elementary_type import ElementaryType class FunctionType(Type): @@ -9,7 +10,7 @@ class FunctionType(Type): self, params: List[FunctionTypeVariable], return_values: List[FunctionTypeVariable], - ): + ) -> None: assert all(isinstance(x, FunctionTypeVariable) for x in params) assert all(isinstance(x, FunctionTypeVariable) for x in return_values) super().__init__() @@ -68,7 +69,7 @@ class FunctionType(Type): return f"({params}) returns({return_values})" return f"({params})" - def __eq__(self, other): + def __eq__(self, other: ElementaryType) -> bool: if not isinstance(other, FunctionType): return False return self.params == other.params and self.return_values == other.return_values diff --git a/slither/core/solidity_types/mapping_type.py b/slither/core/solidity_types/mapping_type.py index fea5bdb7b..92b27fd04 100644 --- a/slither/core/solidity_types/mapping_type.py +++ b/slither/core/solidity_types/mapping_type.py @@ -1,10 +1,13 @@ -from typing import Tuple +from typing import Union, Tuple, TYPE_CHECKING from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.solidity_types.elementary_type import ElementaryType + from slither.core.solidity_types.type_alias import TypeAliasTopLevel class MappingType(Type): - def __init__(self, type_from, type_to): + def __init__(self, type_from: "ElementaryType", type_to: Union["MappingType", "TypeAliasTopLevel", "ElementaryType"]) -> None: assert isinstance(type_from, Type) assert isinstance(type_to, Type) super().__init__() @@ -27,7 +30,7 @@ class MappingType(Type): def is_dynamic(self) -> bool: return True - def __str__(self): + def __str__(self) -> str: return f"mapping({str(self._from)} => {str(self._to)})" def __eq__(self, other): @@ -35,5 +38,5 @@ class MappingType(Type): return False return self.type_from == other.type_from and self.type_to == other.type_to - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 128d12597..b202d3608 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: class TypeAlias(Type): - def __init__(self, underlying_type: Type, name: str): + def __init__(self, underlying_type: Type, name: str) -> None: super().__init__() self.name = name self.underlying_type = underlying_type @@ -31,7 +31,7 @@ class TypeAlias(Type): def storage_size(self) -> Tuple[int, bool]: return self.underlying_type.storage_size - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) @property @@ -40,18 +40,18 @@ class TypeAlias(Type): class TypeAliasTopLevel(TypeAlias, TopLevel): - def __init__(self, underlying_type: Type, name: str, scope: "FileScope"): + def __init__(self, underlying_type: Type, name: str, scope: "FileScope") -> None: super().__init__(underlying_type, name) self.file_scope: "FileScope" = scope - def __str__(self): + def __str__(self) -> str: return self.name class TypeAliasContract(TypeAlias, ChildContract): - def __init__(self, underlying_type: Type, name: str, contract: "Contract"): + def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: super().__init__(underlying_type, name) self._contract: "Contract" = contract - def __str__(self): + def __str__(self) -> str: return self.contract.name + "." + self.name diff --git a/slither/core/solidity_types/type_information.py b/slither/core/solidity_types/type_information.py index 0477bb7e6..dd0aeec43 100644 --- a/slither/core/solidity_types/type_information.py +++ b/slither/core/solidity_types/type_information.py @@ -1,16 +1,17 @@ -from typing import TYPE_CHECKING, Tuple +from typing import Union, TYPE_CHECKING, Tuple from slither.core.solidity_types import ElementaryType from slither.core.solidity_types.type import Type if TYPE_CHECKING: from slither.core.declarations.contract import Contract + from slither.core.declarations.enum_top_level import EnumTopLevel # Use to model the Type(X) function, which returns an undefined type # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information class TypeInformation(Type): - def __init__(self, c): + def __init__(self, c: Union[ElementaryType, "Contract", "EnumTopLevel"]) -> None: # pylint: disable=import-outside-toplevel from slither.core.declarations.contract import Contract from slither.core.declarations.enum import Enum diff --git a/slither/core/solidity_types/user_defined_type.py b/slither/core/solidity_types/user_defined_type.py index 38300cdd9..618d63480 100644 --- a/slither/core/solidity_types/user_defined_type.py +++ b/slither/core/solidity_types/user_defined_type.py @@ -1,4 +1,4 @@ -from typing import Union, TYPE_CHECKING, Tuple +from typing import Union, TYPE_CHECKING, Tuple, Any import math from slither.core.solidity_types.type import Type @@ -8,10 +8,12 @@ if TYPE_CHECKING: from slither.core.declarations.structure import Structure from slither.core.declarations.enum import Enum from slither.core.declarations.contract import Contract + from slither.core.declarations import EnumContract + from slither.core.declarations.structure_top_level import StructureTopLevel # pylint: disable=import-outside-toplevel class UserDefinedType(Type): - def __init__(self, t): + def __init__(self, t: Union["EnumContract", "StructureTopLevel", "Contract", "StructureContract"]) -> None: from slither.core.declarations.structure import Structure from slither.core.declarations.enum import Enum from slither.core.declarations.contract import Contract @@ -62,7 +64,7 @@ class UserDefinedType(Type): to_log = f"{self} does not have storage size" raise SlitherException(to_log) - def __str__(self): + def __str__(self) -> str: from slither.core.declarations.structure_contract import StructureContract from slither.core.declarations.enum_contract import EnumContract @@ -71,10 +73,10 @@ class UserDefinedType(Type): return str(type_used.contract) + "." + str(type_used.name) return str(type_used.name) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, UserDefinedType): return False return self.type == other.type - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index ca2f40570..f3ad60d0b 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -3,7 +3,7 @@ from slither.core.children.child_event import ChildEvent class EventVariable(ChildEvent, Variable): - def __init__(self): + def __init__(self) -> None: super().__init__() self._indexed = False diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index c9a90f36b..47b7682a4 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: class StateVariable(ChildContract, Variable): - def __init__(self): + def __init__(self) -> None: super().__init__() self._node_initialization: Optional["Node"] = None diff --git a/slither/core/variables/top_level_variable.py b/slither/core/variables/top_level_variable.py index 6d821092e..e6447c1ef 100644 --- a/slither/core/variables/top_level_variable.py +++ b/slither/core/variables/top_level_variable.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: class TopLevelVariable(TopLevel, Variable): - def __init__(self, scope: "FileScope"): + def __init__(self, scope: "FileScope") -> None: super().__init__() self._node_initialization: Optional["Node"] = None self.file_scope = scope diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 5fda02e93..8607a8921 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: # pylint: disable=too-many-instance-attributes class Variable(SourceMapping): - def __init__(self): + def __init__(self) -> None: super().__init__() self._name: Optional[str] = None self._initial_expression: Optional["Expression"] = None diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 95d113a84..4a2f1f28f 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -31,6 +31,9 @@ from slither.slithir.operations import ( from slither.slithir.operations.binary import Binary from slither.slithir.variables import Constant from slither.visitors.expression.constants_folding import ConstantFolding +import slither.core.declarations.function +import slither.slithir.operations.operation +from slither.utils.output import Output def _get_name(f: Union[Function, Variable]) -> str: @@ -168,7 +171,7 @@ def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-n all_cst_used: List[ConstantValue], all_cst_used_in_binary: Dict[str, List[ConstantValue]], context_explored: Set[Node], -): +) -> None: for ir in irs: if isinstance(ir, Binary): for r in ir.read: @@ -364,7 +367,7 @@ class Echidna(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#echidna" - def output(self, filename): # pylint: disable=too-many-locals + def output(self, filename: str) -> Output: # pylint: disable=too-many-locals """ Output the inheritance relation diff --git a/slither/slither.py b/slither/slither.py index 93cf16394..14e67a8ce 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -12,6 +12,7 @@ from slither.exceptions import SlitherError from slither.printers.abstract_printer import AbstractPrinter from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc from slither.utils.output import Output +import crytic_compile.crytic_compile logger = logging.getLogger("Slither") logging.basicConfig() @@ -49,7 +50,7 @@ def _update_file_scopes(candidates: ValuesView[FileScope]): class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes - def __init__(self, target: Union[str, CryticCompile], **kwargs): + def __init__(self, target: Union[str, CryticCompile], **kwargs) -> None: """ Args: target (str | CryticCompile) diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 89f85499c..17dfe7675 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import List, TYPE_CHECKING, Union, Optional +from typing import Any, List, TYPE_CHECKING, Union, Optional # pylint: disable= too-many-lines,import-outside-toplevel,too-many-branches,too-many-statements,too-many-nested-blocks from slither.core.declarations import ( @@ -34,7 +34,7 @@ from slither.core.solidity_types.elementary_type import ( MaxValues, ) from slither.core.solidity_types.type import Type -from slither.core.solidity_types.type_alias import TypeAlias +from slither.core.solidity_types.type_alias import TypeAliasTopLevel, TypeAlias from slither.core.variables.function_type_variable import FunctionTypeVariable from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable @@ -83,6 +83,28 @@ from slither.slithir.variables import TupleVariable from slither.utils.function import get_function_id from slither.utils.type import export_nested_types_from_variable from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR +import slither.core.declarations.contract +import slither.core.declarations.function +import slither.core.solidity_types.elementary_type +import slither.core.solidity_types.function_type +import slither.core.solidity_types.user_defined_type +import slither.slithir.operations.assignment +import slither.slithir.operations.binary +import slither.slithir.operations.call +import slither.slithir.operations.high_level_call +import slither.slithir.operations.index +import slither.slithir.operations.init_array +import slither.slithir.operations.internal_call +import slither.slithir.operations.length +import slither.slithir.operations.library_call +import slither.slithir.operations.low_level_call +import slither.slithir.operations.member +import slither.slithir.operations.operation +import slither.slithir.operations.send +import slither.slithir.operations.solidity_call +import slither.slithir.operations.transfer +import slither.slithir.variables.temporary +from slither.core.expressions.expression import Expression if TYPE_CHECKING: from slither.core.cfg.node import Node @@ -91,7 +113,7 @@ if TYPE_CHECKING: logger = logging.getLogger("ConvertToIR") -def convert_expression(expression, node): +def convert_expression(expression: Expression, node: "Node") -> List[Any]: # handle standlone expression # such as return true; from slither.core.cfg.node import NodeType @@ -143,7 +165,7 @@ def convert_expression(expression, node): ################################################################################### -def is_value(ins): +def is_value(ins: slither.slithir.operations.operation.Operation) -> bool: if isinstance(ins, TmpCall): if isinstance(ins.ori, Member): if ins.ori.variable_right == "value": @@ -151,7 +173,7 @@ def is_value(ins): return False -def is_gas(ins): +def is_gas(ins: slither.slithir.operations.operation.Operation) -> bool: if isinstance(ins, TmpCall): if isinstance(ins.ori, Member): if ins.ori.variable_right == "gas": @@ -159,7 +181,7 @@ def is_gas(ins): return False -def _fits_under_integer(val: int, can_be_int: bool, can_be_uint) -> List[str]: +def _fits_under_integer(val: int, can_be_int: bool, can_be_uint: bool) -> List[str]: """ Return the list of uint/int that can contain val @@ -271,7 +293,7 @@ def _find_function_from_parameter( return None -def is_temporary(ins): +def is_temporary(ins: slither.slithir.operations.operation.Operation) -> bool: return isinstance( ins, (Argument, TmpNewElementaryType, TmpNewContract, TmpNewArray, TmpNewStructure), @@ -300,7 +322,7 @@ def _make_function_type(func: Function) -> FunctionType: ################################################################################### -def integrate_value_gas(result): +def integrate_value_gas(result: List[Any]) -> List[Any]: """ Integrate value and gas temporary arguments to call instruction """ @@ -504,7 +526,7 @@ def _convert_type_contract(ir: Member) -> Assignment: raise SlithIRError(f"type({contract.name}).{ir.variable_right} is unknown") -def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals +def propagate_types(ir: slither.slithir.operations.operation.Operation, node: "Node"): # pylint: disable=too-many-locals # propagate the type node_function = node.function using_for = ( @@ -813,7 +835,7 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals return None -def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: disable=too-many-locals +def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]) -> slither.slithir.operations.call.Call: # pylint: disable=too-many-locals assert isinstance(ins, TmpCall) if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): # If the call is made to a variable member, where the member is this @@ -1114,7 +1136,7 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis ################################################################################### -def can_be_low_level(ir): +def can_be_low_level(ir: slither.slithir.operations.high_level_call.HighLevelCall) -> bool: return ir.function_name in [ "transfer", "send", @@ -1125,7 +1147,7 @@ def can_be_low_level(ir): ] -def convert_to_low_level(ir): +def convert_to_low_level(ir: slither.slithir.operations.high_level_call.HighLevelCall) -> Union[slither.slithir.operations.send.Send, slither.slithir.operations.low_level_call.LowLevelCall, slither.slithir.operations.transfer.Transfer]: """ Convert to a transfer/send/or low level call The funciton assume to receive a correct IR @@ -1165,7 +1187,7 @@ def convert_to_low_level(ir): raise SlithIRError(f"Incorrect conversion to low level {ir}") -def can_be_solidity_func(ir) -> bool: +def can_be_solidity_func(ir: slither.slithir.operations.high_level_call.HighLevelCall) -> bool: if not isinstance(ir, HighLevelCall): return False return ir.destination.name == "abi" and ir.function_name in [ @@ -1178,7 +1200,7 @@ def can_be_solidity_func(ir) -> bool: ] -def convert_to_solidity_func(ir): +def convert_to_solidity_func(ir: slither.slithir.operations.high_level_call.HighLevelCall) -> slither.slithir.operations.solidity_call.SolidityCall: """ Must be called after can_be_solidity_func :param ir: @@ -1214,7 +1236,7 @@ def convert_to_solidity_func(ir): return new_ir -def convert_to_push_expand_arr(ir, node, ret): +def convert_to_push_expand_arr(ir: slither.slithir.operations.high_level_call.HighLevelCall, node: "Node", ret: List[Any]) -> slither.slithir.variables.temporary.TemporaryVariable: arr = ir.destination length = ReferenceVariable(node) @@ -1249,7 +1271,7 @@ def convert_to_push_expand_arr(ir, node, ret): return length_val -def convert_to_push_set_val(ir, node, length_val, ret): +def convert_to_push_set_val(ir: slither.slithir.operations.high_level_call.HighLevelCall, node: "Node", length_val: slither.slithir.variables.temporary.TemporaryVariable, ret: List[Union[slither.slithir.operations.length.Length, slither.slithir.operations.assignment.Assignment, slither.slithir.operations.binary.Binary]]) -> None: arr = ir.destination new_type = ir.destination.type.type @@ -1284,7 +1306,7 @@ def convert_to_push_set_val(ir, node, length_val, ret): ret.append(ir_assign_value) -def convert_to_push(ir, node): +def convert_to_push(ir: slither.slithir.operations.high_level_call.HighLevelCall, node: "Node") -> List[Union[slither.slithir.operations.length.Length, slither.slithir.operations.assignment.Assignment, slither.slithir.operations.binary.Binary, slither.slithir.operations.index.Index, slither.slithir.operations.init_array.InitArray]]: """ Convert a call to a series of operations to push a new value onto the array @@ -1358,7 +1380,7 @@ def convert_to_pop(ir, node): return ret -def look_for_library_or_top_level(contract, ir, using_for, t): +def look_for_library_or_top_level(contract: slither.core.declarations.contract.Contract, ir: slither.slithir.operations.high_level_call.HighLevelCall, using_for, t: Union[slither.core.solidity_types.user_defined_type.UserDefinedType, slither.core.solidity_types.elementary_type.ElementaryType, str, TypeAliasTopLevel]) -> Optional[Union[slither.slithir.operations.library_call.LibraryCall, slither.slithir.operations.internal_call.InternalCall]]: for destination in using_for[t]: if isinstance(destination, FunctionTopLevel) and destination.name == ir.function_name: arguments = [ir.destination] + ir.arguments @@ -1403,7 +1425,7 @@ def look_for_library_or_top_level(contract, ir, using_for, t): return None -def convert_to_library_or_top_level(ir, node, using_for): +def convert_to_library_or_top_level(ir: slither.slithir.operations.high_level_call.HighLevelCall, node: "Node", using_for) -> Optional[Union[slither.slithir.operations.library_call.LibraryCall, slither.slithir.operations.internal_call.InternalCall]]: # We use contract_declarer, because Solidity resolve the library # before resolving the inheritance. # Though we could use .contract as libraries cannot be shadowed @@ -1422,7 +1444,7 @@ def convert_to_library_or_top_level(ir, node, using_for): return None -def get_type(t): +def get_type(t: Union[slither.core.solidity_types.user_defined_type.UserDefinedType, slither.core.solidity_types.elementary_type.ElementaryType]) -> str: """ Convert a type to a str If the instance is a Contract, return 'address' instead @@ -1441,7 +1463,7 @@ def _can_be_implicitly_converted(source: str, target: str) -> bool: return source == target -def convert_type_library_call(ir: HighLevelCall, lib_contract: Contract): +def convert_type_library_call(ir: HighLevelCall, lib_contract: Contract) -> Optional[slither.slithir.operations.library_call.LibraryCall]: func = None candidates = [ f @@ -1652,7 +1674,7 @@ def convert_type_of_high_and_internal_level_call( ################################################################################### -def find_references_origin(irs): +def find_references_origin(irs: List[Any]) -> None: """ Make lvalue of each Index, Member operation points to the left variable @@ -1689,7 +1711,7 @@ def remove_temporary(result): return result -def remove_unused(result): +def remove_unused(result: List[Any]) -> List[Any]: removed = True if not result: @@ -1736,7 +1758,7 @@ def remove_unused(result): ################################################################################### -def convert_constant_types(irs): +def convert_constant_types(irs: List[Any]) -> None: """ late conversion of uint -> type for constant (Literal) :param irs: @@ -1812,7 +1834,7 @@ def convert_constant_types(irs): ################################################################################### -def convert_delete(irs): +def convert_delete(irs: List[Any]) -> None: """ Convert the lvalue of the Delete to point to the variable removed This can only be done after find_references_origin is called @@ -1833,7 +1855,7 @@ def convert_delete(irs): ################################################################################### -def _find_source_mapping_references(irs: List[Operation]): +def _find_source_mapping_references(irs: List[Operation]) -> None: for ir in irs: if isinstance(ir, NewContract): @@ -1848,7 +1870,7 @@ def _find_source_mapping_references(irs: List[Operation]): ################################################################################### -def apply_ir_heuristics(irs: List[Operation], node: "Node"): +def apply_ir_heuristics(irs: List[Operation], node: "Node") -> List[Any]: """ Apply a set of heuristic to improve slithIR """ diff --git a/slither/slithir/operations/assignment.py b/slither/slithir/operations/assignment.py index 3d05c3040..84eda0e12 100644 --- a/slither/slithir/operations/assignment.py +++ b/slither/slithir/operations/assignment.py @@ -4,12 +4,15 @@ from slither.core.declarations.function import Function from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.variables import TupleVariable, ReferenceVariable +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.variable import Variable +from typing import List logger = logging.getLogger("AssignmentOperationIR") class Assignment(OperationWithLValue): - def __init__(self, left_variable, right_variable, variable_return_type): + def __init__(self, left_variable: Variable, right_variable: SourceMapping, variable_return_type) -> None: assert is_valid_lvalue(left_variable) assert is_valid_rvalue(right_variable) or isinstance( right_variable, (Function, TupleVariable) @@ -25,7 +28,7 @@ class Assignment(OperationWithLValue): return list(self._variables) @property - def read(self): + def read(self) -> List[SourceMapping]: return [self.rvalue] @property @@ -33,7 +36,7 @@ class Assignment(OperationWithLValue): return self._variable_return_type @property - def rvalue(self): + def rvalue(self) -> SourceMapping: return self._rvalue def __str__(self): diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index d8844fa32..6b679bca0 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -7,6 +7,9 @@ from slither.slithir.exceptions import SlithIRError from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.variables import ReferenceVariable +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.variable import Variable +from typing import List logger = logging.getLogger("BinaryOperationIR") @@ -33,7 +36,7 @@ class BinaryType(Enum): OROR = 18 # || @staticmethod - def return_bool(operation_type): + def return_bool(operation_type: "BinaryType") -> bool: return operation_type in [ BinaryType.OROR, BinaryType.ANDAND, @@ -98,7 +101,7 @@ class BinaryType(Enum): BinaryType.DIVISION, ] - def __str__(self): # pylint: disable=too-many-branches + def __str__(self) -> str: # pylint: disable=too-many-branches if self == BinaryType.POWER: return "**" if self == BinaryType.MULTIPLICATION: @@ -141,7 +144,7 @@ class BinaryType(Enum): class Binary(OperationWithLValue): - def __init__(self, result, left_variable, right_variable, operation_type: BinaryType): + def __init__(self, result: Variable, left_variable: SourceMapping, right_variable: Variable, operation_type: BinaryType) -> None: assert is_valid_rvalue(left_variable) or isinstance(left_variable, Function) assert is_valid_rvalue(right_variable) or isinstance(right_variable, Function) assert is_valid_lvalue(result) @@ -156,7 +159,7 @@ class Binary(OperationWithLValue): result.set_type(left_variable.type) @property - def read(self): + def read(self) -> List[SourceMapping]: return [self.variable_left, self.variable_right] @property @@ -164,15 +167,15 @@ class Binary(OperationWithLValue): return self._variables @property - def variable_left(self): + def variable_left(self) -> SourceMapping: return self._variables[0] @property - def variable_right(self): + def variable_right(self) -> Variable: return self._variables[1] @property - def type(self): + def type(self) -> BinaryType: return self._type @property diff --git a/slither/slithir/operations/condition.py b/slither/slithir/operations/condition.py index 5ba959a73..a5ce929a4 100644 --- a/slither/slithir/operations/condition.py +++ b/slither/slithir/operations/condition.py @@ -1,6 +1,12 @@ from slither.slithir.operations.operation import Operation from slither.slithir.utils.utils import is_valid_rvalue +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from typing import List, Union class Condition(Operation): @@ -9,13 +15,13 @@ class Condition(Operation): Only present as last operation in conditional node """ - def __init__(self, value): + def __init__(self, value: Union[LocalVariable, TemporaryVariableSSA, TemporaryVariable, Constant, LocalIRVariable]) -> None: assert is_valid_rvalue(value) super().__init__() self._value = value @property - def read(self): + def read(self) -> List[Union[LocalIRVariable, Constant, LocalVariable, TemporaryVariableSSA, TemporaryVariable]]: return [self.value] @property diff --git a/slither/slithir/operations/delete.py b/slither/slithir/operations/delete.py index 4fb05b8f5..ed5aadaa8 100644 --- a/slither/slithir/operations/delete.py +++ b/slither/slithir/operations/delete.py @@ -1,6 +1,11 @@ from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue +from slither.core.variables.state_variable import StateVariable +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from slither.slithir.variables.state_variable import StateIRVariable +from typing import List, Union class Delete(OperationWithLValue): @@ -9,18 +14,18 @@ class Delete(OperationWithLValue): of its operand """ - def __init__(self, lvalue, variable): + def __init__(self, lvalue: Union[StateIRVariable, StateVariable, ReferenceVariable], variable: Union[StateIRVariable, StateVariable, ReferenceVariable, ReferenceVariableSSA]) -> None: assert is_valid_lvalue(variable) super().__init__() self._variable = variable self._lvalue = lvalue @property - def read(self): + def read(self) -> List[Union[StateIRVariable, ReferenceVariable, ReferenceVariableSSA, StateVariable]]: return [self.variable] @property - def variable(self): + def variable(self) -> Union[StateIRVariable, StateVariable, ReferenceVariable, ReferenceVariableSSA]: return self._variable def __str__(self): diff --git a/slither/slithir/operations/event_call.py b/slither/slithir/operations/event_call.py index 6ef846d4b..efaecae88 100644 --- a/slither/slithir/operations/event_call.py +++ b/slither/slithir/operations/event_call.py @@ -1,18 +1,20 @@ from slither.slithir.operations.call import Call +from slither.slithir.variables.constant import Constant +from typing import Any, List, Union class EventCall(Call): - def __init__(self, name): + def __init__(self, name: Union[str, Constant]) -> None: super().__init__() self._name = name # todo add instance of the Event @property - def name(self): + def name(self) -> Union[str, Constant]: return self._name @property - def read(self): + def read(self) -> List[Any]: return self._unroll(self.arguments) def __str__(self): diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index ff72a0899..8184b0188 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Optional, Union from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue @@ -8,6 +8,10 @@ from slither.core.declarations.function import Function from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.slithir.variables.tuple import TupleVariable class HighLevelCall(Call, OperationWithLValue): @@ -16,7 +20,7 @@ class HighLevelCall(Call, OperationWithLValue): """ # pylint: disable=too-many-arguments,too-many-instance-attributes - def __init__(self, destination, function_name, nbr_arguments, result, type_call): + def __init__(self, destination: SourceMapping, function_name: Constant, nbr_arguments: int, result: Optional[Union[TemporaryVariable, TupleVariable, TemporaryVariableSSA]], type_call: str) -> None: assert isinstance(function_name, Constant) assert is_valid_lvalue(result) or result is None self._check_destination(destination) @@ -34,7 +38,7 @@ class HighLevelCall(Call, OperationWithLValue): # Development function, to be removed once the code is stable # It is ovveride by LbraryCall - def _check_destination(self, destination): # pylint: disable=no-self-use + def _check_destination(self, destination: SourceMapping) -> None: # pylint: disable=no-self-use assert isinstance(destination, (Variable, SolidityVariable)) @property @@ -62,17 +66,17 @@ class HighLevelCall(Call, OperationWithLValue): self._call_gas = v @property - def read(self): + def read(self) -> List[SourceMapping]: all_read = [self.destination, self.call_gas, self.call_value] + self._unroll(self.arguments) # remove None return [x for x in all_read if x] + [self.destination] @property - def destination(self): + def destination(self) -> SourceMapping: return self._destination @property - def function_name(self): + def function_name(self) -> Constant: return self._function_name @property @@ -84,11 +88,11 @@ class HighLevelCall(Call, OperationWithLValue): self._function_instance = function @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call ################################################################################### diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py index 096cc7268..cef575d4a 100644 --- a/slither/slithir/operations/index.py +++ b/slither/slithir/operations/index.py @@ -2,10 +2,15 @@ from slither.core.declarations import SolidityVariableComposed from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.variables.reference import ReferenceVariable +from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.variable import Variable +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from typing import List, Union class Index(OperationWithLValue): - def __init__(self, result, left_variable, right_variable, index_type): + def __init__(self, result: Union[ReferenceVariable, ReferenceVariableSSA], left_variable: Variable, right_variable: SourceMapping, index_type: Union[ElementaryType, str]) -> None: super().__init__() assert is_valid_lvalue(left_variable) or left_variable == SolidityVariableComposed( "msg.data" @@ -17,23 +22,23 @@ class Index(OperationWithLValue): self._lvalue = result @property - def read(self): + def read(self) -> List[SourceMapping]: return list(self.variables) @property - def variables(self): + def variables(self) -> List[SourceMapping]: return self._variables @property - def variable_left(self): + def variable_left(self) -> Variable: return self._variables[0] @property - def variable_right(self): + def variable_right(self) -> SourceMapping: return self._variables[1] @property - def index_type(self): + def index_type(self) -> Union[ElementaryType, str]: return self._type def __str__(self): diff --git a/slither/slithir/operations/init_array.py b/slither/slithir/operations/init_array.py index 23a95ebe7..c16ba4510 100644 --- a/slither/slithir/operations/init_array.py +++ b/slither/slithir/operations/init_array.py @@ -1,9 +1,13 @@ from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_rvalue +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from typing import List, Union class InitArray(OperationWithLValue): - def __init__(self, init_values, lvalue): + def __init__(self, init_values: List[Constant], lvalue: Union[TemporaryVariableSSA, TemporaryVariable]) -> None: # init_values can be an array of n dimension # reduce was removed in py3 super().__init__() @@ -24,11 +28,11 @@ class InitArray(OperationWithLValue): self._lvalue = lvalue @property - def read(self): + def read(self) -> List[Constant]: return self._unroll(self.init_values) @property - def init_values(self): + def init_values(self) -> List[Constant]: return list(self._init_values) def __str__(self): diff --git a/slither/slithir/operations/internal_call.py b/slither/slithir/operations/internal_call.py index c096bdceb..2f423ecaf 100644 --- a/slither/slithir/operations/internal_call.py +++ b/slither/slithir/operations/internal_call.py @@ -1,15 +1,20 @@ -from typing import Union, Tuple, List, Optional +from typing import Any, Union, Tuple, List, Optional from slither.core.declarations import Modifier from slither.core.declarations.function import Function from slither.core.declarations.function_contract import FunctionContract from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.slithir.variables.tuple import TupleVariable +from slither.slithir.variables.tuple_ssa import TupleVariableSSA class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes def __init__( - self, function: Union[Function, Tuple[str, str]], nbr_arguments, result, type_call - ): + self, function: Union[Function, Tuple[str, str]], nbr_arguments: int, result: Optional[Union[TupleVariableSSA, TemporaryVariableSSA, TupleVariable, TemporaryVariable]], type_call: str + ) -> None: super().__init__() self._contract_name = "" if isinstance(function, Function): @@ -30,7 +35,7 @@ class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-insta self.function_candidates: Optional[List[Function]] = None @property - def read(self): + def read(self) -> List[Any]: return list(self._unroll(self.arguments)) @property @@ -42,19 +47,19 @@ class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-insta self._function = f @property - def function_name(self): + def function_name(self) -> Constant: return self._function_name @property - def contract_name(self): + def contract_name(self) -> str: return self._contract_name @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call @property diff --git a/slither/slithir/operations/internal_dynamic_call.py b/slither/slithir/operations/internal_dynamic_call.py index 39a8ae6a1..5c8884ea1 100644 --- a/slither/slithir/operations/internal_dynamic_call.py +++ b/slither/slithir/operations/internal_dynamic_call.py @@ -3,12 +3,19 @@ from slither.core.variables.variable import Variable from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue +import slither.core.solidity_types.function_type +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from typing import List, Optional, Union class InternalDynamicCall( Call, OperationWithLValue ): # pylint: disable=too-many-instance-attributes - def __init__(self, lvalue, function, function_type): + def __init__(self, lvalue: Optional[Union[TemporaryVariableSSA, TemporaryVariable]], function: Union[LocalVariable, LocalIRVariable], function_type: slither.core.solidity_types.function_type.FunctionType) -> None: assert isinstance(function_type, FunctionType) assert isinstance(function, Variable) assert is_valid_lvalue(lvalue) or lvalue is None @@ -22,15 +29,15 @@ class InternalDynamicCall( self._call_gas = None @property - def read(self): + def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return self._unroll(self.arguments) + [self.function] @property - def function(self): + def function(self) -> Union[LocalVariable, LocalIRVariable]: return self._function @property - def function_type(self): + def function_type(self) -> slither.core.solidity_types.function_type.FunctionType: return self._function_type @property diff --git a/slither/slithir/operations/length.py b/slither/slithir/operations/length.py index 9ba33e655..d09eadf1f 100644 --- a/slither/slithir/operations/length.py +++ b/slither/slithir/operations/length.py @@ -1,10 +1,17 @@ from slither.core.solidity_types import ElementaryType from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.core.variables.local_variable import LocalVariable +from slither.core.variables.state_variable import StateVariable +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from slither.slithir.variables.state_variable import StateIRVariable +from typing import List, Union class Length(OperationWithLValue): - def __init__(self, value, lvalue): + def __init__(self, value: Union[StateVariable, LocalIRVariable, LocalVariable, StateIRVariable], lvalue: Union[ReferenceVariable, ReferenceVariableSSA]) -> None: super().__init__() assert is_valid_rvalue(value) assert is_valid_lvalue(lvalue) @@ -13,11 +20,11 @@ class Length(OperationWithLValue): lvalue.set_type(ElementaryType("uint256")) @property - def read(self): + def read(self) -> List[Union[LocalVariable, StateVariable, LocalIRVariable, StateIRVariable]]: return [self._value] @property - def value(self): + def value(self) -> Union[StateVariable, LocalVariable]: return self._value def __str__(self): diff --git a/slither/slithir/operations/library_call.py b/slither/slithir/operations/library_call.py index 1f1f507a6..105187ad7 100644 --- a/slither/slithir/operations/library_call.py +++ b/slither/slithir/operations/library_call.py @@ -9,7 +9,7 @@ class LibraryCall(HighLevelCall): """ # Development function, to be removed once the code is stable - def _check_destination(self, destination): + def _check_destination(self, destination: Contract) -> None: assert isinstance(destination, Contract) def can_reenter(self, callstack=None): diff --git a/slither/slithir/operations/low_level_call.py b/slither/slithir/operations/low_level_call.py index 83bbbb336..d9370bf84 100644 --- a/slither/slithir/operations/low_level_call.py +++ b/slither/slithir/operations/low_level_call.py @@ -4,6 +4,13 @@ from slither.core.variables.variable import Variable from slither.core.declarations.solidity_variables import SolidityVariable from slither.slithir.variables.constant import Constant +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from slither.slithir.variables.tuple import TupleVariable +from slither.slithir.variables.tuple_ssa import TupleVariableSSA +from typing import List, Union class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes @@ -11,7 +18,7 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta High level message call """ - def __init__(self, destination, function_name, nbr_arguments, result, type_call): + def __init__(self, destination: Union[LocalVariable, LocalIRVariable, TemporaryVariableSSA, TemporaryVariable], function_name: Constant, nbr_arguments: int, result: Union[TupleVariable, TupleVariableSSA], type_call: str) -> None: # pylint: disable=too-many-arguments assert isinstance(destination, (Variable, SolidityVariable)) assert isinstance(function_name, Constant) @@ -51,7 +58,7 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta self._call_gas = v @property - def read(self): + def read(self) -> List[Union[LocalIRVariable, Constant, LocalVariable, TemporaryVariableSSA, TemporaryVariable]]: all_read = [self.destination, self.call_gas, self.call_value] + self.arguments # remove None return self._unroll([x for x in all_read if x]) @@ -73,19 +80,19 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta return self._call_value is not None @property - def destination(self): + def destination(self) -> Union[LocalVariable, LocalIRVariable, TemporaryVariableSSA, TemporaryVariable]: return self._destination @property - def function_name(self): + def function_name(self) -> Constant: return self._function_name @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call def __str__(self): diff --git a/slither/slithir/operations/lvalue.py b/slither/slithir/operations/lvalue.py index 4571b9f1f..482385432 100644 --- a/slither/slithir/operations/lvalue.py +++ b/slither/slithir/operations/lvalue.py @@ -6,7 +6,7 @@ class OperationWithLValue(Operation): Operation with a lvalue """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._lvalue = None diff --git a/slither/slithir/operations/member.py b/slither/slithir/operations/member.py index f0c6ae523..799cf4317 100644 --- a/slither/slithir/operations/member.py +++ b/slither/slithir/operations/member.py @@ -7,10 +7,13 @@ from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_rvalue from slither.slithir.variables.constant import Constant from slither.slithir.variables.reference import ReferenceVariable +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from typing import List, Union class Member(OperationWithLValue): - def __init__(self, variable_left, variable_right, result): + def __init__(self, variable_left: SourceMapping, variable_right: Constant, result: Union[ReferenceVariable, ReferenceVariableSSA]) -> None: # Function can happen for something like # library FunctionExtensions { # function h(function() internal _t, uint8) internal { } @@ -38,15 +41,15 @@ class Member(OperationWithLValue): self._value = None @property - def read(self): + def read(self) -> List[SourceMapping]: return [self.variable_left, self.variable_right] @property - def variable_left(self): + def variable_left(self) -> SourceMapping: return self._variable_left @property - def variable_right(self): + def variable_right(self) -> Constant: return self._variable_right @property diff --git a/slither/slithir/operations/new_array.py b/slither/slithir/operations/new_array.py index 57ee6dcf0..bbbda5bb7 100644 --- a/slither/slithir/operations/new_array.py +++ b/slither/slithir/operations/new_array.py @@ -1,10 +1,18 @@ +from typing import List, Union, TYPE_CHECKING from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.operations.call import Call from slither.core.solidity_types.type import Type +if TYPE_CHECKING: + from slither.core.solidity_types.type_alias import TypeAliasTopLevel + from slither.slithir.variables.constant import Constant + from slither.slithir.variables.temporary import TemporaryVariable + from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA + class NewArray(Call, OperationWithLValue): - def __init__(self, depth, array_type, lvalue): + def __init__(self, depth: int, array_type: "TypeAliasTopLevel", + lvalue: Union["TemporaryVariableSSA", "TemporaryVariable"]) -> None: super().__init__() assert isinstance(array_type, Type) self._depth = depth @@ -13,15 +21,15 @@ class NewArray(Call, OperationWithLValue): self._lvalue = lvalue @property - def array_type(self): + def array_type(self) -> "TypeAliasTopLevel": return self._array_type @property - def read(self): + def read(self) -> List["Constant"]: return self._unroll(self.arguments) @property - def depth(self): + def depth(self) -> int: return self._depth def __str__(self): diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py index fb0014c76..91a48ec15 100644 --- a/slither/slithir/operations/new_contract.py +++ b/slither/slithir/operations/new_contract.py @@ -1,10 +1,14 @@ from slither.slithir.operations import Call, OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant +from slither.core.declarations.contract import Contract +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from typing import Any, List, Union class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes - def __init__(self, contract_name, lvalue): + def __init__(self, contract_name: Constant, lvalue: Union[TemporaryVariableSSA, TemporaryVariable]) -> None: assert isinstance(contract_name, Constant) assert is_valid_lvalue(lvalue) super().__init__() @@ -40,15 +44,15 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan self._call_salt = s @property - def contract_name(self): + def contract_name(self) -> Constant: return self._contract_name @property - def read(self): + def read(self) -> List[Any]: return self._unroll(self.arguments) @property - def contract_created(self): + def contract_created(self) -> Contract: contract_name = self.contract_name contract_instance = self.node.file_scope.get_contract_from_name(contract_name) return contract_instance diff --git a/slither/slithir/operations/new_structure.py b/slither/slithir/operations/new_structure.py index 16a8af785..9287a3f7f 100644 --- a/slither/slithir/operations/new_structure.py +++ b/slither/slithir/operations/new_structure.py @@ -4,10 +4,15 @@ from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.core.declarations.structure import Structure +from slither.core.declarations.structure_contract import StructureContract +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from typing import List, Union class NewStructure(Call, OperationWithLValue): - def __init__(self, structure, lvalue): + def __init__(self, structure: StructureContract, lvalue: Union[TemporaryVariableSSA, TemporaryVariable]) -> None: super().__init__() assert isinstance(structure, Structure) assert is_valid_lvalue(lvalue) @@ -16,11 +21,11 @@ class NewStructure(Call, OperationWithLValue): self._lvalue = lvalue @property - def read(self): + def read(self) -> List[Union[TemporaryVariableSSA, TemporaryVariable, Constant]]: return self._unroll(self.arguments) @property - def structure(self): + def structure(self) -> StructureContract: return self._structure @property diff --git a/slither/slithir/operations/operation.py b/slither/slithir/operations/operation.py index fa1db89c2..b28cf5f3b 100644 --- a/slither/slithir/operations/operation.py +++ b/slither/slithir/operations/operation.py @@ -3,6 +3,7 @@ from slither.core.context.context import Context from slither.core.children.child_expression import ChildExpression from slither.core.children.child_node import ChildNode from slither.utils.utils import unroll +from typing import Any, List class AbstractOperation(abc.ABC): @@ -33,5 +34,5 @@ class Operation(Context, ChildExpression, ChildNode, AbstractOperation): # if array inside the parameters @staticmethod - def _unroll(l): + def _unroll(l: List[Any]) -> List[Any]: return unroll(l) diff --git a/slither/slithir/operations/phi.py b/slither/slithir/operations/phi.py index a4fa0217e..211a71474 100644 --- a/slither/slithir/operations/phi.py +++ b/slither/slithir/operations/phi.py @@ -1,9 +1,16 @@ +from typing import List, Set, Union, TYPE_CHECKING from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue +from slither.core.declarations.solidity_variables import SolidityVariableComposed +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.state_variable import StateIRVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +if TYPE_CHECKING: + from slither.core.cfg.node import Node class Phi(OperationWithLValue): - def __init__(self, left_variable, nodes): + def __init__(self, left_variable: Union[LocalIRVariable, StateIRVariable], nodes: Set["Node" ]) -> None: # When Phi operations are created the # correct indexes of the variables are not yet computed # We store the nodes where the variables are written @@ -17,7 +24,7 @@ class Phi(OperationWithLValue): self._nodes = nodes @property - def read(self): + def read(self) -> List[Union[SolidityVariableComposed, LocalIRVariable, TemporaryVariableSSA, StateIRVariable]]: return self.rvalues @property @@ -29,7 +36,7 @@ class Phi(OperationWithLValue): self._rvalues = vals @property - def nodes(self): + def nodes(self) -> Set["Node"]: return self._nodes def __str__(self): diff --git a/slither/slithir/operations/phi_callback.py b/slither/slithir/operations/phi_callback.py index 486015b0c..262230d61 100644 --- a/slither/slithir/operations/phi_callback.py +++ b/slither/slithir/operations/phi_callback.py @@ -1,9 +1,17 @@ +from typing import List, Set, Union, TYPE_CHECKING + from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.operations.phi import Phi +from slither.slithir.operations.high_level_call import HighLevelCall +from slither.slithir.operations.internal_call import InternalCall +from slither.slithir.variables.state_variable import StateIRVariable + +if TYPE_CHECKING: + from slither.core.cfg.node import Node class PhiCallback(Phi): - def __init__(self, left_variable, nodes, call_ir, rvalue): + def __init__(self, left_variable: StateIRVariable, nodes: Set["Node"], call_ir: Union[InternalCall, HighLevelCall], rvalue: StateIRVariable) -> None: assert is_valid_lvalue(left_variable) assert isinstance(nodes, set) super().__init__(left_variable, nodes) @@ -12,11 +20,11 @@ class PhiCallback(Phi): self._rvalue_no_callback = rvalue @property - def callee_ir(self): + def callee_ir(self) -> Union[InternalCall, HighLevelCall]: return self._call_ir @property - def read(self): + def read(self) -> List[StateIRVariable]: return self.rvalues @property diff --git a/slither/slithir/operations/return_operation.py b/slither/slithir/operations/return_operation.py index c1ccf47d1..234635c84 100644 --- a/slither/slithir/operations/return_operation.py +++ b/slither/slithir/operations/return_operation.py @@ -3,6 +3,8 @@ from slither.slithir.operations.operation import Operation from slither.slithir.variables.tuple import TupleVariable from slither.slithir.utils.utils import is_valid_rvalue +from slither.core.variables.variable import Variable +from typing import List class Return(Operation): @@ -11,7 +13,7 @@ class Return(Operation): Only present as last operation in RETURN node """ - def __init__(self, values): + def __init__(self, values) -> None: # Note: Can return None # ex: return call() # where call() dont return @@ -35,7 +37,7 @@ class Return(Operation): super().__init__() self._values = values - def _valid_value(self, value): + def _valid_value(self, value) -> bool: if isinstance(value, list): assert all(self._valid_value(v) for v in value) else: @@ -43,11 +45,11 @@ class Return(Operation): return True @property - def read(self): + def read(self) -> List[Variable]: return self._unroll(self.values) @property - def values(self): + def values(self) -> List[Variable]: return self._unroll(self._values) def __str__(self): diff --git a/slither/slithir/operations/send.py b/slither/slithir/operations/send.py index 27b0c7021..ae2a7b9af 100644 --- a/slither/slithir/operations/send.py +++ b/slither/slithir/operations/send.py @@ -3,10 +3,16 @@ from slither.core.variables.variable import Variable from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from typing import List, Union class Send(Call, OperationWithLValue): - def __init__(self, destination, value, result): + def __init__(self, destination: Union[LocalVariable, LocalIRVariable], value: Constant, result: Union[TemporaryVariable, TemporaryVariableSSA]) -> None: assert is_valid_lvalue(result) assert isinstance(destination, (Variable, SolidityVariable)) super().__init__() @@ -19,15 +25,15 @@ class Send(Call, OperationWithLValue): return True @property - def call_value(self): + def call_value(self) -> Constant: return self._call_value @property - def read(self): + def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return [self.destination, self.call_value] @property - def destination(self): + def destination(self) -> Union[LocalVariable, LocalIRVariable]: return self._destination def __str__(self): diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py index 628a527fe..aab901872 100644 --- a/slither/slithir/operations/solidity_call.py +++ b/slither/slithir/operations/solidity_call.py @@ -1,10 +1,13 @@ -from slither.core.declarations.solidity_variables import SolidityFunction +from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.children.child_node import ChildNode +from slither.core.solidity_types.elementary_type import ElementaryType +from typing import Any, List, Union class SolidityCall(Call, OperationWithLValue): - def __init__(self, function, nbr_arguments, result, type_call): + def __init__(self, function: Union[SolidityCustomRevert, SolidityFunction], nbr_arguments: int, result: ChildNode, type_call: Union[str, List[ElementaryType]]) -> None: assert isinstance(function, SolidityFunction) super().__init__() self._function = function @@ -13,19 +16,19 @@ class SolidityCall(Call, OperationWithLValue): self._lvalue = result @property - def read(self): + def read(self) -> List[Any]: return self._unroll(self.arguments) @property - def function(self): + def function(self) -> Union[SolidityCustomRevert, SolidityFunction]: return self._function @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> Union[str, List[ElementaryType]]: return self._type_call def __str__(self): diff --git a/slither/slithir/operations/transfer.py b/slither/slithir/operations/transfer.py index 6cfc58f07..8f638b9f8 100644 --- a/slither/slithir/operations/transfer.py +++ b/slither/slithir/operations/transfer.py @@ -1,10 +1,14 @@ from slither.slithir.operations.call import Call from slither.core.variables.variable import Variable from slither.core.declarations.solidity_variables import SolidityVariable +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from typing import List, Union class Transfer(Call): - def __init__(self, destination, value): + def __init__(self, destination: Union[LocalVariable, LocalIRVariable], value: Constant) -> None: assert isinstance(destination, (Variable, SolidityVariable)) self._destination = destination super().__init__() @@ -15,15 +19,15 @@ class Transfer(Call): return True @property - def call_value(self): + def call_value(self) -> Constant: return self._call_value @property - def read(self): + def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return [self.destination, self.call_value] @property - def destination(self): + def destination(self) -> Union[LocalVariable, LocalIRVariable]: return self._destination def __str__(self): diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py index feee46a2c..629bacfd8 100644 --- a/slither/slithir/operations/type_conversion.py +++ b/slither/slithir/operations/type_conversion.py @@ -2,10 +2,18 @@ from slither.core.declarations import Contract from slither.core.solidity_types.type import Type from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +import slither.core.declarations.contract +from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.solidity_types.type_alias import TypeAliasContract, TypeAliasTopLevel +from slither.core.solidity_types.user_defined_type import UserDefinedType +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from typing import List, Union class TypeConversion(OperationWithLValue): - def __init__(self, result, variable, variable_type): + def __init__(self, result: Union[TemporaryVariableSSA, TemporaryVariable], variable: SourceMapping, variable_type: Union[TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel]) -> None: super().__init__() assert is_valid_rvalue(variable) or isinstance(variable, Contract) assert is_valid_lvalue(result) @@ -16,15 +24,15 @@ class TypeConversion(OperationWithLValue): self._lvalue = result @property - def variable(self): + def variable(self) -> SourceMapping: return self._variable @property - def type(self): + def type(self) -> Union[TypeAliasContract, TypeAliasTopLevel, slither.core.declarations.contract.Contract, UserDefinedType, ElementaryType]: return self._type @property - def read(self): + def read(self) -> List[SourceMapping]: return [self.variable] def __str__(self): diff --git a/slither/slithir/operations/unary.py b/slither/slithir/operations/unary.py index 7df55ad75..b09a64deb 100644 --- a/slither/slithir/operations/unary.py +++ b/slither/slithir/operations/unary.py @@ -4,6 +4,13 @@ from enum import Enum from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.exceptions import SlithIRError +from slither.core.expressions.unary_operation import UnaryOperationType +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.variables.constant import Constant +from slither.slithir.variables.local_variable import LocalIRVariable +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA +from typing import List, Union logger = logging.getLogger("BinaryOperationIR") @@ -31,7 +38,7 @@ class UnaryType(Enum): class Unary(OperationWithLValue): - def __init__(self, result, variable, operation_type): + def __init__(self, result: Union[TemporaryVariableSSA, TemporaryVariable], variable: Union[Constant, LocalIRVariable, LocalVariable], operation_type: UnaryOperationType) -> None: assert is_valid_rvalue(variable) assert is_valid_lvalue(result) super().__init__() @@ -40,15 +47,15 @@ class Unary(OperationWithLValue): self._lvalue = result @property - def read(self): + def read(self) -> List[Union[Constant, LocalIRVariable, LocalVariable]]: return [self._variable] @property - def rvalue(self): + def rvalue(self) -> Union[Constant, LocalVariable]: return self._variable @property - def type(self): + def type(self) -> UnaryOperationType: return self._type @property diff --git a/slither/slithir/tmp_operations/argument.py b/slither/slithir/tmp_operations/argument.py index 746bc13f2..25ea5d019 100644 --- a/slither/slithir/tmp_operations/argument.py +++ b/slither/slithir/tmp_operations/argument.py @@ -10,7 +10,7 @@ class ArgumentType(Enum): class Argument(Operation): - def __init__(self, argument): + def __init__(self, argument) -> None: super().__init__() self._argument = argument self._type = ArgumentType.CALL @@ -32,11 +32,11 @@ class Argument(Operation): def read(self): return [self.argument] - def set_type(self, t): + def set_type(self, t: ArgumentType) -> None: assert isinstance(t, ArgumentType) self._type = t - def get_type(self): + def get_type(self) -> ArgumentType: return self._type def __str__(self): diff --git a/slither/slithir/tmp_operations/tmp_call.py b/slither/slithir/tmp_operations/tmp_call.py index fb6641139..ace16c098 100644 --- a/slither/slithir/tmp_operations/tmp_call.py +++ b/slither/slithir/tmp_operations/tmp_call.py @@ -8,10 +8,17 @@ from slither.core.declarations import ( from slither.core.declarations.custom_error import CustomError from slither.core.variables.variable import Variable from slither.slithir.operations.lvalue import OperationWithLValue +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.operations.member import Member +from slither.slithir.tmp_operations.tmp_new_array import TmpNewArray +from slither.slithir.tmp_operations.tmp_new_contract import TmpNewContract +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.tuple import TupleVariable +from typing import Optional, Union class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attributes - def __init__(self, called, nbr_arguments, result, type_call): + def __init__(self, called: SourceMapping, nbr_arguments: int, result: Union[TupleVariable, TemporaryVariable], type_call: str) -> None: assert isinstance( called, ( @@ -80,18 +87,18 @@ class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attribu self._called = c @property - def nbr_arguments(self): + def nbr_arguments(self) -> int: return self._nbr_arguments @property - def type_call(self): + def type_call(self) -> str: return self._type_call @property - def ori(self): + def ori(self) -> Optional[Union[TmpNewContract, TmpNewArray, "TmpCall", Member]]: return self._ori - def set_ori(self, ori): + def set_ori(self, ori: Union["TmpCall", TmpNewContract, TmpNewArray, Member]) -> None: self._ori = ori def __str__(self): diff --git a/slither/slithir/tmp_operations/tmp_new_array.py b/slither/slithir/tmp_operations/tmp_new_array.py index 0da9c54eb..cff256e10 100644 --- a/slither/slithir/tmp_operations/tmp_new_array.py +++ b/slither/slithir/tmp_operations/tmp_new_array.py @@ -1,9 +1,13 @@ from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.solidity_types.type import Type +from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.solidity_types.type_alias import TypeAliasTopLevel +from slither.slithir.variables.temporary import TemporaryVariable +from typing import Union class TmpNewArray(OperationWithLValue): - def __init__(self, depth, array_type, lvalue): + def __init__(self, depth: int, array_type: Union[TypeAliasTopLevel, ElementaryType], lvalue: TemporaryVariable) -> None: super().__init__() assert isinstance(array_type, Type) self._depth = depth @@ -11,7 +15,7 @@ class TmpNewArray(OperationWithLValue): self._lvalue = lvalue @property - def array_type(self): + def array_type(self) -> TypeAliasTopLevel: return self._array_type @property @@ -19,7 +23,7 @@ class TmpNewArray(OperationWithLValue): return [] @property - def depth(self): + def depth(self) -> int: return self._depth def __str__(self): diff --git a/slither/slithir/tmp_operations/tmp_new_contract.py b/slither/slithir/tmp_operations/tmp_new_contract.py index 337a40bd1..72c09cc00 100644 --- a/slither/slithir/tmp_operations/tmp_new_contract.py +++ b/slither/slithir/tmp_operations/tmp_new_contract.py @@ -1,8 +1,9 @@ from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.variables.temporary import TemporaryVariable class TmpNewContract(OperationWithLValue): - def __init__(self, contract_name, lvalue): + def __init__(self, contract_name: str, lvalue: TemporaryVariable) -> None: super().__init__() self._contract_name = contract_name self._lvalue = lvalue diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 322583381..73a2aa153 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -1,6 +1,6 @@ import logging -from slither.core.cfg.node import NodeType +from slither.core.cfg.node import Node, NodeType from slither.core.declarations import ( Contract, Enum, @@ -58,6 +58,21 @@ from slither.slithir.variables import ( TupleVariableSSA, ) from slither.slithir.exceptions import SlithIRError +import slither.slithir.operations.init_array +import slither.slithir.operations.new_array +import slither.slithir.operations.return_operation +import slither.slithir.variables.local_variable +import slither.slithir.variables.reference_ssa +import slither.slithir.variables.state_variable +import slither.slithir.variables.temporary_ssa +import slither.slithir.variables.tuple_ssa +from slither.core.declarations.function_contract import FunctionContract +from slither.core.declarations.function_top_level import FunctionTopLevel +from slither.core.declarations.modifier import Modifier +from slither.core.variables.variable import Variable +from slither.slithir.operations.call import Call +from slither.slithir.operations.operation import Operation +from typing import Any, Callable, Dict, List, Tuple, Union logger = logging.getLogger("SSA_Conversion") @@ -68,7 +83,7 @@ logger = logging.getLogger("SSA_Conversion") ################################################################################### -def transform_slithir_vars_to_ssa(function): +def transform_slithir_vars_to_ssa(function: Union[FunctionContract, Modifier, FunctionTopLevel]) -> None: """ Transform slithIR vars to SSA (TemporaryVariable, ReferenceVariable, TupleVariable) """ @@ -98,7 +113,7 @@ def transform_slithir_vars_to_ssa(function): # pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks,too-many-statements,too-many-branches -def add_ssa_ir(function, all_state_variables_instances): +def add_ssa_ir(function: Union[FunctionContract, Modifier, FunctionTopLevel], all_state_variables_instances: Dict[str, slither.slithir.variables.state_variable.StateIRVariable]) -> None: """ Add SSA version of the IR Args: @@ -199,14 +214,14 @@ def add_ssa_ir(function, all_state_variables_instances): def generate_ssa_irs( - node, - local_variables_instances, - all_local_variables_instances, - state_variables_instances, - all_state_variables_instances, - init_local_variables_instances, - visited, -): + node: Node, + local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], + all_local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], + state_variables_instances: Dict[str, slither.slithir.variables.state_variable.StateIRVariable], + all_state_variables_instances: Dict[str, slither.slithir.variables.state_variable.StateIRVariable], + init_local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], + visited: List[Union[Node, Any]], +) -> None: if node in visited: return @@ -323,7 +338,7 @@ def generate_ssa_irs( ################################################################################### -def last_name(n, var, init_vars): +def last_name(n: Node, var: Union[slither.slithir.variables.state_variable.StateIRVariable, slither.slithir.variables.local_variable.LocalIRVariable], init_vars: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable]) -> Union[slither.slithir.variables.state_variable.StateIRVariable, slither.slithir.variables.local_variable.LocalIRVariable]: candidates = [] # Todo optimize by creating a variables_ssa_written attribute for ir_ssa in n.irs_ssa: @@ -342,7 +357,7 @@ def last_name(n, var, init_vars): return max(candidates, key=lambda v: v.index) -def is_used_later(initial_node, variable): +def is_used_later(initial_node: Node, variable: Union[slither.slithir.variables.state_variable.StateIRVariable, LocalVariable]) -> bool: # TODO: does not handle the case where its read and written in the declaration node # It can be problematic if this happens in a loop/if structure # Ex: @@ -389,13 +404,13 @@ def is_used_later(initial_node, variable): def update_lvalue( - new_ir, - node, - local_variables_instances, - all_local_variables_instances, - state_variables_instances, - all_state_variables_instances, -): + new_ir: Operation, + node: Node, + local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], + all_local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], + state_variables_instances: Dict[str, slither.slithir.variables.state_variable.StateIRVariable], + all_state_variables_instances: Dict[str, slither.slithir.variables.state_variable.StateIRVariable], +) -> None: if isinstance(new_ir, OperationWithLValue): lvalue = new_ir.lvalue update_through_ref = False @@ -437,8 +452,8 @@ def update_lvalue( def initiate_all_local_variables_instances( - nodes, local_variables_instances, all_local_variables_instances -): + nodes: List[Node], local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], all_local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable] +) -> None: for node in nodes: if node.variable_declaration: new_var = LocalIRVariable(node.variable_declaration) @@ -457,13 +472,13 @@ def initiate_all_local_variables_instances( def fix_phi_rvalues_and_storage_ref( - node, - local_variables_instances, - all_local_variables_instances, - state_variables_instances, - all_state_variables_instances, - init_local_variables_instances, -): + node: Node, + local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], + all_local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], + state_variables_instances: Dict[str, slither.slithir.variables.state_variable.StateIRVariable], + all_state_variables_instances: Dict[str, slither.slithir.variables.state_variable.StateIRVariable], + init_local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], +) -> None: for ir in node.irs_ssa: if isinstance(ir, (Phi)) and not ir.rvalues: variables = [ @@ -506,7 +521,7 @@ def fix_phi_rvalues_and_storage_ref( ) -def add_phi_origins(node, local_variables_definition, state_variables_definition): +def add_phi_origins(node: Node, local_variables_definition: Dict[str, Tuple[LocalVariable, Node]], state_variables_definition: Dict[str, Tuple[StateVariable, Node]]) -> None: # Add new key to local_variables_definition # The key is the variable_name @@ -557,12 +572,12 @@ def add_phi_origins(node, local_variables_definition, state_variables_definition def get( variable, - local_variables_instances, - state_variables_instances, - temporary_variables_instances, - reference_variables_instances, - tuple_variables_instances, - all_local_variables_instances, + local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], + state_variables_instances: Dict[str, slither.slithir.variables.state_variable.StateIRVariable], + temporary_variables_instances: Dict[int, slither.slithir.variables.temporary_ssa.TemporaryVariableSSA], + reference_variables_instances: Dict[int, slither.slithir.variables.reference_ssa.ReferenceVariableSSA], + tuple_variables_instances: Dict[int, slither.slithir.variables.tuple_ssa.TupleVariableSSA], + all_local_variables_instances: Dict[str, slither.slithir.variables.local_variable.LocalIRVariable], ): # variable can be None # for example, on LowLevelCall, ir.lvalue can be none @@ -623,14 +638,14 @@ def get( return variable -def get_variable(ir, f, *instances): +def get_variable(ir: Operation, f: Callable, *instances): # pylint: disable=no-value-for-parameter variable = f(ir) variable = get(variable, *instances) return variable -def _get_traversal(values, *instances): +def _get_traversal(values: List[Any], *instances) -> List[Any]: ret = [] # pylint: disable=no-value-for-parameter for v in values: @@ -642,11 +657,11 @@ def _get_traversal(values, *instances): return ret -def get_arguments(ir, *instances): +def get_arguments(ir: Call, *instances) -> List[Any]: return _get_traversal(ir.arguments, *instances) -def get_rec_values(ir, f, *instances): +def get_rec_values(ir: Union[slither.slithir.operations.init_array.InitArray, slither.slithir.operations.return_operation.Return, slither.slithir.operations.new_array.NewArray], f: Callable, *instances) -> List[Variable]: # Use by InitArray and NewArray # Potential recursive array(s) ori_init_values = f(ir) @@ -654,7 +669,7 @@ def get_rec_values(ir, f, *instances): return _get_traversal(ori_init_values, *instances) -def copy_ir(ir, *instances): +def copy_ir(ir: Operation, *instances) -> Operation: """ Args: ir (Operation) diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index 7bebc0a80..0a50f8e50 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -8,9 +8,10 @@ from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.constant import Constant from slither.slithir.variables.reference import ReferenceVariable from slither.slithir.variables.tuple import TupleVariable +from slither.core.source_mapping.source_mapping import SourceMapping -def is_valid_rvalue(v): +def is_valid_rvalue(v: SourceMapping) -> bool: return isinstance( v, ( @@ -25,7 +26,7 @@ def is_valid_rvalue(v): ) -def is_valid_lvalue(v): +def is_valid_lvalue(v) -> bool: return isinstance( v, ( diff --git a/slither/slithir/variables/constant.py b/slither/slithir/variables/constant.py index 5da2d9cc0..147ad8115 100644 --- a/slither/slithir/variables/constant.py +++ b/slither/slithir/variables/constant.py @@ -4,13 +4,14 @@ from slither.core.solidity_types.elementary_type import ElementaryType, Int, Uin from slither.slithir.variables.variable import SlithIRVariable from slither.utils.arithmetic import convert_subdenomination from slither.utils.integer_conversion import convert_string_to_int +from typing import Optional, Union @total_ordering class Constant(SlithIRVariable): def __init__( - self, val, constant_type=None, subdenomination=None - ): # pylint: disable=too-many-branches + self, val: str, constant_type: Optional[ElementaryType]=None, subdenomination: Optional[str]=None + ) -> None: # pylint: disable=too-many-branches super().__init__() assert isinstance(val, str) @@ -38,7 +39,7 @@ class Constant(SlithIRVariable): self._val = val @property - def value(self): + def value(self) -> Union[bool, int, str]: """ Return the value. If the expression was an hexadecimal delcared as hex'...' @@ -56,14 +57,14 @@ class Constant(SlithIRVariable): """ return self._original_value - def __str__(self): + def __str__(self) -> str: return str(self.value) @property - def name(self): + def name(self) -> str: return str(self) - def __eq__(self, other): + def __eq__(self, other: Union["Constant", str]) -> bool: return self.value == other def __ne__(self, other): diff --git a/slither/slithir/variables/local_variable.py b/slither/slithir/variables/local_variable.py index b78500f14..627629bc5 100644 --- a/slither/slithir/variables/local_variable.py +++ b/slither/slithir/variables/local_variable.py @@ -1,12 +1,13 @@ from slither.core.variables.local_variable import LocalVariable from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.variable import SlithIRVariable +from slither.slithir.variables.state_variable import StateIRVariable class LocalIRVariable( LocalVariable, SlithIRVariable ): # pylint: disable=too-many-instance-attributes - def __init__(self, local_variable): + def __init__(self, local_variable: LocalVariable) -> None: assert isinstance(local_variable, LocalVariable) super().__init__() @@ -57,10 +58,10 @@ class LocalIRVariable( self._refers_to = variables @property - def non_ssa_version(self): + def non_ssa_version(self) -> LocalVariable: return self._non_ssa_version - def add_refers_to(self, variable): + def add_refers_to(self, variable: StateIRVariable) -> None: # It is a temporaryVariable if its the return of a new .. # ex: string[] memory dynargs = new string[](1); assert isinstance(variable, (SlithIRVariable, TemporaryVariable)) diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py index 100886f2f..64c42a24d 100644 --- a/slither/slithir/variables/reference.py +++ b/slither/slithir/variables/reference.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from slither.core.children.child_node import ChildNode from slither.core.declarations import Contract, Enum, SolidityVariable, Function @@ -9,7 +9,7 @@ if TYPE_CHECKING: class ReferenceVariable(ChildNode, Variable): - def __init__(self, node: "Node", index=None): + def __init__(self, node: "Node", index: Optional[int]=None) -> None: super().__init__() if index is None: self._index = node.compilation_unit.counter_slithir_reference @@ -56,17 +56,17 @@ class ReferenceVariable(ChildNode, Variable): self._points_to = points_to @property - def name(self): + def name(self) -> str: return f"REF_{self.index}" # overide of core.variables.variables # reference can have Function has a type # to handle the function selector - def set_type(self, t): + def set_type(self, t) -> None: if not isinstance(t, Function): super().set_type(t) else: self._type = t - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/slithir/variables/reference_ssa.py b/slither/slithir/variables/reference_ssa.py index 3e591555d..4fa82c6d7 100644 --- a/slither/slithir/variables/reference_ssa.py +++ b/slither/slithir/variables/reference_ssa.py @@ -7,7 +7,7 @@ from slither.slithir.variables.reference import ReferenceVariable class ReferenceVariableSSA(ReferenceVariable): # pylint: disable=too-few-public-methods - def __init__(self, reference): + def __init__(self, reference: ReferenceVariable) -> None: super().__init__(reference.node, reference.index) self._non_ssa_version = reference diff --git a/slither/slithir/variables/state_variable.py b/slither/slithir/variables/state_variable.py index 0f92d8687..7bb3a4077 100644 --- a/slither/slithir/variables/state_variable.py +++ b/slither/slithir/variables/state_variable.py @@ -5,7 +5,7 @@ from slither.slithir.variables.variable import SlithIRVariable class StateIRVariable( StateVariable, SlithIRVariable ): # pylint: disable=too-many-instance-attributes - def __init__(self, state_variable): + def __init__(self, state_variable: StateVariable) -> None: assert isinstance(state_variable, StateVariable) super().__init__() @@ -38,7 +38,7 @@ class StateIRVariable( self._index = idx @property - def non_ssa_version(self): + def non_ssa_version(self) -> StateVariable: return self._non_ssa_version @property diff --git a/slither/slithir/variables/temporary.py b/slither/slithir/variables/temporary.py index e0a3adb26..a5b638f5c 100644 --- a/slither/slithir/variables/temporary.py +++ b/slither/slithir/variables/temporary.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from slither.core.children.child_node import ChildNode from slither.core.variables.variable import Variable @@ -8,7 +8,7 @@ if TYPE_CHECKING: class TemporaryVariable(ChildNode, Variable): - def __init__(self, node: "Node", index=None): + def __init__(self, node: "Node", index: Optional[int]=None) -> None: super().__init__() if index is None: self._index = node.compilation_unit.counter_slithir_temporary @@ -26,8 +26,8 @@ class TemporaryVariable(ChildNode, Variable): self._index = idx @property - def name(self): + def name(self) -> str: return f"TMP_{self.index}" - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/slithir/variables/temporary_ssa.py b/slither/slithir/variables/temporary_ssa.py index 3dc772b75..0e3221650 100644 --- a/slither/slithir/variables/temporary_ssa.py +++ b/slither/slithir/variables/temporary_ssa.py @@ -4,14 +4,17 @@ as the TemporaryVariable are in SSA form in both version """ from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.reference import ReferenceVariable +from slither.slithir.variables.tuple import TupleVariable +from typing import Union class TemporaryVariableSSA(TemporaryVariable): # pylint: disable=too-few-public-methods - def __init__(self, temporary): + def __init__(self, temporary: TemporaryVariable) -> None: super().__init__(temporary.node, temporary.index) self._non_ssa_version = temporary @property - def non_ssa_version(self): + def non_ssa_version(self) -> Union[TemporaryVariable, TupleVariable, ReferenceVariable]: return self._non_ssa_version diff --git a/slither/slithir/variables/tuple.py b/slither/slithir/variables/tuple.py index 537f91d42..3557ddf3d 100644 --- a/slither/slithir/variables/tuple.py +++ b/slither/slithir/variables/tuple.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from slither.core.children.child_node import ChildNode from slither.slithir.variables.variable import SlithIRVariable @@ -8,7 +8,7 @@ if TYPE_CHECKING: class TupleVariable(ChildNode, SlithIRVariable): - def __init__(self, node: "Node", index=None): + def __init__(self, node: "Node", index: Optional[int]=None) -> None: super().__init__() if index is None: self._index = node.compilation_unit.counter_slithir_tuple @@ -27,8 +27,8 @@ class TupleVariable(ChildNode, SlithIRVariable): self._index = idx @property - def name(self): + def name(self) -> str: return f"TUPLE_{self.index}" - def __str__(self): + def __str__(self) -> str: return self.name diff --git a/slither/slithir/variables/tuple_ssa.py b/slither/slithir/variables/tuple_ssa.py index 50a4b672b..881feb1d6 100644 --- a/slither/slithir/variables/tuple_ssa.py +++ b/slither/slithir/variables/tuple_ssa.py @@ -7,7 +7,7 @@ from slither.slithir.variables.tuple import TupleVariable class TupleVariableSSA(TupleVariable): # pylint: disable=too-few-public-methods - def __init__(self, t): + def __init__(self, t: TupleVariable) -> None: super().__init__(t.node, t.index) self._non_ssa_version = t diff --git a/slither/slithir/variables/variable.py b/slither/slithir/variables/variable.py index 8e2cb145c..a1a1a6df9 100644 --- a/slither/slithir/variables/variable.py +++ b/slither/slithir/variables/variable.py @@ -2,7 +2,7 @@ from slither.core.variables.variable import Variable class SlithIRVariable(Variable): - def __init__(self): + def __init__(self) -> None: super().__init__() self._index = 0 diff --git a/slither/solc_parsing/cfg/node.py b/slither/solc_parsing/cfg/node.py index d7009fdda..b1380553b 100644 --- a/slither/solc_parsing/cfg/node.py +++ b/slither/solc_parsing/cfg/node.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Union, Optional, Dict, TYPE_CHECKING from slither.core.cfg.node import Node from slither.core.cfg.node import NodeType @@ -12,9 +12,13 @@ from slither.visitors.expression.find_calls import FindCalls from slither.visitors.expression.read_var import ReadVar from slither.visitors.expression.write_var import WriteVar +if TYPE_CHECKING: + from slither.solc_parsing.declarations.function import FunctionSolc + from slither.solc_parsing.declarations.modifier import ModifierSolc + class NodeSolc: - def __init__(self, node: Node): + def __init__(self, node: Node) -> None: self._unparsed_expression: Optional[Dict] = None self._node = node @@ -22,11 +26,11 @@ class NodeSolc: def underlying_node(self) -> Node: return self._node - def add_unparsed_expression(self, expression: Dict): + def add_unparsed_expression(self, expression: Dict) -> None: assert self._unparsed_expression is None self._unparsed_expression = expression - def analyze_expressions(self, caller_context): + def analyze_expressions(self, caller_context: Union["FunctionSolc", "ModifierSolc"]) -> None: if self._node.type == NodeType.VARIABLE and not self._node.expression: self._node.add_expression(self._node.variable_declaration.expression) if self._unparsed_expression: diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 0ffd24131..304b8d342 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -1,6 +1,6 @@ import logging import re -from typing import List, Dict, Callable, TYPE_CHECKING, Union, Set +from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set from slither.core.declarations import Modifier, Event, EnumContract, StructureContract, Function from slither.core.declarations.contract import Contract @@ -17,19 +17,21 @@ from slither.solc_parsing.declarations.structure_contract import StructureContra from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.solidity_types.type_parsing import parse_type from slither.solc_parsing.variables.state_variable import StateVariableSolc +import slither.core.declarations.function +import slither.core.declarations.modifier +import slither.core.solidity_types.type LOGGER = logging.getLogger("ContractSolcParsing") if TYPE_CHECKING: from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc - from slither.core.slither_core import SlitherCore from slither.core.compilation_unit import SlitherCompilationUnit # pylint: disable=too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks,too-many-public-methods class ContractSolc(CallerContextExpression): - def __init__(self, slither_parser: "SlitherCompilationUnitSolc", contract: Contract, data): + def __init__(self, slither_parser: "SlitherCompilationUnitSolc", contract: Contract, data: Dict[str, Any]) -> None: # assert slitherSolc.solc_version.startswith('0.4') self._contract = contract @@ -86,7 +88,7 @@ class ContractSolc(CallerContextExpression): def is_analyzed(self) -> bool: return self._is_analyzed - def set_is_analyzed(self, is_analyzed: bool): + def set_is_analyzed(self, is_analyzed: bool) -> None: self._is_analyzed = is_analyzed @property @@ -130,7 +132,7 @@ class ContractSolc(CallerContextExpression): def get_key(self) -> str: return self._slither_parser.get_key() - def get_children(self, key="nodes") -> str: + def get_children(self, key: str="nodes") -> str: if self.is_compact_ast: return key return "children" @@ -150,7 +152,7 @@ class ContractSolc(CallerContextExpression): ################################################################################### ################################################################################### - def _parse_contract_info(self): + def _parse_contract_info(self) -> None: if self.is_compact_ast: attributes = self._data else: @@ -178,7 +180,7 @@ class ContractSolc(CallerContextExpression): "name" ] - def _parse_base_contract_info(self): # pylint: disable=too-many-branches + def _parse_base_contract_info(self) -> None: # pylint: disable=too-many-branches # Parse base contracts (immediate, non-linearized) if self.is_compact_ast: # Parse base contracts + constructors in compact-ast @@ -236,7 +238,7 @@ class ContractSolc(CallerContextExpression): ): self.baseConstructorContractsCalled.append(referencedDeclaration) - def _parse_contract_items(self): + def _parse_contract_items(self) -> None: # pylint: disable=too-many-branches if not self.get_children() in self._data: # empty contract return @@ -289,7 +291,7 @@ class ContractSolc(CallerContextExpression): self._contract.file_scope.user_defined_types[alias] = user_defined_type self._contract.file_scope.user_defined_types[alias_canonical] = user_defined_type - def _parse_struct(self, struct: Dict): + def _parse_struct(self, struct: Dict) -> None: st = StructureContract(self._contract.compilation_unit) st.set_contract(self._contract) @@ -299,7 +301,7 @@ class ContractSolc(CallerContextExpression): self._contract.structures_as_dict[st.name] = st self._structures_parser.append(st_parser) - def parse_structs(self): + def parse_structs(self) -> None: for father in self._contract.inheritance_reverse: self._contract.structures_as_dict.update(father.structures_as_dict) @@ -307,7 +309,7 @@ class ContractSolc(CallerContextExpression): self._parse_struct(struct) self._structuresNotParsed = None - def _parse_custom_error(self, custom_error: Dict): + def _parse_custom_error(self, custom_error: Dict) -> None: ce = CustomErrorContract(self.compilation_unit) ce.set_contract(self._contract) ce.set_offset(custom_error["src"], self.compilation_unit) @@ -316,7 +318,7 @@ class ContractSolc(CallerContextExpression): self._contract.custom_errors_as_dict[ce.name] = ce self._custom_errors_parser.append(ce_parser) - def parse_custom_errors(self): + def parse_custom_errors(self) -> None: for father in self._contract.inheritance_reverse: self._contract.custom_errors_as_dict.update(father.custom_errors_as_dict) @@ -324,7 +326,7 @@ class ContractSolc(CallerContextExpression): self._parse_custom_error(custom_error) self._customErrorParsed = None - def parse_state_variables(self): + def parse_state_variables(self) -> None: for father in self._contract.inheritance_reverse: self._contract.variables_as_dict.update( { @@ -352,7 +354,7 @@ class ContractSolc(CallerContextExpression): self._contract.variables_as_dict[var.name] = var self._contract.add_variables_ordered([var]) - def _parse_modifier(self, modifier_data: Dict): + def _parse_modifier(self, modifier_data: Dict) -> None: modif = Modifier(self._contract.compilation_unit) modif.set_offset(modifier_data["src"], self._contract.compilation_unit) modif.set_contract(self._contract) @@ -365,12 +367,12 @@ class ContractSolc(CallerContextExpression): self._slither_parser.add_function_or_modifier_parser(modif_parser) - def parse_modifiers(self): + def parse_modifiers(self) -> None: for modifier in self._modifiersNotParsed: self._parse_modifier(modifier) self._modifiersNotParsed = None - def _parse_function(self, function_data: Dict): + def _parse_function(self, function_data: Dict) -> None: func = FunctionContract(self._contract.compilation_unit) func.set_offset(function_data["src"], self._contract.compilation_unit) func.set_contract(self._contract) @@ -383,7 +385,7 @@ class ContractSolc(CallerContextExpression): self._slither_parser.add_function_or_modifier_parser(func_parser) - def parse_functions(self): + def parse_functions(self) -> None: for function in self._functionsNotParsed: self._parse_function(function) @@ -403,21 +405,21 @@ class ContractSolc(CallerContextExpression): LOGGER.error(error) self._contract.is_incorrectly_constructed = True - def analyze_content_modifiers(self): + def analyze_content_modifiers(self) -> None: try: for modifier_parser in self._modifiers_parser: modifier_parser.analyze_content() except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing modifier {e}") - def analyze_content_functions(self): + def analyze_content_functions(self) -> None: try: for function_parser in self._functions_parser: function_parser.analyze_content() except (VariableNotFound, KeyError, ParsingError) as e: self.log_incorrect_parsing(f"Missing function {e}") - def analyze_params_modifiers(self): + def analyze_params_modifiers(self) -> None: try: elements_no_params = self._modifiers_no_params getter = lambda c: c.modifiers_parser @@ -437,7 +439,7 @@ class ContractSolc(CallerContextExpression): self.log_incorrect_parsing(f"Missing params {e}") self._modifiers_no_params = [] - def analyze_params_functions(self): + def analyze_params_functions(self) -> None: try: elements_no_params = self._functions_no_params getter = lambda c: c.functions_parser @@ -465,7 +467,7 @@ class ContractSolc(CallerContextExpression): explored_reference_id: Set[str], parser: List[FunctionSolc], all_elements: Dict[str, Function], - ): + ) -> None: elem = Cls(self._contract.compilation_unit) elem.set_contract(self._contract) underlying_function = element_parser.underlying_function @@ -566,7 +568,7 @@ class ContractSolc(CallerContextExpression): self.log_incorrect_parsing(f"Missing params {e}") return all_elements - def analyze_constant_state_variables(self): + def analyze_constant_state_variables(self) -> None: for var_parser in self._variables_parser: if var_parser.underlying_variable.is_constant: # cant parse constant expression based on function calls @@ -575,7 +577,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: LOGGER.error(e) - def analyze_state_variables(self): + def analyze_state_variables(self) -> None: try: for var_parser in self._variables_parser: var_parser.analyze(self) @@ -583,7 +585,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing state variable {e}") - def analyze_using_for(self): # pylint: disable=too-many-branches + def analyze_using_for(self) -> None: # pylint: disable=too-many-branches try: for father in self._contract.inheritance: self._contract.using_for.update(father.using_for) @@ -620,7 +622,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing using for {e}") - def _analyze_function_list(self, function_list: List, type_name: Type): + def _analyze_function_list(self, function_list: List, type_name: Type) -> None: for f in function_list: full_name_split = f["function"]["name"].split(".") if len(full_name_split) == 1: @@ -639,7 +641,7 @@ class ContractSolc(CallerContextExpression): function_name = full_name_split[2] self._analyze_library_function(library_name, function_name, type_name) - def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type): + def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type) -> None: # We check if the first part appear as alias for an import # if it is then function_name must be a top level function # otherwise it's a library function @@ -649,7 +651,7 @@ class ContractSolc(CallerContextExpression): return self._analyze_library_function(first_part, function_name, type_name) - def _analyze_top_level_function(self, function_name: str, type_name: Type): + def _analyze_top_level_function(self, function_name: str, type_name: Type) -> None: for tl_function in self.compilation_unit.functions_top_level: if tl_function.name == function_name: self._contract.using_for[type_name].append(tl_function) @@ -673,7 +675,7 @@ class ContractSolc(CallerContextExpression): f"Contract level using for: Library {library_name} - function {function_name} not found" ) - def analyze_enums(self): + def analyze_enums(self) -> None: try: for father in self._contract.inheritance: self._contract.enums_as_dict.update(father.enums_as_dict) @@ -686,7 +688,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing enum {e}") - def _analyze_enum(self, enum): + def _analyze_enum(self, enum: Dict[str, Union[str, int, List[Dict[str, Union[int, str]]], Dict[str, str], List[Dict[str, Union[Dict[str, str], int, str]]]]]) -> None: # Enum can be parsed in one pass if self.is_compact_ast: name = enum["name"] @@ -710,21 +712,21 @@ class ContractSolc(CallerContextExpression): new_enum.set_offset(enum["src"], self._contract.compilation_unit) self._contract.enums_as_dict[canonicalName] = new_enum - def _analyze_struct(self, struct: StructureContractSolc): # pylint: disable=no-self-use + def _analyze_struct(self, struct: StructureContractSolc) -> None: # pylint: disable=no-self-use struct.analyze() - def analyze_structs(self): + def analyze_structs(self) -> None: try: for struct in self._structures_parser: self._analyze_struct(struct) except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing struct {e}") - def analyze_custom_errors(self): + def analyze_custom_errors(self) -> None: for custom_error in self._custom_errors_parser: custom_error.analyze_params() - def analyze_events(self): + def analyze_events(self) -> None: try: for father in self._contract.inheritance_reverse: self._contract.events_as_dict.update(father.events_as_dict) diff --git a/slither/solc_parsing/declarations/custom_error.py b/slither/solc_parsing/declarations/custom_error.py index 83f6f0e5c..8cd459262 100644 --- a/slither/solc_parsing/declarations/custom_error.py +++ b/slither/solc_parsing/declarations/custom_error.py @@ -22,7 +22,7 @@ class CustomErrorSolc(CallerContextExpression): custom_error: CustomError, custom_error_data: dict, slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: self._slither_parser: "SlitherCompilationUnitSolc" = slither_parser self._custom_error = custom_error custom_error.name = custom_error_data["name"] @@ -32,7 +32,7 @@ class CustomErrorSolc(CallerContextExpression): custom_error_data = custom_error_data["attributes"] self._custom_error_data = custom_error_data - def analyze_params(self): + def analyze_params(self) -> None: # Can be re-analyzed due to inheritance if self._params_was_analyzed: return @@ -68,7 +68,7 @@ class CustomErrorSolc(CallerContextExpression): return key return "children" - def _parse_params(self, params: Dict): + def _parse_params(self, params: Dict) -> None: assert params[self.get_key()] == "ParameterList" if self._slither_parser.is_compact_ast: diff --git a/slither/solc_parsing/declarations/event.py b/slither/solc_parsing/declarations/event.py index 1f8904fc1..6531e6536 100644 --- a/slither/solc_parsing/declarations/event.py +++ b/slither/solc_parsing/declarations/event.py @@ -16,7 +16,7 @@ class EventSolc: Event class """ - def __init__(self, event: Event, event_data: Dict, contract_parser: "ContractSolc"): + def __init__(self, event: Event, event_data: Dict, contract_parser: "ContractSolc") -> None: self._event = event event.set_contract(contract_parser.underlying_contract) @@ -43,7 +43,7 @@ class EventSolc: def is_compact_ast(self) -> bool: return self._parser_contract.is_compact_ast - def analyze(self, contract: "ContractSolc"): + def analyze(self, contract: "ContractSolc") -> None: for elem_to_parse in self._elemsNotParsed: elem = EventVariable() # Todo: check if the source offset is always here diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 6b8aca51e..66f6e0fe5 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -23,10 +23,10 @@ from slither.solc_parsing.variables.local_variable_init_from_tuple import ( LocalVariableInitFromTupleSolc, ) from slither.solc_parsing.variables.variable_declaration import MultipleVariablesDeclaration -from slither.solc_parsing.yul.parse_yul import YulBlock from slither.utils.expression_manipulations import SplitTernaryExpression from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.has_conditional import HasConditional +from slither.solc_parsing.yul.parse_yul import YulBlock if TYPE_CHECKING: from slither.core.expressions.expression import Expression @@ -35,6 +35,7 @@ if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit + LOGGER = logging.getLogger("FunctionSolc") @@ -55,7 +56,7 @@ class FunctionSolc(CallerContextExpression): function_data: Dict, contract_parser: Optional["ContractSolc"], slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: self._slither_parser: "SlitherCompilationUnitSolc" = slither_parser self._contract_parser = contract_parser self._function = function @@ -143,7 +144,7 @@ class FunctionSolc(CallerContextExpression): def _add_local_variable( self, local_var_parser: Union[LocalVariableSolc, LocalVariableInitFromTupleSolc] - ): + ) -> None: # If two local variables have the same name # We add a suffix to the new variable # This is done to prevent collision during SSA translation @@ -175,7 +176,7 @@ class FunctionSolc(CallerContextExpression): def function_not_parsed(self) -> Dict: return self._functionNotParsed - def _analyze_type(self): + def _analyze_type(self) -> None: """ Analyz the type of the function Myst be called in the constructor as the name might change according to the function's type @@ -201,7 +202,7 @@ class FunctionSolc(CallerContextExpression): if self._function.name == self._function.contract_declarer.name: self._function.function_type = FunctionType.CONSTRUCTOR - def _analyze_attributes(self): + def _analyze_attributes(self) -> None: if self.is_compact_ast: attributes = self._functionNotParsed else: @@ -1018,7 +1019,7 @@ class FunctionSolc(CallerContextExpression): return node - def _parse_block(self, block: Dict, node: NodeSolc, check_arithmetic: bool = False): + def _parse_block(self, block: Dict, node: NodeSolc, check_arithmetic: bool = False) -> NodeSolc: """ Return: Node @@ -1053,7 +1054,7 @@ class FunctionSolc(CallerContextExpression): node = self._parse_statement(statement, node, new_scope) return node - def _parse_cfg(self, cfg: Dict): + def _parse_cfg(self, cfg: Dict) -> None: assert cfg[self.get_key()] == "Block" @@ -1118,7 +1119,7 @@ class FunctionSolc(CallerContextExpression): return None - def _fix_break_node(self, node: Node): + def _fix_break_node(self, node: Node) -> None: end_node = self._find_end_loop(node, [], 0) if not end_node: @@ -1134,7 +1135,7 @@ class FunctionSolc(CallerContextExpression): node.set_sons([end_node]) end_node.add_father(node) - def _fix_continue_node(self, node: Node): + def _fix_continue_node(self, node: Node) -> None: start_node = self._find_start_loop(node, []) if not start_node: @@ -1145,14 +1146,14 @@ class FunctionSolc(CallerContextExpression): node.set_sons([start_node]) start_node.add_father(node) - def _fix_try(self, node: Node): + def _fix_try(self, node: Node) -> None: end_node = next((son for son in node.sons if son.type != NodeType.CATCH), None) if end_node: for son in node.sons: if son.type == NodeType.CATCH: self._fix_catch(son, end_node) - def _fix_catch(self, node: Node, end_node: Node): + def _fix_catch(self, node: Node, end_node: Node) -> None: if not node.sons: link_nodes(node, end_node) else: diff --git a/slither/solc_parsing/declarations/modifier.py b/slither/solc_parsing/declarations/modifier.py index e55487612..ea7af00b3 100644 --- a/slither/solc_parsing/declarations/modifier.py +++ b/slither/solc_parsing/declarations/modifier.py @@ -23,7 +23,7 @@ class ModifierSolc(FunctionSolc): function_data: Dict, contract_parser: "ContractSolc", slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: super().__init__(modifier, function_data, contract_parser, slither_parser) # _modifier is equal to _function, but keep it here to prevent # confusion for mypy in underlying_function @@ -33,7 +33,7 @@ class ModifierSolc(FunctionSolc): def underlying_function(self) -> Modifier: return self._modifier - def analyze_params(self): + def analyze_params(self) -> None: # Can be re-analyzed due to inheritance if self._params_was_analyzed: return @@ -55,7 +55,7 @@ class ModifierSolc(FunctionSolc): if params: self._parse_params(params) - def analyze_content(self): + def analyze_content(self) -> None: if self._content_was_analyzed: return diff --git a/slither/solc_parsing/declarations/structure_contract.py b/slither/solc_parsing/declarations/structure_contract.py index 9c3784ea9..c48c73c4f 100644 --- a/slither/solc_parsing/declarations/structure_contract.py +++ b/slither/solc_parsing/declarations/structure_contract.py @@ -23,7 +23,7 @@ class StructureContractSolc: # pylint: disable=too-few-public-methods st: Structure, struct: Dict, contract_parser: "ContractSolc", - ): + ) -> None: if contract_parser.is_compact_ast: name = struct["name"] @@ -45,7 +45,7 @@ class StructureContractSolc: # pylint: disable=too-few-public-methods self._elemsNotParsed = children - def analyze(self): + def analyze(self) -> None: for elem_to_parse in self._elemsNotParsed: elem = StructureVariable() elem.set_structure(self._structure) diff --git a/slither/solc_parsing/declarations/structure_top_level.py b/slither/solc_parsing/declarations/structure_top_level.py index 1597ad44e..6dcca19d4 100644 --- a/slither/solc_parsing/declarations/structure_top_level.py +++ b/slither/solc_parsing/declarations/structure_top_level.py @@ -25,7 +25,7 @@ class StructureTopLevelSolc(CallerContextExpression): # pylint: disable=too-few st: StructureTopLevel, struct: Dict, slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: if slither_parser.is_compact_ast: name = struct["name"] @@ -47,7 +47,7 @@ class StructureTopLevelSolc(CallerContextExpression): # pylint: disable=too-few self._elemsNotParsed = children - def analyze(self): + def analyze(self) -> None: for elem_to_parse in self._elemsNotParsed: elem = StructureVariable() elem.set_structure(self._structure) diff --git a/slither/solc_parsing/declarations/using_for_top_level.py b/slither/solc_parsing/declarations/using_for_top_level.py index 16e3666b0..52e734e8f 100644 --- a/slither/solc_parsing/declarations/using_for_top_level.py +++ b/slither/solc_parsing/declarations/using_for_top_level.py @@ -15,6 +15,7 @@ from slither.core.solidity_types import TypeAliasTopLevel from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.solidity_types.type_parsing import parse_type +import slither.core.solidity_types.type_alias if TYPE_CHECKING: from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc @@ -77,7 +78,7 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- first_part: str, function_name: str, type_name: Union[TypeAliasTopLevel, UserDefinedType], - ): + ) -> None: # We check if the first part appear as alias for an import # if it is then function_name must be a top level function # otherwise it's a library function @@ -132,7 +133,7 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- f"Error when propagating global using for {type_name} {type(type_name)}" ) - def _propagate_global_UserDefinedType(self, scope: FileScope, type_name: UserDefinedType): + def _propagate_global_UserDefinedType(self, scope: FileScope, type_name: UserDefinedType) -> None: underlying = type_name.type if isinstance(underlying, StructureTopLevel): for struct in scope.structures.values(): diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index 9e3a12b06..207409708 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -1,19 +1,12 @@ import logging import re -from typing import Dict, TYPE_CHECKING +from typing import Union, Dict, TYPE_CHECKING +import slither.core.expressions.type_conversion from slither.core.declarations.solidity_variables import ( SOLIDITY_VARIABLES_COMPOSED, SolidityVariableComposed, ) -from slither.core.expressions.assignment_operation import ( - AssignmentOperation, - AssignmentOperationType, -) -from slither.core.expressions.binary_operation import ( - BinaryOperation, - BinaryOperationType, -) from slither.core.expressions import ( CallExpression, ConditionalExpression, @@ -32,6 +25,14 @@ from slither.core.expressions import ( UnaryOperation, UnaryOperationType, ) +from slither.core.expressions.assignment_operation import ( + AssignmentOperation, + AssignmentOperationType, +) +from slither.core.expressions.binary_operation import ( + BinaryOperation, + BinaryOperationType, +) from slither.core.solidity_types import ( ArrayType, ElementaryType, @@ -41,8 +42,12 @@ from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.expressions.find_variable import find_variable from slither.solc_parsing.solidity_types.type_parsing import UnknownType, parse_type + if TYPE_CHECKING: from slither.core.expressions.expression import Expression + from slither.solc_parsing.declarations.contract import ContractSolc + from slither.solc_parsing.declarations.function import FunctionSolc + from slither.solc_parsing.variables.top_level_variable import TopLevelVariableSolc logger = logging.getLogger("ExpressionParsing") @@ -98,7 +103,7 @@ def filter_name(value: str) -> str: ################################################################################### -def parse_call(expression: Dict, caller_context): # pylint: disable=too-many-statements +def parse_call(expression: Dict, caller_context: Union["FunctionSolc", "ContractSolc", "TopLevelVariableSolc"]) -> Union[slither.core.expressions.call_expression.CallExpression, slither.core.expressions.type_conversion.TypeConversion]: # pylint: disable=too-many-statements src = expression["src"] if caller_context.is_compact_ast: attributes = expression @@ -221,8 +226,7 @@ def _parse_elementary_type_name_expression( if TYPE_CHECKING: - - from slither.core.scope.scope import FileScope + pass def parse_expression(expression: Dict, caller_context: CallerContextExpression) -> "Expression": diff --git a/slither/solc_parsing/expressions/find_variable.py b/slither/solc_parsing/expressions/find_variable.py index 471f1a750..c62c700a2 100644 --- a/slither/solc_parsing/expressions/find_variable.py +++ b/slither/solc_parsing/expressions/find_variable.py @@ -25,6 +25,10 @@ from slither.core.variables.variable import Variable from slither.exceptions import SlitherError from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.exceptions import VariableNotFound +import slither.core.declarations.enum +import slither.core.declarations.event +import slither.core.declarations.structure +import slither.core.solidity_types.type_alias if TYPE_CHECKING: from slither.solc_parsing.declarations.function import FunctionSolc @@ -36,7 +40,7 @@ if TYPE_CHECKING: # CallerContext =Union["ContractSolc", "FunctionSolc", "CustomErrorSolc", "StructureTopLevelSolc"] -def _get_pointer_name(variable: Variable): +def _get_pointer_name(variable: Variable) -> Optional[str]: curr_type = variable.type while isinstance(curr_type, (ArrayType, MappingType)): if isinstance(curr_type, ArrayType): diff --git a/slither/solc_parsing/slither_compilation_unit_solc.py b/slither/solc_parsing/slither_compilation_unit_solc.py index 6e7de2db9..1977e155c 100644 --- a/slither/solc_parsing/slither_compilation_unit_solc.py +++ b/slither/solc_parsing/slither_compilation_unit_solc.py @@ -19,6 +19,7 @@ from slither.core.scope.scope import FileScope from slither.core.solidity_types import ElementaryType, TypeAliasTopLevel from slither.core.variables.top_level_variable import TopLevelVariable from slither.exceptions import SlitherException +from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.custom_error import CustomErrorSolc from slither.solc_parsing.declarations.function import FunctionSolc @@ -26,7 +27,6 @@ from slither.solc_parsing.declarations.structure_top_level import StructureTopLe from slither.solc_parsing.declarations.using_for_top_level import UsingForTopLevelSolc from slither.solc_parsing.exceptions import VariableNotFound from slither.solc_parsing.variables.top_level_variable import TopLevelVariableSolc -from slither.solc_parsing.declarations.caller_context import CallerContextExpression logging.basicConfig() logger = logging.getLogger("SlitherSolcParsing") @@ -68,7 +68,7 @@ def _handle_import_aliases( class SlitherCompilationUnitSolc(CallerContextExpression): # pylint: disable=no-self-use,too-many-instance-attributes - def __init__(self, compilation_unit: SlitherCompilationUnit): + def __init__(self, compilation_unit: SlitherCompilationUnit) -> None: super().__init__() self._contracts_by_id: Dict[int, ContractSolc] = {} @@ -98,7 +98,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): def all_functions_and_modifiers_parser(self) -> List[FunctionSolc]: return self._all_functions_and_modifier_parser - def add_function_or_modifier_parser(self, f: FunctionSolc): + def add_function_or_modifier_parser(self, f: FunctionSolc) -> None: self._all_functions_and_modifier_parser.append(f) @property @@ -163,7 +163,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): return True return False - def _parse_enum(self, top_level_data: Dict, filename: str): + def _parse_enum(self, top_level_data: Dict, filename: str) -> None: if self.is_compact_ast: name = top_level_data["name"] canonicalName = top_level_data["canonicalName"] @@ -194,7 +194,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): def parse_top_level_from_loaded_json( self, data_loaded: Dict, filename: str - ): # pylint: disable=too-many-branches,too-many-statements,too-many-locals + ) -> None: # pylint: disable=too-many-branches,too-many-statements,too-many-locals if "nodeType" in data_loaded: self._is_compact_ast = True @@ -342,7 +342,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): else: raise SlitherException(f"Top level {top_level_data[self.get_key()]} not supported") - def _parse_source_unit(self, data: Dict, filename: str): + def _parse_source_unit(self, data: Dict, filename: str) -> None: if data[self.get_key()] != "SourceUnit": return # handle solc prior 0.3.6 @@ -392,7 +392,7 @@ class SlitherCompilationUnitSolc(CallerContextExpression): def analyzed(self) -> bool: return self._analyzed - def parse_contracts(self): # pylint: disable=too-many-statements,too-many-branches + def parse_contracts(self) -> None: # pylint: disable=too-many-statements,too-many-branches if not self._underlying_contract_to_parser: logger.info( f"No contract were found in {self._compilation_unit.core.filename}, check the correct compilation" @@ -523,7 +523,7 @@ Please rename it, this name is reserved for Slither's internals""" self._parsed = True - def analyze_contracts(self): # pylint: disable=too-many-statements,too-many-branches + def analyze_contracts(self) -> None: # pylint: disable=too-many-statements,too-many-branches if not self._parsed: raise SlitherException("Parse the contract before running analyses") self._convert_to_slithir() @@ -532,7 +532,7 @@ Please rename it, this name is reserved for Slither's internals""" self._compilation_unit.compute_storage_layout() self._analyzed = True - def _analyze_all_enums(self, contracts_to_be_analyzed: List[ContractSolc]): + def _analyze_all_enums(self, contracts_to_be_analyzed: List[ContractSolc]) -> None: while contracts_to_be_analyzed: contract = contracts_to_be_analyzed[0] @@ -551,7 +551,7 @@ Please rename it, this name is reserved for Slither's internals""" self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc], - ): + ) -> None: for lib in libraries: self._parse_struct_var_modifiers_functions(lib) @@ -578,7 +578,7 @@ Please rename it, this name is reserved for Slither's internals""" self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc], - ): + ) -> None: for lib in libraries: self._analyze_struct_events(lib) @@ -608,7 +608,7 @@ Please rename it, this name is reserved for Slither's internals""" self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc], - ): + ) -> None: for lib in libraries: self._analyze_variables_modifiers_functions(lib) @@ -633,7 +633,7 @@ Please rename it, this name is reserved for Slither's internals""" def _analyze_using_for( self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc] - ): + ) -> None: self._analyze_top_level_using_for() for lib in libraries: @@ -654,12 +654,12 @@ Please rename it, this name is reserved for Slither's internals""" else: contracts_to_be_analyzed += [contract] - def _analyze_enums(self, contract: ContractSolc): + def _analyze_enums(self, contract: ContractSolc) -> None: # Enum must be analyzed first contract.analyze_enums() contract.set_is_analyzed(True) - def _parse_struct_var_modifiers_functions(self, contract: ContractSolc): + def _parse_struct_var_modifiers_functions(self, contract: ContractSolc) -> None: contract.parse_structs() # struct can refer another struct contract.parse_state_variables() contract.parse_modifiers() @@ -667,7 +667,7 @@ Please rename it, this name is reserved for Slither's internals""" contract.parse_custom_errors() contract.set_is_analyzed(True) - def _analyze_struct_events(self, contract: ContractSolc): + def _analyze_struct_events(self, contract: ContractSolc) -> None: contract.analyze_constant_state_variables() @@ -680,41 +680,41 @@ Please rename it, this name is reserved for Slither's internals""" contract.set_is_analyzed(True) - def _analyze_top_level_structures(self): + def _analyze_top_level_structures(self) -> None: try: for struct in self._structures_top_level_parser: struct.analyze() except (VariableNotFound, KeyError) as e: raise SlitherException(f"Missing struct {e} during top level structure analyze") from e - def _analyze_top_level_variables(self): + def _analyze_top_level_variables(self) -> None: try: for var in self._variables_top_level_parser: var.analyze(var) except (VariableNotFound, KeyError) as e: raise SlitherException(f"Missing {e} during variable analyze") from e - def _analyze_params_top_level_function(self): + def _analyze_params_top_level_function(self) -> None: for func_parser in self._functions_top_level_parser: func_parser.analyze_params() self._compilation_unit.add_function(func_parser.underlying_function) - def _analyze_top_level_using_for(self): + def _analyze_top_level_using_for(self) -> None: for using_for in self._using_for_top_level_parser: using_for.analyze() - def _analyze_params_custom_error(self): + def _analyze_params_custom_error(self) -> None: for custom_error_parser in self._custom_error_parser: custom_error_parser.analyze_params() - def _analyze_content_top_level_function(self): + def _analyze_content_top_level_function(self) -> None: try: for func_parser in self._functions_top_level_parser: func_parser.analyze_content() except (VariableNotFound, KeyError) as e: raise SlitherException(f"Missing {e} during top level function analyze") from e - def _analyze_variables_modifiers_functions(self, contract: ContractSolc): + def _analyze_variables_modifiers_functions(self, contract: ContractSolc) -> None: # State variables, modifiers and functions can refer to anything contract.analyze_params_modifiers() @@ -730,7 +730,7 @@ Please rename it, this name is reserved for Slither's internals""" contract.set_is_analyzed(True) - def _convert_to_slithir(self): + def _convert_to_slithir(self) -> None: for contract in self._compilation_unit.contracts: contract.add_constructor_variables() diff --git a/slither/solc_parsing/solidity_types/type_parsing.py b/slither/solc_parsing/solidity_types/type_parsing.py index 91e320a42..e289090b1 100644 --- a/slither/solc_parsing/solidity_types/type_parsing.py +++ b/slither/solc_parsing/solidity_types/type_parsing.py @@ -33,11 +33,11 @@ logger = logging.getLogger("TypeParsing") class UnknownType: # pylint: disable=too-few-public-methods - def __init__(self, name): + def __init__(self, name: str) -> None: self._name = name @property - def name(self): + def name(self) -> str: return self._name @@ -195,7 +195,7 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t return UserDefinedType(var_type) -def _add_type_references(type_found: Type, src: str, sl: "SlitherCompilationUnit"): +def _add_type_references(type_found: Type, src: str, sl: "SlitherCompilationUnit") -> None: if isinstance(type_found, UserDefinedType): type_found.type.add_reference_from_raw_source(src, sl) diff --git a/slither/solc_parsing/variables/event_variable.py b/slither/solc_parsing/variables/event_variable.py index 664b7d057..fe30b8a3a 100644 --- a/slither/solc_parsing/variables/event_variable.py +++ b/slither/solc_parsing/variables/event_variable.py @@ -14,7 +14,7 @@ class EventVariableSolc(VariableDeclarationSolc): assert isinstance(self._variable, EventVariable) return self._variable - def _analyze_variable_attributes(self, attributes: Dict): + def _analyze_variable_attributes(self, attributes: Dict) -> None: """ Analyze event variable attributes :param attributes: The event variable attributes to parse. diff --git a/slither/solc_parsing/variables/local_variable.py b/slither/solc_parsing/variables/local_variable.py index b9617a59c..cd9030d58 100644 --- a/slither/solc_parsing/variables/local_variable.py +++ b/slither/solc_parsing/variables/local_variable.py @@ -5,7 +5,7 @@ from slither.core.variables.local_variable import LocalVariable class LocalVariableSolc(VariableDeclarationSolc): - def __init__(self, variable: LocalVariable, variable_data: Dict): + def __init__(self, variable: LocalVariable, variable_data: Dict) -> None: super().__init__(variable, variable_data) @property @@ -14,7 +14,7 @@ class LocalVariableSolc(VariableDeclarationSolc): assert isinstance(self._variable, LocalVariable) return self._variable - def _analyze_variable_attributes(self, attributes: Dict): + def _analyze_variable_attributes(self, attributes: Dict) -> None: """' Variable Location Can be storage/memory or default diff --git a/slither/solc_parsing/variables/top_level_variable.py b/slither/solc_parsing/variables/top_level_variable.py index 6c24c3bdf..56eb79c46 100644 --- a/slither/solc_parsing/variables/top_level_variable.py +++ b/slither/solc_parsing/variables/top_level_variable.py @@ -15,7 +15,7 @@ class TopLevelVariableSolc(VariableDeclarationSolc, CallerContextExpression): variable: TopLevelVariable, variable_data: Dict, slither_parser: "SlitherCompilationUnitSolc", - ): + ) -> None: super().__init__(variable, variable_data) self._slither_parser = slither_parser diff --git a/slither/solc_parsing/variables/variable_declaration.py b/slither/solc_parsing/variables/variable_declaration.py index 119604ca4..46827cb7e 100644 --- a/slither/solc_parsing/variables/variable_declaration.py +++ b/slither/solc_parsing/variables/variable_declaration.py @@ -32,7 +32,7 @@ class MultipleVariablesDeclaration(Exception): class VariableDeclarationSolc: def __init__( self, variable: Variable, variable_data: Dict - ): # pylint: disable=too-many-branches + ) -> None: # pylint: disable=too-many-branches """ A variable can be declared through a statement, or directly. If it is through a statement, the following children may contain @@ -104,7 +104,7 @@ class VariableDeclarationSolc: """ return self._reference_id - def _handle_comment(self, attributes: Dict): + def _handle_comment(self, attributes: Dict) -> None: if "documentation" in attributes and "text" in attributes["documentation"]: candidates = attributes["documentation"]["text"].split(",") @@ -121,13 +121,13 @@ class VariableDeclarationSolc: self._variable.write_protection = [] self._variable.write_protection.append(write_protection.group(1)) - def _analyze_variable_attributes(self, attributes: Dict): + def _analyze_variable_attributes(self, attributes: Dict) -> None: if "visibility" in attributes: self._variable.visibility = attributes["visibility"] else: self._variable.visibility = "internal" - def _init_from_declaration(self, var: Dict, init: bool): # pylint: disable=too-many-branches + def _init_from_declaration(self, var: Dict, init: bool) -> None: # pylint: disable=too-many-branches if self._is_compact_ast: attributes = var self._typeName = attributes["typeDescriptions"]["typeString"] @@ -200,7 +200,7 @@ class VariableDeclarationSolc: self._variable.initialized = True self._initializedNotParsed = var["children"][1] - def analyze(self, caller_context: CallerContextExpression): + def analyze(self, caller_context: CallerContextExpression) -> None: # Can be re-analyzed due to inheritance if self._was_analyzed: return diff --git a/slither/solc_parsing/yul/evm_functions.py b/slither/solc_parsing/yul/evm_functions.py index 41c150765..dfb52a244 100644 --- a/slither/solc_parsing/yul/evm_functions.py +++ b/slither/solc_parsing/yul/evm_functions.py @@ -225,7 +225,7 @@ function_args = { } -def format_function_descriptor(name): +def format_function_descriptor(name: str) -> str: if name not in function_args: return name + "()" diff --git a/slither/solc_parsing/yul/parse_yul.py b/slither/solc_parsing/yul/parse_yul.py index f7c9938fc..5c3ae7831 100644 --- a/slither/solc_parsing/yul/parse_yul.py +++ b/slither/solc_parsing/yul/parse_yul.py @@ -1,4 +1,8 @@ import abc +import slither.core.declarations.contract +import slither.core.declarations.function +import slither.core.expressions.identifier + import json from typing import Optional, Dict, List, Union @@ -43,7 +47,7 @@ from slither.visitors.expression.write_var import WriteVar class YulNode: - def __init__(self, node: Node, scope: "YulScope"): + def __init__(self, node: Node, scope: "YulScope") -> None: self._node = node self._scope = scope self._unparsed_expression: Optional[Dict] = None @@ -99,7 +103,7 @@ class YulNode: ] -def link_underlying_nodes(node1: YulNode, node2: YulNode): +def link_underlying_nodes(node1: YulNode, node2: YulNode) -> None: link_nodes(node1.underlying_node, node2.underlying_node) @@ -191,7 +195,7 @@ class YulScope(metaclass=abc.ABCMeta): class YulLocalVariable: # pylint: disable=too-few-public-methods __slots__ = ["_variable", "_root"] - def __init__(self, var: LocalVariable, root: YulScope, ast: Dict): + def __init__(self, var: LocalVariable, root: YulScope, ast: Dict) -> None: assert ast["nodeType"] == "YulTypedName" self._variable = var @@ -215,7 +219,7 @@ class YulFunction(YulScope): def __init__( self, func: Function, root: YulScope, ast: Dict, node_scope: Union[Function, Scope] - ): + ) -> None: super().__init__(root.contract, root.id + [ast["name"]], parent_func=root.parent_func) assert ast["nodeType"] == "YulFunctionDefinition" @@ -272,7 +276,7 @@ class YulFunction(YulScope): for node in self._nodes: node.analyze_expressions() - def new_node(self, node_type, src) -> YulNode: + def new_node(self, node_type: NodeType, src: str) -> YulNode: if self._function: node = self._function.new_node(node_type, src, self.node_scope) else: @@ -299,7 +303,7 @@ class YulBlock(YulScope): entrypoint: Node, yul_id: List[str], node_scope: Union[Scope, Function], - ): + ) -> None: super().__init__(contract, yul_id, entrypoint.function) self._entrypoint: YulNode = YulNode(entrypoint, self) @@ -883,7 +887,7 @@ def vars_to_typestr(rets: List[Expression]) -> str: return f"tuple({','.join(str(ret.type) for ret in rets)})" -def vars_to_val(vars_to_convert): +def vars_to_val(vars_to_convert: List[slither.core.expressions.identifier.Identifier]) -> slither.core.expressions.identifier.Identifier: if len(vars_to_convert) == 1: return vars_to_convert[0] return TupleExpression(vars_to_convert) diff --git a/slither/utils/expression_manipulations.py b/slither/utils/expression_manipulations.py index a63db9829..e27feb246 100644 --- a/slither/utils/expression_manipulations.py +++ b/slither/utils/expression_manipulations.py @@ -20,6 +20,7 @@ from slither.core.expressions.new_contract import NewContract from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.type_conversion import TypeConversion from slither.all_exceptions import SlitherException +import slither.core.expressions.unary_operation # pylint: disable=protected-access def f_expressions( @@ -29,7 +30,7 @@ def f_expressions( e._expressions.append(x) -def f_call(e: CallExpression, x): +def f_call(e: CallExpression, x: ElementaryTypeNameExpression) -> None: e._arguments.append(x) @@ -41,11 +42,11 @@ def f_call_gas(e: CallExpression, x): e._gas = x -def f_expression(e: Union[TypeConversion, UnaryOperation, MemberAccess], x): +def f_expression(e: Union[TypeConversion, UnaryOperation, MemberAccess], x: CallExpression) -> None: e._expression = x -def f_called(e: CallExpression, x): +def f_called(e: CallExpression, x: Identifier) -> None: e._called = x diff --git a/slither/utils/output.py b/slither/utils/output.py index 5db6492db..c67805b1c 100644 --- a/slither/utils/output.py +++ b/slither/utils/output.py @@ -15,6 +15,12 @@ from slither.core.variables.variable import Variable from slither.exceptions import SlitherError from slither.utils.colors import yellow from slither.utils.myprettytable import MyPrettyTable +import slither.core.declarations.contract +import slither.core.declarations.enum +import slither.core.declarations.event +import slither.core.declarations.function +import slither.core.declarations.pragma_directive +import slither.core.declarations.structure if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit @@ -218,7 +224,7 @@ def output_to_zip(filename: str, error: Optional[str], results: Dict, zip_type: ################################################################################### -def _convert_to_description(d): +def _convert_to_description(d: str) -> str: if isinstance(d, str): return d @@ -239,7 +245,7 @@ def _convert_to_description(d): raise SlitherError(f"{type(d)} cannot be converted (no name, or canonical_name") -def _convert_to_markdown(d, markdown_root): +def _convert_to_markdown(d: str, markdown_root: str) -> str: if isinstance(d, str): return d @@ -260,7 +266,7 @@ def _convert_to_markdown(d, markdown_root): raise SlitherError(f"{type(d)} cannot be converted (no name, or canonical_name") -def _convert_to_id(d): +def _convert_to_id(d: str) -> str: """ Id keeps the source mapping of the node, otherwise we risk to consider two different node as the same :param d: @@ -345,9 +351,9 @@ class Output: self, info_: Union[str, List[Union[str, SupportedOutput]]], additional_fields: Optional[Dict] = None, - markdown_root="", - standard_format=True, - ): + markdown_root: str="", + standard_format: bool=True, + ) -> None: if additional_fields is None: additional_fields = {} diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 797d1f46e..7e76377c5 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -1,4 +1,6 @@ from fractions import Fraction +from typing import Union, TYPE_CHECKING + from slither.core.expressions import ( BinaryOperationType, Literal, @@ -7,9 +9,13 @@ from slither.core.expressions import ( BinaryOperation, UnaryOperation, ) + from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor +if TYPE_CHECKING: + from slither.core.solidity_types.elementary_type import ElementaryType + class NotConstant(Exception): pass @@ -17,33 +23,35 @@ class NotConstant(Exception): KEY = "ConstantFolding" +CONSTANT_TYPES_OPERATIONS = Union[Literal, BinaryOperation, UnaryOperation, Identifier] + -def get_val(expression): +def get_val(expression: CONSTANT_TYPES_OPERATIONS) -> Union[bool, int, Fraction, str]: val = expression.context[KEY] # we delete the item to reduce memory use del expression.context[KEY] return val -def set_val(expression, val): +def set_val(expression: CONSTANT_TYPES_OPERATIONS, val: Union[bool, int, Fraction, str]) -> None: expression.context[KEY] = val class ConstantFolding(ExpressionVisitor): - def __init__(self, expression, custom_type): + def __init__(self, expression: CONSTANT_TYPES_OPERATIONS, custom_type: Union[str, "ElementaryType"]) -> None: self._type = custom_type super().__init__(expression) - def result(self): + def result(self) -> "Literal": value = get_val(self._expression) if isinstance(value, Fraction): value = int(value) # emulate 256-bit wrapping if str(self._type).startswith("uint"): - value = value & (2**256 - 1) + value = value & (2 ** 256 - 1) return Literal(value, self._type) - def _post_identifier(self, expression: Identifier): + def _post_identifier(self, expression: Identifier) -> None: if not expression.value.is_constant: raise NotConstant expr = expression.value.expression @@ -54,11 +62,11 @@ class ConstantFolding(ExpressionVisitor): set_val(expression, convert_string_to_int(expr.converted_value)) # pylint: disable=too-many-branches - def _post_binary_operation(self, expression: BinaryOperation): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get_val(expression.expression_left) right = get_val(expression.expression_right) if expression.type == BinaryOperationType.POWER: - set_val(expression, left**right) + set_val(expression, left ** right) elif expression.type == BinaryOperationType.MULTIPLICATION: set_val(expression, left * right) elif expression.type == BinaryOperationType.DIVISION: @@ -100,7 +108,7 @@ class ConstantFolding(ExpressionVisitor): else: raise NotConstant - def _post_unary_operation(self, expression: UnaryOperation): + def _post_unary_operation(self, expression: UnaryOperation) -> None: # Case of uint a = -7; uint[-a] arr; if expression.type == UnaryOperationType.MINUS_PRE: expr = expression.expression @@ -112,7 +120,7 @@ class ConstantFolding(ExpressionVisitor): else: raise NotConstant - def _post_literal(self, expression: Literal): + def _post_literal(self, expression: Literal) -> None: if expression.converted_value in ["true", "false"]: set_val(expression, expression.converted_value) else: diff --git a/slither/visitors/expression/export_values.py b/slither/visitors/expression/export_values.py index 246cc8384..04f6f189f 100644 --- a/slither/visitors/expression/export_values.py +++ b/slither/visitors/expression/export_values.py @@ -1,4 +1,7 @@ from slither.visitors.expression.expression import ExpressionVisitor +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.literal import Literal key = "ExportValues" @@ -32,7 +35,7 @@ class ExportValues(ExpressionVisitor): val = left + right set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) args = [get(a) for a in expression.arguments if a] args = [item for sublist in args for item in sublist] @@ -49,7 +52,7 @@ class ExportValues(ExpressionVisitor): def _post_elementary_type_name_expression(self, expression): set_val(expression, []) - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: set_val(expression, [expression.value]) def _post_index_access(self, expression): @@ -58,7 +61,7 @@ class ExportValues(ExpressionVisitor): val = left + right set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: set_val(expression, []) def _post_member_access(self, expression): diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index 17020aaba..38c3fbb06 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -23,7 +23,7 @@ logger = logging.getLogger("ExpressionVisitor") class ExpressionVisitor: - def __init__(self, expression: Expression): + def __init__(self, expression: Expression) -> None: # Inherited class must declared their variables prior calling super().__init__ self._expression = expression self._result: Any = None @@ -38,7 +38,7 @@ class ExpressionVisitor: # visit an expression # call pre_visit, visit_expression_name, post_visit - def _visit_expression(self, expression: Expression): # pylint: disable=too-many-branches + def _visit_expression(self, expression: Expression) -> None: # pylint: disable=too-many-branches self._pre_visit(expression) if isinstance(expression, AssignmentOperation): @@ -96,15 +96,15 @@ class ExpressionVisitor: # visit_expression_name - def _visit_assignement_operation(self, expression): + def _visit_assignement_operation(self, expression: AssignmentOperation) -> None: self._visit_expression(expression.expression_left) self._visit_expression(expression.expression_right) - def _visit_binary_operation(self, expression): + def _visit_binary_operation(self, expression: BinaryOperation) -> None: self._visit_expression(expression.expression_left) self._visit_expression(expression.expression_right) - def _visit_call_expression(self, expression): + def _visit_call_expression(self, expression: CallExpression) -> None: self._visit_expression(expression.called) for arg in expression.arguments: if arg: @@ -116,50 +116,50 @@ class ExpressionVisitor: if expression.call_salt: self._visit_expression(expression.call_salt) - def _visit_conditional_expression(self, expression): + def _visit_conditional_expression(self, expression: ConditionalExpression) -> None: self._visit_expression(expression.if_expression) self._visit_expression(expression.else_expression) self._visit_expression(expression.then_expression) - def _visit_elementary_type_name_expression(self, expression): + def _visit_elementary_type_name_expression(self, expression: ElementaryTypeNameExpression) -> None: pass - def _visit_identifier(self, expression): + def _visit_identifier(self, expression: Identifier) -> None: pass - def _visit_index_access(self, expression): + def _visit_index_access(self, expression: IndexAccess) -> None: self._visit_expression(expression.expression_left) self._visit_expression(expression.expression_right) - def _visit_literal(self, expression): + def _visit_literal(self, expression: Literal) -> None: pass - def _visit_member_access(self, expression): + def _visit_member_access(self, expression: MemberAccess) -> None: self._visit_expression(expression.expression) - def _visit_new_array(self, expression): + def _visit_new_array(self, expression: NewArray) -> None: pass - def _visit_new_contract(self, expression): + def _visit_new_contract(self, expression: NewContract) -> None: pass def _visit_new_elementary_type(self, expression): pass - def _visit_tuple_expression(self, expression): + def _visit_tuple_expression(self, expression: TupleExpression) -> None: for e in expression.expressions: if e: self._visit_expression(e) - def _visit_type_conversion(self, expression): + def _visit_type_conversion(self, expression: TypeConversion) -> None: self._visit_expression(expression.expression) - def _visit_unary_operation(self, expression): + def _visit_unary_operation(self, expression: UnaryOperation) -> None: self._visit_expression(expression.expression) # pre visit - def _pre_visit(self, expression): # pylint: disable=too-many-branches + def _pre_visit(self, expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._pre_assignement_operation(expression) @@ -213,54 +213,54 @@ class ExpressionVisitor: # pre_expression_name - def _pre_assignement_operation(self, expression): + def _pre_assignement_operation(self, expression: AssignmentOperation) -> None: pass - def _pre_binary_operation(self, expression): + def _pre_binary_operation(self, expression: BinaryOperation) -> None: pass - def _pre_call_expression(self, expression): + def _pre_call_expression(self, expression: CallExpression) -> None: pass - def _pre_conditional_expression(self, expression): + def _pre_conditional_expression(self, expression: ConditionalExpression) -> None: pass - def _pre_elementary_type_name_expression(self, expression): + def _pre_elementary_type_name_expression(self, expression: ElementaryTypeNameExpression) -> None: pass - def _pre_identifier(self, expression): + def _pre_identifier(self, expression: Identifier) -> None: pass - def _pre_index_access(self, expression): + def _pre_index_access(self, expression: IndexAccess) -> None: pass - def _pre_literal(self, expression): + def _pre_literal(self, expression: Literal) -> None: pass - def _pre_member_access(self, expression): + def _pre_member_access(self, expression: MemberAccess) -> None: pass - def _pre_new_array(self, expression): + def _pre_new_array(self, expression: NewArray) -> None: pass - def _pre_new_contract(self, expression): + def _pre_new_contract(self, expression: NewContract) -> None: pass def _pre_new_elementary_type(self, expression): pass - def _pre_tuple_expression(self, expression): + def _pre_tuple_expression(self, expression: TupleExpression) -> None: pass - def _pre_type_conversion(self, expression): + def _pre_type_conversion(self, expression: TypeConversion) -> None: pass - def _pre_unary_operation(self, expression): + def _pre_unary_operation(self, expression: UnaryOperation) -> None: pass # post visit - def _post_visit(self, expression): # pylint: disable=too-many-branches + def _post_visit(self, expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._post_assignement_operation(expression) @@ -314,47 +314,47 @@ class ExpressionVisitor: # post_expression_name - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: pass - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: pass - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: pass def _post_conditional_expression(self, expression): pass - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression(self, expression: ElementaryTypeNameExpression) -> None: pass - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: pass - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: pass - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: pass - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: pass - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: pass - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: pass def _post_new_elementary_type(self, expression): pass - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: pass - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: pass - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: pass diff --git a/slither/visitors/expression/find_calls.py b/slither/visitors/expression/find_calls.py index 9b9141c76..cdf10d7fa 100644 --- a/slither/visitors/expression/find_calls.py +++ b/slither/visitors/expression/find_calls.py @@ -1,19 +1,33 @@ -from typing import List +from typing import Any, Union, List from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation key = "FindCall" -def get(expression): +def get(expression: Expression) -> List[Union[Any, CallExpression]]: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def set_val(expression, val): +def set_val(expression: Expression, val: List[Union[Any, CallExpression]]) -> None: expression.context[key] = val @@ -23,19 +37,19 @@ class FindCalls(ExpressionVisitor): self._result = list(set(get(self.expression))) return self._result - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) args = [get(a) for a in expression.arguments if a] args = [item for sublist in args for item in sublist] @@ -43,54 +57,54 @@ class FindCalls(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) val = if_expr + else_expr + then_expr set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression(self, expression: ElementaryTypeNameExpression) -> None: set_val(expression, []) # save only identifier expression - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: set_val(expression, []) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: set_val(expression, []) - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: expr = get(expression.expression) val = expr set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: set_val(expression, []) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) def _post_new_elementary_type(self, expression): set_val(expression, []) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) for e in expression.expressions if e] val = [item for sublist in expressions for item in sublist] set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: expr = get(expression.expression) val = expr set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: expr = get(expression.expression) val = expr set_val(expression, val) diff --git a/slither/visitors/expression/has_conditional.py b/slither/visitors/expression/has_conditional.py index 906f522ff..b866a696b 100644 --- a/slither/visitors/expression/has_conditional.py +++ b/slither/visitors/expression/has_conditional.py @@ -1,12 +1,13 @@ from slither.visitors.expression.expression import ExpressionVisitor +from slither.core.expressions.conditional_expression import ConditionalExpression class HasConditional(ExpressionVisitor): - def result(self): + def result(self) -> bool: # == True, to convert None to false return self._result is True - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: # if self._result is True: # raise('Slither does not support nested ternary operator') self._result = True diff --git a/slither/visitors/expression/read_var.py b/slither/visitors/expression/read_var.py index 8fe063a2f..86151fa8f 100644 --- a/slither/visitors/expression/read_var.py +++ b/slither/visitors/expression/read_var.py @@ -1,38 +1,53 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignment_operation import AssignmentOperationType +from slither.core.expressions.assignment_operation import AssignmentOperation, AssignmentOperationType from slither.core.variables.variable import Variable from slither.core.declarations.solidity_variables import SolidityVariable +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.expression import Expression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation +from typing import Any, List, Union key = "ReadVar" -def get(expression): +def get(expression: Expression) -> List[Union[Identifier, IndexAccess, Any]]: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def set_val(expression, val): +def set_val(expression: Expression, val: List[Union[Identifier, IndexAccess, Any]]) -> None: expression.context[key] = val class ReadVar(ExpressionVisitor): - def result(self): + def result(self) -> List[Union[Identifier, IndexAccess, Any]]: if self._result is None: self._result = list(set(get(self.expression))) return self._result # overide assignement # dont explore if its direct assignement (we explore if its +=, -=, ...) - def _visit_assignement_operation(self, expression): + def _visit_assignement_operation(self, expression: AssignmentOperation) -> None: if expression.type != AssignmentOperationType.ASSIGN: self._visit_expression(expression.expression_left) self._visit_expression(expression.expression_right) - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: if expression.type != AssignmentOperationType.ASSIGN: left = get(expression.expression_left) else: @@ -41,31 +56,31 @@ class ReadVar(ExpressionVisitor): val = left + right set_val(expression, val) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) args = [get(a) for a in expression.arguments if a] args = [item for sublist in args for item in sublist] val = called + args set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) val = if_expr + else_expr + then_expr set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression(self, expression: ElementaryTypeNameExpression) -> None: set_val(expression, []) # save only identifier expression - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: if isinstance(expression.value, Variable): set_val(expression, [expression]) elif isinstance(expression.value, SolidityVariable): @@ -73,40 +88,40 @@ class ReadVar(ExpressionVisitor): else: set_val(expression, []) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right + [expression] set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: set_val(expression, []) - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: expr = get(expression.expression) val = expr set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: set_val(expression, []) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) def _post_new_elementary_type(self, expression): set_val(expression, []) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) for e in expression.expressions if e] val = [item for sublist in expressions for item in sublist] set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: expr = get(expression.expression) val = expr set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: expr = get(expression.expression) val = expr set_val(expression, val) diff --git a/slither/visitors/expression/write_var.py b/slither/visitors/expression/write_var.py index 0509a2eb7..b7f35335d 100644 --- a/slither/visitors/expression/write_var.py +++ b/slither/visitors/expression/write_var.py @@ -1,26 +1,42 @@ from slither.visitors.expression.expression import ExpressionVisitor +from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.call_expression import CallExpression +from slither.core.expressions.conditional_expression import ConditionalExpression +from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.expression import Expression +from slither.core.expressions.identifier import Identifier +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.member_access import MemberAccess +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.type_conversion import TypeConversion +from slither.core.expressions.unary_operation import UnaryOperation +from typing import Any, List key = "WriteVar" -def get(expression): +def get(expression: Expression) -> List[Any]: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def set_val(expression, val): +def set_val(expression: Expression, val: List[Any]) -> None: expression.context[key] = val class WriteVar(ExpressionVisitor): - def result(self): + def result(self) -> List[Any]: if self._result is None: self._result = list(set(get(self.expression))) return self._result - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -28,7 +44,7 @@ class WriteVar(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) args = [get(a) for a in expression.arguments if a] args = [item for sublist in args for item in sublist] @@ -37,7 +53,7 @@ class WriteVar(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) @@ -46,7 +62,7 @@ class WriteVar(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -54,11 +70,11 @@ class WriteVar(ExpressionVisitor): val += [expression] set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression(self, expression: ElementaryTypeNameExpression) -> None: set_val(expression, []) # save only identifier expression - def _post_identifier(self, expression): + def _post_identifier(self, expression: Identifier) -> None: if expression.is_lvalue: set_val(expression, [expression]) else: @@ -69,7 +85,7 @@ class WriteVar(ExpressionVisitor): # else: # set_val(expression, []) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -87,10 +103,10 @@ class WriteVar(ExpressionVisitor): # n = n.expression set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: set_val(expression, []) - def _post_member_access(self, expression): + def _post_member_access(self, expression: MemberAccess) -> None: expr = get(expression.expression) val = expr if expression.is_lvalue: @@ -98,30 +114,30 @@ class WriteVar(ExpressionVisitor): val += [expression.expression] set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: set_val(expression, []) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) def _post_new_elementary_type(self, expression): set_val(expression, []) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) for e in expression.expressions if e] val = [item for sublist in expressions for item in sublist] if expression.is_lvalue: val += [expression] set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: TypeConversion) -> None: expr = get(expression.expression) val = expr if expression.is_lvalue: val += [expression] set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: expr = get(expression.expression) val = expr if expression.is_lvalue: diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index 6b7b4c264..0ad25485a 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -1,7 +1,7 @@ import logging +from typing import Any, Union, List, TYPE_CHECKING -from typing import List - +import slither.slithir.variables.reference from slither.core.declarations import ( Function, SolidityVariable, @@ -19,9 +19,20 @@ from slither.core.expressions import ( Identifier, MemberAccess, ) +from slither.core.expressions.binary_operation import BinaryOperation +from slither.core.expressions.expression import Expression +from slither.core.expressions.index_access import IndexAccess +from slither.core.expressions.literal import Literal +from slither.core.expressions.new_array import NewArray +from slither.core.expressions.new_contract import NewContract +from slither.core.expressions.tuple_expression import TupleExpression +from slither.core.expressions.unary_operation import UnaryOperation from slither.core.solidity_types import ArrayType, ElementaryType, TypeAlias from slither.core.solidity_types.type import Type +from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable from slither.slithir.exceptions import SlithIRError from slither.slithir.operations import ( @@ -53,12 +64,15 @@ from slither.slithir.variables import ( ) from slither.visitors.expression.expression import ExpressionVisitor +if TYPE_CHECKING: + from slither.core.cfg.node import Node + logger = logging.getLogger("VISTIOR:ExpressionToSlithIR") key = "expressionToSlithIR" -def get(expression): +def get(expression: Expression): val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] @@ -69,7 +83,7 @@ def get_without_removing(expression): return expression.context[key] -def set_val(expression, val): +def set_val(expression: Expression, val) -> None: expression.context[key] = val @@ -104,7 +118,7 @@ _signed_to_unsigned = { } -def convert_assignment(left, right, t, return_type): +def convert_assignment(left: Union[LocalVariable, StateVariable, slither.slithir.variables.reference.ReferenceVariable], right: SourceMapping, t: slither.core.expressions.assignment_operation.AssignmentOperationType, return_type) -> Union[slither.slithir.operations.binary.Binary, slither.slithir.operations.assignment.Assignment]: if t == AssignmentOperationType.ASSIGN: return Assignment(left, right, return_type) if t == AssignmentOperationType.ASSIGN_OR: @@ -132,7 +146,7 @@ def convert_assignment(left, right, t, return_type): class ExpressionToSlithIR(ExpressionVisitor): - def __init__(self, expression, node): # pylint: disable=super-init-not-called + def __init__(self, expression: Expression, node: "Node") -> None: # pylint: disable=super-init-not-called from slither.core.cfg.node import NodeType # pylint: disable=import-outside-toplevel self._expression = expression @@ -146,10 +160,10 @@ class ExpressionToSlithIR(ExpressionVisitor): for ir in self._result: ir.set_node(node) - def result(self): + def result(self) -> List[Any]: return self._result - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: slither.core.expressions.assignment_operation.AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) if isinstance(left, list): # tuple expression: @@ -211,7 +225,7 @@ class ExpressionToSlithIR(ExpressionVisitor): # a = b = 1; set_val(expression, left) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = TemporaryVariable(self._node) @@ -249,8 +263,8 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) def _post_call_expression( - self, expression - ): # pylint: disable=too-many-branches,too-many-statements,too-many-locals + self, expression: slither.core.expressions.call_expression.CallExpression + ) -> None: # pylint: disable=too-many-branches,too-many-statements,too-many-locals assert isinstance(expression, CallExpression) @@ -358,13 +372,13 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_conditional_expression(self, expression): raise Exception(f"Ternary operator are not convertible to SlithIR {expression}") - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression(self, expression: slither.core.expressions.elementary_type_name_expression.ElementaryTypeNameExpression) -> None: set_val(expression, expression.type) - def _post_identifier(self, expression): + def _post_identifier(self, expression: slither.core.expressions.identifier.Identifier) -> None: set_val(expression, expression.value) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) # Left can be a type for abi.decode(var, uint[2]) @@ -390,11 +404,11 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: Literal) -> None: cst = Constant(expression.value, expression.type, expression.subdenomination) set_val(expression, cst) - def _post_member_access(self, expression): + def _post_member_access(self, expression: slither.core.expressions.member_access.MemberAccess) -> None: expr = get(expression.expression) # Look for type(X).max / min @@ -479,14 +493,14 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(member) set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: val = TemporaryVariable(self._node) operation = TmpNewArray(expression.depth, expression.array_type, val) operation.set_expression(expression) self._result.append(operation) set_val(expression, val) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: val = TemporaryVariable(self._node) operation = TmpNewContract(expression.contract_name, val) operation.set_expression(expression) @@ -508,7 +522,7 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, val) - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: TupleExpression) -> None: expressions = [get(e) if e else None for e in expression.expressions] if len(expressions) == 1: val = expressions[0] @@ -516,7 +530,7 @@ class ExpressionToSlithIR(ExpressionVisitor): val = expressions set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: slither.core.expressions.type_conversion.TypeConversion) -> None: expr = get(expression.expression) val = TemporaryVariable(self._node) operation = TypeConversion(val, expr, expression.type) @@ -526,8 +540,8 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) def _post_unary_operation( - self, expression - ): # pylint: disable=too-many-branches,too-many-statements + self, expression: UnaryOperation + ) -> None: # pylint: disable=too-many-branches,too-many-statements value = get(expression.expression) if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: lvalue = TemporaryVariable(self._node)