diff --git a/examples/scripts/data_dependency.py b/examples/scripts/data_dependency.py index 478394766..23c82cae1 100644 --- a/examples/scripts/data_dependency.py +++ b/examples/scripts/data_dependency.py @@ -18,6 +18,8 @@ assert len(contracts) == 1 contract = contracts[0] destination = contract.get_state_variable_from_name("destination") source = contract.get_state_variable_from_name("source") +assert source +assert destination print(f"{source} is dependent of {destination}: {is_dependent(source, destination, contract)}") assert not is_dependent(source, destination, contract) @@ -47,9 +49,11 @@ print(f"{destination} is tainted {is_tainted(destination, contract)}") assert is_tainted(destination, contract) destination_indirect_1 = contract.get_state_variable_from_name("destination_indirect_1") +assert destination_indirect_1 print(f"{destination_indirect_1} is tainted {is_tainted(destination_indirect_1, contract)}") assert is_tainted(destination_indirect_1, contract) destination_indirect_2 = contract.get_state_variable_from_name("destination_indirect_2") +assert destination_indirect_2 print(f"{destination_indirect_2} is tainted {is_tainted(destination_indirect_2, contract)}") assert is_tainted(destination_indirect_2, contract) @@ -88,6 +92,8 @@ contract = contracts[0] contract_derived = slither.get_contract_from_name("Derived")[0] destination = contract.get_state_variable_from_name("destination") source = contract.get_state_variable_from_name("source") +assert destination +assert source print(f"{destination} is dependent of {source}: {is_dependent(destination, source, contract)}") assert not is_dependent(destination, source, contract) diff --git a/examples/scripts/variable_in_condition.py b/examples/scripts/variable_in_condition.py index 43dcf41e7..bde41424d 100644 --- a/examples/scripts/variable_in_condition.py +++ b/examples/scripts/variable_in_condition.py @@ -14,6 +14,7 @@ assert len(contracts) == 1 contract = contracts[0] # Get the variable var_a = contract.get_state_variable_from_name("a") +assert var_a # Get the functions reading the variable functions_reading_a = contract.get_functions_reading_from_variable(var_a) diff --git a/slither/__main__.py b/slither/__main__.py index a5d51dcd6..d6c3ea717 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -615,7 +615,9 @@ def parse_args( class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs detectors, _ = get_detectors_and_printers() output_detectors(detectors) parser.exit() diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index 2b66f2bb3..d133cd2dc 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -2,7 +2,7 @@ Compute the data depenency between all the SSA variables """ from collections import defaultdict -from typing import Union, Set, Dict, TYPE_CHECKING +from typing import Union, Set, Dict, TYPE_CHECKING, List from slither.core.cfg.node import Node from slither.core.declarations import ( @@ -20,6 +20,7 @@ from slither.core.solidity_types.type import Type from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.variable import Variable from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation +from slither.slithir.utils.utils import LVALUE from slither.slithir.variables import ( Constant, LocalIRVariable, @@ -29,6 +30,7 @@ from slither.slithir.variables import ( TemporaryVariableSSA, TupleVariableSSA, ) +from slither.slithir.variables.variable import SlithIRVariable if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit @@ -393,13 +395,9 @@ def transitive_close_dependencies( while changed: changed = False to_add = defaultdict(set) - [ # pylint: disable=expression-not-assigned - [ + for key, items in context.context[context_key].items(): + for item in items & keys: to_add[key].update(context.context[context_key][item] - {key} - items) - for item in items & keys - ] - for key, items in context.context[context_key].items() - ] for k, v in to_add.items(): # Because we dont have any check on the update operation # We might update an empty set with an empty set @@ -418,20 +416,20 @@ def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_prote function.context[KEY_SSA][lvalue] = set() if not is_protected: function.context[KEY_SSA_UNPROTECTED][lvalue] = set() + read: Union[List[Union[LVALUE, SolidityVariableComposed]], List[SlithIRVariable]] if isinstance(ir, Index): read = [ir.variable_left] - elif isinstance(ir, InternalCall): + elif isinstance(ir, InternalCall) and ir.function: read = ir.function.return_values_ssa else: read = ir.read - # pylint: disable=expression-not-assigned - [function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)] + for v in read: + if not isinstance(v, Constant): + function.context[KEY_SSA][lvalue].add(v) if not is_protected: - [ - function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) - for v in read - if not isinstance(v, Constant) - ] + for v in read: + if not isinstance(v, Constant): + function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) def compute_dependency_function(function: Function) -> None: diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 38b4221d9..2c82f9b58 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -49,6 +49,9 @@ if TYPE_CHECKING: LOGGER = logging.getLogger("Contract") +USING_FOR_KEY = Union[str, Type] +USING_FOR_ITEM = List[Union[Type, Function]] + class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ @@ -80,8 +83,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods self._custom_errors: Dict[str, "CustomErrorContract"] = {} # The only str is "*" - self._using_for: Dict[Union[str, Type], List[Type]] = {} - self._using_for_complete: Optional[Dict[Union[str, Type], List[Type]]] = None + self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {} + self._using_for_complete: Optional[Dict[USING_FOR_KEY, USING_FOR_ITEM]] = None self._kind: Optional[str] = None self._is_interface: bool = False self._is_library: bool = False @@ -123,7 +126,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._name @name.setter - def name(self, name: str): + def name(self, name: str) -> None: self._name = name @property @@ -133,7 +136,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._id @id.setter - def id(self, new_id): + def id(self, new_id: int) -> None: """Unique id.""" self._id = new_id @@ -146,7 +149,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._kind @contract_kind.setter - def contract_kind(self, kind): + def contract_kind(self, kind: str) -> None: self._kind = kind @property @@ -154,7 +157,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_interface @is_interface.setter - def is_interface(self, is_interface: bool): + def is_interface(self, is_interface: bool) -> None: self._is_interface = is_interface @property @@ -162,7 +165,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_library @is_library.setter - def is_library(self, is_library: bool): + def is_library(self, is_library: bool) -> None: self._is_library = is_library # endregion @@ -266,16 +269,18 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### @property - def using_for(self) -> Dict[Union[str, Type], List[Type]]: + def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: return self._using_for @property - def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]: + def using_for_complete(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: """ Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive """ - def _merge_using_for(uf1: Dict, uf2: Dict) -> Dict: + def _merge_using_for( + uf1: Dict[USING_FOR_KEY, USING_FOR_ITEM], uf2: Dict[USING_FOR_KEY, USING_FOR_ITEM] + ) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: result = {**uf1, **uf2} for key, value in result.items(): if key in uf1 and key in uf2: @@ -1452,7 +1457,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods result = func.get_last_ssa_state_variables_instances() for variable_name, instances in result.items(): # TODO: investigate the next operation - last_state_variables_instances[variable_name] += instances + last_state_variables_instances[variable_name] += list(instances) for func in self.functions + list(self.modifiers): func.fix_phi(last_state_variables_instances, initial_state_variables_instances) diff --git a/slither/core/declarations/custom_error.py b/slither/core/declarations/custom_error.py index c566fccec..7e78748c6 100644 --- a/slither/core/declarations/custom_error.py +++ b/slither/core/declarations/custom_error.py @@ -1,4 +1,4 @@ -from typing import List, TYPE_CHECKING, Optional, Type, Union +from typing import List, TYPE_CHECKING, Optional, Type from slither.core.solidity_types import UserDefinedType from slither.core.source_mapping.source_mapping import SourceMapping @@ -42,7 +42,7 @@ class CustomError(SourceMapping): ################################################################################### @staticmethod - def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) -> str: + def _convert_type_for_solidity_signature(t: Optional[Type]) -> str: # pylint: disable=import-outside-toplevel from slither.core.declarations import Contract @@ -72,7 +72,7 @@ class CustomError(SourceMapping): Returns: """ - parameters = [x.type for x in self.parameters] + parameters = [x.type for x in self.parameters if x.type] self._full_name = self.name + "(" + ",".join(map(str, parameters)) + ")" solidity_parameters = map(self._convert_type_for_solidity_signature, parameters) self._solidity_signature = self.name + "(" + ",".join(solidity_parameters) + ")" diff --git a/slither/core/declarations/using_for_top_level.py b/slither/core/declarations/using_for_top_level.py index 27d1f90e4..edf846a5b 100644 --- a/slither/core/declarations/using_for_top_level.py +++ b/slither/core/declarations/using_for_top_level.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, List, Dict, Union +from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM from slither.core.solidity_types.type import Type from slither.core.declarations.top_level import TopLevel @@ -14,5 +15,5 @@ class UsingForTopLevel(TopLevel): self.file_scope: "FileScope" = scope @property - def using_for(self) -> Dict[Union[str, Type], List[Type]]: + def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: return self._using_for diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 0d610c928..2b777e672 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -160,8 +160,8 @@ class Variable(SourceMapping): return ( self.name, - [str(x) for x in export_nested_types_from_variable(self)], - [str(x) for x in export_return_type_from_variable(self)], + [str(x) for x in export_nested_types_from_variable(self)], # type: ignore + [str(x) for x in export_return_type_from_variable(self)], # type: ignore ) @property @@ -179,4 +179,5 @@ class Variable(SourceMapping): return f'{name}({",".join(parameters)})' def __str__(self) -> str: + assert self._name return self._name diff --git a/slither/detectors/statements/costly_operations_in_loop.py b/slither/detectors/statements/costly_operations_in_loop.py index 6af04329c..53fa12647 100644 --- a/slither/detectors/statements/costly_operations_in_loop.py +++ b/slither/detectors/statements/costly_operations_in_loop.py @@ -43,7 +43,7 @@ def costly_operations_in_loop( if isinstance(ir, OperationWithLValue) and isinstance(ir.lvalue, StateVariable): ret.append(ir.node) break - if isinstance(ir, (InternalCall)): + if isinstance(ir, (InternalCall)) and ir.function: costly_operations_in_loop(ir.function.entry_point, in_loop_counter, visited, ret) for son in node.sons: diff --git a/slither/detectors/statements/write_after_write.py b/slither/detectors/statements/write_after_write.py index 40a82d3ff..1f11921cb 100644 --- a/slither/detectors/statements/write_after_write.py +++ b/slither/detectors/statements/write_after_write.py @@ -37,6 +37,8 @@ def _handle_ir( _remove_states(written) if isinstance(ir, InternalCall): + if not ir.function: + return if ir.function.all_high_level_calls() or ir.function.all_library_calls(): _remove_states(written) diff --git a/slither/printers/call/call_graph.py b/slither/printers/call/call_graph.py index 0a4df0c65..38225e6d7 100644 --- a/slither/printers/call/call_graph.py +++ b/slither/printers/call/call_graph.py @@ -13,6 +13,7 @@ from slither.core.declarations.function import Function from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.variables.variable import Variable from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output def _contract_subgraph(contract: Contract) -> str: @@ -222,7 +223,7 @@ class PrinterCallGraph(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph" - def output(self, filename): + def output(self, filename: str) -> Output: """ Output the graph in filename Args: diff --git a/slither/printers/summary/constructor_calls.py b/slither/printers/summary/constructor_calls.py index 665c76546..789811c36 100644 --- a/slither/printers/summary/constructor_calls.py +++ b/slither/printers/summary/constructor_calls.py @@ -5,6 +5,7 @@ from slither.core.declarations import Function from slither.core.source_mapping.source_mapping import Source from slither.printers.abstract_printer import AbstractPrinter from slither.utils import output +from slither.utils.output import Output def _get_source_code(cst: Function) -> str: @@ -17,7 +18,7 @@ class ConstructorPrinter(AbstractPrinter): ARGUMENT = "constructor-calls" HELP = "Print the constructors executed" - def output(self, _filename): + def output(self, _filename: str) -> Output: info = "" for contract in self.slither.contracts_derived: stack_name = [] diff --git a/slither/printers/summary/contract.py b/slither/printers/summary/contract.py index 5af953e20..5fee94416 100644 --- a/slither/printers/summary/contract.py +++ b/slither/printers/summary/contract.py @@ -2,9 +2,13 @@ Module printing summary of the contract """ import collections +from typing import Dict, List + +from slither.core.declarations import FunctionContract from slither.printers.abstract_printer import AbstractPrinter from slither.utils import output from slither.utils.colors import blue, green, magenta +from slither.utils.output import Output class ContractSummary(AbstractPrinter): @@ -13,7 +17,7 @@ class ContractSummary(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#contract-summary" - def output(self, _filename): # pylint: disable=too-many-locals + def output(self, _filename: str) -> Output: # pylint: disable=too-many-locals """ _filename is not used Args: @@ -53,17 +57,16 @@ class ContractSummary(AbstractPrinter): # Order the function with # contract_declarer -> list_functions - public = [ + public_function = [ (f.contract_declarer.name, f) for f in c.functions if (not f.is_shadowed and not f.is_constructor_variables) ] - collect = collections.defaultdict(list) - for a, b in public: + collect: Dict[str, List[FunctionContract]] = collections.defaultdict(list) + for a, b in public_function: collect[a].append(b) - public = list(collect.items()) - for contract, functions in public: + for contract, functions in collect.items(): txt += blue(f" - From {contract}\n") functions = sorted(functions, key=lambda f: f.full_name) @@ -90,7 +93,7 @@ class ContractSummary(AbstractPrinter): self.info(txt) res = self.generate_output(txt) - for contract, additional_fields in all_contracts: - res.add(contract, additional_fields=additional_fields) + for current_contract, current_additional_fields in all_contracts: + res.add(current_contract, additional_fields=current_additional_fields) return res diff --git a/slither/printers/summary/variable_order.py b/slither/printers/summary/variable_order.py index 9dc9e77c2..3325b7a01 100644 --- a/slither/printers/summary/variable_order.py +++ b/slither/printers/summary/variable_order.py @@ -4,6 +4,7 @@ from slither.printers.abstract_printer import AbstractPrinter from slither.utils.myprettytable import MyPrettyTable +from slither.utils.output import Output class VariableOrder(AbstractPrinter): @@ -13,7 +14,7 @@ class VariableOrder(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#variable-order" - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 8b16cd516..9a180d14f 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -366,7 +366,7 @@ def last_name( def is_used_later( initial_node: Node, - variable: Union[StateIRVariable, LocalVariable], + variable: Union[StateIRVariable, LocalVariable, TemporaryVariableSSA], ) -> bool: # TODO: does not handle the case where its read and written in the declaration node # It can be problematic if this happens in a loop/if structure diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index a0ca0bd6f..4619c08bc 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -46,7 +46,7 @@ def is_valid_rvalue(v: SourceMapping) -> bool: ) -def is_valid_lvalue(v) -> bool: +def is_valid_lvalue(v: SourceMapping) -> bool: return isinstance( v, ( diff --git a/slither/slithir/variables/local_variable.py b/slither/slithir/variables/local_variable.py index eb32d4024..35b624a01 100644 --- a/slither/slithir/variables/local_variable.py +++ b/slither/slithir/variables/local_variable.py @@ -41,11 +41,11 @@ class LocalIRVariable( self._non_ssa_version = local_variable @property - def index(self): + def index(self) -> int: return self._index @index.setter - def index(self, idx): + def index(self, idx: int) -> None: self._index = idx @property diff --git a/slither/slithir/variables/state_variable.py b/slither/slithir/variables/state_variable.py index 7bb3a4077..f7fb8ab8a 100644 --- a/slither/slithir/variables/state_variable.py +++ b/slither/slithir/variables/state_variable.py @@ -30,11 +30,11 @@ class StateIRVariable( self._non_ssa_version = state_variable @property - def index(self): + def index(self) -> int: return self._index @index.setter - def index(self, idx): + def index(self, idx: int) -> None: self._index = idx @property diff --git a/slither/slithir/variables/variable.py b/slither/slithir/variables/variable.py index a1a1a6df9..20d203ea4 100644 --- a/slither/slithir/variables/variable.py +++ b/slither/slithir/variables/variable.py @@ -7,8 +7,9 @@ class SlithIRVariable(Variable): self._index = 0 @property - def ssa_name(self): + def ssa_name(self) -> str: + assert self.name return self.name - def __str__(self): + def __str__(self) -> str: return self.ssa_name diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 47ee7ec10..b9dbe9a9f 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -1,6 +1,6 @@ import logging import re -from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set +from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set, Sequence from slither.core.declarations import ( Modifier, @@ -9,10 +9,10 @@ from slither.core.declarations import ( StructureContract, Function, ) -from slither.core.declarations.contract import Contract +from slither.core.declarations.contract import Contract, USING_FOR_KEY from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.function_contract import FunctionContract -from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type +from slither.core.solidity_types import ElementaryType, TypeAliasContract from slither.core.variables.state_variable import StateVariable from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.custom_error import CustomErrorSolc @@ -302,7 +302,7 @@ class ContractSolc(CallerContextExpression): st.set_contract(self._contract) st.set_offset(struct["src"], self._contract.compilation_unit) - st_parser = StructureContractSolc(st, struct, self) + st_parser = StructureContractSolc(st, struct, self) # type: ignore self._contract.structures_as_dict[st.name] = st self._structures_parser.append(st_parser) @@ -312,7 +312,7 @@ class ContractSolc(CallerContextExpression): for struct in self._structuresNotParsed: self._parse_struct(struct) - self._structuresNotParsed = None + self._structuresNotParsed = [] def _parse_custom_error(self, custom_error: Dict) -> None: ce = CustomErrorContract(self.compilation_unit) @@ -329,7 +329,7 @@ class ContractSolc(CallerContextExpression): for custom_error in self._customErrorParsed: self._parse_custom_error(custom_error) - self._customErrorParsed = None + self._customErrorParsed = [] def parse_state_variables(self) -> None: for father in self._contract.inheritance_reverse: @@ -356,6 +356,7 @@ class ContractSolc(CallerContextExpression): var_parser = StateVariableSolc(var, varNotParsed) self._variables_parser.append(var_parser) + assert var.name self._contract.variables_as_dict[var.name] = var self._contract.add_variables_ordered([var]) @@ -365,7 +366,7 @@ class ContractSolc(CallerContextExpression): modif.set_contract(self._contract) modif.set_contract_declarer(self._contract) - modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) + modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) # type: ignore self._contract.compilation_unit.add_modifier(modif) self._modifiers_no_params.append(modif_parser) self._modifiers_parser.append(modif_parser) @@ -375,7 +376,7 @@ class ContractSolc(CallerContextExpression): def parse_modifiers(self) -> None: for modifier in self._modifiersNotParsed: self._parse_modifier(modifier) - self._modifiersNotParsed = None + self._modifiersNotParsed = [] def _parse_function(self, function_data: Dict) -> None: func = FunctionContract(self._contract.compilation_unit) @@ -383,7 +384,7 @@ class ContractSolc(CallerContextExpression): func.set_contract(self._contract) func.set_contract_declarer(self._contract) - func_parser = FunctionSolc(func, function_data, self, self._slither_parser) + func_parser = FunctionSolc(func, function_data, self, self._slither_parser) # type: ignore self._contract.compilation_unit.add_function(func) self._functions_no_params.append(func_parser) self._functions_parser.append(func_parser) @@ -395,7 +396,7 @@ class ContractSolc(CallerContextExpression): for function in self._functionsNotParsed: self._parse_function(function) - self._functionsNotParsed = None + self._functionsNotParsed = [] # endregion ################################################################################### @@ -439,7 +440,8 @@ class ContractSolc(CallerContextExpression): Cls_parser, self._modifiers_parser, ) - self._contract.set_modifiers(modifiers) + # modifiers will be using Modifier so we can ignore the next type check + self._contract.set_modifiers(modifiers) # type: ignore except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing params {e}") self._modifiers_no_params = [] @@ -459,7 +461,8 @@ class ContractSolc(CallerContextExpression): Cls_parser, self._functions_parser, ) - self._contract.set_functions(functions) + # function will be using FunctionContract so we can ignore the next type check + self._contract.set_functions(functions) # type: ignore except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing params {e}") self._functions_no_params = [] @@ -470,7 +473,7 @@ class ContractSolc(CallerContextExpression): Cls_parser: Callable, element_parser: FunctionSolc, explored_reference_id: Set[str], - parser: List[FunctionSolc], + parser: Union[List[FunctionSolc], List[ModifierSolc]], all_elements: Dict[str, Function], ) -> None: elem = Cls(self._contract.compilation_unit) @@ -508,13 +511,13 @@ class ContractSolc(CallerContextExpression): def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-locals self, - elements_no_params: List[FunctionSolc], + elements_no_params: Sequence[FunctionSolc], getter: Callable[["ContractSolc"], List[FunctionSolc]], getter_available: Callable[[Contract], List[FunctionContract]], Cls: Callable, Cls_parser: Callable, - parser: List[FunctionSolc], - ) -> Dict[str, Union[FunctionContract, Modifier]]: + parser: Union[List[FunctionSolc], List[ModifierSolc]], + ) -> Dict[str, Function]: """ Analyze the parameters of the given elements (Function or Modifier). The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier) @@ -526,13 +529,13 @@ class ContractSolc(CallerContextExpression): :param Cls: Class to create for collision :return: """ - all_elements = {} + all_elements: Dict[str, Function] = {} - explored_reference_id = set() + explored_reference_id: Set[str] = set() try: for father in self._contract.inheritance: father_parser = self._slither_parser.underlying_contract_to_parser[father] - for element_parser in getter(father_parser): + for element_parser in getter(father_parser): # type: ignore self._analyze_params_element( Cls, Cls_parser, element_parser, explored_reference_id, parser, all_elements ) @@ -597,7 +600,7 @@ class ContractSolc(CallerContextExpression): if self.is_compact_ast: for using_for in self._usingForNotParsed: if "typeName" in using_for and using_for["typeName"]: - type_name = parse_type(using_for["typeName"], self) + type_name: USING_FOR_KEY = parse_type(using_for["typeName"], self) else: type_name = "*" if type_name not in self._contract.using_for: @@ -616,7 +619,7 @@ class ContractSolc(CallerContextExpression): assert children and len(children) <= 2 if len(children) == 2: new = parse_type(children[0], self) - old = parse_type(children[1], self) + old: USING_FOR_KEY = parse_type(children[1], self) else: new = parse_type(children[0], self) old = "*" @@ -627,7 +630,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing using for {e}") - def _analyze_function_list(self, function_list: List, type_name: Type) -> None: + def _analyze_function_list(self, function_list: List, type_name: USING_FOR_KEY) -> None: for f in function_list: full_name_split = f["function"]["name"].split(".") if len(full_name_split) == 1: @@ -646,7 +649,9 @@ class ContractSolc(CallerContextExpression): function_name = full_name_split[2] self._analyze_library_function(library_name, function_name, type_name) - def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type) -> None: + def _check_aliased_import( + self, first_part: str, function_name: str, type_name: USING_FOR_KEY + ) -> None: # We check if the first part appear as alias for an import # if it is then function_name must be a top level function # otherwise it's a library function @@ -656,13 +661,13 @@ class ContractSolc(CallerContextExpression): return self._analyze_library_function(first_part, function_name, type_name) - def _analyze_top_level_function(self, function_name: str, type_name: Type) -> None: + def _analyze_top_level_function(self, function_name: str, type_name: USING_FOR_KEY) -> None: for tl_function in self.compilation_unit.functions_top_level: if tl_function.name == function_name: self._contract.using_for[type_name].append(tl_function) def _analyze_library_function( - self, library_name: str, function_name: str, type_name: Type + self, library_name: str, function_name: str, type_name: USING_FOR_KEY ) -> None: # Get the library function found = False @@ -689,22 +694,13 @@ class ContractSolc(CallerContextExpression): # for enum, we can parse and analyze it # at the same time self._analyze_enum(enum) - self._enumsNotParsed = None + self._enumsNotParsed = [] except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing enum {e}") def _analyze_enum( self, - enum: Dict[ - str, - Union[ - str, - int, - List[Dict[str, Union[int, str]]], - Dict[str, str], - List[Dict[str, Union[Dict[str, str], int, str]]], - ], - ], + enum: Dict, ) -> None: # Enum can be parsed in one pass if self.is_compact_ast: @@ -753,13 +749,13 @@ class ContractSolc(CallerContextExpression): event.set_contract(self._contract) event.set_offset(event_to_parse["src"], self._contract.compilation_unit) - event_parser = EventSolc(event, event_to_parse, self) - event_parser.analyze(self) + event_parser = EventSolc(event, event_to_parse, self) # type: ignore + event_parser.analyze(self) # type: ignore self._contract.events_as_dict[event.full_name] = event except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing event {e}") - self._eventsNotParsed = None + self._eventsNotParsed = [] # endregion ################################################################################### @@ -768,7 +764,7 @@ class ContractSolc(CallerContextExpression): ################################################################################### ################################################################################### - def delete_content(self): + def delete_content(self) -> None: """ Remove everything not parsed from the contract This is used only if something went wrong with the inheritance parsing @@ -810,7 +806,7 @@ class ContractSolc(CallerContextExpression): ################################################################################### ################################################################################### - def __hash__(self): + def __hash__(self) -> int: return self._contract.id # endregion diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 9671d9bbe..ba2f225f0 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -242,7 +242,7 @@ class FunctionSolc(CallerContextExpression): if "payable" in attributes: self._function.payable = attributes["payable"] - def analyze_params(self): + def analyze_params(self) -> None: # Can be re-analyzed due to inheritance if self._params_was_analyzed: return @@ -272,7 +272,7 @@ class FunctionSolc(CallerContextExpression): if returns: self._parse_returns(returns) - def analyze_content(self): + def analyze_content(self) -> None: if self._content_was_analyzed: return @@ -308,8 +308,8 @@ class FunctionSolc(CallerContextExpression): for node_parser in self._node_to_nodesolc.values(): node_parser.analyze_expressions(self) - for node_parser in self._node_to_yulobject.values(): - node_parser.analyze_expressions() + for yul_parser in self._node_to_yulobject.values(): + yul_parser.analyze_expressions() self._rewrite_ternary_as_if_else() @@ -1297,7 +1297,7 @@ class FunctionSolc(CallerContextExpression): son.remove_father(node) node.set_sons(new_sons) - def _remove_alone_endif(self): + def _remove_alone_endif(self) -> None: """ Can occur on: if(..){ diff --git a/slither/solc_parsing/variables/variable_declaration.py b/slither/solc_parsing/variables/variable_declaration.py index d21d89875..69b72a521 100644 --- a/slither/solc_parsing/variables/variable_declaration.py +++ b/slither/solc_parsing/variables/variable_declaration.py @@ -1,6 +1,6 @@ import logging import re -from typing import Dict, Optional +from typing import Dict, Optional, Union from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.expressions.expression_parsing import parse_expression @@ -42,12 +42,12 @@ class VariableDeclarationSolc: self._variable = variable self._was_analyzed = False - self._elem_to_parse = None - self._initializedNotParsed = None + self._elem_to_parse: Optional[Union[Dict, UnknownType]] = None + self._initializedNotParsed: Optional[Dict] = None self._is_compact_ast = False - self._reference_id = None + self._reference_id: Optional[int] = None if "nodeType" in variable_data: self._is_compact_ast = True @@ -87,7 +87,7 @@ class VariableDeclarationSolc: declaration = variable_data["children"][0] self._init_from_declaration(declaration, init) elif nodeType == "VariableDeclaration": - self._init_from_declaration(variable_data, False) + self._init_from_declaration(variable_data, None) else: raise ParsingError(f"Incorrect variable declaration type {nodeType}") @@ -101,6 +101,7 @@ class VariableDeclarationSolc: Return the solc id. It can be compared with the referencedDeclaration attr Returns None if it was not parsed (legacy AST) """ + assert self._reference_id return self._reference_id def _handle_comment(self, attributes: Dict) -> None: @@ -127,7 +128,7 @@ class VariableDeclarationSolc: self._variable.visibility = "internal" def _init_from_declaration( - self, var: Dict, init: Optional[bool] + self, var: Dict, init: Optional[Dict] ) -> None: # pylint: disable=too-many-branches if self._is_compact_ast: attributes = var @@ -195,7 +196,7 @@ class VariableDeclarationSolc: self._initializedNotParsed = init elif len(var["children"]) in [0, 1]: self._variable.initialized = False - self._initializedNotParsed = [] + self._initializedNotParsed = None else: assert len(var["children"]) == 2 self._variable.initialized = True @@ -212,5 +213,6 @@ class VariableDeclarationSolc: self._elem_to_parse = None if self._variable.initialized: + assert self._initializedNotParsed self._variable.expression = parse_expression(self._initializedNotParsed, caller_context) self._initializedNotParsed = None diff --git a/slither/tools/doctor/checks/versions.py b/slither/tools/doctor/checks/versions.py index ec7ef1d1f..00662b3e9 100644 --- a/slither/tools/doctor/checks/versions.py +++ b/slither/tools/doctor/checks/versions.py @@ -1,6 +1,6 @@ from importlib import metadata import json -from typing import Optional +from typing import Optional, Any import urllib from packaging.version import parse, Version @@ -17,6 +17,7 @@ def get_installed_version(name: str) -> Optional[Version]: def get_github_version(name: str) -> Optional[Version]: try: + # type: ignore with urllib.request.urlopen( f"https://api.github.com/repos/crytic/{name}/releases/latest" ) as response: @@ -27,7 +28,7 @@ def get_github_version(name: str) -> Optional[Version]: return None -def show_versions(**_kwargs) -> None: +def show_versions(**_kwargs: Any) -> None: versions = { "Slither": (get_installed_version("slither-analyzer"), get_github_version("slither")), "crytic-compile": ( diff --git a/slither/tools/read_storage/utils/utils.py b/slither/tools/read_storage/utils/utils.py index 3e51e2181..4a04a5b6d 100644 --- a/slither/tools/read_storage/utils/utils.py +++ b/slither/tools/read_storage/utils/utils.py @@ -2,6 +2,7 @@ from typing import Union from eth_typing.evm import ChecksumAddress from eth_utils import to_int, to_text, to_checksum_address +from web3 import Web3 def get_offset_value(hex_bytes: bytes, offset: int, size: int) -> bytes: @@ -48,7 +49,7 @@ def coerce_type( if "address" in solidity_type: if not isinstance(value, (str, bytes)): raise TypeError - return to_checksum_address(value) + return to_checksum_address(value) # type: ignore if not isinstance(value, bytes): raise TypeError @@ -56,7 +57,7 @@ def coerce_type( def get_storage_data( - web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str] + web3: Web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str] ) -> bytes: """ Retrieves the storage data from the blockchain at target address and slot. diff --git a/slither/tools/upgradeability/checks/variable_initialization.py b/slither/tools/upgradeability/checks/variable_initialization.py index e8ae9b26c..b4535ddfe 100644 --- a/slither/tools/upgradeability/checks/variable_initialization.py +++ b/slither/tools/upgradeability/checks/variable_initialization.py @@ -1,7 +1,11 @@ +from typing import List + from slither.tools.upgradeability.checks.abstract_checks import ( CheckClassification, AbstractCheck, + CHECK_INFO, ) +from slither.utils.output import Output class VariableWithInit(AbstractCheck): @@ -37,11 +41,11 @@ Using initialize functions to write initial values in state variables. REQUIRE_CONTRACT = True - def _check(self): + def _check(self) -> List[Output]: results = [] for s in self.contract.state_variables_ordered: if s.initialized and not (s.is_constant or s.is_immutable): - info = [s, " is a state variable with an initial value.\n"] + info: CHECK_INFO = [s, " is a state variable with an initial value.\n"] json = self.generate_result(info) results.append(json) return results diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index 030fb0f65..fc83c44c6 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -1,7 +1,12 @@ +from typing import List + +from slither.core.declarations import Contract from slither.tools.upgradeability.checks.abstract_checks import ( CheckClassification, AbstractCheck, + CHECK_INFO, ) +from slither.utils.output import Output class MissingVariable(AbstractCheck): @@ -45,9 +50,12 @@ Do not change the order of the state variables in the updated contract. REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True - def _check(self): + def _check(self) -> List[Output]: contract1 = self.contract contract2 = self.contract_v2 + + assert contract2 + order1 = [ variable for variable in contract1.state_variables_ordered @@ -63,7 +71,7 @@ Do not change the order of the state variables in the updated contract. for idx, _ in enumerate(order1): variable1 = order1[idx] if len(order2) <= idx: - info = ["Variable missing in ", contract2, ": ", variable1, "\n"] + info: CHECK_INFO = ["Variable missing in ", contract2, ": ", variable1, "\n"] json = self.generate_result(info) results.append(json) @@ -108,13 +116,14 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s REQUIRE_CONTRACT = True REQUIRE_PROXY = True - def _contract1(self): + def _contract1(self) -> Contract: return self.contract - def _contract2(self): + def _contract2(self) -> Contract: + assert self.proxy return self.proxy - def _check(self): + def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() order1 = [ @@ -128,7 +137,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s if not (variable.is_constant or variable.is_immutable) ] - results = [] + results: List[Output] = [] for idx, _ in enumerate(order1): if len(order2) <= idx: # Handle by MissingVariable @@ -137,7 +146,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s variable1 = order1[idx] variable2 = order2[idx] if (variable1.name != variable2.name) or (variable1.type != variable2.type): - info = [ + info: CHECK_INFO = [ "Different variables between ", contract1, " and ", @@ -190,7 +199,8 @@ Respect the variable order of the original contract in the updated contract. REQUIRE_PROXY = False REQUIRE_CONTRACT_V2 = True - def _contract2(self): + def _contract2(self) -> Contract: + assert self.contract_v2 return self.contract_v2 @@ -235,13 +245,14 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s REQUIRE_CONTRACT = True REQUIRE_PROXY = True - def _contract1(self): + def _contract1(self) -> Contract: return self.contract - def _contract2(self): + def _contract2(self) -> Contract: + assert self.proxy return self.proxy - def _check(self): + def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() order1 = [ @@ -264,7 +275,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s while idx < len(order2): variable2 = order2[idx] - info = ["Extra variables in ", contract2, ": ", variable2, "\n"] + info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"] json = self.generate_result(info) results.append(json) idx = idx + 1 @@ -299,5 +310,6 @@ Ensure that all the new variables are expected. REQUIRE_PROXY = False REQUIRE_CONTRACT_V2 = True - def _contract2(self): + def _contract2(self) -> Contract: + assert self.contract_v2 return self.contract_v2 diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 7b1a8f8ee..12eb6be9d 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -104,7 +104,7 @@ class ConstantFolding(ExpressionVisitor): and isinstance(left, (int, Fraction)) and isinstance(right, (int, Fraction)) ): - set_val(expression, left**right) #type: ignore + set_val(expression, left**right) # type: ignore elif ( expression.type == BinaryOperationType.MULTIPLICATION and isinstance(left, (int, Fraction)) diff --git a/tests/test_ssa_generation.py b/tests/test_ssa_generation.py index 9bb008fdf..c7bc8d5cc 100644 --- a/tests/test_ssa_generation.py +++ b/tests/test_ssa_generation.py @@ -6,7 +6,7 @@ from collections import defaultdict from contextlib import contextmanager from inspect import getsourcefile from tempfile import NamedTemporaryFile -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict, Callable import pytest from solc_select import solc_select @@ -15,6 +15,7 @@ from solc_select.solc_select import valid_version as solc_valid_version from slither import Slither from slither.core.cfg.node import Node, NodeType from slither.core.declarations import Function, Contract +from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.slithir.operations import ( OperationWithLValue, @@ -34,10 +35,11 @@ from slither.slithir.variables import ( ReferenceVariable, LocalIRVariable, StateIRVariable, + TemporaryVariableSSA, ) # Directory of currently executing script. Will be used as basis for temporary file names. -SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent +SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent # type:ignore def valid_version(ver: str) -> bool: @@ -53,15 +55,15 @@ def valid_version(ver: str) -> bool: return False -def have_ssa_if_ir(function: Function): +def have_ssa_if_ir(function: Function) -> None: """Verifies that all nodes in a function that have IR also have SSA IR""" for n in function.nodes: if n.irs: assert n.irs_ssa -# pylint: disable=too-many-branches -def ssa_basic_properties(function: Function): +# pylint: disable=too-many-branches, too-many-locals +def ssa_basic_properties(function: Function) -> None: """Verifies that basic properties of ssa holds 1. Every name is defined only once @@ -75,12 +77,14 @@ def ssa_basic_properties(function: Function): """ ssa_lvalues = set() ssa_rvalues = set() - lvalue_assignments = {} + lvalue_assignments: Dict[str, int] = {} for n in function.nodes: for ir in n.irs: - if isinstance(ir, OperationWithLValue): + if isinstance(ir, OperationWithLValue) and ir.lvalue: name = ir.lvalue.name + if name is None: + continue if name in lvalue_assignments: lvalue_assignments[name] += 1 else: @@ -93,8 +97,9 @@ def ssa_basic_properties(function: Function): ssa_lvalues.add(ssa.lvalue) # 2 (if Local/State Var) - if isinstance(ssa.lvalue, (StateIRVariable, LocalIRVariable)): - assert ssa.lvalue.index > 0 + ssa_lvalue = ssa.lvalue + if isinstance(ssa_lvalue, (StateIRVariable, LocalIRVariable)): + assert ssa_lvalue.index > 0 for rvalue in filter( lambda x: not isinstance(x, (StateIRVariable, Constant)), ssa.read @@ -111,15 +116,18 @@ def ssa_basic_properties(function: Function): undef_vars.add(rvalue.non_ssa_version) # 4 - ssa_defs = defaultdict(int) + ssa_defs: Dict[str, int] = defaultdict(int) for v in ssa_lvalues: - ssa_defs[v.name] += 1 + if v and v.name: + ssa_defs[v.name] += 1 - for (k, n) in lvalue_assignments.items(): - assert ssa_defs[k] >= n + for (k, count) in lvalue_assignments.items(): + assert ssa_defs[k] >= count # Helper 5/6 - def check_property_5_and_6(variables, ssavars): + def check_property_5_and_6( + variables: List[LocalVariable], ssavars: List[LocalIRVariable] + ) -> None: for var in filter(lambda x: x.name, variables): ssa_vars = [x for x in ssavars if x.non_ssa_version == var] assert len(ssa_vars) == 1 @@ -136,7 +144,7 @@ def ssa_basic_properties(function: Function): check_property_5_and_6(function.returns, function.returns_ssa) -def ssa_phi_node_properties(f: Function): +def ssa_phi_node_properties(f: Function) -> None: """Every phi-function should have as many args as predecessors This does not apply if the phi-node refers to state variables, @@ -152,7 +160,7 @@ def ssa_phi_node_properties(f: Function): # TODO (hbrodin): This should probably go into another file, not specific to SSA -def dominance_properties(f: Function): +def dominance_properties(f: Function) -> None: """Verifies properties related to dominators holds 1. Every node have an immediate dominator except entry_node which have none @@ -180,14 +188,16 @@ def dominance_properties(f: Function): assert find_path(node.immediate_dominator, node) -def phi_values_inserted(f: Function): +def phi_values_inserted(f: Function) -> None: """Verifies that phi-values are inserted at the right places For every node that has a dominance frontier, any def (including phi) should be a phi function in its dominance frontier """ - def have_phi_for_var(node: Node, var): + def have_phi_for_var( + node: Node, var: Union[StateIRVariable, LocalIRVariable, TemporaryVariableSSA] + ) -> bool: """Checks if a node has a phi-instruction for var The ssa version would ideally be checked, but then @@ -198,7 +208,14 @@ def phi_values_inserted(f: Function): non_ssa = var.non_ssa_version for ssa in node.irs_ssa: if isinstance(ssa, Phi): - if non_ssa in map(lambda ssa_var: ssa_var.non_ssa_version, ssa.read): + if non_ssa in map( + lambda ssa_var: ssa_var.non_ssa_version, + [ + r + for r in ssa.read + if isinstance(r, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA)) + ], + ): return True return False @@ -206,12 +223,15 @@ def phi_values_inserted(f: Function): for df in node.dominance_frontier: for ssa in node.irs_ssa: if isinstance(ssa, OperationWithLValue): - if is_used_later(node, ssa.lvalue): - assert have_phi_for_var(df, ssa.lvalue) + ssa_lvalue = ssa.lvalue + if isinstance( + ssa_lvalue, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA) + ) and is_used_later(node, ssa_lvalue): + assert have_phi_for_var(df, ssa_lvalue) @contextmanager -def select_solc_version(version: Optional[str]): +def select_solc_version(version: Optional[str]) -> None: """Selects solc version to use for running tests. If no version is provided, latest is used.""" @@ -256,17 +276,17 @@ def slither_from_source(source_code: str, solc_version: Optional[str] = None): pathlib.Path(fname).unlink() -def verify_properties_hold(source_code_or_slither: Union[str, Slither]): +def verify_properties_hold(source_code_or_slither: Union[str, Slither]) -> None: """Ensures that basic properties of SSA hold true""" - def verify_func(func: Function): + def verify_func(func: Function) -> None: have_ssa_if_ir(func) phi_values_inserted(func) ssa_basic_properties(func) ssa_phi_node_properties(func) dominance_properties(func) - def verify(slither): + def verify(slither: Slither) -> None: for cu in slither.compilation_units: for func in cu.functions_and_modifiers: _dump_function(func) @@ -280,11 +300,12 @@ def verify_properties_hold(source_code_or_slither: Union[str, Slither]): if isinstance(source_code_or_slither, Slither): verify(source_code_or_slither) else: + slither: Slither with slither_from_source(source_code_or_slither) as slither: verify(slither) -def _dump_function(f: Function): +def _dump_function(f: Function) -> None: """Helper function to print nodes/ssa ir for a function or modifier""" print(f"---- {f.name} ----") for n in f.nodes: @@ -294,13 +315,13 @@ def _dump_function(f: Function): print("") -def _dump_functions(c: Contract): +def _dump_functions(c: Contract) -> None: """Helper function to print functions and modifiers of a contract""" for f in c.functions_and_modifiers: _dump_function(f) -def get_filtered_ssa(f: Union[Function, Node], flt) -> List[Operation]: +def get_filtered_ssa(f: Union[Function, Node], flt: Callable) -> List[Operation]: """Returns a list of all ssanodes filtered by filter for all nodes in function f""" if isinstance(f, Function): return [ssanode for node in f.nodes for ssanode in node.irs_ssa if flt(ssanode)] @@ -314,7 +335,7 @@ def get_ssa_of_type(f: Union[Function, Node], ssatype) -> List[Operation]: return get_filtered_ssa(f, lambda ssanode: isinstance(ssanode, ssatype)) -def test_multi_write(): +def test_multi_write() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -327,7 +348,7 @@ def test_multi_write(): verify_properties_hold(contract) -def test_single_branch_phi(): +def test_single_branch_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -342,7 +363,7 @@ def test_single_branch_phi(): verify_properties_hold(contract) -def test_basic_phi(): +def test_basic_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -359,7 +380,7 @@ def test_basic_phi(): verify_properties_hold(contract) -def test_basic_loop_phi(): +def test_basic_loop_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -375,7 +396,7 @@ def test_basic_loop_phi(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_phi_propagation_loop(): +def test_phi_propagation_loop() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -396,7 +417,7 @@ def test_phi_propagation_loop(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_free_function_properties(): +def test_free_function_properties() -> None: contract = """ pragma solidity ^0.8.11; @@ -417,7 +438,7 @@ def test_free_function_properties(): verify_properties_hold(contract) -def test_ssa_inter_transactional(): +def test_ssa_inter_transactional() -> None: source = """ pragma solidity ^0.8.11; contract A { @@ -460,7 +481,7 @@ def test_ssa_inter_transactional(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_ssa_phi_callbacks(): +def test_ssa_phi_callbacks() -> None: source = """ pragma solidity ^0.8.11; contract A { @@ -519,7 +540,7 @@ def test_ssa_phi_callbacks(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_storage_refers_to(): +def test_storage_refers_to() -> None: """Test the storage aspects of the SSA IR When declaring a var as being storage, start tracking what storage it refers_to.