From 371e3cbe475334e6cc9f5cba99cb0172bb3cc92c Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Thu, 16 Feb 2023 20:19:14 +0100 Subject: [PATCH 01/34] Add more types hints --- slither/__main__.py | 63 ++++--------- .../data_dependency/data_dependency.py | 88 ++++++++++++++---- slither/core/cfg/node.py | 89 ++++++++++++------- slither/core/children/child_contract.py | 9 +- slither/core/children/child_event.py | 11 ++- slither/core/children/child_expression.py | 11 ++- slither/core/children/child_function.py | 9 +- slither/core/children/child_inheritance.py | 9 +- slither/core/children/child_node.py | 11 ++- slither/core/children/child_structure.py | 9 +- slither/core/declarations/contract.py | 51 ++++++----- slither/core/declarations/custom_error.py | 4 +- .../declarations/custom_error_contract.py | 7 +- .../declarations/custom_error_top_level.py | 2 +- slither/core/declarations/function.py | 7 +- .../core/declarations/solidity_variables.py | 12 +-- slither/core/dominators/utils.py | 1 + .../core/expressions/assignment_operation.py | 2 +- slither/core/slither_core.py | 4 +- slither/core/solidity_types/array_type.py | 4 +- .../core/solidity_types/elementary_type.py | 6 +- slither/core/solidity_types/mapping_type.py | 4 +- slither/core/solidity_types/type_alias.py | 4 +- .../core/solidity_types/type_information.py | 6 +- slither/core/source_mapping/source_mapping.py | 6 +- slither/core/variables/event_variable.py | 2 +- slither/core/variables/variable.py | 6 +- slither/detectors/abstract_detector.py | 4 +- .../assembly/shift_parameter_mixup.py | 13 ++- .../attributes/const_functions_asm.py | 6 +- .../compiler_bugs/array_by_reference.py | 7 +- .../erc/erc20/arbitrary_send_erc20.py | 8 +- .../erc20/arbitrary_send_erc20_no_permit.py | 8 +- .../erc/erc20/arbitrary_send_erc20_permit.py | 8 +- .../detectors/functions/arbitrary_send_eth.py | 10 ++- .../statements/array_length_assignment.py | 3 +- slither/detectors/statements/assembly.py | 8 +- slither/slithir/operations/call.py | 7 +- slither/slithir/operations/high_level_call.py | 20 +++-- slither/slithir/operations/index.py | 26 +++--- slither/slithir/operations/library_call.py | 17 ++-- slither/slithir/operations/low_level_call.py | 6 +- slither/slithir/operations/lvalue.py | 18 ++-- slither/slithir/operations/member.py | 12 ++- slither/slithir/operations/new_contract.py | 11 ++- slither/slithir/operations/solidity_call.py | 11 +-- slither/slithir/utils/utils.py | 20 +++++ slither/tools/mutator/__main__.py | 7 +- .../mutator/mutators/abstract_mutator.py | 11 ++- 49 files changed, 422 insertions(+), 256 deletions(-) diff --git a/slither/__main__.py b/slither/__main__.py index 528a93e8f..a5d51dcd6 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -66,7 +66,7 @@ def process_single( args: argparse.Namespace, detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[Slither, List[Dict], List[Dict], int]: +) -> Tuple[Slither, List[Dict], List[Output], int]: """ The core high-level code for running Slither static analysis. @@ -89,7 +89,7 @@ def process_all( args: argparse.Namespace, detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[List[Slither], List[Dict], List[Dict], int]: +) -> Tuple[List[Slither], List[Dict], List[Output], int]: compilations = compile_all(target, **vars(args)) slither_instances = [] results_detectors = [] @@ -144,23 +144,6 @@ def _process( return slither, results_detectors, results_printers, analyzed_contracts_count -# TODO: delete me? -def process_from_asts( - filenames: List[str], - args: argparse.Namespace, - detector_classes: List[Type[AbstractDetector]], - printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[Slither, List[Dict], List[Dict], int]: - all_contracts: List[str] = [] - - for filename in filenames: - with open(filename, encoding="utf8") as file_open: - contract_loaded = json.load(file_open) - all_contracts.append(contract_loaded["ast"]) - - return process_single(all_contracts, args, detector_classes, printer_classes) - - # endregion ################################################################################### ################################################################################### @@ -608,9 +591,6 @@ def parse_args( default=False, ) - # if the json is splitted in different files - parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False) - # Disable the throw/catch on partial analyses parser.add_argument( "--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False @@ -626,7 +606,7 @@ def parse_args( args.filter_paths = parse_filter_paths(args) # Verify our json-type output is valid - args.json_types = set(args.json_types.split(",")) + args.json_types = set(args.json_types.split(",")) # type:ignore for json_type in args.json_types: if json_type not in JSON_OUTPUT_TYPES: raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.') @@ -697,14 +677,14 @@ class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods class FormatterCryticCompile(logging.Formatter): - def format(self, record): + def format(self, record: logging.LogRecord) -> str: # for i, msg in enumerate(record.msg): if record.msg.startswith("Compilation warnings/errors on "): - txt = record.args[1] - txt = txt.split("\n") + txt = record.args[1] # type:ignore + txt = txt.split("\n") # type:ignore txt = [red(x) if "Error" in x else x for x in txt] txt = "\n".join(txt) - record.args = (record.args[0], txt) + record.args = (record.args[0], txt) # type:ignore return super().format(record) @@ -747,7 +727,7 @@ def main_impl( set_colorization_enabled(False if args.disable_color else sys.stdout.isatty()) # Define some variables for potential JSON output - json_results = {} + json_results: Dict[str, Any] = {} output_error = None outputting_json = args.json is not None outputting_json_stdout = args.json == "-" @@ -796,7 +776,7 @@ def main_impl( crytic_compile_error.setLevel(logging.INFO) results_detectors: List[Dict] = [] - results_printers: List[Dict] = [] + results_printers: List[Output] = [] try: filename = args.filename @@ -809,26 +789,17 @@ def main_impl( number_contracts = 0 slither_instances = [] - if args.splitted: + for filename in filenames: ( slither_instance, - results_detectors, - results_printers, - number_contracts, - ) = process_from_asts(filenames, args, detector_classes, printer_classes) + results_detectors_tmp, + results_printers_tmp, + number_contracts_tmp, + ) = process_single(filename, args, detector_classes, printer_classes) + number_contracts += number_contracts_tmp + results_detectors += results_detectors_tmp + results_printers += results_printers_tmp slither_instances.append(slither_instance) - else: - for filename in filenames: - ( - slither_instance, - results_detectors_tmp, - results_printers_tmp, - number_contracts_tmp, - ) = process_single(filename, args, detector_classes, printer_classes) - number_contracts += number_contracts_tmp - results_detectors += results_detectors_tmp - results_printers += results_printers_tmp - slither_instances.append(slither_instance) # Rely on CryticCompile to discern the underlying type of compilations. else: diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index b2a154672..448ee393a 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -4,6 +4,7 @@ from collections import defaultdict from typing import Union, Set, Dict, TYPE_CHECKING +from slither.core.cfg.node import Node from slither.core.declarations import ( Contract, Enum, @@ -12,6 +13,7 @@ from slither.core.declarations import ( SolidityVariable, SolidityVariableComposed, Structure, + FunctionContract, ) from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.variables.top_level_variable import TopLevelVariable @@ -40,25 +42,37 @@ if TYPE_CHECKING: Variable_types = Union[Variable, SolidityVariable] +# TODO refactor the data deps to be better suited for top level function object +# Right now we allow to pass a node to ease the API, but we need something +# better +# The deps propagation for top level elements is also not working as expected +Context_types_API = Union[Contract, Function, Node] Context_types = Union[Contract, Function] def is_dependent( variable: Variable_types, source: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable (Variable) source (Variable) - context (Contract|Function) + context (Contract|Function|Node). only_unprotected (bool): True only unprotected function are considered Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func + if isinstance(variable, Constant): return False if variable == source: @@ -76,10 +90,13 @@ def is_dependent( def is_dependent_ssa( variable: Variable_types, source: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable (Variable) taint (Variable) @@ -88,7 +105,10 @@ def is_dependent_ssa( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func context_dict = context.context if isinstance(variable, Constant): return False @@ -112,11 +132,14 @@ GENERIC_TAINT = { def is_tainted( variable: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ignore_generic_taint: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable context (Contract|Function) @@ -124,7 +147,10 @@ def is_tainted( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if isinstance(variable, Constant): return False @@ -139,11 +165,14 @@ def is_tainted( def is_tainted_ssa( variable: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ignore_generic_taint: bool = False, -): +) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable context (Contract|Function) @@ -151,7 +180,10 @@ def is_tainted_ssa( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if isinstance(variable, Constant): return False @@ -166,18 +198,23 @@ def is_tainted_ssa( def get_dependencies( variable: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ) -> Set[Variable]: """ Return the variables for which `variable` depends on. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param variable: The target :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: set(Variable) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set()) @@ -185,16 +222,21 @@ def get_dependencies( def get_all_dependencies( - context: Context_types, only_unprotected: bool = False + context: Context_types_API, only_unprotected: bool = False ) -> Dict[Variable, Set[Variable]]: """ Return the dictionary of dependencies. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: Dict(Variable, set(Variable)) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_NON_SSA_UNPROTECTED] @@ -203,18 +245,23 @@ def get_all_dependencies( def get_dependencies_ssa( variable: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ) -> Set[Variable]: """ Return the variables for which `variable` depends on (SSA version). + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param variable: The target (must be SSA variable) :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: set(Variable) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_SSA_UNPROTECTED].get(variable, set()) @@ -222,16 +269,21 @@ def get_dependencies_ssa( def get_all_dependencies_ssa( - context: Context_types, only_unprotected: bool = False + context: Context_types_API, only_unprotected: bool = False ) -> Dict[Variable, Set[Variable]]: """ Return the dictionary of dependencies. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: Dict(Variable, set(Variable)) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_SSA_UNPROTECTED] diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 7643b19b7..82502f889 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -6,7 +6,7 @@ from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING from slither.all_exceptions import SlitherException from slither.core.children.child_function import ChildFunction -from slither.core.declarations import Contract, Function +from slither.core.declarations import Contract, Function, FunctionContract from slither.core.declarations.solidity_variables import ( SolidityVariable, SolidityFunction, @@ -33,6 +33,7 @@ from slither.slithir.operations import ( Return, Operation, ) +from slither.slithir.utils.utils import RVALUE from slither.slithir.variables import ( Constant, LocalIRVariable, @@ -146,12 +147,12 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._node_id: int = node_id self._vars_written: List[Variable] = [] - self._vars_read: List[Variable] = [] + self._vars_read: List[Union[Variable, SolidityVariable]] = [] self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = [] - self._internal_calls: List["Function"] = [] + self._internal_calls: List[Union["Function", "SolidityFunction"]] = [] self._solidity_calls: List[SolidityFunction] = [] self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls self._library_calls: List["LibraryCallType"] = [] @@ -172,7 +173,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._local_vars_read: List[LocalVariable] = [] self._local_vars_written: List[LocalVariable] = [] - self._slithir_vars: Set["SlithIRVariable"] = set() # non SSA + self._slithir_vars: Set[ + Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable] + ] = set() # non SSA self._ssa_local_vars_read: List[LocalIRVariable] = [] self._ssa_local_vars_written: List[LocalIRVariable] = [] @@ -213,7 +216,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._node_type @type.setter - def type(self, new_type: NodeType): + def type(self, new_type: NodeType) -> None: self._node_type = new_type @property @@ -232,7 +235,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### @property - def variables_read(self) -> List[Variable]: + def variables_read(self) -> List[Union[Variable, SolidityVariable]]: """ list(Variable): Variables read (local/state/solidity) """ @@ -285,11 +288,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._expression_vars_read @variables_read_as_expression.setter - def variables_read_as_expression(self, exprs: List[Expression]): + def variables_read_as_expression(self, exprs: List[Expression]) -> None: self._expression_vars_read = exprs @property - def slithir_variables(self) -> List["SlithIRVariable"]: + def slithir_variables( + self, + ) -> List[Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]]: return list(self._slithir_vars) @property @@ -339,7 +344,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._expression_vars_written @variables_written_as_expression.setter - def variables_written_as_expression(self, exprs: List[Expression]): + def variables_written_as_expression(self, exprs: List[Expression]) -> None: self._expression_vars_written = exprs # endregion @@ -399,7 +404,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._external_calls_as_expressions @external_calls_as_expressions.setter - def external_calls_as_expressions(self, exprs: List[Expression]): + def external_calls_as_expressions(self, exprs: List[Expression]) -> None: self._external_calls_as_expressions = exprs @property @@ -410,7 +415,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._internal_calls_as_expressions @internal_calls_as_expressions.setter - def internal_calls_as_expressions(self, exprs: List[Expression]): + def internal_calls_as_expressions(self, exprs: List[Expression]) -> None: self._internal_calls_as_expressions = exprs @property @@ -418,10 +423,10 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return list(self._expression_calls) @calls_as_expression.setter - def calls_as_expression(self, exprs: List[Expression]): + def calls_as_expression(self, exprs: List[Expression]) -> None: self._expression_calls = exprs - def can_reenter(self, callstack=None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Check if the node can re-enter Do not consider CREATE as potential re-enter, but check if the @@ -567,7 +572,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ self._fathers.append(father) - def set_fathers(self, fathers: List["Node"]): + def set_fathers(self, fathers: List["Node"]) -> None: """Set the father nodes Args: @@ -663,20 +668,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._irs_ssa @irs_ssa.setter - def irs_ssa(self, irs): + def irs_ssa(self, irs: List[Operation]) -> None: self._irs_ssa = irs def add_ssa_ir(self, ir: Operation) -> None: """ Use to place phi operation """ - ir.set_node(self) + ir.set_node(self) # type: ignore self._irs_ssa.append(ir) def slithir_generation(self) -> None: if self.expression: expression = self.expression - self._irs = convert_expression(expression, self) + self._irs = convert_expression(expression, self) # type:ignore self._find_read_write_call() @@ -713,7 +718,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._dominators @dominators.setter - def dominators(self, dom: Set["Node"]): + def dominators(self, dom: Set["Node"]) -> None: self._dominators = dom @property @@ -725,7 +730,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._immediate_dominator @immediate_dominator.setter - def immediate_dominator(self, idom: "Node"): + def immediate_dominator(self, idom: "Node") -> None: self._immediate_dominator = idom @property @@ -737,7 +742,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._dominance_frontier @dominance_frontier.setter - def dominance_frontier(self, doms: Set["Node"]): + def dominance_frontier(self, doms: Set["Node"]) -> None: """ Returns: set(Node) @@ -789,6 +794,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None: if variable.name not in self._phi_origins_local_variables: + assert variable.name self._phi_origins_local_variables[variable.name] = (variable, set()) (v, nodes) = self._phi_origins_local_variables[variable.name] assert v == variable @@ -827,7 +833,8 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met if isinstance(ir, OperationWithLValue): var = ir.lvalue if var and self._is_valid_slithir_var(var): - self._slithir_vars.add(var) + # The type is checked by is_valid_slithir_var + self._slithir_vars.add(var) # type: ignore if not isinstance(ir, (Phi, Index, Member)): self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)] @@ -835,8 +842,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met if isinstance(var, ReferenceVariable): self._vars_read.append(var.points_to_origin) elif isinstance(ir, (Member, Index)): + # TODO investigate types for member variable left var = ir.variable_left if isinstance(ir, Member) else ir.variable_right - if self._is_non_slithir_var(var): + if var and self._is_non_slithir_var(var): self._vars_read.append(var) if isinstance(var, ReferenceVariable): origin = var.points_to_origin @@ -860,14 +868,21 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._internal_calls.append(ir.function) if isinstance(ir, LowLevelCall): assert isinstance(ir.destination, (Variable, SolidityVariable)) - self._low_level_calls.append((ir.destination, ir.function_name.value)) + self._low_level_calls.append((ir.destination, str(ir.function_name.value))) elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): + # Todo investigate this if condition + # It does seem right to compare against a contract + # This might need a refactoring if isinstance(ir.destination.type, Contract): self._high_level_calls.append((ir.destination.type, ir.function)) elif ir.destination == SolidityVariable("this"): - self._high_level_calls.append((self.function.contract, ir.function)) + func = self.function + # Can't use this in a top level function + assert isinstance(func, FunctionContract) + self._high_level_calls.append((func.contract, ir.function)) else: try: + # Todo this part needs more tests and documentation self._high_level_calls.append((ir.destination.type.type, ir.function)) except AttributeError as error: # pylint: disable=raise-missing-from @@ -883,7 +898,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._vars_read = list(set(self._vars_read)) self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] - self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)] + self._solidity_vars_read = [ + v_ for v_ in self._vars_read if isinstance(v_, SolidityVariable) + ] self._vars_written = list(set(self._vars_written)) self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] @@ -895,12 +912,15 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met @staticmethod def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]: + non_ssa_var: Optional[Union[StateVariable, LocalVariable]] if isinstance(v, StateIRVariable): contract = v.contract + assert v.name non_ssa_var = contract.get_state_variable_from_name(v.name) return non_ssa_var assert isinstance(v, LocalIRVariable) function = v.function + assert v.name non_ssa_var = function.get_local_variable_from_name(v.name) return non_ssa_var @@ -921,10 +941,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._ssa_vars_read.append(origin) elif isinstance(ir, (Member, Index)): - if isinstance(ir.variable_right, (StateIRVariable, LocalIRVariable)): - self._ssa_vars_read.append(ir.variable_right) - if isinstance(ir.variable_right, ReferenceVariable): - origin = ir.variable_right.points_to_origin + variable_right: RVALUE = ir.variable_right + if isinstance(variable_right, (StateIRVariable, LocalIRVariable)): + self._ssa_vars_read.append(variable_right) + if isinstance(variable_right, ReferenceVariable): + origin = variable_right.points_to_origin if isinstance(origin, (StateIRVariable, LocalIRVariable)): self._ssa_vars_read.append(origin) @@ -944,20 +965,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] self._ssa_vars_written = list(set(self._ssa_vars_written)) self._ssa_state_vars_written = [ - v for v in self._ssa_vars_written if isinstance(v, StateVariable) + v for v in self._ssa_vars_written if v and isinstance(v, StateIRVariable) ] self._ssa_local_vars_written = [ - v for v in self._ssa_vars_written if isinstance(v, LocalVariable) + v for v in self._ssa_vars_written if v and isinstance(v, LocalIRVariable) ] vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] - self._vars_read += [v for v in vars_read if v not in self._vars_read] + self._vars_read += [v_ for v_ in vars_read if v_ and v_ not in self._vars_read] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] - self._vars_written += [v for v in vars_written if v not in self._vars_written] + self._vars_written += [v_ for v_ in vars_written if v_ and v_ not in self._vars_written] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] @@ -974,7 +995,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met additional_info += " " + str(self.expression) elif self.variable_declaration: additional_info += " " + str(self.variable_declaration) - txt = self._node_type.value + additional_info + txt = str(self._node_type.value) + additional_info return txt diff --git a/slither/core/children/child_contract.py b/slither/core/children/child_contract.py index 86f9dea53..2c93d9a51 100644 --- a/slither/core/children/child_contract.py +++ b/slither/core/children/child_contract.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from slither.core.source_mapping.source_mapping import SourceMapping @@ -9,11 +9,14 @@ if TYPE_CHECKING: class ChildContract(SourceMapping): def __init__(self) -> None: super().__init__() - self._contract = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._contract: Optional["Contract"] = None def set_contract(self, contract: "Contract") -> None: self._contract = contract @property def contract(self) -> "Contract": - return self._contract + return self._contract # type: ignore diff --git a/slither/core/children/child_event.py b/slither/core/children/child_event.py index df91596e3..e9a2177c5 100644 --- a/slither/core/children/child_event.py +++ b/slither/core/children/child_event.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.declarations import Event @@ -7,11 +7,14 @@ if TYPE_CHECKING: class ChildEvent: def __init__(self) -> None: super().__init__() - self._event = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._event: Optional["Event"] = None - def set_event(self, event: "Event"): + def set_event(self, event: "Event") -> None: self._event = event @property def event(self) -> "Event": - return self._event + return self._event # type: ignore diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py index 0064658c0..2294cf384 100644 --- a/slither/core/children/child_expression.py +++ b/slither/core/children/child_expression.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union, Optional if TYPE_CHECKING: from slither.core.expressions.expression import Expression @@ -8,11 +8,16 @@ if TYPE_CHECKING: class ChildExpression: def __init__(self) -> None: super().__init__() - self._expression = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._expression: Optional[Union["Expression", "Operation"]] = None def set_expression(self, expression: Union["Expression", "Operation"]) -> None: + # TODO investigate when this can be an operation? + # It was auto generated during an AST or detectors tests self._expression = expression @property def expression(self) -> Union["Expression", "Operation"]: - return self._expression + return self._expression # type: ignore diff --git a/slither/core/children/child_function.py b/slither/core/children/child_function.py index 5367320ca..d79d12c10 100644 --- a/slither/core/children/child_function.py +++ b/slither/core/children/child_function.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.declarations import Function @@ -7,11 +7,14 @@ if TYPE_CHECKING: class ChildFunction: def __init__(self) -> None: super().__init__() - self._function = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._function: Optional["Function"] = None def set_function(self, function: "Function") -> None: self._function = function @property def function(self) -> "Function": - return self._function + return self._function # type: ignore diff --git a/slither/core/children/child_inheritance.py b/slither/core/children/child_inheritance.py index 30b32f6c1..1ff1a4967 100644 --- a/slither/core/children/child_inheritance.py +++ b/slither/core/children/child_inheritance.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.declarations import Contract @@ -7,11 +7,14 @@ if TYPE_CHECKING: class ChildInheritance: def __init__(self) -> None: super().__init__() - self._contract_declarer = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._contract_declarer: Optional["Contract"] = None def set_contract_declarer(self, contract: "Contract") -> None: self._contract_declarer = contract @property def contract_declarer(self) -> "Contract": - return self._contract_declarer + return self._contract_declarer # type: ignore diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py index 8e6e1f0b5..998ec5ea4 100644 --- a/slither/core/children/child_node.py +++ b/slither/core/children/child_node.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit @@ -9,14 +9,17 @@ if TYPE_CHECKING: class ChildNode: def __init__(self) -> None: super().__init__() - self._node = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._node: Optional["Node"] = None def set_node(self, node: "Node") -> None: self._node = node @property def node(self) -> "Node": - return self._node + return self._node # type:ignore @property def function(self) -> "Function": @@ -24,7 +27,7 @@ class ChildNode: @property def contract(self) -> "Contract": - return self.node.function.contract + return self.node.function.contract # type: ignore @property def compilation_unit(self) -> "SlitherCompilationUnit": diff --git a/slither/core/children/child_structure.py b/slither/core/children/child_structure.py index abcb041c2..413d7f4df 100644 --- a/slither/core/children/child_structure.py +++ b/slither/core/children/child_structure.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.declarations import Structure @@ -7,11 +7,14 @@ if TYPE_CHECKING: class ChildStructure: def __init__(self) -> None: super().__init__() - self._structure = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._structure: Optional["Structure"] = None def set_structure(self, structure: "Structure") -> None: self._structure = structure @property def structure(self) -> "Structure": - return self._structure + return self._structure # type: ignore diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 2d2d10b04..d3c016b4b 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -81,7 +81,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods # The only str is "*" self._using_for: Dict[Union[str, Type], List[Type]] = {} - self._using_for_complete: Dict[Union[str, Type], List[Type]] = None + self._using_for_complete: Optional[Dict[Union[str, Type], List[Type]]] = None self._kind: Optional[str] = None self._is_interface: bool = False self._is_library: bool = False @@ -275,7 +275,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive """ - def _merge_using_for(uf1, uf2): + def _merge_using_for(uf1: Dict, uf2: Dict) -> Dict: result = {**uf1, **uf2} for key, value in result.items(): if key in uf1 and key in uf2: @@ -524,14 +524,14 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return list(self._functions.values()) - def available_functions_as_dict(self) -> Dict[str, "FunctionContract"]: + def available_functions_as_dict(self) -> Dict[str, "Function"]: if self._available_functions_as_dict is None: self._available_functions_as_dict = { f.full_name: f for f in self._functions.values() if not f.is_shadowed } return self._available_functions_as_dict - def add_function(self, func: "FunctionContract"): + def add_function(self, func: "FunctionContract") -> None: self._functions[func.canonical_name] = func def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None: @@ -699,7 +699,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods list(Contract): Return the list of contracts derived from self """ candidates = self.compilation_unit.contracts - return [c for c in candidates if self in c.inheritance] + return [c for c in candidates if self in c.inheritance] # type: ignore # endregion ################################################################################### @@ -855,7 +855,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return next((e for e in self.enums if e.name == enum_name), None) - def get_enum_from_canonical_name(self, enum_name) -> Optional["Enum"]: + def get_enum_from_canonical_name(self, enum_name: str) -> Optional["Enum"]: """ Return an enum from a canonical name Args: @@ -956,7 +956,9 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def get_summary(self, include_shadowed=True) -> Tuple[str, List[str], List[str], List, List]: + def get_summary( + self, include_shadowed: bool = True + ) -> Tuple[str, List[str], List[str], List, List]: """Return the function summary :param include_shadowed: boolean to indicate if shadowed functions should be included (default True) @@ -1209,7 +1211,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods @property def is_test(self) -> bool: - return is_test_contract(self) or self.is_truffle_migration + return is_test_contract(self) or self.is_truffle_migration # type: ignore # endregion ################################################################################### @@ -1219,7 +1221,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### def update_read_write_using_ssa(self) -> None: - for function in self.functions + self.modifiers: + for function in self.functions + list(self.modifiers): function.update_read_write_using_ssa() # endregion @@ -1254,7 +1256,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_upgradeable @is_upgradeable.setter - def is_upgradeable(self, upgradeable: bool): + def is_upgradeable(self, upgradeable: bool) -> None: self._is_upgradeable = upgradeable @property @@ -1283,7 +1285,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_upgradeable_proxy @is_upgradeable_proxy.setter - def is_upgradeable_proxy(self, upgradeable_proxy: bool): + def is_upgradeable_proxy(self, upgradeable_proxy: bool) -> None: self._is_upgradeable_proxy = upgradeable_proxy @property @@ -1291,7 +1293,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._upgradeable_version @upgradeable_version.setter - def upgradeable_version(self, version_name: str): + def upgradeable_version(self, version_name: str) -> None: self._upgradeable_version = version_name # endregion @@ -1310,7 +1312,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_incorrectly_parsed @is_incorrectly_constructed.setter - def is_incorrectly_constructed(self, incorrect: bool): + def is_incorrectly_constructed(self, incorrect: bool) -> None: self._is_incorrectly_parsed = incorrect def add_constructor_variables(self) -> None: @@ -1322,8 +1324,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods constructor_variable = FunctionContract(self.compilation_unit) constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) - constructor_variable.set_contract(self) - constructor_variable.set_contract_declarer(self) + constructor_variable.set_contract(self) # type: ignore + constructor_variable.set_contract_declarer(self) # type: ignore constructor_variable.set_visibility("internal") # For now, source mapping of the constructor variable is the whole contract # Could be improved with a targeted source mapping @@ -1354,8 +1356,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods constructor_variable.set_function_type( FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES ) - constructor_variable.set_contract(self) - constructor_variable.set_contract_declarer(self) + constructor_variable.set_contract(self) # type: ignore + constructor_variable.set_contract_declarer(self) # type: ignore constructor_variable.set_visibility("internal") # For now, source mapping of the constructor variable is the whole contract # Could be improved with a targeted source mapping @@ -1436,22 +1438,23 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods all_ssa_state_variables_instances[v.canonical_name] = new_var self._initial_state_variables.append(new_var) - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): func.generate_slithir_ssa(all_ssa_state_variables_instances) def fix_phi(self) -> None: - last_state_variables_instances = {} - initial_state_variables_instances = {} + last_state_variables_instances: Dict[str, List["StateVariable"]] = {} + initial_state_variables_instances: Dict[str, "StateVariable"] = {} for v in self._initial_state_variables: last_state_variables_instances[v.canonical_name] = [] initial_state_variables_instances[v.canonical_name] = v - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): result = func.get_last_ssa_state_variables_instances() for variable_name, instances in result.items(): + # TODO: investigate the next operation last_state_variables_instances[variable_name] += instances - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): func.fix_phi(last_state_variables_instances, initial_state_variables_instances) # endregion @@ -1461,7 +1464,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def __eq__(self, other: SourceMapping) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, str): return other == self.name return NotImplemented @@ -1475,6 +1478,6 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self.name def __hash__(self) -> int: - return self._id + return self._id # type:ignore # endregion diff --git a/slither/core/declarations/custom_error.py b/slither/core/declarations/custom_error.py index 5e851c8da..c566fccec 100644 --- a/slither/core/declarations/custom_error.py +++ b/slither/core/declarations/custom_error.py @@ -51,7 +51,7 @@ class CustomError(SourceMapping): return str(t) @property - def solidity_signature(self) -> Optional[str]: + def solidity_signature(self) -> str: """ Return a signature following the Solidity Standard Contract and converted into address @@ -63,7 +63,7 @@ class CustomError(SourceMapping): # (set_solidity_sig was not called before find_variable) if self._solidity_signature is None: raise ValueError("Custom Error not yet built") - return self._solidity_signature + return self._solidity_signature # type: ignore def set_solidity_sig(self) -> None: """ diff --git a/slither/core/declarations/custom_error_contract.py b/slither/core/declarations/custom_error_contract.py index a96f12057..d5b7b6a92 100644 --- a/slither/core/declarations/custom_error_contract.py +++ b/slither/core/declarations/custom_error_contract.py @@ -1,9 +1,14 @@ +from typing import TYPE_CHECKING + from slither.core.children.child_contract import ChildContract from slither.core.declarations.custom_error import CustomError +if TYPE_CHECKING: + from slither.core.declarations import Contract + class CustomErrorContract(CustomError, ChildContract): - def is_declared_by(self, contract): + def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract :param contract: diff --git a/slither/core/declarations/custom_error_top_level.py b/slither/core/declarations/custom_error_top_level.py index 29a9fd41a..64a6a8535 100644 --- a/slither/core/declarations/custom_error_top_level.py +++ b/slither/core/declarations/custom_error_top_level.py @@ -9,6 +9,6 @@ if TYPE_CHECKING: class CustomErrorTopLevel(CustomError, TopLevel): - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__(compilation_unit) self.file_scope: "FileScope" = scope diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index c383fc99b..e77801961 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -47,7 +47,6 @@ if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope from slither.slithir.variables.state_variable import StateIRVariable - from slither.core.declarations.function_contract import FunctionContract LOGGER = logging.getLogger("Function") ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) @@ -298,7 +297,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def contains_assembly(self, c: bool): self._contains_assembly = c - def can_reenter(self, callstack: Optional[List["FunctionContract"]] = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union["Function", "Variable"]]] = None) -> bool: """ Check if the function can re-enter Follow internal calls. @@ -1720,8 +1719,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def fix_phi( self, - last_state_variables_instances: Dict[str, List["StateIRVariable"]], - initial_state_variables_instances: Dict[str, "StateIRVariable"], + last_state_variables_instances: Dict[str, List["StateVariable"]], + initial_state_variables_instances: Dict[str, "StateVariable"], ) -> None: from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.variables import Constant, StateIRVariable diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index 9569cde93..f0e903d7b 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -82,7 +82,7 @@ SOLIDITY_FUNCTIONS: Dict[str, List[str]] = { } -def solidity_function_signature(name): +def solidity_function_signature(name: str) -> str: """ Return the function signature (containing the return value) It is useful if a solidity function is used as a pointer @@ -106,7 +106,7 @@ class SolidityVariable(SourceMapping): assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset")) @property - def state_variable(self): + def state_variable(self) -> str: if self._name.endswith("_slot"): return self._name[:-5] if self._name.endswith("_offset"): @@ -125,7 +125,7 @@ class SolidityVariable(SourceMapping): def __str__(self) -> str: return self._name - def __eq__(self, other: SourceMapping) -> bool: + def __eq__(self, other: Any) -> bool: return self.__class__ == other.__class__ and self.name == other.name def __hash__(self) -> int: @@ -182,13 +182,13 @@ class SolidityFunction(SourceMapping): return self._return_type @return_type.setter - def return_type(self, r: List[Union[TypeInformation, ElementaryType]]): + def return_type(self, r: List[Union[TypeInformation, ElementaryType]]) -> None: self._return_type = r def __str__(self) -> str: return self._name - def __eq__(self, other: "SolidityFunction") -> bool: + def __eq__(self, other: Any) -> bool: return self.__class__ == other.__class__ and self.name == other.name def __hash__(self) -> int: @@ -201,7 +201,7 @@ class SolidityCustomRevert(SolidityFunction): self._custom_error = custom_error self._return_type: List[Union[TypeInformation, ElementaryType]] = [] - def __eq__(self, other: Union["SolidityCustomRevert", SolidityFunction]) -> bool: + def __eq__(self, other: Any) -> bool: return ( self.__class__ == other.__class__ and self.name == other.name diff --git a/slither/core/dominators/utils.py b/slither/core/dominators/utils.py index ca5c51282..4dd55749d 100644 --- a/slither/core/dominators/utils.py +++ b/slither/core/dominators/utils.py @@ -95,4 +95,5 @@ def compute_dominance_frontier(nodes: List["Node"]) -> None: runner.dominance_frontier = runner.dominance_frontier.union({node}) while runner != node.immediate_dominator: runner.dominance_frontier = runner.dominance_frontier.union({node}) + assert runner.immediate_dominator runner = runner.immediate_dominator diff --git a/slither/core/expressions/assignment_operation.py b/slither/core/expressions/assignment_operation.py index 22aba57fb..7b7bc62d6 100644 --- a/slither/core/expressions/assignment_operation.py +++ b/slither/core/expressions/assignment_operation.py @@ -91,7 +91,7 @@ class AssignmentOperation(ExpressionTyped): super().__init__() left_expression.set_lvalue() self._expressions = [left_expression, right_expression] - self._type: Optional["AssignmentOperationType"] = expression_type + self._type: AssignmentOperationType = expression_type self._expression_return_type: Optional["Type"] = expression_return_type @property diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index e5f4e830a..ba2ec802d 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -482,8 +482,8 @@ class SlitherCore(Context): ################################################################################### @property - def crytic_compile(self) -> Optional[CryticCompile]: - return self._crytic_compile + def crytic_compile(self) -> CryticCompile: + return self._crytic_compile # type: ignore # endregion ################################################################################### diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index 9a0b12c00..cdb8c10c7 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -4,11 +4,11 @@ from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type from slither.visitors.expression.constants_folding import ConstantFolding from slither.core.expressions.literal import Literal +from slither.core.solidity_types.elementary_type import ElementaryType if TYPE_CHECKING: from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.identifier import Identifier - from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.function_type import FunctionType from slither.core.solidity_types.type_alias import TypeAliasTopLevel @@ -22,7 +22,7 @@ class ArrayType(Type): assert isinstance(t, Type) if length: if isinstance(length, int): - length = Literal(length, "uint256") + length = Literal(length, ElementaryType("uint256")) assert isinstance(length, Expression) super().__init__() self._type: Type = t diff --git a/slither/core/solidity_types/elementary_type.py b/slither/core/solidity_types/elementary_type.py index ec2b0ef04..a9f45c8d8 100644 --- a/slither/core/solidity_types/elementary_type.py +++ b/slither/core/solidity_types/elementary_type.py @@ -1,5 +1,5 @@ import itertools -from typing import Tuple +from typing import Tuple, Optional, Any from slither.core.solidity_types.type import Type @@ -176,7 +176,7 @@ class ElementaryType(Type): return self.type @property - def size(self) -> int: + def size(self) -> Optional[int]: """ Return the size in bits Return None if the size is not known @@ -219,7 +219,7 @@ class ElementaryType(Type): def __str__(self) -> str: return self._type - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, ElementaryType): return False return self.type == other.type diff --git a/slither/core/solidity_types/mapping_type.py b/slither/core/solidity_types/mapping_type.py index a8acb4d9c..9741569ed 100644 --- a/slither/core/solidity_types/mapping_type.py +++ b/slither/core/solidity_types/mapping_type.py @@ -1,4 +1,4 @@ -from typing import Union, Tuple, TYPE_CHECKING +from typing import Union, Tuple, TYPE_CHECKING, Any from slither.core.solidity_types.type import Type @@ -38,7 +38,7 @@ class MappingType(Type): def __str__(self) -> str: return f"mapping({str(self._from)} => {str(self._to)})" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, MappingType): return False return self.type_from == other.type_from and self.type_to == other.type_to diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 5b9ea0a37..1da2d4182 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -40,7 +40,7 @@ class TypeAlias(Type): class TypeAliasTopLevel(TypeAlias, TopLevel): - def __init__(self, underlying_type: Type, name: str, scope: "FileScope") -> None: + def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None: super().__init__(underlying_type, name) self.file_scope: "FileScope" = scope @@ -49,7 +49,7 @@ class TypeAliasTopLevel(TypeAlias, TopLevel): class TypeAliasContract(TypeAlias, ChildContract): - def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: + def __init__(self, underlying_type: ElementaryType, name: str, contract: "Contract") -> None: super().__init__(underlying_type, name) self._contract: "Contract" = contract diff --git a/slither/core/solidity_types/type_information.py b/slither/core/solidity_types/type_information.py index 2af0b097a..9cef9352c 100644 --- a/slither/core/solidity_types/type_information.py +++ b/slither/core/solidity_types/type_information.py @@ -1,4 +1,4 @@ -from typing import Union, TYPE_CHECKING, Tuple +from typing import Union, TYPE_CHECKING, Tuple, Any from slither.core.solidity_types import ElementaryType from slither.core.solidity_types.type import Type @@ -40,10 +40,10 @@ class TypeInformation(Type): def is_dynamic(self) -> bool: raise NotImplementedError - def __str__(self): + def __str__(self) -> str: return f"type({self.type.name})" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, TypeInformation): return False return self.type == other.type diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index a0fcf354a..4c8742b22 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -1,6 +1,6 @@ import re from abc import ABCMeta -from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional +from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional, Any from Crypto.Hash import SHA1 from crytic_compile.utils.naming import Filename @@ -102,10 +102,10 @@ class Source: filename_short: str = self.filename.short if self.filename.short else "" return f"{filename_short}{lines}" - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, type(self)): return NotImplemented return ( diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index f3ad60d0b..e191433dc 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -16,5 +16,5 @@ class EventVariable(ChildEvent, Variable): return self._indexed @indexed.setter - def indexed(self, is_indexed: bool): + def indexed(self, is_indexed: bool) -> bool: self._indexed = is_indexed diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 8607a8921..c775e7c98 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -55,7 +55,7 @@ class Variable(SourceMapping): return self._initialized @initialized.setter - def initialized(self, is_init: bool): + def initialized(self, is_init: bool) -> None: self._initialized = is_init @property @@ -73,7 +73,7 @@ class Variable(SourceMapping): return self._name @name.setter - def name(self, name): + def name(self, name: str) -> None: self._name = name @property @@ -89,7 +89,7 @@ class Variable(SourceMapping): return self._is_constant @is_constant.setter - def is_constant(self, is_cst: bool): + def is_constant(self, is_cst: bool) -> None: self._is_constant = is_cst @property diff --git a/slither/detectors/abstract_detector.py b/slither/detectors/abstract_detector.py index 8e2dd490d..59f8ca3a0 100644 --- a/slither/detectors/abstract_detector.py +++ b/slither/detectors/abstract_detector.py @@ -59,6 +59,8 @@ ALL_SOLC_VERSIONS_06 = make_solc_versions(6, 0, 12) ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6) # No VERSIONS_08 as it is still in dev +DETECTOR_INFO = Union[str, List[Union[str, SupportedOutput]]] + class AbstractDetector(metaclass=abc.ABCMeta): ARGUMENT = "" # run the detector with slither.py --ARGUMENT @@ -251,7 +253,7 @@ class AbstractDetector(metaclass=abc.ABCMeta): def generate_result( self, - info: Union[str, List[Union[str, SupportedOutput]]], + info: DETECTOR_INFO, additional_fields: Optional[Dict] = None, ) -> Output: output = Output( diff --git a/slither/detectors/assembly/shift_parameter_mixup.py b/slither/detectors/assembly/shift_parameter_mixup.py index 31dad2371..a4169499a 100644 --- a/slither/detectors/assembly/shift_parameter_mixup.py +++ b/slither/detectors/assembly/shift_parameter_mixup.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Binary, BinaryType from slither.slithir.variables import Constant from slither.core.declarations.function_contract import FunctionContract @@ -49,7 +53,12 @@ The shift statement will right-shift the constant 8 by `a` bits""" BinaryType.RIGHT_SHIFT, ]: if isinstance(ir.variable_left, Constant): - info = [f, " contains an incorrect shift operation: ", node, "\n"] + info: DETECTOR_INFO = [ + f, + " contains an incorrect shift operation: ", + node, + "\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/attributes/const_functions_asm.py b/slither/detectors/attributes/const_functions_asm.py index e3a938361..6de5062d8 100644 --- a/slither/detectors/attributes/const_functions_asm.py +++ b/slither/detectors/attributes/const_functions_asm.py @@ -7,6 +7,7 @@ from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, + DETECTOR_INFO, ) from slither.formatters.attributes.const_functions import custom_format from slither.utils.output import Output @@ -73,7 +74,10 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" if f.contains_assembly: attr = "view" if f.view else "pure" - info = [f, f" is declared {attr} but contains assembly code\n"] + info: DETECTOR_INFO = [ + f, + f" is declared {attr} but contains assembly code\n", + ] res = self.generate_result(info, {"contains_assembly": True}) results.append(res) diff --git a/slither/detectors/compiler_bugs/array_by_reference.py b/slither/detectors/compiler_bugs/array_by_reference.py index 83ed69b9b..fffe93847 100644 --- a/slither/detectors/compiler_bugs/array_by_reference.py +++ b/slither/detectors/compiler_bugs/array_by_reference.py @@ -105,7 +105,12 @@ As a result, Bob's usage of the contract is incorrect.""" write to the array unsuccessfully. """ # Define our resulting array. - results = [] + results: List[ + Union[ + Tuple[Node, StateVariable, FunctionContract], + Tuple[Node, LocalVariable, FunctionContract], + ] + ] = [] # Verify we have functions in our list to check for. if not array_modifying_funcs: diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20.py b/slither/detectors/erc/erc20/arbitrary_send_erc20.py index 17b1fba30..f06005459 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20.py @@ -61,12 +61,12 @@ class ArbitrarySendErc20: is_dependent( ir.arguments[0], SolidityVariableComposed("msg.sender"), - node.function.contract, + node, ) or is_dependent( ir.arguments[0], SolidityVariable("this"), - node.function.contract, + node, ) ) ): @@ -79,12 +79,12 @@ class ArbitrarySendErc20: is_dependent( ir.arguments[1], SolidityVariableComposed("msg.sender"), - node.function.contract, + node, ) or is_dependent( ir.arguments[1], SolidityVariable("this"), - node.function.contract, + node, ) ) ): diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py b/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py index f43b6302e..351f1dcfa 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from .arbitrary_send_erc20 import ArbitrarySendErc20 @@ -38,7 +42,7 @@ Use `msg.sender` as `from` in transferFrom. arbitrary_sends.detect() for node in arbitrary_sends.no_permit_results: func = node.function - info = [func, " uses arbitrary from in transferFrom: ", node, "\n"] + info: DETECTOR_INFO = [func, " uses arbitrary from in transferFrom: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py b/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py index 1d311c442..ca4c4a793 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from .arbitrary_send_erc20 import ArbitrarySendErc20 @@ -41,7 +45,7 @@ Ensure that the underlying ERC20 token correctly implements a permit function. arbitrary_sends.detect() for node in arbitrary_sends.permit_results: func = node.function - info = [ + info: DETECTOR_INFO = [ func, " uses arbitrary from in transferFrom in combination with permit: ", node, diff --git a/slither/detectors/functions/arbitrary_send_eth.py b/slither/detectors/functions/arbitrary_send_eth.py index 390b1f2ab..e0112c6e1 100644 --- a/slither/detectors/functions/arbitrary_send_eth.py +++ b/slither/detectors/functions/arbitrary_send_eth.py @@ -39,6 +39,10 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: ret: List[Node] = [] for node in func.nodes: + func = node.function + deps_target: Union[Contract, Function] = ( + func.contract if isinstance(func, FunctionContract) else func + ) for ir in node.irs: if isinstance(ir, SolidityCall): if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"): @@ -49,7 +53,7 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: if is_dependent( ir.variable_right, SolidityVariableComposed("msg.sender"), - func.contract, + deps_target, ): return False if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): @@ -64,11 +68,11 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: if is_dependent( ir.call_value, SolidityVariableComposed("msg.value"), - func.contract, + node, ): continue - if is_tainted(ir.destination, func.contract): + if is_tainted(ir.destination, node): ret.append(node) return ret diff --git a/slither/detectors/statements/array_length_assignment.py b/slither/detectors/statements/array_length_assignment.py index 51302a2c9..f4ae01b88 100644 --- a/slither/detectors/statements/array_length_assignment.py +++ b/slither/detectors/statements/array_length_assignment.py @@ -7,6 +7,7 @@ from slither.detectors.abstract_detector import ( DetectorClassification, ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_05, + DETECTOR_INFO, ) from slither.core.cfg.node import Node, NodeType from slither.slithir.operations import Assignment, Length @@ -120,7 +121,7 @@ Otherwise, thoroughly review the contract to ensure a user-controlled variable c for contract in self.contracts: array_length_assignments = detect_array_length_assignment(contract) if array_length_assignments: - contract_info = [ + contract_info: DETECTOR_INFO = [ contract, " contract sets array length with a user-controlled value:\n", ] diff --git a/slither/detectors/statements/assembly.py b/slither/detectors/statements/assembly.py index 2c0c49f09..25b5d8034 100644 --- a/slither/detectors/statements/assembly.py +++ b/slither/detectors/statements/assembly.py @@ -6,7 +6,11 @@ from typing import List, Tuple from slither.core.cfg.node import Node, NodeType from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -52,7 +56,7 @@ class Assembly(AbstractDetector): for c in self.contracts: values = self.detect_assembly(c) for func, nodes in values: - info = [func, " uses assembly\n"] + info: DETECTOR_INFO = [func, " uses assembly\n"] # sort the nodes to get deterministic results nodes.sort(key=lambda x: x.node_id) diff --git a/slither/slithir/operations/call.py b/slither/slithir/operations/call.py index 07304fa99..37a2fe0b3 100644 --- a/slither/slithir/operations/call.py +++ b/slither/slithir/operations/call.py @@ -1,5 +1,7 @@ -from typing import Optional, List +from typing import Optional, List, Union +from slither.core.declarations import Function +from slither.core.variables import Variable from slither.slithir.operations.operation import Operation @@ -16,7 +18,8 @@ class Call(Operation): def arguments(self, v): self._arguments = v - def can_reenter(self, _callstack: Optional[List] = None) -> bool: # pylint: disable=no-self-use + # pylint: disable=no-self-use + def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index 93fb73bd4..d707e11b3 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union +from slither.core.declarations import Contract from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.variables.variable import Variable @@ -32,7 +33,8 @@ class HighLevelCall(Call, OperationWithLValue): assert is_valid_lvalue(result) or result is None self._check_destination(destination) super().__init__() - self._destination = destination + # Contract is only possible for library call, which inherits from highlevelcall + self._destination: Union[Variable, SolidityVariable, Contract] = destination # type: ignore self._function_name = function_name self._nbr_arguments = nbr_arguments self._type_call = type_call @@ -44,8 +46,9 @@ class HighLevelCall(Call, OperationWithLValue): self._call_gas = None # Development function, to be removed once the code is stable - # It is ovveride by LbraryCall - def _check_destination(self, destination: SourceMapping) -> None: # pylint: disable=no-self-use + # It is overridden by LibraryCall + # pylint: disable=no-self-use + def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None: assert isinstance(destination, (Variable, SolidityVariable)) @property @@ -79,7 +82,14 @@ class HighLevelCall(Call, OperationWithLValue): return [x for x in all_read if x] + [self.destination] @property - def destination(self) -> SourceMapping: + def destination(self) -> Union[Variable, SolidityVariable, Contract]: + """ + Return a variable or a solidityVariable + Contract is only possible for LibraryCall + + Returns: + + """ return self._destination @property @@ -116,7 +126,7 @@ class HighLevelCall(Call, OperationWithLValue): return True return False - def can_reenter(self, callstack: None = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass For Solidity > 0.5, filter access to public variables and constant/pure/view diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py index ade84fe1d..77daa9462 100644 --- a/slither/slithir/operations/index.py +++ b/slither/slithir/operations/index.py @@ -1,20 +1,20 @@ from typing import List, Union + from slither.core.declarations import SolidityVariableComposed -from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue -from slither.slithir.variables.reference import ReferenceVariable from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.variable import Variable -from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, RVALUE, LVALUE +from slither.slithir.variables.reference import ReferenceVariable class Index(OperationWithLValue): def __init__( self, - result: Union[ReferenceVariable, ReferenceVariableSSA], + result: ReferenceVariable, left_variable: Variable, - right_variable: SourceMapping, + right_variable: RVALUE, index_type: Union[ElementaryType, str], ) -> None: super().__init__() @@ -25,23 +25,23 @@ class Index(OperationWithLValue): assert isinstance(result, ReferenceVariable) self._variables = [left_variable, right_variable] self._type = index_type - self._lvalue = result + self._lvalue: ReferenceVariable = result @property def read(self) -> List[SourceMapping]: return list(self.variables) @property - def variables(self) -> List[SourceMapping]: - return self._variables + def variables(self) -> List[Union[LVALUE, RVALUE, SolidityVariableComposed]]: + return self._variables # type: ignore @property - def variable_left(self) -> Variable: - return self._variables[0] + def variable_left(self) -> Union[LVALUE, SolidityVariableComposed]: + return self._variables[0] # type: ignore @property - def variable_right(self) -> SourceMapping: - return self._variables[1] + def variable_right(self) -> RVALUE: + return self._variables[1] # type: ignore @property def index_type(self) -> Union[ElementaryType, str]: diff --git a/slither/slithir/operations/library_call.py b/slither/slithir/operations/library_call.py index ebe9bf5ef..1b7f4e8a6 100644 --- a/slither/slithir/operations/library_call.py +++ b/slither/slithir/operations/library_call.py @@ -1,4 +1,7 @@ -from slither.core.declarations import Function +from typing import Union, Optional, List + +from slither.core.declarations import Function, SolidityVariable +from slither.core.variables import Variable from slither.slithir.operations.high_level_call import HighLevelCall from slither.core.declarations.contract import Contract @@ -9,10 +12,10 @@ class LibraryCall(HighLevelCall): """ # Development function, to be removed once the code is stable - def _check_destination(self, destination: Contract) -> None: + def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None: assert isinstance(destination, Contract) - def can_reenter(self, callstack: None = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool @@ -20,11 +23,11 @@ class LibraryCall(HighLevelCall): if self.is_static_call(): return False # In case of recursion, return False - callstack = [] if callstack is None else callstack - if self.function in callstack: + callstack_local = [] if callstack is None else callstack + if self.function in callstack_local: return False - callstack = callstack + [self.function] - return self.function.can_reenter(callstack) + callstack_local = callstack_local + [self.function] + return self.function.can_reenter(callstack_local) def __str__(self): gas = "" diff --git a/slither/slithir/operations/low_level_call.py b/slither/slithir/operations/low_level_call.py index 7e8c278b8..eac779d27 100644 --- a/slither/slithir/operations/low_level_call.py +++ b/slither/slithir/operations/low_level_call.py @@ -1,4 +1,6 @@ -from typing import List, Union +from typing import List, Union, Optional + +from slither.core.declarations import Function from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.variables.variable import Variable @@ -74,7 +76,7 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta # remove None return self._unroll([x for x in all_read if x]) - def can_reenter(self, _callstack: None = None) -> bool: + def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/lvalue.py b/slither/slithir/operations/lvalue.py index d9b800c92..b983d1c5d 100644 --- a/slither/slithir/operations/lvalue.py +++ b/slither/slithir/operations/lvalue.py @@ -1,4 +1,6 @@ -from typing import Any, List +from typing import Any, List, Optional + +from slither.core.variables import Variable from slither.slithir.operations.operation import Operation @@ -10,16 +12,16 @@ class OperationWithLValue(Operation): def __init__(self) -> None: super().__init__() - self._lvalue = None + self._lvalue: Optional[Variable] = None @property - def lvalue(self): + def lvalue(self) -> Optional[Variable]: return self._lvalue - @property - def used(self) -> List[Any]: - return self.read + [self.lvalue] - @lvalue.setter - def lvalue(self, lvalue): + def lvalue(self, lvalue: Variable) -> None: self._lvalue = lvalue + + @property + def used(self) -> List[Optional[Any]]: + return self.read + [self.lvalue] diff --git a/slither/slithir/operations/member.py b/slither/slithir/operations/member.py index 9a561ea87..0942813cf 100644 --- a/slither/slithir/operations/member.py +++ b/slither/slithir/operations/member.py @@ -5,7 +5,7 @@ from slither.core.declarations.enum import Enum from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.solidity_types import ElementaryType from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_rvalue +from slither.slithir.utils.utils import is_valid_rvalue, RVALUE from slither.slithir.variables.constant import Constant from slither.slithir.variables.reference import ReferenceVariable from slither.core.source_mapping.source_mapping import SourceMapping @@ -39,7 +39,9 @@ class Member(OperationWithLValue): assert isinstance(variable_right, Constant) assert isinstance(result, ReferenceVariable) super().__init__() - self._variable_left = variable_left + self._variable_left: Union[ + RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType + ] = variable_left self._variable_right = variable_right self._lvalue = result self._gas = None @@ -50,7 +52,11 @@ class Member(OperationWithLValue): return [self.variable_left, self.variable_right] @property - def variable_left(self) -> SourceMapping: + def variable_left( + self, + ) -> Union[ + RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType + ]: return self._variable_left @property diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py index 879d12df6..8d3c949df 100644 --- a/slither/slithir/operations/new_contract.py +++ b/slither/slithir/operations/new_contract.py @@ -1,11 +1,13 @@ from typing import Optional, Any, List, Union + +from slither.core.declarations import Function +from slither.core.declarations.contract import Contract +from slither.core.variables import Variable from slither.slithir.operations import Call, OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant -from slither.core.declarations.contract import Contract from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA -from slither.core.declarations.function_contract import FunctionContract class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes @@ -58,6 +60,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan def contract_created(self) -> Contract: contract_name = self.contract_name contract_instance = self.node.file_scope.get_contract_from_name(contract_name) + assert contract_instance return contract_instance ################################################################################### @@ -66,7 +69,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan ################################################################################### ################################################################################### - def can_reenter(self, callstack: Optional[List[FunctionContract]] = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass For Solidity > 0.5, filter access to public variables and constant/pure/view @@ -92,7 +95,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan # endregion - def __str__(self): + def __str__(self) -> str: options = "" if self.call_value: options = f"value:{self.call_value} " diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py index b059c55a6..88430f934 100644 --- a/slither/slithir/operations/solidity_call.py +++ b/slither/slithir/operations/solidity_call.py @@ -1,15 +1,16 @@ from typing import Any, List, Union -from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction -from slither.slithir.operations.call import Call -from slither.slithir.operations.lvalue import OperationWithLValue + from slither.core.children.child_node import ChildNode +from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.solidity_types.elementary_type import ElementaryType +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue class SolidityCall(Call, OperationWithLValue): def __init__( self, - function: Union[SolidityCustomRevert, SolidityFunction], + function: SolidityFunction, nbr_arguments: int, result: ChildNode, type_call: Union[str, List[ElementaryType]], @@ -26,7 +27,7 @@ class SolidityCall(Call, OperationWithLValue): return self._unroll(self.arguments) @property - def function(self) -> Union[SolidityCustomRevert, SolidityFunction]: + def function(self) -> SolidityFunction: return self._function @property diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index 0a50f8e50..a0ca0bd6f 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -1,3 +1,5 @@ +from typing import Union + from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable @@ -10,6 +12,24 @@ from slither.slithir.variables.reference import ReferenceVariable from slither.slithir.variables.tuple import TupleVariable from slither.core.source_mapping.source_mapping import SourceMapping +RVALUE = Union[ + StateVariable, + LocalVariable, + TopLevelVariable, + TemporaryVariable, + Constant, + SolidityVariable, + ReferenceVariable, +] + +LVALUE = Union[ + StateVariable, + LocalVariable, + TemporaryVariable, + ReferenceVariable, + TupleVariable, +] + def is_valid_rvalue(v: SourceMapping) -> bool: return isinstance( diff --git a/slither/tools/mutator/__main__.py b/slither/tools/mutator/__main__.py index 27e396d0b..84286ce66 100644 --- a/slither/tools/mutator/__main__.py +++ b/slither/tools/mutator/__main__.py @@ -79,9 +79,10 @@ def main() -> None: print(args.codebase) sl = Slither(args.codebase, **vars(args)) - for M in _get_mutators(): - m = M(sl) - m.mutate() + for compilation_unit in sl.compilation_units: + for M in _get_mutators(): + m = M(compilation_unit) + m.mutate() # endregion diff --git a/slither/tools/mutator/mutators/abstract_mutator.py b/slither/tools/mutator/mutators/abstract_mutator.py index 850c3c399..169d8725e 100644 --- a/slither/tools/mutator/mutators/abstract_mutator.py +++ b/slither/tools/mutator/mutators/abstract_mutator.py @@ -3,7 +3,7 @@ import logging from enum import Enum from typing import Optional, Dict -from slither import Slither +from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.utils.patches import apply_patch, create_diff logger = logging.getLogger("Slither") @@ -34,8 +34,11 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public- FAULTCLASS = FaultClass.Undefined FAULTNATURE = FaultNature.Undefined - def __init__(self, slither: Slither, rate: int = 10, seed: Optional[int] = None): - self.slither = slither + def __init__( + self, compilation_unit: SlitherCompilationUnit, rate: int = 10, seed: Optional[int] = None + ): + self.compilation_unit = compilation_unit + self.slither = compilation_unit.core self.seed = seed self.rate = rate @@ -87,7 +90,7 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public- continue for patch in patches: patched_txt, offset = apply_patch(patched_txt, patch, offset) - diff = create_diff(self.slither, original_txt, patched_txt, file) + diff = create_diff(self.compilation_unit, original_txt, patched_txt, file) if not diff: logger.info(f"Impossible to generate patch; empty {patches}") print(diff) From efed98327a7553badfd1c56720136637885b9207 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Fri, 17 Feb 2023 14:01:47 +0100 Subject: [PATCH 02/34] Add more types --- .../data_dependency/data_dependency.py | 30 ++-- .../analyses/write/are_variables_written.py | 29 ++-- slither/core/children/child_expression.py | 2 +- slither/core/declarations/contract.py | 2 +- slither/core/solidity_types/array_type.py | 5 +- slither/detectors/abstract_detector.py | 2 +- .../attributes/const_functions_asm.py | 6 +- .../attributes/const_functions_state.py | 9 +- .../detectors/attributes/constant_pragma.py | 15 +- .../detectors/attributes/incorrect_solc.py | 8 +- slither/detectors/attributes/locked_ether.py | 8 +- .../attributes/unimplemented_interface.py | 8 +- .../compiler_bugs/array_by_reference.py | 8 +- .../compiler_bugs/enum_conversion.py | 11 +- .../multiple_constructor_schemes.py | 11 +- .../compiler_bugs/reused_base_constructor.py | 3 +- .../storage_ABIEncoderV2_array.py | 9 +- .../storage_signed_integer_array.py | 32 ++-- ...initialized_function_ptr_in_constructor.py | 5 +- .../erc/erc20/incorrect_erc20_interface.py | 8 +- .../erc/incorrect_erc721_interface.py | 8 +- slither/detectors/examples/backdoor.py | 8 +- .../detectors/functions/arbitrary_send_eth.py | 7 +- .../functions/cyclomatic_complexity.py | 8 +- slither/detectors/functions/dead_code.py | 8 +- slither/detectors/functions/modifier.py | 12 +- .../permit_domain_signature_collision.py | 8 +- .../detectors/functions/protected_variable.py | 8 +- slither/detectors/functions/suicidal.py | 8 +- slither/detectors/functions/unimplemented.py | 14 +- .../naming_convention/naming_convention.py | 7 +- slither/detectors/operations/bad_prng.py | 9 +- .../detectors/operations/block_timestamp.py | 29 ++-- .../detectors/operations/low_level_calls.py | 8 +- .../missing_events_access_control.py | 8 +- .../operations/missing_events_arithmetic.py | 8 +- .../missing_zero_address_validation.py | 8 +- .../operations/unused_return_values.py | 8 +- .../detectors/operations/void_constructor.py | 8 +- slither/detectors/reentrancy/token.py | 8 +- .../detectors/shadowing/builtin_symbols.py | 8 +- slither/detectors/shadowing/local.py | 10 +- slither/detectors/shadowing/state.py | 8 +- slither/detectors/slither/name_reused.py | 8 +- slither/detectors/source/rtlo.py | 8 +- .../statements/array_length_assignment.py | 17 +- .../statements/assert_state_change.py | 13 +- .../statements/boolean_constant_equality.py | 8 +- .../statements/boolean_constant_misuse.py | 8 +- slither/detectors/statements/calls_in_loop.py | 8 +- .../statements/controlled_delegatecall.py | 10 +- .../statements/costly_operations_in_loop.py | 8 +- .../statements/delegatecall_in_loop.py | 13 +- .../detectors/statements/deprecated_calls.py | 8 +- .../statements/divide_before_multiply.py | 35 ++-- .../detectors/statements/mapping_deletion.py | 8 +- .../detectors/statements/msg_value_in_loop.py | 8 +- .../statements/redundant_statements.py | 14 +- .../detectors/statements/too_many_digits.py | 10 +- slither/detectors/statements/tx_origin.py | 8 +- .../statements/type_based_tautology.py | 7 +- slither/detectors/statements/unary.py | 11 +- .../statements/unprotected_upgradeable.py | 26 +-- .../detectors/statements/write_after_write.py | 17 +- .../function_init_state_variables.py | 8 +- .../variables/predeclaration_usage_local.py | 8 +- .../detectors/variables/similar_variables.py | 14 +- .../uninitialized_state_variables.py | 8 +- .../variables/unused_state_variables.py | 25 ++- .../variables/var_read_using_this.py | 8 +- .../formatters/attributes/const_functions.py | 3 +- .../formatters/attributes/constant_pragma.py | 5 +- .../naming_convention/naming_convention.py | 163 +++++++++++++----- .../variables/unused_state_variables.py | 8 +- slither/slithir/operations/assignment.py | 37 ++-- slither/slithir/operations/binary.py | 49 +++--- slither/slithir/operations/internal_call.py | 4 +- slither/slithir/tmp_operations/argument.py | 17 +- 78 files changed, 717 insertions(+), 320 deletions(-) diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index 448ee393a..2b66f2bb3 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -16,6 +16,7 @@ from slither.core.declarations import ( FunctionContract, ) from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder +from slither.core.solidity_types.type import Type from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.variable import Variable from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation @@ -28,12 +29,10 @@ from slither.slithir.variables import ( TemporaryVariableSSA, TupleVariableSSA, ) -from slither.core.solidity_types.type import Type if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit - ################################################################################### ################################################################################### # region User APIs @@ -41,7 +40,8 @@ if TYPE_CHECKING: ################################################################################### -Variable_types = Union[Variable, SolidityVariable] +SUPPORTED_TYPES = Union[Variable, SolidityVariable] + # TODO refactor the data deps to be better suited for top level function object # Right now we allow to pass a node to ease the API, but we need something # better @@ -51,8 +51,8 @@ Context_types = Union[Contract, Function] def is_dependent( - variable: Variable_types, - source: Variable_types, + variable: SUPPORTED_TYPES, + source: SUPPORTED_TYPES, context: Context_types_API, only_unprotected: bool = False, ) -> bool: @@ -88,8 +88,8 @@ def is_dependent( def is_dependent_ssa( - variable: Variable_types, - source: Variable_types, + variable: SUPPORTED_TYPES, + source: SUPPORTED_TYPES, context: Context_types_API, only_unprotected: bool = False, ) -> bool: @@ -131,7 +131,7 @@ GENERIC_TAINT = { def is_tainted( - variable: Variable_types, + variable: SUPPORTED_TYPES, context: Context_types_API, only_unprotected: bool = False, ignore_generic_taint: bool = False, @@ -164,7 +164,7 @@ def is_tainted( def is_tainted_ssa( - variable: Variable_types, + variable: SUPPORTED_TYPES, context: Context_types_API, only_unprotected: bool = False, ignore_generic_taint: bool = False, @@ -197,7 +197,7 @@ def is_tainted_ssa( def get_dependencies( - variable: Variable_types, + variable: SUPPORTED_TYPES, context: Context_types_API, only_unprotected: bool = False, ) -> Set[Variable]: @@ -244,7 +244,7 @@ def get_all_dependencies( def get_dependencies_ssa( - variable: Variable_types, + variable: SUPPORTED_TYPES, context: Context_types_API, only_unprotected: bool = False, ) -> Set[Variable]: @@ -459,7 +459,7 @@ def compute_dependency_function(function: Function) -> None: ) -def convert_variable_to_non_ssa(v: Variable_types) -> Variable_types: +def convert_variable_to_non_ssa(v: SUPPORTED_TYPES) -> SUPPORTED_TYPES: if isinstance( v, ( @@ -490,10 +490,10 @@ def convert_variable_to_non_ssa(v: Variable_types) -> Variable_types: def convert_to_non_ssa( - data_depencies: Dict[Variable_types, Set[Variable_types]] -) -> Dict[Variable_types, Set[Variable_types]]: + data_depencies: Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]] +) -> Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]]: # Need to create new set() as its changed during iteration - ret: Dict[Variable_types, Set[Variable_types]] = {} + ret: Dict[SUPPORTED_TYPES, Set[SUPPORTED_TYPES]] = {} for (k, values) in data_depencies.items(): var = convert_variable_to_non_ssa(k) if not var in ret: diff --git a/slither/analyses/write/are_variables_written.py b/slither/analyses/write/are_variables_written.py index 1b430012f..2f8f83063 100644 --- a/slither/analyses/write/are_variables_written.py +++ b/slither/analyses/write/are_variables_written.py @@ -2,10 +2,10 @@ Detect if all the given variables are written in all the paths of the function """ from collections import defaultdict -from typing import Dict, Set, List +from typing import Dict, Set, List, Any, Optional from slither.core.cfg.node import NodeType, Node -from slither.core.declarations import SolidityFunction +from slither.core.declarations import SolidityFunction, Function from slither.core.variables.variable import Variable from slither.slithir.operations import ( Index, @@ -18,7 +18,7 @@ from slither.slithir.variables import ReferenceVariable, TemporaryVariable class State: # pylint: disable=too-few-public-methods - def __init__(self): + def __init__(self) -> None: # Map node -> list of variables set # Were each variables set represents a configuration of a path # If two paths lead to the exact same set of variables written, we dont need to explore both @@ -34,11 +34,11 @@ class State: # pylint: disable=too-few-public-methods # pylint: disable=too-many-branches def _visit( - node: Node, + node: Optional[Node], state: State, variables_written: Set[Variable], variables_to_write: List[Variable], -): +) -> List[Variable]: """ Explore all the nodes to look for values not written when the node's function return Fixpoint reaches if no new written variables are found @@ -51,6 +51,8 @@ def _visit( refs = {} variables_written = set(variables_written) + if not node: + return [] for ir in node.irs: if isinstance(ir, SolidityCall): # TODO convert the revert to a THROW node @@ -70,17 +72,20 @@ def _visit( if ir.lvalue and not isinstance(ir.lvalue, (TemporaryVariable, ReferenceVariable)): variables_written.add(ir.lvalue) - lvalue = ir.lvalue + lvalue: Any = ir.lvalue while isinstance(lvalue, ReferenceVariable): if lvalue not in refs: break - if refs[lvalue] and not isinstance( - refs[lvalue], (TemporaryVariable, ReferenceVariable) + refs_lvalues = refs[lvalue] + if ( + refs_lvalues + and isinstance(refs_lvalues, Variable) + and not isinstance(refs_lvalues, (TemporaryVariable, ReferenceVariable)) ): - variables_written.add(refs[lvalue]) - lvalue = refs[lvalue] + variables_written.add(refs_lvalues) + lvalue = refs_lvalues - ret = [] + ret: List[Variable] = [] if not node.sons and node.type not in [NodeType.THROW, NodeType.RETURN]: ret += [v for v in variables_to_write if v not in variables_written] @@ -96,7 +101,7 @@ def _visit( return ret -def are_variables_written(function, variables_to_write): +def are_variables_written(function: Function, variables_to_write: List[Variable]) -> List[Variable]: """ Return the list of variable that are not written at the end of the function diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py index 2294cf384..58e1ae338 100644 --- a/slither/core/children/child_expression.py +++ b/slither/core/children/child_expression.py @@ -19,5 +19,5 @@ class ChildExpression: self._expression = expression @property - def expression(self) -> Union["Expression", "Operation"]: + def expression(self) -> "Expression": return self._expression # type: ignore diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index d3c016b4b..38b4221d9 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -455,7 +455,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ) @property - def constructors(self) -> List["Function"]: + def constructors(self) -> List["FunctionContract"]: """ Return the list of constructors (including inherited) """ diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index cdb8c10c7..85c2d6ba7 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -23,16 +23,17 @@ class ArrayType(Type): if length: if isinstance(length, int): length = Literal(length, ElementaryType("uint256")) - assert isinstance(length, Expression) + super().__init__() self._type: Type = t + assert length is None or isinstance(length, Expression) self._length: Optional[Expression] = length if length: if not isinstance(length, Literal): cf = ConstantFolding(length, "uint256") length = cf.result() - self._length_value = length + self._length_value: Optional[Literal] = length else: self._length_value = None diff --git a/slither/detectors/abstract_detector.py b/slither/detectors/abstract_detector.py index 59f8ca3a0..7bb8eb93f 100644 --- a/slither/detectors/abstract_detector.py +++ b/slither/detectors/abstract_detector.py @@ -59,7 +59,7 @@ ALL_SOLC_VERSIONS_06 = make_solc_versions(6, 0, 12) ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6) # No VERSIONS_08 as it is still in dev -DETECTOR_INFO = Union[str, List[Union[str, SupportedOutput]]] +DETECTOR_INFO = List[Union[str, SupportedOutput]] class AbstractDetector(metaclass=abc.ABCMeta): diff --git a/slither/detectors/attributes/const_functions_asm.py b/slither/detectors/attributes/const_functions_asm.py index 6de5062d8..01798e085 100644 --- a/slither/detectors/attributes/const_functions_asm.py +++ b/slither/detectors/attributes/const_functions_asm.py @@ -2,7 +2,9 @@ Module detecting constant functions Recursively check the called functions """ -from typing import List +from typing import List, Dict + +from slither.core.compilation_unit import SlitherCompilationUnit from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, @@ -85,5 +87,5 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" return results @staticmethod - def _format(comilation_unit, result): + def _format(comilation_unit: SlitherCompilationUnit, result: Dict) -> None: custom_format(comilation_unit, result) diff --git a/slither/detectors/attributes/const_functions_state.py b/slither/detectors/attributes/const_functions_state.py index 36ea8f32d..d86ca7c0e 100644 --- a/slither/detectors/attributes/const_functions_state.py +++ b/slither/detectors/attributes/const_functions_state.py @@ -2,11 +2,14 @@ Module detecting constant functions Recursively check the called functions """ -from typing import List +from typing import List, Dict + +from slither.core.compilation_unit import SlitherCompilationUnit from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, + DETECTOR_INFO, ) from slither.formatters.attributes.const_functions import custom_format from slither.utils.output import Output @@ -74,7 +77,7 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" if variables_written: attr = "view" if f.view else "pure" - info = [ + info: DETECTOR_INFO = [ f, f" is declared {attr} but changes state variables:\n", ] @@ -89,5 +92,5 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" return results @staticmethod - def _format(slither, result): + def _format(slither: SlitherCompilationUnit, result: Dict) -> None: custom_format(slither, result) diff --git a/slither/detectors/attributes/constant_pragma.py b/slither/detectors/attributes/constant_pragma.py index 2164a78e8..2ed76c86a 100644 --- a/slither/detectors/attributes/constant_pragma.py +++ b/slither/detectors/attributes/constant_pragma.py @@ -1,9 +1,14 @@ """ Check that the same pragma is used in all the files """ -from typing import List - -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from typing import List, Dict + +from slither.core.compilation_unit import SlitherCompilationUnit +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.formatters.attributes.constant_pragma import custom_format from slither.utils.output import Output @@ -31,7 +36,7 @@ class ConstantPragma(AbstractDetector): versions = sorted(list(set(versions))) if len(versions) > 1: - info = ["Different versions of Solidity are used:\n"] + info: DETECTOR_INFO = ["Different versions of Solidity are used:\n"] info += [f"\t- Version used: {[str(v) for v in versions]}\n"] for p in sorted(pragma, key=lambda x: x.version): @@ -44,5 +49,5 @@ class ConstantPragma(AbstractDetector): return results @staticmethod - def _format(slither, result): + def _format(slither: SlitherCompilationUnit, result: Dict) -> None: custom_format(slither, result) diff --git a/slither/detectors/attributes/incorrect_solc.py b/slither/detectors/attributes/incorrect_solc.py index fa9ffd88d..393fbbfe4 100644 --- a/slither/detectors/attributes/incorrect_solc.py +++ b/slither/detectors/attributes/incorrect_solc.py @@ -5,7 +5,11 @@ import re from typing import List, Optional, Tuple -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.formatters.attributes.incorrect_solc import custom_format from slither.utils.output import Output @@ -143,7 +147,7 @@ Consider using the latest version of Solidity for testing.""" # If we found any disallowed pragmas, we output our findings. if disallowed_pragmas: for (reason, p) in disallowed_pragmas: - info = ["Pragma version", p, f" {reason}\n"] + info: DETECTOR_INFO = ["Pragma version", p, f" {reason}\n"] json = self.generate_result(info) diff --git a/slither/detectors/attributes/locked_ether.py b/slither/detectors/attributes/locked_ether.py index 2fdabaea6..a6f882922 100644 --- a/slither/detectors/attributes/locked_ether.py +++ b/slither/detectors/attributes/locked_ether.py @@ -4,7 +4,11 @@ from typing import List from slither.core.declarations.contract import Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import ( HighLevelCall, LowLevelCall, @@ -85,7 +89,7 @@ Every Ether sent to `Locked` will be lost.""" funcs_payable = [function for function in contract.functions if function.payable] if funcs_payable: if self.do_no_send_ether(contract): - info = ["Contract locking ether found:\n"] + info: DETECTOR_INFO = ["Contract locking ether found:\n"] info += ["\tContract ", contract, " has payable functions:\n"] for function in funcs_payable: info += ["\t - ", function, "\n"] diff --git a/slither/detectors/attributes/unimplemented_interface.py b/slither/detectors/attributes/unimplemented_interface.py index ff0889d11..5c6c9c5f2 100644 --- a/slither/detectors/attributes/unimplemented_interface.py +++ b/slither/detectors/attributes/unimplemented_interface.py @@ -5,7 +5,11 @@ Collect all the interfaces Check for contracts which implement all interface functions but do not explicitly derive from those interfaces. """ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations.contract import Contract from slither.utils.output import Output @@ -139,7 +143,7 @@ contract Something { continue intended_interfaces = self.detect_unimplemented_interface(contract, interfaces) for interface in intended_interfaces: - info = [contract, " should inherit from ", interface, "\n"] + info: DETECTOR_INFO = [contract, " should inherit from ", interface, "\n"] res = self.generate_result(info) results.append(res) return results diff --git a/slither/detectors/compiler_bugs/array_by_reference.py b/slither/detectors/compiler_bugs/array_by_reference.py index fffe93847..ba4cadcc7 100644 --- a/slither/detectors/compiler_bugs/array_by_reference.py +++ b/slither/detectors/compiler_bugs/array_by_reference.py @@ -2,7 +2,11 @@ Detects the passing of arrays located in memory to functions which expect to modify arrays via storage reference. """ from typing import List, Set, Tuple, Union -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.solidity_types.array_type import ArrayType from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable @@ -164,7 +168,7 @@ As a result, Bob's usage of the contract is incorrect.""" if problematic_calls: for calling_node, affected_argument, invoked_function in problematic_calls: - info = [ + info: DETECTOR_INFO = [ calling_node.function, " passes array ", affected_argument, diff --git a/slither/detectors/compiler_bugs/enum_conversion.py b/slither/detectors/compiler_bugs/enum_conversion.py index 671b8d699..c7f1bcf4e 100644 --- a/slither/detectors/compiler_bugs/enum_conversion.py +++ b/slither/detectors/compiler_bugs/enum_conversion.py @@ -10,6 +10,7 @@ from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, make_solc_versions, + DETECTOR_INFO, ) from slither.slithir.operations import TypeConversion from slither.core.declarations.enum import Enum @@ -73,10 +74,14 @@ Attackers can trigger unexpected behaviour by calling `bug(1)`.""" for c in self.compilation_unit.contracts: ret = _detect_dangerous_enum_conversions(c) for node, var in ret: - func_info = [node, " has a dangerous enum conversion\n"] + func_info: DETECTOR_INFO = [node, " has a dangerous enum conversion\n"] # Output each node with the function info header as a separate result. - variable_info = ["\t- Variable: ", var, f" of type: {str(var.type)}\n"] - node_info = ["\t- Enum conversion: ", node, "\n"] + variable_info: DETECTOR_INFO = [ + "\t- Variable: ", + var, + f" of type: {str(var.type)}\n", + ] + node_info: DETECTOR_INFO = ["\t- Enum conversion: ", node, "\n"] json = self.generate_result(func_info + variable_info + node_info) results.append(json) diff --git a/slither/detectors/compiler_bugs/multiple_constructor_schemes.py b/slither/detectors/compiler_bugs/multiple_constructor_schemes.py index 3486cc41b..ae325b2a6 100644 --- a/slither/detectors/compiler_bugs/multiple_constructor_schemes.py +++ b/slither/detectors/compiler_bugs/multiple_constructor_schemes.py @@ -1,6 +1,10 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -58,7 +62,10 @@ In Solidity [0.4.22](https://github.com/ethereum/solidity/releases/tag/v0.4.23), # If there is more than one, we encountered the described issue occurring. if constructors and len(constructors) > 1: - info = [contract, " contains multiple constructors in the same contract:\n"] + info: DETECTOR_INFO = [ + contract, + " contains multiple constructors in the same contract:\n", + ] for constructor in constructors: info += ["\t- ", constructor, "\n"] diff --git a/slither/detectors/compiler_bugs/reused_base_constructor.py b/slither/detectors/compiler_bugs/reused_base_constructor.py index 73cfac12e..73bd410c7 100644 --- a/slither/detectors/compiler_bugs/reused_base_constructor.py +++ b/slither/detectors/compiler_bugs/reused_base_constructor.py @@ -6,6 +6,7 @@ from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, + DETECTOR_INFO, ) from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract @@ -151,7 +152,7 @@ The constructor of `A` is called multiple times in `D` and `E`: continue # Generate data to output. - info = [ + info: DETECTOR_INFO = [ contract, " gives base constructor ", base_constructor, diff --git a/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py b/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py index aee6361c6..dd34eb5e0 100644 --- a/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py +++ b/slither/detectors/compiler_bugs/storage_ABIEncoderV2_array.py @@ -6,6 +6,7 @@ from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, make_solc_versions, + DETECTOR_INFO, ) from slither.core.solidity_types import ArrayType from slither.core.solidity_types import UserDefinedType @@ -122,7 +123,13 @@ contract A { for contract in self.contracts: storage_abiencoderv2_arrays = self._detect_storage_abiencoderv2_arrays(contract) for function, node in storage_abiencoderv2_arrays: - info = ["Function ", function, " trigger an abi encoding bug:\n\t- ", node, "\n"] + info: DETECTOR_INFO = [ + "Function ", + function, + " trigger an abi encoding bug:\n\t- ", + node, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/compiler_bugs/storage_signed_integer_array.py b/slither/detectors/compiler_bugs/storage_signed_integer_array.py index 736f66789..cfd13cdbc 100644 --- a/slither/detectors/compiler_bugs/storage_signed_integer_array.py +++ b/slither/detectors/compiler_bugs/storage_signed_integer_array.py @@ -1,18 +1,21 @@ """ Module detecting storage signed integer array bug """ -from typing import List +from typing import List, Tuple, Set +from slither.core.declarations import Function, Contract from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, make_solc_versions, + DETECTOR_INFO, ) -from slither.core.cfg.node import NodeType +from slither.core.cfg.node import NodeType, Node from slither.core.solidity_types import ArrayType from slither.core.solidity_types.elementary_type import Int, ElementaryType from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable +from slither.slithir.operations import Operation, OperationWithLValue from slither.slithir.operations.assignment import Assignment from slither.slithir.operations.init_array import InitArray from slither.utils.output import Output @@ -60,7 +63,7 @@ contract A { VULNERABLE_SOLC_VERSIONS = make_solc_versions(4, 7, 25) + make_solc_versions(5, 0, 9) @staticmethod - def _is_vulnerable_type(ir): + def _is_vulnerable_type(ir: Operation) -> bool: """ Detect if the IR lvalue is a vulnerable type Must be a storage allocation, and an array of Int @@ -68,23 +71,28 @@ contract A { """ # Storage allocation # Base type is signed integer + if not isinstance(ir, OperationWithLValue): + return False + return ( ( isinstance(ir.lvalue, StateVariable) or (isinstance(ir.lvalue, LocalVariable) and ir.lvalue.is_storage) ) - and isinstance(ir.lvalue.type.type, ElementaryType) - and ir.lvalue.type.type.type in Int + and isinstance(ir.lvalue.type.type, ElementaryType) # type: ignore + and ir.lvalue.type.type.type in Int # type: ignore ) - def detect_storage_signed_integer_arrays(self, contract): + def detect_storage_signed_integer_arrays( + self, contract: Contract + ) -> Set[Tuple[Function, Node]]: """ Detects and returns all nodes with storage-allocated signed integer array init/assignment :param contract: Contract to detect within :return: A list of tuples with (function, node) where function node has storage-allocated signed integer array init/assignment """ # Create our result set. - results = set() + results: Set[Tuple[Function, Node]] = set() # Loop for each function and modifier. for function in contract.functions_and_modifiers_declared: @@ -118,9 +126,13 @@ contract A { for contract in self.contracts: storage_signed_integer_arrays = self.detect_storage_signed_integer_arrays(contract) for function, node in storage_signed_integer_arrays: - contract_info = ["Contract ", contract, " \n"] - function_info = ["\t- Function ", function, "\n"] - node_info = ["\t\t- ", node, " has a storage signed integer array assignment\n"] + contract_info: DETECTOR_INFO = ["Contract ", contract, " \n"] + function_info: DETECTOR_INFO = ["\t- Function ", function, "\n"] + node_info: DETECTOR_INFO = [ + "\t\t- ", + node, + " has a storage signed integer array assignment\n", + ] res = self.generate_result(contract_info + function_info + node_info) results.append(res) diff --git a/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py b/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py index 6685948b3..826b671bd 100644 --- a/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py +++ b/slither/detectors/compiler_bugs/uninitialized_function_ptr_in_constructor.py @@ -6,6 +6,7 @@ from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, make_solc_versions, + DETECTOR_INFO, ) from slither.slithir.operations import InternalDynamicCall, OperationWithLValue from slither.slithir.variables import ReferenceVariable @@ -115,10 +116,10 @@ The call to `a(10)` will lead to unexpected behavior because function pointer `a results = [] for contract in self.compilation_unit.contracts: - contract_info = ["Contract ", contract, " \n"] + contract_info: DETECTOR_INFO = ["Contract ", contract, " \n"] nodes = self._detect_uninitialized_function_ptr_in_constructor(contract) for node in nodes: - node_info = [ + node_info: DETECTOR_INFO = [ "\t ", node, " is an unintialized function pointer call in a constructor\n", diff --git a/slither/detectors/erc/erc20/incorrect_erc20_interface.py b/slither/detectors/erc/erc20/incorrect_erc20_interface.py index 4da6ab5ae..a17f04e8c 100644 --- a/slither/detectors/erc/erc20/incorrect_erc20_interface.py +++ b/slither/detectors/erc/erc20/incorrect_erc20_interface.py @@ -6,7 +6,11 @@ from typing import List, Tuple from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -109,7 +113,7 @@ contract Token{ functions = IncorrectERC20InterfaceDetection.detect_incorrect_erc20_interface(c) if functions: for function in functions: - info = [ + info: DETECTOR_INFO = [ c, " has incorrect ERC20 function interface:", function, diff --git a/slither/detectors/erc/incorrect_erc721_interface.py b/slither/detectors/erc/incorrect_erc721_interface.py index 8327e8b2e..9d19b5c02 100644 --- a/slither/detectors/erc/incorrect_erc721_interface.py +++ b/slither/detectors/erc/incorrect_erc721_interface.py @@ -2,7 +2,11 @@ Detect incorrect erc721 interface. """ from typing import Any, List, Tuple, Union -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.utils.output import Output @@ -119,7 +123,7 @@ contract Token{ functions = IncorrectERC721InterfaceDetection.detect_incorrect_erc721_interface(c) if functions: for function in functions: - info = [ + info: DETECTOR_INFO = [ c, " has incorrect ERC721 function interface:", function, diff --git a/slither/detectors/examples/backdoor.py b/slither/detectors/examples/backdoor.py index 0e8e9ad81..392834641 100644 --- a/slither/detectors/examples/backdoor.py +++ b/slither/detectors/examples/backdoor.py @@ -1,6 +1,10 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -28,7 +32,7 @@ class Backdoor(AbstractDetector): for f in contract.functions: if "backdoor" in f.name: # Info to be printed - info = ["Backdoor function found in ", f, "\n"] + info: DETECTOR_INFO = ["Backdoor function found in ", f, "\n"] # Add the result in result res = self.generate_result(info) diff --git a/slither/detectors/functions/arbitrary_send_eth.py b/slither/detectors/functions/arbitrary_send_eth.py index e0112c6e1..f6c688a3f 100644 --- a/slither/detectors/functions/arbitrary_send_eth.py +++ b/slither/detectors/functions/arbitrary_send_eth.py @@ -18,7 +18,9 @@ from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.solidity_variables import ( SolidityFunction, SolidityVariableComposed, + SolidityVariable, ) +from slither.core.variables import Variable from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import ( HighLevelCall, @@ -72,8 +74,9 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: ): continue - if is_tainted(ir.destination, node): - ret.append(node) + if isinstance(ir.destination, (Variable, SolidityVariable)): + if is_tainted(ir.destination, node): + ret.append(node) return ret diff --git a/slither/detectors/functions/cyclomatic_complexity.py b/slither/detectors/functions/cyclomatic_complexity.py index f03cf61b8..d8258994f 100644 --- a/slither/detectors/functions/cyclomatic_complexity.py +++ b/slither/detectors/functions/cyclomatic_complexity.py @@ -1,7 +1,11 @@ from typing import List, Tuple from slither.core.declarations import Function -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.code_complexity import compute_cyclomatic_complexity from slither.utils.output import Output @@ -44,7 +48,7 @@ class CyclomaticComplexity(AbstractDetector): _check_for_high_cc(high_cc_functions, f) for f, cc in high_cc_functions: - info = [f, f" has a high cyclomatic complexity ({cc}).\n"] + info: DETECTOR_INFO = [f, f" has a high cyclomatic complexity ({cc}).\n"] res = self.generate_result(info) results.append(res) return results diff --git a/slither/detectors/functions/dead_code.py b/slither/detectors/functions/dead_code.py index 1a25c5776..98eb97ff7 100644 --- a/slither/detectors/functions/dead_code.py +++ b/slither/detectors/functions/dead_code.py @@ -4,7 +4,11 @@ Module detecting dead code from typing import List, Tuple from slither.core.declarations import Function, FunctionContract, Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -72,7 +76,7 @@ contract Contract{ # Continue if the functon is not implemented because it means the contract is abstract if not function.is_implemented: continue - info = [function, " is never used and should be removed\n"] + info: DETECTOR_INFO = [function, " is never used and should be removed\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/functions/modifier.py b/slither/detectors/functions/modifier.py index 271d8e6cb..61ec1825e 100644 --- a/slither/detectors/functions/modifier.py +++ b/slither/detectors/functions/modifier.py @@ -6,7 +6,11 @@ are in the outermost scope, they do not guarantee a revert, so a default value can still be returned. """ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.cfg.node import Node, NodeType from slither.utils.output import Output @@ -82,7 +86,11 @@ If the condition in `myModif` is false, the execution of `get()` will return 0." node = None else: # Nothing was found in the outer scope - info = ["Modifier ", mod, " does not always execute _; or revert"] + info: DETECTOR_INFO = [ + "Modifier ", + mod, + " does not always execute _; or revert", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/functions/permit_domain_signature_collision.py b/slither/detectors/functions/permit_domain_signature_collision.py index de64ec52e..39543fb49 100644 --- a/slither/detectors/functions/permit_domain_signature_collision.py +++ b/slither/detectors/functions/permit_domain_signature_collision.py @@ -6,7 +6,11 @@ from typing import Union, List from slither.core.declarations import Function from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.function import get_function_id from slither.utils.output import Output @@ -63,7 +67,7 @@ contract Contract{ assert isinstance(func_or_var, StateVariable) incorrect_return_type = func_or_var.type != ElementaryType("bytes32") if hash_collision or incorrect_return_type: - info = [ + info: DETECTOR_INFO = [ "The function signature of ", func_or_var, " collides with DOMAIN_SEPARATOR and should be renamed or removed.\n", diff --git a/slither/detectors/functions/protected_variable.py b/slither/detectors/functions/protected_variable.py index 68ed098c7..579672926 100644 --- a/slither/detectors/functions/protected_variable.py +++ b/slither/detectors/functions/protected_variable.py @@ -6,7 +6,11 @@ A suicidal contract is an unprotected function that calls selfdestruct from typing import List from slither.core.declarations import Function, Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -58,7 +62,7 @@ contract Buggy{ self.logger.error(f"{function_sig} not found") continue if function_protection not in function.all_internal_calls(): - info = [ + info: DETECTOR_INFO = [ function, " should have ", function_protection, diff --git a/slither/detectors/functions/suicidal.py b/slither/detectors/functions/suicidal.py index 7741da57d..1f8cb52f9 100644 --- a/slither/detectors/functions/suicidal.py +++ b/slither/detectors/functions/suicidal.py @@ -7,7 +7,11 @@ from typing import List from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -78,7 +82,7 @@ Bob calls `kill` and destructs the contract.""" functions = self.detect_suicidal(c) for func in functions: - info = [func, " allows anyone to destruct the contract\n"] + info: DETECTOR_INFO = [func, " allows anyone to destruct the contract\n"] res = self.generate_result(info) diff --git a/slither/detectors/functions/unimplemented.py b/slither/detectors/functions/unimplemented.py index 11a1fad80..27a2d94a9 100644 --- a/slither/detectors/functions/unimplemented.py +++ b/slither/detectors/functions/unimplemented.py @@ -8,7 +8,13 @@ Consider public state variables as implemented functions Do not consider fallback function or constructor """ from typing import List, Set -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification + +from slither.core.declarations import Function +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.utils.output import Output @@ -62,7 +68,7 @@ All unimplemented functions must be implemented on a contract that is meant to b def _match_state_variable(contract: Contract, f: FunctionContract) -> bool: return any(s.full_name == f.full_name for s in contract.state_variables) - def _detect_unimplemented_function(self, contract: Contract) -> Set[FunctionContract]: + def _detect_unimplemented_function(self, contract: Contract) -> Set[Function]: """ Detects any function definitions which are not implemented in the given contract. :param contract: The contract to search unimplemented functions for. @@ -77,6 +83,8 @@ All unimplemented functions must be implemented on a contract that is meant to b # fallback function and constructor. unimplemented = set() for f in contract.all_functions_called: + if not isinstance(f, Function): + continue if ( not f.is_implemented and not f.is_constructor @@ -102,7 +110,7 @@ All unimplemented functions must be implemented on a contract that is meant to b for contract in self.compilation_unit.contracts_derived: functions = self._detect_unimplemented_function(contract) if functions: - info = [contract, " does not implement functions:\n"] + info: DETECTOR_INFO = [contract, " does not implement functions:\n"] for function in sorted(functions, key=lambda x: x.full_name): info += ["\t- ", function, "\n"] diff --git a/slither/detectors/naming_convention/naming_convention.py b/slither/detectors/naming_convention/naming_convention.py index 96d3964fa..02deb719e 100644 --- a/slither/detectors/naming_convention/naming_convention.py +++ b/slither/detectors/naming_convention/naming_convention.py @@ -1,6 +1,10 @@ import re from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.formatters.naming_convention.naming_convention import custom_format from slither.utils.output import Output @@ -63,6 +67,7 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2 def _detect(self) -> List[Output]: results = [] + info: DETECTOR_INFO for contract in self.contracts: if not self.is_cap_words(contract.name): diff --git a/slither/detectors/operations/bad_prng.py b/slither/detectors/operations/bad_prng.py index d8bf28f6c..f816e96c8 100644 --- a/slither/detectors/operations/bad_prng.py +++ b/slither/detectors/operations/bad_prng.py @@ -50,14 +50,17 @@ def contains_bad_PRNG_sources(func: Function, blockhash_ret_values: List[Variabl for node in func.nodes: for ir in node.irs_ssa: if isinstance(ir, Binary) and ir.type == BinaryType.MODULO: + var_left = ir.variable_left + if not isinstance(var_left, (Variable, SolidityVariable)): + continue if is_dependent_ssa( - ir.variable_left, SolidityVariableComposed("block.timestamp"), func.contract - ) or is_dependent_ssa(ir.variable_left, SolidityVariable("now"), func.contract): + var_left, SolidityVariableComposed("block.timestamp"), node + ) or is_dependent_ssa(var_left, SolidityVariable("now"), node): ret.add(node) break for ret_val in blockhash_ret_values: - if is_dependent_ssa(ir.variable_left, ret_val, func.contract): + if is_dependent_ssa(var_left, ret_val, node): ret.add(node) break return list(ret) diff --git a/slither/detectors/operations/block_timestamp.py b/slither/detectors/operations/block_timestamp.py index b80c8c392..d5c2c8df7 100644 --- a/slither/detectors/operations/block_timestamp.py +++ b/slither/detectors/operations/block_timestamp.py @@ -6,12 +6,17 @@ from typing import List, Tuple from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node -from slither.core.declarations import Function, Contract +from slither.core.declarations import Function, Contract, FunctionContract from slither.core.declarations.solidity_variables import ( SolidityVariableComposed, SolidityVariable, ) -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.core.variables import Variable +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Binary, BinaryType from slither.utils.output import Output @@ -21,25 +26,25 @@ def _timestamp(func: Function) -> List[Node]: for node in func.nodes: if node.contains_require_or_assert(): for var in node.variables_read: - if is_dependent(var, SolidityVariableComposed("block.timestamp"), func.contract): + if is_dependent(var, SolidityVariableComposed("block.timestamp"), node): ret.add(node) - if is_dependent(var, SolidityVariable("now"), func.contract): + if is_dependent(var, SolidityVariable("now"), node): ret.add(node) for ir in node.irs: if isinstance(ir, Binary) and BinaryType.return_bool(ir.type): - for var in ir.read: - if is_dependent( - var, SolidityVariableComposed("block.timestamp"), func.contract - ): + for var_read in ir.read: + if not isinstance(var_read, (Variable, SolidityVariable)): + continue + if is_dependent(var_read, SolidityVariableComposed("block.timestamp"), node): ret.add(node) - if is_dependent(var, SolidityVariable("now"), func.contract): + if is_dependent(var_read, SolidityVariable("now"), node): ret.add(node) return sorted(list(ret), key=lambda x: x.node_id) def _detect_dangerous_timestamp( contract: Contract, -) -> List[Tuple[Function, List[Node]]]: +) -> List[Tuple[FunctionContract, List[Node]]]: """ Args: contract (Contract) @@ -48,7 +53,7 @@ def _detect_dangerous_timestamp( """ ret = [] for f in [f for f in contract.functions if f.contract_declarer == contract]: - nodes = _timestamp(f) + nodes: List[Node] = _timestamp(f) if nodes: ret.append((f, nodes)) return ret @@ -78,7 +83,7 @@ class Timestamp(AbstractDetector): dangerous_timestamp = _detect_dangerous_timestamp(c) for (func, nodes) in dangerous_timestamp: - info = [func, " uses timestamp for comparisons\n"] + info: DETECTOR_INFO = [func, " uses timestamp for comparisons\n"] info += ["\tDangerous comparisons:\n"] diff --git a/slither/detectors/operations/low_level_calls.py b/slither/detectors/operations/low_level_calls.py index 1ea91c37a..463c74875 100644 --- a/slither/detectors/operations/low_level_calls.py +++ b/slither/detectors/operations/low_level_calls.py @@ -2,7 +2,11 @@ Module detecting usage of low level calls """ from typing import List, Tuple -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract @@ -52,7 +56,7 @@ class LowLevelCalls(AbstractDetector): for c in self.contracts: values = self.detect_low_level_calls(c) for func, nodes in values: - info = ["Low level call in ", func, ":\n"] + info: DETECTOR_INFO = ["Low level call in ", func, ":\n"] # sort the nodes to get deterministic results nodes.sort(key=lambda x: x.node_id) diff --git a/slither/detectors/operations/missing_events_access_control.py b/slither/detectors/operations/missing_events_access_control.py index 20c229759..853eafd73 100644 --- a/slither/detectors/operations/missing_events_access_control.py +++ b/slither/detectors/operations/missing_events_access_control.py @@ -11,7 +11,11 @@ from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.modifier import Modifier from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations.event_call import EventCall from slither.utils.output import Output @@ -100,7 +104,7 @@ contract C { for contract in self.compilation_unit.contracts_derived: missing_events = self._detect_missing_events(contract) for (function, nodes) in missing_events: - info = [function, " should emit an event for: \n"] + info: DETECTOR_INFO = [function, " should emit an event for: \n"] for (node, _sv, _mod) in nodes: info += ["\t- ", node, " \n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/missing_events_arithmetic.py b/slither/detectors/operations/missing_events_arithmetic.py index 6e1d5fbb5..c17ed32a3 100644 --- a/slither/detectors/operations/missing_events_arithmetic.py +++ b/slither/detectors/operations/missing_events_arithmetic.py @@ -10,7 +10,11 @@ from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType, Int, Uint from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations.event_call import EventCall from slither.utils.output import Output @@ -122,7 +126,7 @@ contract C { for contract in self.compilation_unit.contracts_derived: missing_events = self._detect_missing_events(contract) for (function, nodes) in missing_events: - info = [function, " should emit an event for: \n"] + info: DETECTOR_INFO = [function, " should emit an event for: \n"] for (node, _) in nodes: info += ["\t- ", node, " \n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/missing_zero_address_validation.py b/slither/detectors/operations/missing_zero_address_validation.py index a6c8de9ff..4feac9d0c 100644 --- a/slither/detectors/operations/missing_zero_address_validation.py +++ b/slither/detectors/operations/missing_zero_address_validation.py @@ -12,7 +12,11 @@ from slither.core.declarations.function import ModifierStatements from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.variables.local_variable import LocalVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Call from slither.slithir.operations import Send, Transfer, LowLevelCall from slither.utils.output import Output @@ -155,7 +159,7 @@ Bob calls `updateOwner` without specifying the `newOwner`, so Bob loses ownershi missing_zero_address_validation = self._detect_missing_zero_address_validation(contract) for (_, var_nodes) in missing_zero_address_validation: for var, nodes in var_nodes.items(): - info = [var, " lacks a zero-check on ", ":\n"] + info: DETECTOR_INFO = [var, " lacks a zero-check on ", ":\n"] for node in nodes: info += ["\t\t- ", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/unused_return_values.py b/slither/detectors/operations/unused_return_values.py index 7edde20fc..93dda274a 100644 --- a/slither/detectors/operations/unused_return_values.py +++ b/slither/detectors/operations/unused_return_values.py @@ -7,7 +7,11 @@ from slither.core.cfg.node import Node from slither.core.declarations import Function from slither.core.declarations.function_contract import FunctionContract from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import HighLevelCall from slither.slithir.operations.operation import Operation from slither.utils.output import Output @@ -91,7 +95,7 @@ contract MyConc{ if unused_return: for node in unused_return: - info = [f, " ignores return value by ", node, "\n"] + info: DETECTOR_INFO = [f, " ignores return value by ", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/void_constructor.py b/slither/detectors/operations/void_constructor.py index fb44ea98c..365904fa9 100644 --- a/slither/detectors/operations/void_constructor.py +++ b/slither/detectors/operations/void_constructor.py @@ -1,6 +1,10 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Nop from slither.utils.output import Output @@ -39,7 +43,7 @@ When reading `B`'s constructor definition, we might assume that `A()` initiates for constructor_call in cst.explicit_base_constructor_calls_statements: for node in constructor_call.nodes: if any(isinstance(ir, Nop) for ir in node.irs): - info = ["Void constructor called in ", cst, ":\n"] + info: DETECTOR_INFO = ["Void constructor called in ", cst, ":\n"] info += ["\t- ", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/reentrancy/token.py b/slither/detectors/reentrancy/token.py index c960bffa7..d906a7303 100644 --- a/slither/detectors/reentrancy/token.py +++ b/slither/detectors/reentrancy/token.py @@ -4,7 +4,11 @@ from typing import Dict, List from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node from slither.core.declarations import Function, Contract, SolidityVariableComposed -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall, HighLevelCall from slither.utils.output import Output @@ -88,7 +92,7 @@ If you do, ensure your users are aware of the potential issues.""" for contract in self.compilation_unit.contracts_derived: vulns = _detect_token_reentrant(contract) for function, nodes in vulns.items(): - info = [function, " is an reentrancy unsafe token function:\n"] + info: DETECTOR_INFO = [function, " is an reentrancy unsafe token function:\n"] for node in nodes: info += ["\t-", node, "\n"] json = self.generate_result(info) diff --git a/slither/detectors/shadowing/builtin_symbols.py b/slither/detectors/shadowing/builtin_symbols.py index b0a44c8e2..ab5486105 100644 --- a/slither/detectors/shadowing/builtin_symbols.py +++ b/slither/detectors/shadowing/builtin_symbols.py @@ -9,7 +9,11 @@ from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.modifier import Modifier from slither.core.variables import Variable from slither.core.variables.local_variable import LocalVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -194,7 +198,7 @@ contract Bug { shadow_type = shadow[0] shadow_object = shadow[1] - info = [ + info: DETECTOR_INFO = [ shadow_object, f' ({shadow_type}) shadows built-in symbol"\n', ] diff --git a/slither/detectors/shadowing/local.py b/slither/detectors/shadowing/local.py index ad65b62d9..a705f45b0 100644 --- a/slither/detectors/shadowing/local.py +++ b/slither/detectors/shadowing/local.py @@ -9,7 +9,11 @@ from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.modifier import Modifier from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -84,7 +88,7 @@ contract Bug { ] = [] # Loop through all functions + modifiers in this contract. - for function in contract.functions + contract.modifiers: + for function in contract.functions + list(contract.modifiers): # We should only look for functions declared directly in this contract (not in a base contract). if function.contract_declarer != contract: continue @@ -134,7 +138,7 @@ contract Bug { for shadow in shadows: local_variable = shadow[0] overshadowed = shadow[1] - info = [local_variable, " shadows:\n"] + info: DETECTOR_INFO = [local_variable, " shadows:\n"] for overshadowed_entry in overshadowed: info += [ "\t- ", diff --git a/slither/detectors/shadowing/state.py b/slither/detectors/shadowing/state.py index 801c370a5..c08dbfd25 100644 --- a/slither/detectors/shadowing/state.py +++ b/slither/detectors/shadowing/state.py @@ -6,7 +6,11 @@ from typing import List from slither.core.declarations import Contract from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.detectors.shadowing.common import is_upgradable_gap_variable from slither.utils.output import Output @@ -89,7 +93,7 @@ contract DerivedContract is BaseContract{ for all_variables in shadowing: shadow = all_variables[0] variables = all_variables[1:] - info = [shadow, " shadows:\n"] + info: DETECTOR_INFO = [shadow, " shadows:\n"] for var in variables: info += ["\t- ", var, "\n"] diff --git a/slither/detectors/slither/name_reused.py b/slither/detectors/slither/name_reused.py index f6f2820fa..e8a40881a 100644 --- a/slither/detectors/slither/name_reused.py +++ b/slither/detectors/slither/name_reused.py @@ -2,7 +2,11 @@ from collections import defaultdict from typing import Any, List from slither.core.compilation_unit import SlitherCompilationUnit -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -80,7 +84,7 @@ As a result, the second contract cannot be analyzed. inheritance_corrupted[father.name].append(contract) for contract_name, files in names_reused.items(): - info = [contract_name, " is re-used:\n"] + info: DETECTOR_INFO = [contract_name, " is re-used:\n"] for file in files: if file is None: info += ["\t- In an file not found, most likely in\n"] diff --git a/slither/detectors/source/rtlo.py b/slither/detectors/source/rtlo.py index f89eb70eb..b020f69f9 100644 --- a/slither/detectors/source/rtlo.py +++ b/slither/detectors/source/rtlo.py @@ -1,7 +1,11 @@ import re from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -78,7 +82,7 @@ contract Token idx = start_index + result_index relative = self.slither.crytic_compile.filename_lookup(filename).relative - info = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n" + info: DETECTOR_INFO = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n" # We have a patch, so pattern.find will return at least one result diff --git a/slither/detectors/statements/array_length_assignment.py b/slither/detectors/statements/array_length_assignment.py index f4ae01b88..70dc5aadb 100644 --- a/slither/detectors/statements/array_length_assignment.py +++ b/slither/detectors/statements/array_length_assignment.py @@ -1,13 +1,14 @@ """ Module detecting assignment of array length """ -from typing import List, Set +from typing import List, Set, Union + +from slither.core.variables import Variable from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_05, - DETECTOR_INFO, ) from slither.core.cfg.node import Node, NodeType from slither.slithir.operations import Assignment, Length @@ -15,7 +16,7 @@ from slither.slithir.variables.reference import ReferenceVariable from slither.slithir.operations.binary import Binary from slither.analyses.data_dependency.data_dependency import is_tainted from slither.core.declarations.contract import Contract -from slither.utils.output import Output +from slither.utils.output import Output, SupportedOutput def detect_array_length_assignment(contract: Contract) -> Set[Node]: @@ -51,7 +52,7 @@ def detect_array_length_assignment(contract: Contract) -> Set[Node]: elif isinstance(ir, (Assignment, Binary)): if isinstance(ir.lvalue, ReferenceVariable): if ir.lvalue in array_length_refs and any( - is_tainted(v, contract) for v in ir.read + is_tainted(v, contract) for v in ir.read if isinstance(v, Variable) ): # the taint is not precise enough yet # as a result, REF_0 = REF_0 + 1 @@ -121,12 +122,16 @@ Otherwise, thoroughly review the contract to ensure a user-controlled variable c for contract in self.contracts: array_length_assignments = detect_array_length_assignment(contract) if array_length_assignments: - contract_info: DETECTOR_INFO = [ + contract_info: List[Union[str, SupportedOutput]] = [ contract, " contract sets array length with a user-controlled value:\n", ] for node in array_length_assignments: - node_info = contract_info + ["\t- ", node, "\n"] + node_info: List[Union[str, SupportedOutput]] = contract_info + [ + "\t- ", + node, + "\n", + ] res = self.generate_result(node_info) results.append(res) diff --git a/slither/detectors/statements/assert_state_change.py b/slither/detectors/statements/assert_state_change.py index c82919de6..62299202e 100644 --- a/slither/detectors/statements/assert_state_change.py +++ b/slither/detectors/statements/assert_state_change.py @@ -6,7 +6,11 @@ from typing import List, Tuple from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations.internal_call import InternalCall from slither.utils.output import Output @@ -25,7 +29,7 @@ def detect_assert_state_change( results = [] # Loop for each function and modifier. - for function in contract.functions_declared + contract.modifiers_declared: + for function in contract.functions_declared + list(contract.modifiers_declared): for node in function.nodes: # Detect assert() calls if any(c.name == "assert(bool)" for c in node.internal_calls) and ( @@ -85,7 +89,10 @@ The assert in `bad()` increments the state variable `s_a` while checking for the for contract in self.contracts: assert_state_change = detect_assert_state_change(contract) for (func, node) in assert_state_change: - info = [func, " has an assert() call which possibly changes state.\n"] + info: DETECTOR_INFO = [ + func, + " has an assert() call which possibly changes state.\n", + ] info += ["\t-", node, "\n"] info += [ "Consider using require() or change the invariant to not modify the state.\n" diff --git a/slither/detectors/statements/boolean_constant_equality.py b/slither/detectors/statements/boolean_constant_equality.py index 5b91f364f..97eb14aa5 100644 --- a/slither/detectors/statements/boolean_constant_equality.py +++ b/slither/detectors/statements/boolean_constant_equality.py @@ -6,7 +6,11 @@ from typing import List, Set, Tuple from slither.core.cfg.node import Node from slither.core.declarations import Function from slither.core.declarations.contract import Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import ( Binary, BinaryType, @@ -84,7 +88,7 @@ Boolean constants can be used directly and do not need to be compare to `true` o boolean_constant_misuses = self._detect_boolean_equality(contract) for (func, nodes) in boolean_constant_misuses: for node in nodes: - info = [ + info: DETECTOR_INFO = [ func, " compares to a boolean constant:\n\t-", node, diff --git a/slither/detectors/statements/boolean_constant_misuse.py b/slither/detectors/statements/boolean_constant_misuse.py index 96dd2012f..093e43fee 100644 --- a/slither/detectors/statements/boolean_constant_misuse.py +++ b/slither/detectors/statements/boolean_constant_misuse.py @@ -7,7 +7,11 @@ from slither.core.cfg.node import Node, NodeType from slither.core.declarations import Function from slither.core.declarations.contract import Contract from slither.core.solidity_types import ElementaryType -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import ( Assignment, Call, @@ -120,7 +124,7 @@ Other uses (in complex expressions, as conditionals) indicate either an error or boolean_constant_misuses = self._detect_boolean_constant_misuses(contract) for (func, nodes) in boolean_constant_misuses: for node in nodes: - info = [ + info: DETECTOR_INFO = [ func, " uses a Boolean constant improperly:\n\t-", node, diff --git a/slither/detectors/statements/calls_in_loop.py b/slither/detectors/statements/calls_in_loop.py index fdd0c6732..b3a177ee6 100644 --- a/slither/detectors/statements/calls_in_loop.py +++ b/slither/detectors/statements/calls_in_loop.py @@ -1,6 +1,10 @@ from typing import List, Optional from slither.core.cfg.node import NodeType, Node -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations import Contract from slither.utils.output import Output from slither.slithir.operations import ( @@ -94,7 +98,7 @@ If one of the destinations has a fallback function that reverts, `bad` will alwa for node in values: func = node.function - info = [func, " has external calls inside a loop: ", node, "\n"] + info: DETECTOR_INFO = [func, " has external calls inside a loop: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/controlled_delegatecall.py b/slither/detectors/statements/controlled_delegatecall.py index 08280940d..32e59d6eb 100644 --- a/slither/detectors/statements/controlled_delegatecall.py +++ b/slither/detectors/statements/controlled_delegatecall.py @@ -3,7 +3,11 @@ from typing import List from slither.analyses.data_dependency.data_dependency import is_tainted from slither.core.cfg.node import Node from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall from slither.utils.output import Output @@ -58,13 +62,13 @@ Bob calls `delegate` and delegates the execution to his malicious contract. As a continue nodes = controlled_delegatecall(f) if nodes: - func_info = [ + func_info: DETECTOR_INFO = [ f, " uses delegatecall to a input-controlled function id\n", ] for node in nodes: - node_info = func_info + ["\t- ", node, "\n"] + node_info: DETECTOR_INFO = func_info + ["\t- ", node, "\n"] res = self.generate_result(node_info) results.append(res) diff --git a/slither/detectors/statements/costly_operations_in_loop.py b/slither/detectors/statements/costly_operations_in_loop.py index 930085cc6..6af04329c 100644 --- a/slither/detectors/statements/costly_operations_in_loop.py +++ b/slither/detectors/statements/costly_operations_in_loop.py @@ -1,6 +1,10 @@ from typing import List, Optional from slither.core.cfg.node import NodeType, Node -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.core.declarations import Contract from slither.utils.output import Output from slither.slithir.operations import InternalCall, OperationWithLValue @@ -98,7 +102,7 @@ Incrementing `state_variable` in a loop incurs a lot of gas because of expensive values = detect_costly_operations_in_loop(c) for node in values: func = node.function - info = [func, " has costly operations inside a loop:\n"] + info: DETECTOR_INFO = [func, " has costly operations inside a loop:\n"] info += ["\t- ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/delegatecall_in_loop.py b/slither/detectors/statements/delegatecall_in_loop.py index b7bf70cbc..d97466edf 100644 --- a/slither/detectors/statements/delegatecall_in_loop.py +++ b/slither/detectors/statements/delegatecall_in_loop.py @@ -1,6 +1,10 @@ from typing import List, Optional from slither.core.cfg.node import NodeType, Node -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall, InternalCall from slither.core.declarations import Contract from slither.utils.output import Output @@ -94,7 +98,12 @@ Carefully check that the function called by `delegatecall` is not payable/doesn' for node in values: func = node.function - info = [func, " has delegatecall inside a loop in a payable function: ", node, "\n"] + info: DETECTOR_INFO = [ + func, + " has delegatecall inside a loop in a payable function: ", + node, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/deprecated_calls.py b/slither/detectors/statements/deprecated_calls.py index 3d0ca4ba9..e59d254bb 100644 --- a/slither/detectors/statements/deprecated_calls.py +++ b/slither/detectors/statements/deprecated_calls.py @@ -11,7 +11,11 @@ from slither.core.declarations.solidity_variables import ( ) from slither.core.expressions.expression import Expression from slither.core.variables import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall from slither.utils.output import Output from slither.visitors.expression.export_values import ExportValues @@ -186,7 +190,7 @@ contract ContractWithDeprecatedReferences { for deprecated_reference in deprecated_references: source_object = deprecated_reference[0] deprecated_entries = deprecated_reference[1] - info = ["Deprecated standard detected ", source_object, ":\n"] + info: DETECTOR_INFO = ["Deprecated standard detected ", source_object, ":\n"] for (_dep_id, original_desc, recommended_disc) in deprecated_entries: info += [ diff --git a/slither/detectors/statements/divide_before_multiply.py b/slither/detectors/statements/divide_before_multiply.py index a9de76b40..6f199db41 100644 --- a/slither/detectors/statements/divide_before_multiply.py +++ b/slither/detectors/statements/divide_before_multiply.py @@ -2,13 +2,18 @@ Module detecting possible loss of precision due to divide before multiple """ from collections import defaultdict -from typing import Any, DefaultDict, List, Set, Tuple +from typing import DefaultDict, List, Set, Tuple from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Binary, Assignment, BinaryType, LibraryCall, Operation +from slither.slithir.utils.utils import LVALUE from slither.slithir.variables import Constant from slither.utils.output import Output @@ -19,7 +24,7 @@ def is_division(ir: Operation) -> bool: return True if isinstance(ir, LibraryCall): - if ir.function.name.lower() in [ + if ir.function.name and ir.function.name.lower() in [ "div", "safediv", ]: @@ -35,7 +40,7 @@ def is_multiplication(ir: Operation) -> bool: return True if isinstance(ir, LibraryCall): - if ir.function.name.lower() in [ + if ir.function.name and ir.function.name.lower() in [ "mul", "safemul", ]: @@ -58,7 +63,7 @@ def is_assert(node: Node) -> bool: # pylint: disable=too-many-branches def _explore( - to_explore: Set[Node], f_results: List[Node], divisions: DefaultDict[Any, Any] + to_explore: Set[Node], f_results: List[List[Node]], divisions: DefaultDict[LVALUE, List[Node]] ) -> None: explored = set() while to_explore: # pylint: disable=too-many-nested-blocks @@ -70,22 +75,22 @@ def _explore( equality_found = False # List of nodes related to one bug instance - node_results = [] + node_results: List[Node] = [] for ir in node.irs: if isinstance(ir, Assignment): if ir.rvalue in divisions: # Avoid dupplicate. We dont use set so we keep the order of the nodes - if node not in divisions[ir.rvalue]: - divisions[ir.lvalue] = divisions[ir.rvalue] + [node] + if node not in divisions[ir.rvalue]: # type: ignore + divisions[ir.lvalue] = divisions[ir.rvalue] + [node] # type: ignore else: - divisions[ir.lvalue] = divisions[ir.rvalue] + divisions[ir.lvalue] = divisions[ir.rvalue] # type: ignore if is_division(ir): - divisions[ir.lvalue] = [node] + divisions[ir.lvalue] = [node] # type: ignore if is_multiplication(ir): - mul_arguments = ir.read if isinstance(ir, Binary) else ir.arguments + mul_arguments = ir.read if isinstance(ir, Binary) else ir.arguments # type: ignore nodes = [] for r in mul_arguments: if not isinstance(r, Constant) and (r in divisions): @@ -125,7 +130,7 @@ def detect_divide_before_multiply( # List of tuple (function -> list(list(nodes))) # Each list(nodes) of the list is one bug instances # Each node in the list(nodes) is involved in the bug - results = [] + results: List[Tuple[FunctionContract, List[Node]]] = [] # Loop for each function and modifier. for function in contract.functions_declared: @@ -134,11 +139,11 @@ def detect_divide_before_multiply( # List of list(nodes) # Each list(nodes) is one bug instances - f_results = [] + f_results: List[List[Node]] = [] # lvalue -> node # track all the division results (and the assignment of the division results) - divisions = defaultdict(list) + divisions: DefaultDict[LVALUE, List[Node]] = defaultdict(list) _explore({function.entry_point}, f_results, divisions) @@ -190,7 +195,7 @@ In general, it's usually a good idea to re-arrange arithmetic to perform multipl if divisions_before_multiplications: for (func, nodes) in divisions_before_multiplications: - info = [ + info: DETECTOR_INFO = [ func, " performs a multiplication on the result of a division:\n", ] diff --git a/slither/detectors/statements/mapping_deletion.py b/slither/detectors/statements/mapping_deletion.py index 59882cc96..4cdac7240 100644 --- a/slither/detectors/statements/mapping_deletion.py +++ b/slither/detectors/statements/mapping_deletion.py @@ -8,7 +8,11 @@ from slither.core.declarations import Structure from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract from slither.core.solidity_types import MappingType, UserDefinedType -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Delete from slither.utils.output import Output @@ -83,7 +87,7 @@ The mapping `balances` is never deleted, so `remove` does not work as intended." for c in self.contracts: mapping = MappingDeletionDetection.detect_mapping_deletion(c) for (func, struct, node) in mapping: - info = [func, " deletes ", struct, " which contains a mapping:\n"] + info: DETECTOR_INFO = [func, " deletes ", struct, " which contains a mapping:\n"] info += ["\t-", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/statements/msg_value_in_loop.py b/slither/detectors/statements/msg_value_in_loop.py index bfd541201..55bd9bfc2 100644 --- a/slither/detectors/statements/msg_value_in_loop.py +++ b/slither/detectors/statements/msg_value_in_loop.py @@ -1,6 +1,10 @@ from typing import List, Optional from slither.core.cfg.node import NodeType, Node -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import InternalCall from slither.core.declarations import SolidityVariableComposed, Contract from slither.utils.output import Output @@ -86,7 +90,7 @@ Track msg.value through a local variable and decrease its amount on every iterat for node in values: func = node.function - info = [func, " use msg.value in a loop: ", node, "\n"] + info: DETECTOR_INFO = [func, " use msg.value in a loop: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/redundant_statements.py b/slither/detectors/statements/redundant_statements.py index 7e7223134..cebaecebe 100644 --- a/slither/detectors/statements/redundant_statements.py +++ b/slither/detectors/statements/redundant_statements.py @@ -7,7 +7,11 @@ from slither.core.cfg.node import Node, NodeType from slither.core.declarations.contract import Contract from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression from slither.core.expressions.identifier import Identifier -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -87,7 +91,13 @@ Each commented line references types/identifiers, but performs no action with th if redundant_statements: for redundant_statement in redundant_statements: - info = ['Redundant expression "', redundant_statement, '" in', contract, "\n"] + info: DETECTOR_INFO = [ + 'Redundant expression "', + redundant_statement, + '" in', + contract, + "\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/statements/too_many_digits.py b/slither/detectors/statements/too_many_digits.py index 239efa4be..a5e09a34c 100644 --- a/slither/detectors/statements/too_many_digits.py +++ b/slither/detectors/statements/too_many_digits.py @@ -7,7 +7,11 @@ from typing import List from slither.core.cfg.node import Node from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.variables import Constant from slither.utils.output import Output @@ -88,9 +92,9 @@ Use: # iterate over all the nodes ret = self._detect_too_many_digits(f) if ret: - func_info = [f, " uses literals with too many digits:"] + func_info: DETECTOR_INFO = [f, " uses literals with too many digits:"] for node in ret: - node_info = func_info + ["\n\t- ", node, "\n"] + node_info: DETECTOR_INFO = func_info + ["\n\t- ", node, "\n"] # Add the result in result res = self.generate_result(node_info) diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py index 34f8173d5..49bf6006d 100644 --- a/slither/detectors/statements/tx_origin.py +++ b/slither/detectors/statements/tx_origin.py @@ -6,7 +6,11 @@ from typing import List, Tuple from slither.core.cfg.node import Node from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -80,7 +84,7 @@ Bob is the owner of `TxOrigin`. Bob calls Eve's contract. Eve's contract calls ` for func, nodes in values: for node in nodes: - info = [func, " uses tx.origin for authorization: ", node, "\n"] + info: DETECTOR_INFO = [func, " uses tx.origin for authorization: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/type_based_tautology.py b/slither/detectors/statements/type_based_tautology.py index 9edb1f53e..2e0fc8480 100644 --- a/slither/detectors/statements/type_based_tautology.py +++ b/slither/detectors/statements/type_based_tautology.py @@ -17,10 +17,9 @@ def typeRange(t: str) -> Tuple[int, int]: bits = int(t.split("int")[1]) if t in Uint: return 0, (2**bits) - 1 - if t in Int: - v = (2 ** (bits - 1)) - 1 - return -v, v - return None + assert t in Int + v = (2 ** (bits - 1)) - 1 + return -v, v def _detect_tautology_or_contradiction(low: int, high: int, cval: int, op: BinaryType) -> bool: diff --git a/slither/detectors/statements/unary.py b/slither/detectors/statements/unary.py index 5bb8d9c3c..152a39736 100644 --- a/slither/detectors/statements/unary.py +++ b/slither/detectors/statements/unary.py @@ -5,7 +5,11 @@ from typing import List from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.unary_operation import UnaryOperationType, UnaryOperation -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from slither.visitors.expression.expression import ExpressionVisitor @@ -74,7 +78,10 @@ contract Bug{ variable.expression and InvalidUnaryStateVariableDetector(variable.expression).result() ): - info = [variable, f" uses an dangerous unary operator: {variable.expression}\n"] + info: DETECTOR_INFO = [ + variable, + f" uses an dangerous unary operator: {variable.expression}\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/statements/unprotected_upgradeable.py b/slither/detectors/statements/unprotected_upgradeable.py index 1adf49540..30e6300f1 100644 --- a/slither/detectors/statements/unprotected_upgradeable.py +++ b/slither/detectors/statements/unprotected_upgradeable.py @@ -2,7 +2,11 @@ from typing import List from slither.core.declarations import SolidityFunction, Function from slither.core.declarations.contract import Contract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import LowLevelCall, SolidityCall from slither.utils.output import Output @@ -110,17 +114,15 @@ Buggy is an upgradeable contract. Anyone can call initialize on the logic contra item for sublist in vars_init_in_constructors_ for item in sublist ] if vars_init and (set(vars_init) - set(vars_init_in_constructors)): - info = ( - [ - contract, - " is an upgradeable contract that does not protect its initialize functions: ", - ] - + initialize_functions - + [ - ". Anyone can delete the contract with: ", - ] - + functions_that_can_destroy - ) + info: DETECTOR_INFO = [ + contract, + " is an upgradeable contract that does not protect its initialize functions: ", + ] + info += initialize_functions + info += [ + ". Anyone can delete the contract with: ", + ] + info += functions_that_can_destroy res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/write_after_write.py b/slither/detectors/statements/write_after_write.py index 5b2e29925..40a82d3ff 100644 --- a/slither/detectors/statements/write_after_write.py +++ b/slither/detectors/statements/write_after_write.py @@ -4,7 +4,11 @@ from slither.core.cfg.node import Node, NodeType from slither.core.solidity_types import ElementaryType from slither.core.variables.state_variable import StateVariable from slither.core.variables.variable import Variable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import ( OperationWithLValue, HighLevelCall, @@ -128,10 +132,17 @@ class WriteAfterWrite(AbstractDetector): for contract in self.compilation_unit.contracts_derived: for function in contract.functions: if function.entry_point: - ret = [] + ret: List[Tuple[Variable, Node, Node]] = [] _detect_write_after_write(function.entry_point, set(), {}, ret) for var, node1, node2 in ret: - info = [var, " is written in both\n\t", node1, "\n\t", node2, "\n"] + info: DETECTOR_INFO = [ + var, + " is written in both\n\t", + node1, + "\n\t", + node2, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/variables/function_init_state_variables.py b/slither/detectors/variables/function_init_state_variables.py index e35cfe351..e440a4f96 100644 --- a/slither/detectors/variables/function_init_state_variables.py +++ b/slither/detectors/variables/function_init_state_variables.py @@ -6,7 +6,11 @@ from typing import List from slither.core.declarations.contract import Contract from slither.core.declarations.function import Function from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from slither.visitors.expression.export_values import ExportValues @@ -104,7 +108,7 @@ Special care must be taken when initializing state variables from an immediate f state_variables = detect_function_init_state_vars(contract) if state_variables: for state_variable in state_variables: - info = [ + info: DETECTOR_INFO = [ state_variable, " is set pre-construction with a non-constant function or state variable:\n", ] diff --git a/slither/detectors/variables/predeclaration_usage_local.py b/slither/detectors/variables/predeclaration_usage_local.py index 2ba539a91..97217d2bb 100644 --- a/slither/detectors/variables/predeclaration_usage_local.py +++ b/slither/detectors/variables/predeclaration_usage_local.py @@ -7,7 +7,11 @@ from slither.core.cfg.node import Node from slither.core.declarations import Function from slither.core.declarations.contract import Contract from slither.core.variables.local_variable import LocalVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -148,7 +152,7 @@ Additionally, the for-loop uses the variable `max`, which is declared in a previ predeclared_usage_node, predeclared_usage_local_variable, ) in predeclared_usage_nodes: - info = [ + info: DETECTOR_INFO = [ "Variable '", predeclared_usage_local_variable, "' in ", diff --git a/slither/detectors/variables/similar_variables.py b/slither/detectors/variables/similar_variables.py index d0a15aaab..465e1ce01 100644 --- a/slither/detectors/variables/similar_variables.py +++ b/slither/detectors/variables/similar_variables.py @@ -7,7 +7,11 @@ from typing import List, Set, Tuple from slither.core.declarations.contract import Contract from slither.core.variables.local_variable import LocalVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -86,7 +90,13 @@ class SimilarVarsDetection(AbstractDetector): for (v1, v2) in sorted(allVars, key=lambda x: (x[0].name, x[1].name)): v_left = v1 if v1.name < v2.name else v2 v_right = v2 if v_left == v1 else v1 - info = ["Variable ", v_left, " is too similar to ", v_right, "\n"] + info: DETECTOR_INFO = [ + "Variable ", + v_left, + " is too similar to ", + v_right, + "\n", + ] json = self.generate_result(info) results.append(json) return results diff --git a/slither/detectors/variables/uninitialized_state_variables.py b/slither/detectors/variables/uninitialized_state_variables.py index 0fbb73b5d..13cf11052 100644 --- a/slither/detectors/variables/uninitialized_state_variables.py +++ b/slither/detectors/variables/uninitialized_state_variables.py @@ -14,7 +14,11 @@ from slither.core.declarations import Function from slither.core.declarations.contract import Contract from slither.core.variables import Variable from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import InternalCall, LibraryCall from slither.slithir.variables import ReferenceVariable from slither.utils.output import Output @@ -140,7 +144,7 @@ Initialize all the variables. If a variable is meant to be initialized to zero, ret = self._detect_uninitialized(c) for variable, functions in ret: - info = [variable, " is never initialized. It is used in:\n"] + info: DETECTOR_INFO = [variable, " is never initialized. It is used in:\n"] for f in functions: info += ["\t- ", f, "\n"] diff --git a/slither/detectors/variables/unused_state_variables.py b/slither/detectors/variables/unused_state_variables.py index d542f67d3..afb4e3ac5 100644 --- a/slither/detectors/variables/unused_state_variables.py +++ b/slither/detectors/variables/unused_state_variables.py @@ -1,13 +1,19 @@ """ Module detecting unused state variables """ -from typing import List, Optional +from typing import List, Optional, Dict from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.declarations import Function from slither.core.declarations.contract import Contract from slither.core.solidity_types import ArrayType +from slither.core.variables import Variable from slither.core.variables.state_variable import StateVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.formatters.variables.unused_state_variables import custom_format from slither.utils.output import Output from slither.visitors.expression.export_values import ExportValues @@ -18,14 +24,19 @@ def detect_unused(contract: Contract) -> Optional[List[StateVariable]]: return None # Get all the variables read in all the functions and modifiers - all_functions = contract.all_functions_called + contract.modifiers + all_functions = [ + f + for f in contract.all_functions_called + list(contract.modifiers) + if isinstance(f, Function) + ] variables_used = [x.state_variables_read for x in all_functions] variables_used += [ x.state_variables_written for x in all_functions if not x.is_constructor_variables ] - array_candidates = [x.variables for x in all_functions] - array_candidates = [i for sl in array_candidates for i in sl] + contract.state_variables + array_candidates_ = [x.variables for x in all_functions] + array_candidates: List[Variable] = [i for sl in array_candidates_ for i in sl] + array_candidates += contract.state_variables array_candidates = [ x.type.length for x in array_candidates if isinstance(x.type, ArrayType) and x.type.length ] @@ -65,12 +76,12 @@ class UnusedStateVars(AbstractDetector): unusedVars = detect_unused(c) if unusedVars: for var in unusedVars: - info = [var, " is never used in ", c, "\n"] + info: DETECTOR_INFO = [var, " is never used in ", c, "\n"] json = self.generate_result(info) results.append(json) return results @staticmethod - def _format(compilation_unit: SlitherCompilationUnit, result): + def _format(compilation_unit: SlitherCompilationUnit, result: Dict) -> None: custom_format(compilation_unit, result) diff --git a/slither/detectors/variables/var_read_using_this.py b/slither/detectors/variables/var_read_using_this.py index b224f8c17..a2b93a7d8 100644 --- a/slither/detectors/variables/var_read_using_this.py +++ b/slither/detectors/variables/var_read_using_this.py @@ -2,7 +2,11 @@ from typing import List from slither.core.cfg.node import Node from slither.core.declarations import Function, SolidityVariable -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations.high_level_call import HighLevelCall from slither.utils.output import Output @@ -35,7 +39,7 @@ contract C { for c in self.contracts: for func in c.functions: for node in self._detect_var_read_using_this(func): - info = [ + info: DETECTOR_INFO = [ "The function ", func, " reads ", diff --git a/slither/formatters/attributes/const_functions.py b/slither/formatters/attributes/const_functions.py index 33588af74..310abe0b9 100644 --- a/slither/formatters/attributes/const_functions.py +++ b/slither/formatters/attributes/const_functions.py @@ -1,11 +1,12 @@ import re +from typing import Dict from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.exceptions import FormatError from slither.formatters.utils.patches import create_patch -def custom_format(compilation_unit: SlitherCompilationUnit, result): +def custom_format(compilation_unit: SlitherCompilationUnit, result: Dict) -> None: for file_scope in compilation_unit.scopes.values(): elements = result["elements"] for element in elements: diff --git a/slither/formatters/attributes/constant_pragma.py b/slither/formatters/attributes/constant_pragma.py index 251dd07ae..108a8fa08 100644 --- a/slither/formatters/attributes/constant_pragma.py +++ b/slither/formatters/attributes/constant_pragma.py @@ -1,4 +1,7 @@ import re +from typing import Dict + +from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.exceptions import FormatImpossible from slither.formatters.utils.patches import create_patch @@ -16,7 +19,7 @@ REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"] PATTERN = re.compile(r"(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)") -def custom_format(slither, result): +def custom_format(slither: SlitherCompilationUnit, result: Dict) -> None: elements = result["elements"] versions_used = [] for element in elements: diff --git a/slither/formatters/naming_convention/naming_convention.py b/slither/formatters/naming_convention/naming_convention.py index 76974296f..4aadad072 100644 --- a/slither/formatters/naming_convention/naming_convention.py +++ b/slither/formatters/naming_convention/naming_convention.py @@ -1,9 +1,10 @@ import re import logging -from typing import List +from typing import List, Set, Dict, Union, Optional, Callable, Type, Sequence from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.variables import Variable from slither.slithir.operations import ( Send, Transfer, @@ -14,7 +15,7 @@ from slither.slithir.operations import ( InternalDynamicCall, Operation, ) -from slither.core.declarations import Modifier +from slither.core.declarations import Modifier, Event from slither.core.solidity_types import UserDefinedType, MappingType from slither.core.declarations import Enum, Contract, Structure, Function from slither.core.solidity_types.elementary_type import ElementaryTypeName @@ -29,7 +30,7 @@ logger = logging.getLogger("Slither.Format") # pylint: disable=anomalous-backslash-in-string -def custom_format(compilation_unit: SlitherCompilationUnit, result): +def custom_format(compilation_unit: SlitherCompilationUnit, result: Dict) -> None: elements = result["elements"] for element in elements: target = element["additional_fields"]["target"] @@ -129,24 +130,24 @@ SOLIDITY_KEYWORDS += [ SOLIDITY_KEYWORDS += ElementaryTypeName -def _name_already_use(slither, name): +def _name_already_use(slither: SlitherCompilationUnit, name: str) -> bool: # Do not convert to a name used somewhere else if not KEY in slither.context: - all_names = set() + all_names: Set[str] = set() for contract in slither.contracts_derived: all_names = all_names.union({st.name for st in contract.structures}) all_names = all_names.union({f.name for f in contract.functions_and_modifiers}) all_names = all_names.union({e.name for e in contract.enums}) - all_names = all_names.union({s.name for s in contract.state_variables}) + all_names = all_names.union({s.name for s in contract.state_variables if s.name}) for function in contract.functions: - all_names = all_names.union({v.name for v in function.variables}) + all_names = all_names.union({v.name for v in function.variables if v.name}) slither.context[KEY] = all_names return name in slither.context[KEY] -def _convert_CapWords(original_name, slither): +def _convert_CapWords(original_name: str, slither: SlitherCompilationUnit) -> str: name = original_name.capitalize() while "_" in name: @@ -162,10 +163,13 @@ def _convert_CapWords(original_name, slither): return name -def _convert_mixedCase(original_name, compilation_unit: SlitherCompilationUnit): - name = original_name - if isinstance(name, bytes): - name = name.decode("utf8") +def _convert_mixedCase( + original_name: Union[str, bytes], compilation_unit: SlitherCompilationUnit +) -> str: + if isinstance(original_name, bytes): + name = original_name.decode("utf8") + else: + name = original_name while "_" in name: offset = name.find("_") @@ -174,13 +178,15 @@ def _convert_mixedCase(original_name, compilation_unit: SlitherCompilationUnit): name = name[0].lower() + name[1:] if _name_already_use(compilation_unit, name): - raise FormatImpossible(f"{original_name} cannot be converted to {name} (already used)") + raise FormatImpossible(f"{original_name} cannot be converted to {name} (already used)") # type: ignore if name in SOLIDITY_KEYWORDS: - raise FormatImpossible(f"{original_name} cannot be converted to {name} (Solidity keyword)") + raise FormatImpossible(f"{original_name} cannot be converted to {name} (Solidity keyword)") # type: ignore return name -def _convert_UPPER_CASE_WITH_UNDERSCORES(name, compilation_unit: SlitherCompilationUnit): +def _convert_UPPER_CASE_WITH_UNDERSCORES( + name: str, compilation_unit: SlitherCompilationUnit +) -> str: if _name_already_use(compilation_unit, name.upper()): raise FormatImpossible(f"{name} cannot be converted to {name.upper()} (already used)") if name.upper() in SOLIDITY_KEYWORDS: @@ -188,7 +194,10 @@ def _convert_UPPER_CASE_WITH_UNDERSCORES(name, compilation_unit: SlitherCompilat return name.upper() -conventions = { +TARGET_TYPE = Union[Contract, Variable, Function] +CONVENTION_F_TYPE = Callable[[str, SlitherCompilationUnit], str] + +conventions: Dict[str, CONVENTION_F_TYPE] = { "CapWords": _convert_CapWords, "mixedCase": _convert_mixedCase, "UPPER_CASE_WITH_UNDERSCORES": _convert_UPPER_CASE_WITH_UNDERSCORES, @@ -203,7 +212,9 @@ conventions = { ################################################################################### -def _get_from_contract(compilation_unit: SlitherCompilationUnit, element, name, getter): +def _get_from_contract( + compilation_unit: SlitherCompilationUnit, element: Dict, name: str, getter: str +) -> TARGET_TYPE: scope = compilation_unit.get_scope(element["source_mapping"]["filename_absolute"]) contract_name = element["type_specific_fields"]["parent"]["name"] contract = scope.get_contract_from_name(contract_name) @@ -218,9 +229,13 @@ def _get_from_contract(compilation_unit: SlitherCompilationUnit, element, name, ################################################################################### -def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): +def _patch( + compilation_unit: SlitherCompilationUnit, result: Dict, element: Dict, _target: str +) -> None: scope = compilation_unit.get_scope(element["source_mapping"]["filename_absolute"]) + target: Optional[TARGET_TYPE] = None + if _target == "contract": target = scope.get_contract_from_name(element["name"]) @@ -257,7 +272,9 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): ] param_name = element["name"] contract = scope.get_contract_from_name(contract_name) + assert contract function = contract.get_function_from_full_name(function_sig) + assert function target = function.get_local_variable_from_name(param_name) elif _target in ["variable", "variable_constant"]: @@ -271,7 +288,9 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): ] var_name = element["name"] contract = scope.get_contract_from_name(contract_name) + assert contract function = contract.get_function_from_full_name(function_sig) + assert function target = function.get_local_variable_from_name(var_name) # State variable else: @@ -287,6 +306,7 @@ def _patch(compilation_unit: SlitherCompilationUnit, result, element, _target): else: raise FormatError("Unknown naming convention! " + _target) + assert target _explore( compilation_unit, result, target, conventions[element["additional_fields"]["convention"]] ) @@ -310,7 +330,7 @@ RE_MAPPING = ( ) -def _is_var_declaration(slither, filename, start): +def _is_var_declaration(slither: SlitherCompilationUnit, filename: str, start: int) -> bool: """ Detect usage of 'var ' for Solidity < 0.5 :param slither: @@ -319,12 +339,19 @@ def _is_var_declaration(slither, filename, start): :return: """ v = "var " - return slither.source_code[filename][start : start + len(v)] == v + return slither.core.source_code[filename][start : start + len(v)] == v def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches - slither, result, target, convert, custom_type, filename_source_code, start, end -): + slither: SlitherCompilationUnit, + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, + custom_type: Optional[Union[Type, List[Type]]], + filename_source_code: str, + start: int, + end: int, +) -> None: if isinstance(custom_type, UserDefinedType): # Patch type based on contract/enum if isinstance(custom_type.type, (Enum, Contract)): @@ -358,7 +385,7 @@ def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-man # Structure contain a list of elements, that might need patching # .elems return a list of VariableStructure _explore_variables_declaration( - slither, custom_type.type.elems.values(), result, target, convert + slither, list(custom_type.type.elems.values()), result, target, convert ) if isinstance(custom_type, MappingType): @@ -377,7 +404,7 @@ def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-man full_txt_start = start full_txt_end = end - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] re_match = re.match(RE_MAPPING, full_txt) @@ -417,14 +444,19 @@ def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-man def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks - slither, variables, result, target, convert, patch_comment=False -): + slither: SlitherCompilationUnit, + variables: Sequence[Variable], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, + patch_comment: bool = False, +) -> None: for variable in variables: # First explore the type of the variable filename_source_code = variable.source_mapping.filename.absolute full_txt_start = variable.source_mapping.start full_txt_end = full_txt_start + variable.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -442,6 +474,8 @@ def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-ma # If the variable is the target if variable == target: old_str = variable.name + if old_str is None: + old_str = "" new_str = convert(old_str, slither) loc_start = full_txt_start + full_txt.find(old_str.encode("utf8")) @@ -458,10 +492,10 @@ def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-ma idx = len(func.parameters) - func.parameters.index(variable) + 1 first_line = end_line - idx - 2 - potential_comments = slither.source_code[filename_source_code].encode( + potential_comments_ = slither.core.source_code[filename_source_code].encode( "utf8" ) - potential_comments = potential_comments.splitlines(keepends=True)[ + potential_comments = potential_comments_.splitlines(keepends=True)[ first_line : end_line - 1 ] @@ -491,10 +525,16 @@ def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-ma idx_beginning += len(line) -def _explore_structures_declaration(slither, structures, result, target, convert): +def _explore_structures_declaration( + slither: SlitherCompilationUnit, + structures: Sequence[Structure], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for st in structures: # Explore the variable declared within the structure (VariableStructure) - _explore_variables_declaration(slither, st.elems.values(), result, target, convert) + _explore_variables_declaration(slither, list(st.elems.values()), result, target, convert) # If the structure is the target if st == target: @@ -504,7 +544,7 @@ def _explore_structures_declaration(slither, structures, result, target, convert filename_source_code = st.source_mapping.filename.absolute full_txt_start = st.source_mapping.start full_txt_end = full_txt_start + st.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -517,7 +557,13 @@ def _explore_structures_declaration(slither, structures, result, target, convert create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def _explore_events_declaration(slither, events, result, target, convert): +def _explore_events_declaration( + slither: SlitherCompilationUnit, + events: Sequence[Event], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for event in events: # Explore the parameters _explore_variables_declaration(slither, event.elems, result, target, convert) @@ -535,7 +581,7 @@ def _explore_events_declaration(slither, events, result, target, convert): create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def get_ir_variables(ir): +def get_ir_variables(ir: Operation) -> List[Union[Variable, Function]]: all_vars = ir.read if isinstance(ir, (InternalCall, InternalDynamicCall, HighLevelCall)): @@ -553,9 +599,15 @@ def get_ir_variables(ir): return [v for v in all_vars if v] -def _explore_irs(slither, irs: List[Operation], result, target, convert): +def _explore_irs( + slither: SlitherCompilationUnit, + irs: List[Operation], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: # pylint: disable=too-many-locals - if irs is None: + if not irs: return for ir in irs: for v in get_ir_variables(ir): @@ -568,7 +620,7 @@ def _explore_irs(slither, irs: List[Operation], result, target, convert): filename_source_code = source_mapping.filename.absolute full_txt_start = source_mapping.start full_txt_end = full_txt_start + source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -600,7 +652,13 @@ def _explore_irs(slither, irs: List[Operation], result, target, convert): ) -def _explore_functions(slither, functions, result, target, convert): +def _explore_functions( + slither: SlitherCompilationUnit, + functions: List[Function], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for function in functions: _explore_variables_declaration(slither, function.variables, result, target, convert, True) _explore_irs(slither, function.all_slithir_operations(), result, target, convert) @@ -612,7 +670,7 @@ def _explore_functions(slither, functions, result, target, convert): filename_source_code = function.source_mapping.filename.absolute full_txt_start = function.source_mapping.start full_txt_end = full_txt_start + function.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -628,7 +686,13 @@ def _explore_functions(slither, functions, result, target, convert): create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def _explore_enums(slither, enums, result, target, convert): +def _explore_enums( + slither: SlitherCompilationUnit, + enums: Sequence[Enum], + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for enum in enums: if enum == target: old_str = enum.name @@ -637,7 +701,7 @@ def _explore_enums(slither, enums, result, target, convert): filename_source_code = enum.source_mapping.filename.absolute full_txt_start = enum.source_mapping.start full_txt_end = full_txt_start + enum.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -650,7 +714,13 @@ def _explore_enums(slither, enums, result, target, convert): create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def _explore_contract(slither, contract, result, target, convert): +def _explore_contract( + slither: SlitherCompilationUnit, + contract: Contract, + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: _explore_variables_declaration(slither, contract.state_variables, result, target, convert) _explore_structures_declaration(slither, contract.structures, result, target, convert) _explore_functions(slither, contract.functions_and_modifiers, result, target, convert) @@ -660,7 +730,7 @@ def _explore_contract(slither, contract, result, target, convert): filename_source_code = contract.source_mapping.filename.absolute full_txt_start = contract.source_mapping.start full_txt_end = full_txt_start + contract.source_mapping.length - full_txt = slither.source_code[filename_source_code].encode("utf8")[ + full_txt = slither.core.source_code[filename_source_code].encode("utf8")[ full_txt_start:full_txt_end ] @@ -677,7 +747,12 @@ def _explore_contract(slither, contract, result, target, convert): create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) -def _explore(compilation_unit: SlitherCompilationUnit, result, target, convert): +def _explore( + compilation_unit: SlitherCompilationUnit, + result: Dict, + target: TARGET_TYPE, + convert: CONVENTION_F_TYPE, +) -> None: for contract in compilation_unit.contracts_derived: _explore_contract(compilation_unit, contract, result, target, convert) diff --git a/slither/formatters/variables/unused_state_variables.py b/slither/formatters/variables/unused_state_variables.py index 8e0852a17..90009c7f1 100644 --- a/slither/formatters/variables/unused_state_variables.py +++ b/slither/formatters/variables/unused_state_variables.py @@ -1,8 +1,10 @@ +from typing import Dict + from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.utils.patches import create_patch -def custom_format(compilation_unit: SlitherCompilationUnit, result): +def custom_format(compilation_unit: SlitherCompilationUnit, result: Dict) -> None: elements = result["elements"] for element in elements: if element["type"] == "variable": @@ -14,7 +16,9 @@ def custom_format(compilation_unit: SlitherCompilationUnit, result): ) -def _patch(compilation_unit: SlitherCompilationUnit, result, in_file, modify_loc_start): +def _patch( + compilation_unit: SlitherCompilationUnit, result: Dict, in_file: str, modify_loc_start: int +) -> None: in_file_str = compilation_unit.core.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:] old_str = ( diff --git a/slither/slithir/operations/assignment.py b/slither/slithir/operations/assignment.py index 0ed5f70a4..5bedf2c85 100644 --- a/slither/slithir/operations/assignment.py +++ b/slither/slithir/operations/assignment.py @@ -1,20 +1,21 @@ import logging -from typing import List +from typing import List, Union from slither.core.declarations.function import Function +from slither.core.solidity_types import Type from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, RVALUE, LVALUE from slither.slithir.variables import TupleVariable, ReferenceVariable -from slither.core.source_mapping.source_mapping import SourceMapping -from slither.core.variables.variable import Variable - logger = logging.getLogger("AssignmentOperationIR") class Assignment(OperationWithLValue): def __init__( - self, left_variable: Variable, right_variable: SourceMapping, variable_return_type + self, + left_variable: LVALUE, + right_variable: Union[RVALUE, Function, TupleVariable], + variable_return_type: Type, ) -> None: assert is_valid_lvalue(left_variable) assert is_valid_rvalue(right_variable) or isinstance( @@ -22,30 +23,32 @@ class Assignment(OperationWithLValue): ) super().__init__() self._variables = [left_variable, right_variable] - self._lvalue = left_variable - self._rvalue = right_variable + self._lvalue: LVALUE = left_variable + self._rvalue: Union[RVALUE, Function, TupleVariable] = right_variable self._variable_return_type = variable_return_type @property - def variables(self): + def variables(self) -> List[Union[LVALUE, RVALUE, Function, TupleVariable]]: return list(self._variables) @property - def read(self) -> List[SourceMapping]: + def read(self) -> List[Union[RVALUE, Function, TupleVariable]]: return [self.rvalue] @property - def variable_return_type(self): + def variable_return_type(self) -> Type: return self._variable_return_type @property - def rvalue(self) -> SourceMapping: + def rvalue(self) -> Union[RVALUE, Function, TupleVariable]: return self._rvalue - def __str__(self): - if isinstance(self.lvalue, ReferenceVariable): - points = self.lvalue.points_to + def __str__(self) -> str: + lvalue = self.lvalue + assert lvalue + if lvalue and isinstance(lvalue, ReferenceVariable): + points = lvalue.points_to while isinstance(points, ReferenceVariable): points = points.points_to - return f"{self.lvalue} (->{points}) := {self.rvalue}({self.rvalue.type})" - return f"{self.lvalue}({self.lvalue.type}) := {self.rvalue}({self.rvalue.type})" + return f"{lvalue} (->{points}) := {self.rvalue}({self.rvalue.type})" + return f"{lvalue}({lvalue.type}) := {self.rvalue}({self.rvalue.type})" diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index ad65e3e75..42f05011d 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -1,17 +1,14 @@ import logging -from typing import List - from enum import Enum +from typing import List, Union from slither.core.declarations import Function from slither.core.solidity_types import ElementaryType +from slither.core.variables.variable import Variable from slither.slithir.exceptions import SlithIRError from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, LVALUE, RVALUE from slither.slithir.variables import ReferenceVariable -from slither.core.source_mapping.source_mapping import SourceMapping -from slither.core.variables.variable import Variable - logger = logging.getLogger("BinaryOperationIR") @@ -51,7 +48,7 @@ class BinaryType(Enum): ] @staticmethod - def get_type(operation_type): # pylint: disable=too-many-branches + def get_type(operation_type: str) -> "BinaryType": # pylint: disable=too-many-branches if operation_type == "**": return BinaryType.POWER if operation_type == "*": @@ -93,7 +90,7 @@ class BinaryType(Enum): raise SlithIRError(f"get_type: Unknown operation type {operation_type})") - def can_be_checked_for_overflow(self): + def can_be_checked_for_overflow(self) -> bool: return self in [ BinaryType.POWER, BinaryType.MULTIPLICATION, @@ -108,8 +105,8 @@ class Binary(OperationWithLValue): def __init__( self, result: Variable, - left_variable: SourceMapping, - right_variable: Variable, + left_variable: Union[LVALUE, Function], + right_variable: Union[RVALUE, Function], operation_type: BinaryType, ) -> None: assert is_valid_rvalue(left_variable) or isinstance(left_variable, Function) @@ -126,36 +123,38 @@ class Binary(OperationWithLValue): result.set_type(left_variable.type) @property - def read(self) -> List[SourceMapping]: + def read(self) -> List[Union[RVALUE, LVALUE, Function]]: return [self.variable_left, self.variable_right] @property - def get_variable(self): + def get_variable(self) -> List[Union[RVALUE, LVALUE, Function]]: return self._variables @property - def variable_left(self) -> SourceMapping: - return self._variables[0] + def variable_left(self) -> Union[LVALUE, Function]: + return self._variables[0] # type: ignore @property - def variable_right(self) -> Variable: - return self._variables[1] + def variable_right(self) -> Union[RVALUE, Function]: + return self._variables[1] # type: ignore @property def type(self) -> BinaryType: return self._type @property - def type_str(self): + def type_str(self) -> str: if self.node.scope.is_checked and self._type.can_be_checked_for_overflow(): - return "(c)" + self._type.value - return self._type.value - - def __str__(self): - if isinstance(self.lvalue, ReferenceVariable): - points = self.lvalue.points_to + return "(c)" + str(self._type.value) + return str(self._type.value) + + def __str__(self) -> str: + lvalue = self.lvalue + assert lvalue + if isinstance(lvalue, ReferenceVariable): + points = lvalue.points_to while isinstance(points, ReferenceVariable): points = points.points_to - return f"{str(self.lvalue)}(-> {points}) = {self.variable_left} {self.type_str} {self.variable_right}" + return f"{str(lvalue)}(-> {points}) = {self.variable_left} {self.type_str} {self.variable_right}" - return f"{str(self.lvalue)}({self.lvalue.type}) = {self.variable_left} {self.type_str} {self.variable_right}" + return f"{str(lvalue)}({lvalue.type}) = {self.variable_left} {self.type_str} {self.variable_right}" diff --git a/slither/slithir/operations/internal_call.py b/slither/slithir/operations/internal_call.py index 395c68846..1983b885f 100644 --- a/slither/slithir/operations/internal_call.py +++ b/slither/slithir/operations/internal_call.py @@ -24,7 +24,7 @@ class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-insta super().__init__() self._contract_name = "" if isinstance(function, Function): - self._function = function + self._function: Optional[Function] = function self._function_name = function.name if isinstance(function, FunctionContract): self._contract_name = function.contract_declarer.name @@ -45,7 +45,7 @@ class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-insta return list(self._unroll(self.arguments)) @property - def function(self): + def function(self) -> Optional[Function]: return self._function @function.setter diff --git a/slither/slithir/tmp_operations/argument.py b/slither/slithir/tmp_operations/argument.py index 25ea5d019..638c5dcb4 100644 --- a/slither/slithir/tmp_operations/argument.py +++ b/slither/slithir/tmp_operations/argument.py @@ -1,4 +1,7 @@ from enum import Enum +from typing import Optional, List + +from slither.core.expressions.expression import Expression from slither.slithir.operations.operation import Operation @@ -10,26 +13,26 @@ class ArgumentType(Enum): class Argument(Operation): - def __init__(self, argument) -> None: + def __init__(self, argument: Expression) -> None: super().__init__() self._argument = argument self._type = ArgumentType.CALL - self._callid = None + self._callid: Optional[str] = None @property - def argument(self): + def argument(self) -> Expression: return self._argument @property - def call_id(self): + def call_id(self) -> Optional[str]: return self._callid @call_id.setter - def call_id(self, c): + def call_id(self, c: str) -> None: self._callid = c @property - def read(self): + def read(self) -> List[Expression]: return [self.argument] def set_type(self, t: ArgumentType) -> None: @@ -39,7 +42,7 @@ class Argument(Operation): def get_type(self) -> ArgumentType: return self._type - def __str__(self): + def __str__(self) -> str: call_id = "none" if self.call_id: call_id = f"(id ({self.call_id}))" From 9259d53cd9b0509c2e638d82a5a4a12447f4bc51 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 10:30:30 +0100 Subject: [PATCH 03/34] Remove ExpressionTyped --- .../core/expressions/assignment_operation.py | 3 +-- slither/core/expressions/binary_operation.py | 3 +-- slither/core/expressions/expression_typed.py | 20 ------------------- slither/core/expressions/identifier.py | 16 ++++++++++++--- slither/core/expressions/index_access.py | 6 ++---- slither/core/expressions/member_access.py | 3 +-- slither/core/expressions/type_conversion.py | 11 ++++++++-- slither/core/expressions/unary_operation.py | 3 +-- 8 files changed, 28 insertions(+), 37 deletions(-) delete mode 100644 slither/core/expressions/expression_typed.py diff --git a/slither/core/expressions/assignment_operation.py b/slither/core/expressions/assignment_operation.py index 22aba57fb..59057e312 100644 --- a/slither/core/expressions/assignment_operation.py +++ b/slither/core/expressions/assignment_operation.py @@ -2,7 +2,6 @@ import logging from enum import Enum from typing import Optional, TYPE_CHECKING, List -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError @@ -78,7 +77,7 @@ class AssignmentOperationType(Enum): raise SlitherCoreError(f"str: Unknown operation type {self})") -class AssignmentOperation(ExpressionTyped): +class AssignmentOperation(Expression): def __init__( self, left_expression: Expression, diff --git a/slither/core/expressions/binary_operation.py b/slither/core/expressions/binary_operation.py index a3d435075..a395d07cf 100644 --- a/slither/core/expressions/binary_operation.py +++ b/slither/core/expressions/binary_operation.py @@ -2,7 +2,6 @@ import logging from enum import Enum from typing import List -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError @@ -148,7 +147,7 @@ class BinaryOperationType(Enum): raise SlitherCoreError(f"str: Unknown operation type {self})") -class BinaryOperation(ExpressionTyped): +class BinaryOperation(Expression): def __init__( self, left_expression: Expression, diff --git a/slither/core/expressions/expression_typed.py b/slither/core/expressions/expression_typed.py deleted file mode 100644 index 2bf3fe39d..000000000 --- a/slither/core/expressions/expression_typed.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Optional, TYPE_CHECKING - -from slither.core.expressions.expression import Expression - -if TYPE_CHECKING: - from slither.core.solidity_types.type import Type - - -class ExpressionTyped(Expression): # pylint: disable=too-few-public-methods - def __init__(self) -> None: - super().__init__() - self._type: Optional["Type"] = None - - @property - def type(self) -> Optional["Type"]: - return self._type - - @type.setter - def type(self, new_type: "Type"): - self._type = new_type diff --git a/slither/core/expressions/identifier.py b/slither/core/expressions/identifier.py index 0b10c5615..58a1174af 100644 --- a/slither/core/expressions/identifier.py +++ b/slither/core/expressions/identifier.py @@ -1,15 +1,25 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional -from slither.core.expressions.expression_typed import ExpressionTyped +from slither.core.expressions.expression import Expression if TYPE_CHECKING: from slither.core.variables.variable import Variable + from slither.core.solidity_types.type import Type -class Identifier(ExpressionTyped): +class Identifier(Expression): def __init__(self, value) -> None: super().__init__() self._value: "Variable" = value + self._type: Optional["Type"] = None + + @property + def type(self) -> Optional["Type"]: + return self._type + + @type.setter + def type(self, new_type: "Type") -> None: + self._type = new_type @property def value(self) -> "Variable": diff --git a/slither/core/expressions/index_access.py b/slither/core/expressions/index_access.py index 4f96a56d6..f8e630a6e 100644 --- a/slither/core/expressions/index_access.py +++ b/slither/core/expressions/index_access.py @@ -1,16 +1,14 @@ from typing import Union, List, TYPE_CHECKING -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.identifier import Identifier from slither.core.expressions.literal import Literal - +from slither.core.expressions.expression import Expression if TYPE_CHECKING: - from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type -class IndexAccess(ExpressionTyped): +class IndexAccess(Expression): def __init__( self, left_expression: Union["IndexAccess", Identifier], diff --git a/slither/core/expressions/member_access.py b/slither/core/expressions/member_access.py index 36d6818b2..e24024318 100644 --- a/slither/core/expressions/member_access.py +++ b/slither/core/expressions/member_access.py @@ -1,10 +1,9 @@ from slither.core.expressions.expression import Expression -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.solidity_types.type import Type -class MemberAccess(ExpressionTyped): +class MemberAccess(Expression): def __init__(self, member_name: str, member_type: str, expression: Expression) -> None: # assert isinstance(member_type, Type) # TODO member_type is not always a Type diff --git a/slither/core/expressions/type_conversion.py b/slither/core/expressions/type_conversion.py index b9cd6879e..2acc8bd52 100644 --- a/slither/core/expressions/type_conversion.py +++ b/slither/core/expressions/type_conversion.py @@ -1,6 +1,5 @@ from typing import Union, TYPE_CHECKING -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type @@ -14,7 +13,7 @@ if TYPE_CHECKING: from slither.core.solidity_types.user_defined_type import UserDefinedType -class TypeConversion(ExpressionTyped): +class TypeConversion(Expression): def __init__( self, expression: Union[ @@ -28,6 +27,14 @@ class TypeConversion(ExpressionTyped): self._expression: Expression = expression self._type: Type = expression_type + @property + def type(self) -> Type: + return self._type + + @type.setter + def type(self, new_type: Type) -> None: + self._type = new_type + @property def expression(self) -> Expression: return self._expression diff --git a/slither/core/expressions/unary_operation.py b/slither/core/expressions/unary_operation.py index a04c57591..657224927 100644 --- a/slither/core/expressions/unary_operation.py +++ b/slither/core/expressions/unary_operation.py @@ -2,7 +2,6 @@ import logging from typing import Union from enum import Enum -from slither.core.expressions.expression_typed import ExpressionTyped from slither.core.expressions.expression import Expression from slither.core.exceptions import SlitherCoreError from slither.core.expressions.identifier import Identifier @@ -91,7 +90,7 @@ class UnaryOperationType(Enum): raise SlitherCoreError(f"is_prefix: Unknown operation type {operation_type}") -class UnaryOperation(ExpressionTyped): +class UnaryOperation(Expression): def __init__( self, expression: Union[Literal, Identifier, IndexAccess, TupleExpression], From 10109fc553d7a6fc08c0bfcb23a9d25c3abb106f Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 11:28:54 +0100 Subject: [PATCH 04/34] Remove core.children --- slither/core/cfg/node.py | 11 ++++-- slither/core/children/child_event.py | 17 --------- slither/core/children/child_expression.py | 18 ---------- slither/core/children/child_function.py | 17 --------- slither/core/children/child_inheritance.py | 17 --------- slither/core/children/child_node.py | 31 ---------------- slither/core/children/child_structure.py | 17 --------- .../contract_level.py} | 13 +++++-- .../declarations/custom_error_contract.py | 4 +-- slither/core/declarations/enum_contract.py | 4 +-- slither/core/declarations/event.py | 4 +-- .../core/declarations/function_contract.py | 29 ++++++++++++--- .../core/declarations/structure_contract.py | 4 +-- slither/core/declarations/top_level.py | 6 +++- slither/core/slither_core.py | 10 ++++-- slither/core/solidity_types/type_alias.py | 4 +-- slither/core/variables/event_variable.py | 5 ++- slither/core/variables/local_variable.py | 17 +++++++-- slither/core/variables/state_variable.py | 4 +-- slither/core/variables/structure_variable.py | 19 ++++++++-- .../erc/incorrect_erc721_interface.py | 4 ++- .../operations/missing_events_arithmetic.py | 4 ++- slither/detectors/statements/tx_origin.py | 4 ++- slither/slithir/convert.py | 2 +- .../operations/internal_dynamic_call.py | 4 +-- slither/slithir/operations/new_structure.py | 4 ++- slither/slithir/operations/operation.py | 35 ++++++++++++++++--- slither/slithir/operations/solidity_call.py | 3 +- slither/slithir/operations/type_conversion.py | 4 ++- slither/slithir/variables/reference.py | 7 ++-- slither/slithir/variables/temporary.py | 7 ++-- slither/slithir/variables/tuple.py | 7 ++-- slither/solc_parsing/declarations/contract.py | 8 ++++- slither/utils/output.py | 21 +++++++---- tests/test_ssa_generation.py | 2 +- 35 files changed, 187 insertions(+), 180 deletions(-) delete mode 100644 slither/core/children/child_event.py delete mode 100644 slither/core/children/child_expression.py delete mode 100644 slither/core/children/child_function.py delete mode 100644 slither/core/children/child_inheritance.py delete mode 100644 slither/core/children/child_node.py delete mode 100644 slither/core/children/child_structure.py rename slither/core/{children/child_contract.py => declarations/contract_level.py} (57%) diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 7643b19b7..a740d41b9 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -5,7 +5,6 @@ from enum import Enum from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING from slither.all_exceptions import SlitherException -from slither.core.children.child_function import ChildFunction from slither.core.declarations import Contract, Function from slither.core.declarations.solidity_variables import ( SolidityVariable, @@ -106,7 +105,7 @@ class NodeType(Enum): # I am not sure why, but pylint reports a lot of "no-member" issue that are not real (Josselin) # pylint: disable=no-member -class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-methods +class Node(SourceMapping): # pylint: disable=too-many-public-methods """ Node class @@ -189,6 +188,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self.scope: Union["Scope", "Function"] = scope self.file_scope: "FileScope" = file_scope + self._function: Optional["Function"] = None ################################################################################### ################################################################################### @@ -224,6 +224,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return True return False + def set_function(self, function: "Function") -> None: + self._function = function + + @property + def function(self) -> "Function": + return self._function + # endregion ################################################################################### ################################################################################### diff --git a/slither/core/children/child_event.py b/slither/core/children/child_event.py deleted file mode 100644 index df91596e3..000000000 --- a/slither/core/children/child_event.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Event - - -class ChildEvent: - def __init__(self) -> None: - super().__init__() - self._event = None - - def set_event(self, event: "Event"): - self._event = event - - @property - def event(self) -> "Event": - return self._event diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py deleted file mode 100644 index 0064658c0..000000000 --- a/slither/core/children/child_expression.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import TYPE_CHECKING, Union - -if TYPE_CHECKING: - from slither.core.expressions.expression import Expression - from slither.slithir.operations import Operation - - -class ChildExpression: - def __init__(self) -> None: - super().__init__() - self._expression = None - - def set_expression(self, expression: Union["Expression", "Operation"]) -> None: - self._expression = expression - - @property - def expression(self) -> Union["Expression", "Operation"]: - return self._expression diff --git a/slither/core/children/child_function.py b/slither/core/children/child_function.py deleted file mode 100644 index 5367320ca..000000000 --- a/slither/core/children/child_function.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Function - - -class ChildFunction: - def __init__(self) -> None: - super().__init__() - self._function = None - - def set_function(self, function: "Function") -> None: - self._function = function - - @property - def function(self) -> "Function": - return self._function diff --git a/slither/core/children/child_inheritance.py b/slither/core/children/child_inheritance.py deleted file mode 100644 index 30b32f6c1..000000000 --- a/slither/core/children/child_inheritance.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Contract - - -class ChildInheritance: - def __init__(self) -> None: - super().__init__() - self._contract_declarer = None - - def set_contract_declarer(self, contract: "Contract") -> None: - self._contract_declarer = contract - - @property - def contract_declarer(self) -> "Contract": - return self._contract_declarer diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py deleted file mode 100644 index 8e6e1f0b5..000000000 --- a/slither/core/children/child_node.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.compilation_unit import SlitherCompilationUnit - from slither.core.cfg.node import Node - from slither.core.declarations import Function, Contract - - -class ChildNode: - def __init__(self) -> None: - super().__init__() - self._node = None - - def set_node(self, node: "Node") -> None: - self._node = node - - @property - def node(self) -> "Node": - return self._node - - @property - def function(self) -> "Function": - return self.node.function - - @property - def contract(self) -> "Contract": - return self.node.function.contract - - @property - def compilation_unit(self) -> "SlitherCompilationUnit": - return self.node.compilation_unit diff --git a/slither/core/children/child_structure.py b/slither/core/children/child_structure.py deleted file mode 100644 index abcb041c2..000000000 --- a/slither/core/children/child_structure.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from slither.core.declarations import Structure - - -class ChildStructure: - def __init__(self) -> None: - super().__init__() - self._structure = None - - def set_structure(self, structure: "Structure") -> None: - self._structure = structure - - @property - def structure(self) -> "Structure": - return self._structure diff --git a/slither/core/children/child_contract.py b/slither/core/declarations/contract_level.py similarity index 57% rename from slither/core/children/child_contract.py rename to slither/core/declarations/contract_level.py index 86f9dea53..5893a7035 100644 --- a/slither/core/children/child_contract.py +++ b/slither/core/declarations/contract_level.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from slither.core.source_mapping.source_mapping import SourceMapping @@ -6,14 +6,21 @@ if TYPE_CHECKING: from slither.core.declarations import Contract -class ChildContract(SourceMapping): +class ContractLevel(SourceMapping): + """ + This class is used to represent objects that are at the contract level + The opposite is TopLevel + + """ + def __init__(self) -> None: super().__init__() - self._contract = None + self._contract: Optional["Contract"] = None def set_contract(self, contract: "Contract") -> None: self._contract = contract @property def contract(self) -> "Contract": + assert self._contract return self._contract diff --git a/slither/core/declarations/custom_error_contract.py b/slither/core/declarations/custom_error_contract.py index a96f12057..a3839e3f2 100644 --- a/slither/core/declarations/custom_error_contract.py +++ b/slither/core/declarations/custom_error_contract.py @@ -1,8 +1,8 @@ -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations.custom_error import CustomError -class CustomErrorContract(CustomError, ChildContract): +class CustomErrorContract(CustomError, ContractLevel): def is_declared_by(self, contract): """ Check if the element is declared by the contract diff --git a/slither/core/declarations/enum_contract.py b/slither/core/declarations/enum_contract.py index 46168d107..2e51ae511 100644 --- a/slither/core/declarations/enum_contract.py +++ b/slither/core/declarations/enum_contract.py @@ -1,13 +1,13 @@ from typing import TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Enum if TYPE_CHECKING: from slither.core.declarations import Contract -class EnumContract(Enum, ChildContract): +class EnumContract(Enum, ContractLevel): def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract diff --git a/slither/core/declarations/event.py b/slither/core/declarations/event.py index d616679a2..9d42ac224 100644 --- a/slither/core/declarations/event.py +++ b/slither/core/declarations/event.py @@ -1,6 +1,6 @@ from typing import List, Tuple, TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.event_variable import EventVariable @@ -8,7 +8,7 @@ if TYPE_CHECKING: from slither.core.declarations import Contract -class Event(ChildContract, SourceMapping): +class Event(ContractLevel, SourceMapping): def __init__(self) -> None: super().__init__() self._name = None diff --git a/slither/core/declarations/function_contract.py b/slither/core/declarations/function_contract.py index 19456bbea..69c50a117 100644 --- a/slither/core/declarations/function_contract.py +++ b/slither/core/declarations/function_contract.py @@ -1,10 +1,9 @@ """ Function module """ -from typing import Dict, TYPE_CHECKING, List, Tuple +from typing import Dict, TYPE_CHECKING, List, Tuple, Optional -from slither.core.children.child_contract import ChildContract -from slither.core.children.child_inheritance import ChildInheritance +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Function @@ -14,9 +13,31 @@ if TYPE_CHECKING: from slither.core.declarations import Contract from slither.core.scope.scope import FileScope from slither.slithir.variables.state_variable import StateIRVariable + from slither.core.compilation_unit import SlitherCompilationUnit -class FunctionContract(Function, ChildContract, ChildInheritance): +class FunctionContract(Function, ContractLevel): + def __init__(self, compilation_unit: "SlitherCompilationUnit") -> None: + super().__init__(compilation_unit) + self._contract_declarer: Optional["Contract"] = None + + def set_contract_declarer(self, contract: "Contract") -> None: + self._contract_declarer = contract + + @property + def contract_declarer(self) -> "Contract": + """ + Return the contract where this function was declared. Only functions have both a contract, and contract_declarer + This is because we need to have separate representation of the function depending of the contract's context + For example a function calling super.f() will generate different IR depending on the current contract's inheritance + + Returns: + The contract where this function was declared + """ + + assert self._contract_declarer + return self._contract_declarer + @property def canonical_name(self) -> str: """ diff --git a/slither/core/declarations/structure_contract.py b/slither/core/declarations/structure_contract.py index aaf660e1e..c9d05ce4e 100644 --- a/slither/core/declarations/structure_contract.py +++ b/slither/core/declarations/structure_contract.py @@ -1,8 +1,8 @@ -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations import Structure -class StructureContract(Structure, ChildContract): +class StructureContract(Structure, ContractLevel): def is_declared_by(self, contract): """ Check if the element is declared by the contract diff --git a/slither/core/declarations/top_level.py b/slither/core/declarations/top_level.py index 15facf2f9..01e6f6dfd 100644 --- a/slither/core/declarations/top_level.py +++ b/slither/core/declarations/top_level.py @@ -2,4 +2,8 @@ from slither.core.source_mapping.source_mapping import SourceMapping class TopLevel(SourceMapping): - pass + """ + This class is used to represent objects that are at the top level + The opposite is ContractLevel + + """ diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index e5f4e830a..e55a9cf0b 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -13,7 +13,7 @@ from typing import Optional, Dict, List, Set, Union, Tuple from crytic_compile import CryticCompile from crytic_compile.utils.naming import Filename -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.context.context import Context from slither.core.declarations import Contract, FunctionContract @@ -206,7 +206,10 @@ class SlitherCore(Context): isinstance(thing, FunctionContract) and thing.contract_declarer == thing.contract ) - or (isinstance(thing, ChildContract) and not isinstance(thing, FunctionContract)) + or ( + isinstance(thing, ContractLevel) + and not isinstance(thing, FunctionContract) + ) ): self._offset_to_objects[definition.filename][offset].add(thing) @@ -224,7 +227,8 @@ class SlitherCore(Context): and thing.contract_declarer == thing.contract ) or ( - isinstance(thing, ChildContract) and not isinstance(thing, FunctionContract) + isinstance(thing, ContractLevel) + and not isinstance(thing, FunctionContract) ) ): self._offset_to_objects[definition.filename][offset].add(thing) diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 5b9ea0a37..c47d2ee14 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Tuple -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.declarations.top_level import TopLevel from slither.core.solidity_types import Type, ElementaryType @@ -48,7 +48,7 @@ class TypeAliasTopLevel(TypeAlias, TopLevel): return self.name -class TypeAliasContract(TypeAlias, ChildContract): +class TypeAliasContract(TypeAlias, ContractLevel): def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: super().__init__(underlying_type, name) self._contract: "Contract" = contract diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index f3ad60d0b..3b6b6c511 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -1,8 +1,7 @@ from slither.core.variables.variable import Variable -from slither.core.children.child_event import ChildEvent -class EventVariable(ChildEvent, Variable): +class EventVariable(Variable): def __init__(self) -> None: super().__init__() self._indexed = False @@ -16,5 +15,5 @@ class EventVariable(ChildEvent, Variable): return self._indexed @indexed.setter - def indexed(self, is_indexed: bool): + def indexed(self, is_indexed: bool) -> None: self._indexed = is_indexed diff --git a/slither/core/variables/local_variable.py b/slither/core/variables/local_variable.py index 7b7b4f8bc..fc23eeba7 100644 --- a/slither/core/variables/local_variable.py +++ b/slither/core/variables/local_variable.py @@ -1,7 +1,6 @@ -from typing import Optional +from typing import Optional, TYPE_CHECKING from slither.core.variables.variable import Variable -from slither.core.children.child_function import ChildFunction from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.array_type import ArrayType from slither.core.solidity_types.mapping_type import MappingType @@ -9,11 +8,23 @@ from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.declarations.structure import Structure +if TYPE_CHECKING: # type: ignore + from slither.core.declarations import Function -class LocalVariable(ChildFunction, Variable): + +class LocalVariable(Variable): def __init__(self) -> None: super().__init__() self._location: Optional[str] = None + self._function: Optional["Function"] = None + + def set_function(self, function: "Function") -> None: + self._function = function + + @property + def function(self) -> "Function": + assert self._function + return self._function def set_location(self, loc: str) -> None: self._location = loc diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index 47b7682a4..f2a2d6ee3 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -1,6 +1,6 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_contract import ChildContract +from slither.core.declarations.contract_level import ContractLevel from slither.core.variables.variable import Variable if TYPE_CHECKING: @@ -8,7 +8,7 @@ if TYPE_CHECKING: from slither.core.declarations import Contract -class StateVariable(ChildContract, Variable): +class StateVariable(ContractLevel, Variable): def __init__(self) -> None: super().__init__() self._node_initialization: Optional["Node"] = None diff --git a/slither/core/variables/structure_variable.py b/slither/core/variables/structure_variable.py index c6034da63..3a001b6a9 100644 --- a/slither/core/variables/structure_variable.py +++ b/slither/core/variables/structure_variable.py @@ -1,6 +1,19 @@ +from typing import TYPE_CHECKING, Optional from slither.core.variables.variable import Variable -from slither.core.children.child_structure import ChildStructure -class StructureVariable(ChildStructure, Variable): - pass +if TYPE_CHECKING: + from slither.core.declarations import Structure + + +class StructureVariable(Variable): + def __init__(self) -> None: + super().__init__() + self._structure: Optional["Structure"] = None + + def set_structure(self, structure: "Structure") -> None: + self._structure = structure + + @property + def structure(self) -> "Structure": + return self._structure diff --git a/slither/detectors/erc/incorrect_erc721_interface.py b/slither/detectors/erc/incorrect_erc721_interface.py index 8327e8b2e..f894fb517 100644 --- a/slither/detectors/erc/incorrect_erc721_interface.py +++ b/slither/detectors/erc/incorrect_erc721_interface.py @@ -89,7 +89,9 @@ contract Token{ return False @staticmethod - def detect_incorrect_erc721_interface(contract: Contract) -> List[Union[FunctionContract, Any]]: + def detect_incorrect_erc721_interface( + contract: Contract, + ) -> List[Union[FunctionContract, Any]]: """Detect incorrect ERC721 interface Returns: diff --git a/slither/detectors/operations/missing_events_arithmetic.py b/slither/detectors/operations/missing_events_arithmetic.py index 6e1d5fbb5..e553e78eb 100644 --- a/slither/detectors/operations/missing_events_arithmetic.py +++ b/slither/detectors/operations/missing_events_arithmetic.py @@ -70,7 +70,9 @@ contract C { def _detect_missing_events( self, contract: Contract - ) -> List[Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]]]: + ) -> List[ + Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]] + ]: """ Detects if critical contract parameters set by owners and used in arithmetic are missing events :param contract: The contract to check diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py index 34f8173d5..e281c1d09 100644 --- a/slither/detectors/statements/tx_origin.py +++ b/slither/detectors/statements/tx_origin.py @@ -57,7 +57,9 @@ Bob is the owner of `TxOrigin`. Bob calls Eve's contract. Eve's contract calls ` ) return False - def detect_tx_origin(self, contract: Contract) -> List[Tuple[FunctionContract, List[Node]]]: + def detect_tx_origin( + self, contract: Contract + ) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in contract.functions: diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 87a6b075b..aa8dfb4ec 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -731,7 +731,7 @@ def propagate_types( return _convert_type_contract(ir) left = ir.variable_left t = None - ir_func = ir.function + ir_func = ir.node.function # Handling of this.function_name usage if ( left == SolidityVariable("this") diff --git a/slither/slithir/operations/internal_dynamic_call.py b/slither/slithir/operations/internal_dynamic_call.py index a1ad1aa15..ca245167e 100644 --- a/slither/slithir/operations/internal_dynamic_call.py +++ b/slither/slithir/operations/internal_dynamic_call.py @@ -24,7 +24,7 @@ class InternalDynamicCall( assert isinstance(function, Variable) assert is_valid_lvalue(lvalue) or lvalue is None super().__init__() - self._function = function + self._function: Variable = function self._function_type = function_type self._lvalue = lvalue @@ -37,7 +37,7 @@ class InternalDynamicCall( return self._unroll(self.arguments) + [self.function] @property - def function(self) -> Union[LocalVariable, LocalIRVariable]: + def function(self) -> Variable: return self._function @property diff --git a/slither/slithir/operations/new_structure.py b/slither/slithir/operations/new_structure.py index 752de6a3d..f24b3bccd 100644 --- a/slither/slithir/operations/new_structure.py +++ b/slither/slithir/operations/new_structure.py @@ -14,7 +14,9 @@ from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA class NewStructure(Call, OperationWithLValue): def __init__( - self, structure: StructureContract, lvalue: Union[TemporaryVariableSSA, TemporaryVariable] + self, + structure: StructureContract, + lvalue: Union[TemporaryVariableSSA, TemporaryVariable], ) -> None: super().__init__() assert isinstance(structure, Structure) diff --git a/slither/slithir/operations/operation.py b/slither/slithir/operations/operation.py index fcf5f4868..aca3e645b 100644 --- a/slither/slithir/operations/operation.py +++ b/slither/slithir/operations/operation.py @@ -1,11 +1,14 @@ import abc -from typing import Any, List +from typing import Any, List, Optional, TYPE_CHECKING from slither.core.context.context import Context -from slither.core.children.child_expression import ChildExpression -from slither.core.children.child_node import ChildNode +from slither.core.expressions.expression import Expression from slither.core.variables.variable import Variable from slither.utils.utils import unroll +if TYPE_CHECKING: + from slither.core.compilation_unit import SlitherCompilationUnit + from slither.core.cfg.node import Node + class AbstractOperation(abc.ABC): @property @@ -25,7 +28,24 @@ class AbstractOperation(abc.ABC): pass # pylint: disable=unnecessary-pass -class Operation(Context, ChildExpression, ChildNode, AbstractOperation): +class Operation(Context, AbstractOperation): + def __init__(self) -> None: + super().__init__() + self._node: Optional["Node"] = None + self._expression: Optional[Expression] = None + + def set_node(self, node: "Node") -> None: + self._node = node + + @property + def node(self) -> "Node": + assert self._node + return self._node + + @property + def compilation_unit(self) -> "SlitherCompilationUnit": + return self.node.compilation_unit + @property def used(self) -> List[Variable]: """ @@ -37,3 +57,10 @@ class Operation(Context, ChildExpression, ChildNode, AbstractOperation): @staticmethod def _unroll(l: List[Any]) -> List[Any]: return unroll(l) + + def set_expression(self, expression: Expression) -> None: + self._expression = expression + + @property + def expression(self) -> Optional[Expression]: + return self._expression diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py index b059c55a6..c0d8d8404 100644 --- a/slither/slithir/operations/solidity_call.py +++ b/slither/slithir/operations/solidity_call.py @@ -2,7 +2,6 @@ from typing import Any, List, Union from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue -from slither.core.children.child_node import ChildNode from slither.core.solidity_types.elementary_type import ElementaryType @@ -11,7 +10,7 @@ class SolidityCall(Call, OperationWithLValue): self, function: Union[SolidityCustomRevert, SolidityFunction], nbr_arguments: int, - result: ChildNode, + result, type_call: Union[str, List[ElementaryType]], ) -> None: assert isinstance(function, SolidityFunction) diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py index f351f1fdd..ce41e3c54 100644 --- a/slither/slithir/operations/type_conversion.py +++ b/slither/slithir/operations/type_conversion.py @@ -17,7 +17,9 @@ class TypeConversion(OperationWithLValue): self, result: Union[TemporaryVariableSSA, TemporaryVariable], variable: SourceMapping, - variable_type: Union[TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel], + variable_type: Union[ + TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel + ], ) -> None: super().__init__() assert is_valid_rvalue(variable) or isinstance(variable, Contract) diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py index 95802b7e2..9ab51be65 100644 --- a/slither/slithir/variables/reference.py +++ b/slither/slithir/variables/reference.py @@ -1,6 +1,5 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.core.declarations import Contract, Enum, SolidityVariable, Function from slither.core.variables.variable import Variable @@ -8,7 +7,7 @@ if TYPE_CHECKING: from slither.core.cfg.node import Node -class ReferenceVariable(ChildNode, Variable): +class ReferenceVariable(Variable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -19,6 +18,10 @@ class ReferenceVariable(ChildNode, Variable): self._points_to = None self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/slithir/variables/temporary.py b/slither/slithir/variables/temporary.py index 8cb1cf350..5a485f985 100644 --- a/slither/slithir/variables/temporary.py +++ b/slither/slithir/variables/temporary.py @@ -1,13 +1,12 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.core.variables.variable import Variable if TYPE_CHECKING: from slither.core.cfg.node import Node -class TemporaryVariable(ChildNode, Variable): +class TemporaryVariable(Variable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -17,6 +16,10 @@ class TemporaryVariable(ChildNode, Variable): self._index = index self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/slithir/variables/tuple.py b/slither/slithir/variables/tuple.py index dc085347e..9a13b1d5d 100644 --- a/slither/slithir/variables/tuple.py +++ b/slither/slithir/variables/tuple.py @@ -1,13 +1,12 @@ from typing import Optional, TYPE_CHECKING -from slither.core.children.child_node import ChildNode from slither.slithir.variables.variable import SlithIRVariable if TYPE_CHECKING: from slither.core.cfg.node import Node -class TupleVariable(ChildNode, SlithIRVariable): +class TupleVariable(SlithIRVariable): def __init__(self, node: "Node", index: Optional[int] = None) -> None: super().__init__() if index is None: @@ -18,6 +17,10 @@ class TupleVariable(ChildNode, SlithIRVariable): self._node = node + @property + def node(self) -> "Node": + return self._node + @property def index(self): return self._index diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 475c3fab2..47ee7ec10 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -2,7 +2,13 @@ import logging import re from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set -from slither.core.declarations import Modifier, Event, EnumContract, StructureContract, Function +from slither.core.declarations import ( + Modifier, + Event, + EnumContract, + StructureContract, + Function, +) from slither.core.declarations.contract import Contract from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.function_contract import FunctionContract diff --git a/slither/utils/output.py b/slither/utils/output.py index 9dba15e31..84c9ac65a 100644 --- a/slither/utils/output.py +++ b/slither/utils/output.py @@ -10,8 +10,17 @@ from zipfile import ZipFile from pkg_resources import require from slither.core.cfg.node import Node -from slither.core.declarations import Contract, Function, Enum, Event, Structure, Pragma +from slither.core.declarations import ( + Contract, + Function, + Enum, + Event, + Structure, + Pragma, + FunctionContract, +) from slither.core.source_mapping.source_mapping import SourceMapping +from slither.core.variables.local_variable import LocalVariable from slither.core.variables.variable import Variable from slither.exceptions import SlitherError from slither.utils.colors import yellow @@ -351,21 +360,19 @@ def _create_parent_element( ], ]: # pylint: disable=import-outside-toplevel - from slither.core.children.child_contract import ChildContract - from slither.core.children.child_function import ChildFunction - from slither.core.children.child_inheritance import ChildInheritance + from slither.core.declarations.contract_level import ContractLevel - if isinstance(element, ChildInheritance): + if isinstance(element, FunctionContract): if element.contract_declarer: contract = Output("") contract.add_contract(element.contract_declarer) return contract.data["elements"][0] - elif isinstance(element, ChildContract): + elif isinstance(element, ContractLevel): if element.contract: contract = Output("") contract.add_contract(element.contract) return contract.data["elements"][0] - elif isinstance(element, ChildFunction): + elif isinstance(element, (LocalVariable, Node)): if element.function: function = Output("") function.add_function(element.function) diff --git a/tests/test_ssa_generation.py b/tests/test_ssa_generation.py index f002ec4e1..9bb008fdf 100644 --- a/tests/test_ssa_generation.py +++ b/tests/test_ssa_generation.py @@ -689,7 +689,7 @@ def test_initial_version_exists_for_state_variables_function_assign(): # temporary variable, that is then assigned to a call = get_ssa_of_type(ctor, InternalCall)[0] - assert call.function == f + assert call.node.function == f assign = get_ssa_of_type(ctor, Assignment)[0] assert assign.rvalue == call.lvalue assert isinstance(assign.lvalue, StateIRVariable) From e8bf225f32f71fa71b48363eca0eb510e1d04f97 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 11:59:21 +0100 Subject: [PATCH 05/34] Improvements post merge --- slither/core/declarations/custom_error_contract.py | 3 ++- slither/core/slither_core.py | 8 ++------ slither/core/solidity_types/type_alias.py | 2 +- slither/detectors/operations/missing_events_arithmetic.py | 4 +--- slither/detectors/statements/tx_origin.py | 4 +--- slither/slithir/operations/type_conversion.py | 4 +--- 6 files changed, 8 insertions(+), 17 deletions(-) diff --git a/slither/core/declarations/custom_error_contract.py b/slither/core/declarations/custom_error_contract.py index 2561a20a0..cd279a3a6 100644 --- a/slither/core/declarations/custom_error_contract.py +++ b/slither/core/declarations/custom_error_contract.py @@ -1,5 +1,6 @@ -from slither.core.declarations.contract_level import ContractLevel from typing import TYPE_CHECKING +from slither.core.declarations.contract_level import ContractLevel + from slither.core.declarations.custom_error import CustomError diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index d6851fc32..798008707 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -206,10 +206,7 @@ class SlitherCore(Context): isinstance(thing, FunctionContract) and thing.contract_declarer == thing.contract ) - or ( - isinstance(thing, ContractLevel) - and not isinstance(thing, FunctionContract) - ) + or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)) ): self._offset_to_objects[definition.filename][offset].add(thing) @@ -227,8 +224,7 @@ class SlitherCore(Context): and thing.contract_declarer == thing.contract ) or ( - isinstance(thing, ContractLevel) - and not isinstance(thing, FunctionContract) + isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract) ) ): self._offset_to_objects[definition.filename][offset].add(thing) diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 555e716bd..9387f511a 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -49,7 +49,7 @@ class TypeAliasTopLevel(TypeAlias, TopLevel): class TypeAliasContract(TypeAlias, ContractLevel): - def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: + def __init__(self, underlying_type: ElementaryType, name: str, contract: "Contract") -> None: super().__init__(underlying_type, name) self._contract: "Contract" = contract diff --git a/slither/detectors/operations/missing_events_arithmetic.py b/slither/detectors/operations/missing_events_arithmetic.py index df8cc90c3..c17ed32a3 100644 --- a/slither/detectors/operations/missing_events_arithmetic.py +++ b/slither/detectors/operations/missing_events_arithmetic.py @@ -74,9 +74,7 @@ contract C { def _detect_missing_events( self, contract: Contract - ) -> List[ - Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]] - ]: + ) -> List[Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]]]: """ Detects if critical contract parameters set by owners and used in arithmetic are missing events :param contract: The contract to check diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py index c0b8a03d6..49bf6006d 100644 --- a/slither/detectors/statements/tx_origin.py +++ b/slither/detectors/statements/tx_origin.py @@ -61,9 +61,7 @@ Bob is the owner of `TxOrigin`. Bob calls Eve's contract. Eve's contract calls ` ) return False - def detect_tx_origin( - self, contract: Contract - ) -> List[Tuple[FunctionContract, List[Node]]]: + def detect_tx_origin(self, contract: Contract) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in contract.functions: diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py index ce41e3c54..f351f1fdd 100644 --- a/slither/slithir/operations/type_conversion.py +++ b/slither/slithir/operations/type_conversion.py @@ -17,9 +17,7 @@ class TypeConversion(OperationWithLValue): self, result: Union[TemporaryVariableSSA, TemporaryVariable], variable: SourceMapping, - variable_type: Union[ - TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel - ], + variable_type: Union[TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel], ) -> None: super().__init__() assert is_valid_rvalue(variable) or isinstance(variable, Contract) From 8582f9c47bb591b04369f133a13f2ee5117f491a Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 12:01:38 +0100 Subject: [PATCH 06/34] Remove unused visitors --- slither/visitors/expression/find_push.py | 96 ----------------- slither/visitors/expression/left_value.py | 109 ------------------- slither/visitors/expression/right_value.py | 115 --------------------- 3 files changed, 320 deletions(-) delete mode 100644 slither/visitors/expression/find_push.py delete mode 100644 slither/visitors/expression/left_value.py delete mode 100644 slither/visitors/expression/right_value.py diff --git a/slither/visitors/expression/find_push.py b/slither/visitors/expression/find_push.py deleted file mode 100644 index cf2b07e60..000000000 --- a/slither/visitors/expression/find_push.py +++ /dev/null @@ -1,96 +0,0 @@ -from slither.visitors.expression.expression import ExpressionVisitor - -from slither.visitors.expression.right_value import RightValue - -key = "FindPush" - - -def get(expression): - val = expression.context[key] - # we delete the item to reduce memory use - del expression.context[key] - return val - - -def set_val(expression, val): - expression.context[key] = val - - -class FindPush(ExpressionVisitor): - def result(self): - if self._result is None: - self._result = list(set(get(self.expression))) - return self._result - - def _post_assignement_operation(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_binary_operation(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_call_expression(self, expression): - called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] - val = called + args - set_val(expression, val) - - def _post_conditional_expression(self, expression): - if_expr = get(expression.if_expression) - else_expr = get(expression.else_expression) - then_expr = get(expression.then_expression) - val = if_expr + else_expr + then_expr - set_val(expression, val) - - def _post_elementary_type_name_expression(self, expression): - set_val(expression, []) - - # save only identifier expression - def _post_identifier(self, expression): - set_val(expression, []) - - def _post_index_access(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_literal(self, expression): - set_val(expression, []) - - def _post_member_access(self, expression): - val = [] - if expression.member_name == "push": - right = RightValue(expression.expression) - val = right.result() - set_val(expression, val) - - def _post_new_array(self, expression): - set_val(expression, []) - - def _post_new_contract(self, expression): - set_val(expression, []) - - def _post_new_elementary_type(self, expression): - set_val(expression, []) - - def _post_tuple_expression(self, expression): - expressions = [get(e) for e in expression.expressions if e] - val = [item for sublist in expressions for item in sublist] - set_val(expression, val) - - def _post_type_conversion(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) - - def _post_unary_operation(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) diff --git a/slither/visitors/expression/left_value.py b/slither/visitors/expression/left_value.py deleted file mode 100644 index 3b16c8c26..000000000 --- a/slither/visitors/expression/left_value.py +++ /dev/null @@ -1,109 +0,0 @@ -# Return the 'left' value of an expression - -from slither.visitors.expression.expression import ExpressionVisitor - -from slither.core.expressions.assignment_operation import AssignmentOperationType - -from slither.core.variables.variable import Variable - -key = "LeftValue" - - -def get(expression): - val = expression.context[key] - # we delete the item to reduce memory use - del expression.context[key] - return val - - -def set_val(expression, val): - expression.context[key] = val - - -class LeftValue(ExpressionVisitor): - def result(self): - if self._result is None: - self._result = list(set(get(self.expression))) - return self._result - - # overide index access visitor to explore only left part - def _visit_index_access(self, expression): - self._visit_expression(expression.expression_left) - - def _post_assignement_operation(self, expression): - if expression.type != AssignmentOperationType.ASSIGN: - left = get(expression.expression_left) - else: - left = [] - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_binary_operation(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_call_expression(self, expression): - called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] - val = called + args - set_val(expression, val) - - def _post_conditional_expression(self, expression): - if_expr = get(expression.if_expression) - else_expr = get(expression.else_expression) - then_expr = get(expression.then_expression) - val = if_expr + else_expr + then_expr - set_val(expression, val) - - def _post_elementary_type_name_expression(self, expression): - set_val(expression, []) - - # save only identifier expression - def _post_identifier(self, expression): - if isinstance(expression.value, Variable): - set_val(expression, [expression.value]) - # elif isinstance(expression.value, SolidityInbuilt): - # set_val(expression, [expression]) - else: - set_val(expression, []) - - def _post_index_access(self, expression): - left = get(expression.expression_left) - val = left - set_val(expression, val) - - def _post_literal(self, expression): - set_val(expression, []) - - def _post_member_access(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) - - def _post_new_array(self, expression): - set_val(expression, []) - - def _post_new_contract(self, expression): - set_val(expression, []) - - def _post_new_elementary_type(self, expression): - set_val(expression, []) - - def _post_tuple_expression(self, expression): - expressions = [get(e) for e in expression.expressions if e] - val = [item for sublist in expressions for item in sublist] - set_val(expression, val) - - def _post_type_conversion(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) - - def _post_unary_operation(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) diff --git a/slither/visitors/expression/right_value.py b/slither/visitors/expression/right_value.py deleted file mode 100644 index 5a97846bc..000000000 --- a/slither/visitors/expression/right_value.py +++ /dev/null @@ -1,115 +0,0 @@ -# Return the 'right' value of an expression -# On index access, explore the left -# on member access, return the member_name -# a.b.c[d] returns c - -from slither.visitors.expression.expression import ExpressionVisitor - -from slither.core.expressions.assignment_operation import AssignmentOperationType -from slither.core.expressions.expression import Expression - -from slither.core.variables.variable import Variable - -key = "RightValue" - - -def get(expression): - val = expression.context[key] - # we delete the item to reduce memory use - del expression.context[key] - return val - - -def set_val(expression, val): - expression.context[key] = val - - -class RightValue(ExpressionVisitor): - def result(self): - if self._result is None: - self._result = list(set(get(self.expression))) - return self._result - - # overide index access visitor to explore only left part - def _visit_index_access(self, expression): - self._visit_expression(expression.expression_left) - - def _post_assignement_operation(self, expression): - if expression.type != AssignmentOperationType.ASSIGN: - left = get(expression.expression_left) - else: - left = [] - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_binary_operation(self, expression): - left = get(expression.expression_left) - right = get(expression.expression_right) - val = left + right - set_val(expression, val) - - def _post_call_expression(self, expression): - called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] - val = called + args - set_val(expression, val) - - def _post_conditional_expression(self, expression): - if_expr = get(expression.if_expression) - else_expr = get(expression.else_expression) - then_expr = get(expression.then_expression) - val = if_expr + else_expr + then_expr - set_val(expression, val) - - def _post_elementary_type_name_expression(self, expression): - set_val(expression, []) - - # save only identifier expression - def _post_identifier(self, expression): - if isinstance(expression.value, Variable): - set_val(expression, [expression.value]) - # elif isinstance(expression.value, SolidityInbuilt): - # set_val(expression, [expression]) - else: - set_val(expression, []) - - def _post_index_access(self, expression): - left = get(expression.expression_left) - val = left - set_val(expression, val) - - def _post_literal(self, expression): - set_val(expression, []) - - def _post_member_access(self, expression): - val = [] - if isinstance(expression.member_name, Expression): - expr = get(expression.member_name) - val = expr - set_val(expression, val) - - def _post_new_array(self, expression): - set_val(expression, []) - - def _post_new_contract(self, expression): - set_val(expression, []) - - def _post_new_elementary_type(self, expression): - set_val(expression, []) - - def _post_tuple_expression(self, expression): - expressions = [get(e) for e in expression.expressions if e] - val = [item for sublist in expressions for item in sublist] - set_val(expression, val) - - def _post_type_conversion(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) - - def _post_unary_operation(self, expression): - expr = get(expression.expression) - val = expr - set_val(expression, val) From 43a27fb09e77dff725700ba93fc5c7c48303ec22 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 12:02:55 +0100 Subject: [PATCH 07/34] Black --- slither/core/slither_core.py | 8 ++------ slither/detectors/operations/missing_events_arithmetic.py | 4 +--- slither/detectors/statements/tx_origin.py | 4 +--- slither/slithir/operations/type_conversion.py | 4 +--- 4 files changed, 5 insertions(+), 15 deletions(-) diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index e55a9cf0b..548c04b3a 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -206,10 +206,7 @@ class SlitherCore(Context): isinstance(thing, FunctionContract) and thing.contract_declarer == thing.contract ) - or ( - isinstance(thing, ContractLevel) - and not isinstance(thing, FunctionContract) - ) + or (isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract)) ): self._offset_to_objects[definition.filename][offset].add(thing) @@ -227,8 +224,7 @@ class SlitherCore(Context): and thing.contract_declarer == thing.contract ) or ( - isinstance(thing, ContractLevel) - and not isinstance(thing, FunctionContract) + isinstance(thing, ContractLevel) and not isinstance(thing, FunctionContract) ) ): self._offset_to_objects[definition.filename][offset].add(thing) diff --git a/slither/detectors/operations/missing_events_arithmetic.py b/slither/detectors/operations/missing_events_arithmetic.py index e553e78eb..6e1d5fbb5 100644 --- a/slither/detectors/operations/missing_events_arithmetic.py +++ b/slither/detectors/operations/missing_events_arithmetic.py @@ -70,9 +70,7 @@ contract C { def _detect_missing_events( self, contract: Contract - ) -> List[ - Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]] - ]: + ) -> List[Tuple[FunctionContract, List[Tuple[Node, List[Tuple[Node, FunctionContract]]]]]]: """ Detects if critical contract parameters set by owners and used in arithmetic are missing events :param contract: The contract to check diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py index e281c1d09..34f8173d5 100644 --- a/slither/detectors/statements/tx_origin.py +++ b/slither/detectors/statements/tx_origin.py @@ -57,9 +57,7 @@ Bob is the owner of `TxOrigin`. Bob calls Eve's contract. Eve's contract calls ` ) return False - def detect_tx_origin( - self, contract: Contract - ) -> List[Tuple[FunctionContract, List[Node]]]: + def detect_tx_origin(self, contract: Contract) -> List[Tuple[FunctionContract, List[Node]]]: ret = [] for f in contract.functions: diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py index ce41e3c54..f351f1fdd 100644 --- a/slither/slithir/operations/type_conversion.py +++ b/slither/slithir/operations/type_conversion.py @@ -17,9 +17,7 @@ class TypeConversion(OperationWithLValue): self, result: Union[TemporaryVariableSSA, TemporaryVariable], variable: SourceMapping, - variable_type: Union[ - TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel - ], + variable_type: Union[TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel], ) -> None: super().__init__() assert is_valid_rvalue(variable) or isinstance(variable, Contract) From 79921b309674fa9582f2224b0875cdfbe7974ff1 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 15:01:12 +0100 Subject: [PATCH 08/34] More types --- slither/core/compilation_unit.py | 17 +-- slither/core/expressions/call_expression.py | 8 +- .../expressions/conditional_expression.py | 2 +- slither/core/variables/variable.py | 7 +- .../compiler_bugs/array_by_reference.py | 17 +-- slither/detectors/slither/name_reused.py | 5 +- .../statements/assert_state_change.py | 4 +- slither/detectors/statements/calls_in_loop.py | 1 + .../formatters/attributes/const_functions.py | 8 +- .../formatters/attributes/constant_pragma.py | 20 +-- slither/printers/call/call_graph.py | 117 ++++++++++-------- slither/printers/functions/authorization.py | 22 ++-- slither/printers/functions/cfg.py | 7 +- slither/slithir/operations/call.py | 6 +- slither/slithir/operations/codesize.py | 2 +- slither/slithir/operations/condition.py | 24 ++-- slither/slithir/variables/constant.py | 20 +-- slither/tools/mutator/utils/command_line.py | 2 +- slither/tools/similarity/cache.py | 5 +- .../upgradeability/checks/abstract_checks.py | 4 +- .../tools/upgradeability/checks/constant.py | 17 ++- .../upgradeability/utils/command_line.py | 2 +- slither/utils/code_complexity.py | 4 +- slither/utils/colors.py | 4 +- slither/visitors/expression/expression.py | 16 +-- 25 files changed, 189 insertions(+), 152 deletions(-) diff --git a/slither/core/compilation_unit.py b/slither/core/compilation_unit.py index f54f08ab3..8d7167451 100644 --- a/slither/core/compilation_unit.py +++ b/slither/core/compilation_unit.py @@ -57,7 +57,7 @@ class SlitherCompilationUnit(Context): self._storage_layouts: Dict[str, Dict[str, Tuple[int, int]]] = {} - self._contract_with_missing_inheritance = set() + self._contract_with_missing_inheritance: Set[Contract] = set() self._source_units: Dict[int, str] = {} @@ -88,7 +88,8 @@ class SlitherCompilationUnit(Context): @property def solc_version(self) -> str: - return self._crytic_compile_compilation_unit.compiler_version.version + # TODO: make version a non optional argument of compiler version in cc + return self._crytic_compile_compilation_unit.compiler_version.version # type:ignore @property def crytic_compile_compilation_unit(self) -> CompilationUnit: @@ -162,13 +163,14 @@ class SlitherCompilationUnit(Context): @property def functions_and_modifiers(self) -> List[Function]: - return self.functions + self.modifiers + return self.functions + list(self.modifiers) def propagate_function_calls(self) -> None: for f in self.functions_and_modifiers: for node in f.nodes: for ir in node.irs_ssa: if isinstance(ir, InternalCall): + assert ir.function ir.function.add_reachable_from_node(node, ir) # endregion @@ -181,8 +183,8 @@ class SlitherCompilationUnit(Context): @property def state_variables(self) -> List[StateVariable]: if self._all_state_variables is None: - state_variables = [c.state_variables for c in self.contracts] - state_variables = [item for sublist in state_variables for item in sublist] + state_variabless = [c.state_variables for c in self.contracts] + state_variables = [item for sublist in state_variabless for item in sublist] self._all_state_variables = set(state_variables) return list(self._all_state_variables) @@ -229,7 +231,7 @@ class SlitherCompilationUnit(Context): ################################################################################### @property - def contracts_with_missing_inheritance(self) -> Set: + def contracts_with_missing_inheritance(self) -> Set[Contract]: return self._contract_with_missing_inheritance # endregion @@ -266,6 +268,7 @@ class SlitherCompilationUnit(Context): if var.is_constant or var.is_immutable: continue + assert var.type size, new_slot = var.type.storage_size if new_slot: @@ -285,7 +288,7 @@ class SlitherCompilationUnit(Context): else: offset += size - def storage_layout_of(self, contract, var) -> Tuple[int, int]: + def storage_layout_of(self, contract: Contract, var: StateVariable) -> Tuple[int, int]: return self._storage_layouts[contract.name][var.canonical_name] # endregion diff --git a/slither/core/expressions/call_expression.py b/slither/core/expressions/call_expression.py index 1dbc4074a..6708dda7e 100644 --- a/slither/core/expressions/call_expression.py +++ b/slither/core/expressions/call_expression.py @@ -22,7 +22,7 @@ class CallExpression(Expression): # pylint: disable=too-many-instance-attribute return self._value @call_value.setter - def call_value(self, v): + def call_value(self, v: Optional[Expression]) -> None: self._value = v @property @@ -30,15 +30,15 @@ class CallExpression(Expression): # pylint: disable=too-many-instance-attribute return self._gas @call_gas.setter - def call_gas(self, gas): + def call_gas(self, gas: Optional[Expression]) -> None: self._gas = gas @property - def call_salt(self): + def call_salt(self) -> Optional[Expression]: return self._salt @call_salt.setter - def call_salt(self, salt): + def call_salt(self, salt: Optional[Expression]) -> None: self._salt = salt @property diff --git a/slither/core/expressions/conditional_expression.py b/slither/core/expressions/conditional_expression.py index 818425ba1..3c0afdb4a 100644 --- a/slither/core/expressions/conditional_expression.py +++ b/slither/core/expressions/conditional_expression.py @@ -42,7 +42,7 @@ class ConditionalExpression(Expression): def then_expression(self) -> Expression: return self._then_expression - def __str__(self): + def __str__(self) -> str: return ( "if " + str(self._if_expression) diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index c775e7c98..0d610c928 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -77,12 +77,13 @@ class Variable(SourceMapping): self._name = name @property - def type(self) -> Optional[Union[Type, List[Type]]]: + def type(self) -> Optional[Type]: return self._type @type.setter - def type(self, types: Union[Type, List[Type]]): - self._type = types + def type(self, new_type: Type) -> None: + assert isinstance(new_type, Type) + self._type = new_type @property def is_constant(self) -> bool: diff --git a/slither/detectors/compiler_bugs/array_by_reference.py b/slither/detectors/compiler_bugs/array_by_reference.py index ba4cadcc7..04dfe085a 100644 --- a/slither/detectors/compiler_bugs/array_by_reference.py +++ b/slither/detectors/compiler_bugs/array_by_reference.py @@ -2,6 +2,9 @@ Detects the passing of arrays located in memory to functions which expect to modify arrays via storage reference. """ from typing import List, Set, Tuple, Union + +from slither.core.declarations import Function +from slither.core.variables import Variable from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, @@ -93,12 +96,7 @@ As a result, Bob's usage of the contract is incorrect.""" @staticmethod def detect_calls_passing_ref_to_function( contracts: List[Contract], array_modifying_funcs: Set[FunctionContract] - ) -> List[ - Union[ - Tuple[Node, StateVariable, FunctionContract], - Tuple[Node, LocalVariable, FunctionContract], - ] - ]: + ) -> List[Tuple[Node, Variable, Union[Function, Variable]]]: """ Obtains all calls passing storage arrays by value to a function which cannot write to them successfully. :param contracts: The collection of contracts to check for problematic calls in. @@ -109,12 +107,7 @@ As a result, Bob's usage of the contract is incorrect.""" write to the array unsuccessfully. """ # Define our resulting array. - results: List[ - Union[ - Tuple[Node, StateVariable, FunctionContract], - Tuple[Node, LocalVariable, FunctionContract], - ] - ] = [] + results: List[Tuple[Node, Variable, Union[Function, Variable]]] = [] # Verify we have functions in our list to check for. if not array_modifying_funcs: diff --git a/slither/detectors/slither/name_reused.py b/slither/detectors/slither/name_reused.py index e8a40881a..babce6389 100644 --- a/slither/detectors/slither/name_reused.py +++ b/slither/detectors/slither/name_reused.py @@ -1,7 +1,8 @@ from collections import defaultdict -from typing import Any, List +from typing import List from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.declarations import Contract from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, @@ -10,7 +11,7 @@ from slither.detectors.abstract_detector import ( from slither.utils.output import Output -def _find_missing_inheritance(compilation_unit: SlitherCompilationUnit) -> List[Any]: +def _find_missing_inheritance(compilation_unit: SlitherCompilationUnit) -> List[Contract]: """ Filter contracts with missing inheritance to return only the "most base" contracts in the inheritance tree. diff --git a/slither/detectors/statements/assert_state_change.py b/slither/detectors/statements/assert_state_change.py index 62299202e..769d730b8 100644 --- a/slither/detectors/statements/assert_state_change.py +++ b/slither/detectors/statements/assert_state_change.py @@ -40,7 +40,9 @@ def detect_assert_state_change( any( ir for ir in node.irs - if isinstance(ir, InternalCall) and ir.function.state_variables_written + if isinstance(ir, InternalCall) + and ir.function + and ir.function.state_variables_written ) ): results.append((function, node)) diff --git a/slither/detectors/statements/calls_in_loop.py b/slither/detectors/statements/calls_in_loop.py index b3a177ee6..d40d18f59 100644 --- a/slither/detectors/statements/calls_in_loop.py +++ b/slither/detectors/statements/calls_in_loop.py @@ -48,6 +48,7 @@ def call_in_loop( continue ret.append(ir.node) if isinstance(ir, (InternalCall)): + assert ir.function call_in_loop(ir.function.entry_point, in_loop_counter, visited, ret) for son in node.sons: diff --git a/slither/formatters/attributes/const_functions.py b/slither/formatters/attributes/const_functions.py index 310abe0b9..feb404f7b 100644 --- a/slither/formatters/attributes/const_functions.py +++ b/slither/formatters/attributes/const_functions.py @@ -34,8 +34,12 @@ def custom_format(compilation_unit: SlitherCompilationUnit, result: Dict) -> Non def _patch( - compilation_unit: SlitherCompilationUnit, result, in_file, modify_loc_start, modify_loc_end -): + compilation_unit: SlitherCompilationUnit, + result: Dict, + in_file: str, + modify_loc_start: int, + modify_loc_end: int, +) -> None: in_file_str = compilation_unit.core.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] # Find the keywords view|pure|constant and remove them diff --git a/slither/formatters/attributes/constant_pragma.py b/slither/formatters/attributes/constant_pragma.py index 108a8fa08..1127b1e43 100644 --- a/slither/formatters/attributes/constant_pragma.py +++ b/slither/formatters/attributes/constant_pragma.py @@ -1,5 +1,5 @@ import re -from typing import Dict +from typing import Dict, List, Union from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.exceptions import FormatImpossible @@ -21,7 +21,7 @@ PATTERN = re.compile(r"(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)") def custom_format(slither: SlitherCompilationUnit, result: Dict) -> None: elements = result["elements"] - versions_used = [] + versions_used: List[str] = [] for element in elements: versions_used.append("".join(element["type_specific_fields"]["directive"][1:])) solc_version_replace = _analyse_versions(versions_used) @@ -36,7 +36,7 @@ def custom_format(slither: SlitherCompilationUnit, result: Dict) -> None: ) -def _analyse_versions(used_solc_versions): +def _analyse_versions(used_solc_versions: List[str]) -> str: replace_solc_versions = [] for version in used_solc_versions: replace_solc_versions.append(_determine_solc_version_replacement(version)) @@ -45,7 +45,7 @@ def _analyse_versions(used_solc_versions): return replace_solc_versions[0] -def _determine_solc_version_replacement(used_solc_version): +def _determine_solc_version_replacement(used_solc_version: str) -> str: versions = PATTERN.findall(used_solc_version) if len(versions) == 1: version = versions[0] @@ -67,10 +67,16 @@ def _determine_solc_version_replacement(used_solc_version): raise FormatImpossible("Unknown version!") +# pylint: disable=too-many-arguments def _patch( - slither, result, in_file, pragma, modify_loc_start, modify_loc_end -): # pylint: disable=too-many-arguments - in_file_str = slither.source_code[in_file].encode("utf8") + slither: SlitherCompilationUnit, + result: Dict, + in_file: str, + pragma: Union[str, bytes], + modify_loc_start: int, + modify_loc_end: int, +) -> None: + in_file_str = slither.core.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] create_patch( result, diff --git a/slither/printers/call/call_graph.py b/slither/printers/call/call_graph.py index e10db3f76..0a4df0c65 100644 --- a/slither/printers/call/call_graph.py +++ b/slither/printers/call/call_graph.py @@ -6,33 +6,36 @@ The output is a dot file named filename.dot """ from collections import defaultdict -from slither.printers.abstract_printer import AbstractPrinter -from slither.core.declarations.solidity_variables import SolidityFunction +from typing import Optional, Union, Dict, Set, Tuple, Sequence + +from slither.core.declarations import Contract, FunctionContract from slither.core.declarations.function import Function +from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.variables.variable import Variable +from slither.printers.abstract_printer import AbstractPrinter -def _contract_subgraph(contract): +def _contract_subgraph(contract: Contract) -> str: return f"cluster_{contract.id}_{contract.name}" # return unique id for contract function to use as node name -def _function_node(contract, function): +def _function_node(contract: Contract, function: Union[Function, Variable]) -> str: return f"{contract.id}_{function.name}" # return unique id for solidity function to use as node name -def _solidity_function_node(solidity_function): +def _solidity_function_node(solidity_function: SolidityFunction) -> str: return f"{solidity_function.name}" # return dot language string to add graph edge -def _edge(from_node, to_node): +def _edge(from_node: str, to_node: str) -> str: return f'"{from_node}" -> "{to_node}"' # return dot language string to add graph node (with optional label) -def _node(node, label=None): +def _node(node: str, label: Optional[str] = None) -> str: return " ".join( ( f'"{node}"', @@ -43,13 +46,13 @@ def _node(node, label=None): # pylint: disable=too-many-arguments def _process_internal_call( - contract, - function, - internal_call, - contract_calls, - solidity_functions, - solidity_calls, -): + contract: Contract, + function: Function, + internal_call: Union[Function, SolidityFunction], + contract_calls: Dict[Contract, Set[str]], + solidity_functions: Set[str], + solidity_calls: Set[str], +) -> None: if isinstance(internal_call, (Function)): contract_calls[contract].add( _edge( @@ -69,11 +72,15 @@ def _process_internal_call( ) -def _render_external_calls(external_calls): +def _render_external_calls(external_calls: Set[str]) -> str: return "\n".join(external_calls) -def _render_internal_calls(contract, contract_functions, contract_calls): +def _render_internal_calls( + contract: Contract, + contract_functions: Dict[Contract, Set[str]], + contract_calls: Dict[Contract, Set[str]], +) -> str: lines = [] lines.append(f"subgraph {_contract_subgraph(contract)} {{") @@ -87,7 +94,7 @@ def _render_internal_calls(contract, contract_functions, contract_calls): return "\n".join(lines) -def _render_solidity_calls(solidity_functions, solidity_calls): +def _render_solidity_calls(solidity_functions: Set[str], solidity_calls: Set[str]) -> str: lines = [] lines.append("subgraph cluster_solidity {") @@ -102,13 +109,13 @@ def _render_solidity_calls(solidity_functions, solidity_calls): def _process_external_call( - contract, - function, - external_call, - contract_functions, - external_calls, - all_contracts, -): + contract: Contract, + function: Function, + external_call: Tuple[Contract, Union[Function, Variable]], + contract_functions: Dict[Contract, Set[str]], + external_calls: Set[str], + all_contracts: Set[Contract], +) -> None: external_contract, external_function = external_call if not external_contract in all_contracts: @@ -133,15 +140,15 @@ def _process_external_call( # pylint: disable=too-many-arguments def _process_function( - contract, - function, - contract_functions, - contract_calls, - solidity_functions, - solidity_calls, - external_calls, - all_contracts, -): + contract: Contract, + function: Function, + contract_functions: Dict[Contract, Set[str]], + contract_calls: Dict[Contract, Set[str]], + solidity_functions: Set[str], + solidity_calls: Set[str], + external_calls: Set[str], + all_contracts: Set[Contract], +) -> None: contract_functions[contract].add( _node(_function_node(contract, function), function.name), ) @@ -166,29 +173,35 @@ def _process_function( ) -def _process_functions(functions): - contract_functions = defaultdict(set) # contract -> contract functions nodes - contract_calls = defaultdict(set) # contract -> contract calls edges +def _process_functions(functions: Sequence[Function]) -> str: + # TODO add support for top level function + + contract_functions: Dict[Contract, Set[str]] = defaultdict( + set + ) # contract -> contract functions nodes + contract_calls: Dict[Contract, Set[str]] = defaultdict(set) # contract -> contract calls edges - solidity_functions = set() # solidity function nodes - solidity_calls = set() # solidity calls edges - external_calls = set() # external calls edges + solidity_functions: Set[str] = set() # solidity function nodes + solidity_calls: Set[str] = set() # solidity calls edges + external_calls: Set[str] = set() # external calls edges all_contracts = set() for function in functions: - all_contracts.add(function.contract_declarer) + if isinstance(function, FunctionContract): + all_contracts.add(function.contract_declarer) for function in functions: - _process_function( - function.contract_declarer, - function, - contract_functions, - contract_calls, - solidity_functions, - solidity_calls, - external_calls, - all_contracts, - ) + if isinstance(function, FunctionContract): + _process_function( + function.contract_declarer, + function, + contract_functions, + contract_calls, + solidity_functions, + solidity_calls, + external_calls, + all_contracts, + ) render_internal_calls = "" for contract in all_contracts: @@ -241,7 +254,9 @@ class PrinterCallGraph(AbstractPrinter): function.canonical_name: function for function in all_functions } content = "\n".join( - ["strict digraph {"] + [_process_functions(all_functions_as_dict.values())] + ["}"] + ["strict digraph {"] + + [_process_functions(list(all_functions_as_dict.values()))] + + ["}"] ) f.write(content) results.append((all_contracts_filename, content)) diff --git a/slither/printers/functions/authorization.py b/slither/printers/functions/authorization.py index ab61d354e..48b94c297 100644 --- a/slither/printers/functions/authorization.py +++ b/slither/printers/functions/authorization.py @@ -1,10 +1,12 @@ """ Module printing summary of the contract """ +from typing import List from slither.printers.abstract_printer import AbstractPrinter from slither.core.declarations.function import Function from slither.utils.myprettytable import MyPrettyTable +from slither.utils.output import Output class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): @@ -15,11 +17,15 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#variables-written-and-authorization" @staticmethod - def get_msg_sender_checks(function): - all_functions = function.all_internal_calls() + [function] + function.modifiers + def get_msg_sender_checks(function: Function) -> List[str]: + all_functions = ( + [f for f in function.all_internal_calls() if isinstance(f, Function)] + + [function] + + [m for m in function.modifiers if isinstance(m, Function)] + ) - all_nodes = [f.nodes for f in all_functions if isinstance(f, Function)] - all_nodes = [item for sublist in all_nodes for item in sublist] + all_nodes_ = [f.nodes for f in all_functions] + all_nodes = [item for sublist in all_nodes_ for item in sublist] all_conditional_nodes = [ n for n in all_nodes if n.contains_if() or n.contains_require_or_assert() @@ -31,7 +37,7 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): ] return all_conditional_nodes_on_msg_sender - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: @@ -40,7 +46,7 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): txt = "" all_tables = [] - for contract in self.contracts: + for contract in self.contracts: # type: ignore if contract.is_top_level: continue txt += f"\nContract {contract.name}\n" @@ -49,7 +55,9 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): ) for function in contract.functions: - state_variables_written = [v.name for v in function.all_state_variables_written()] + state_variables_written = [ + v.name for v in function.all_state_variables_written() if v.name + ] msg_sender_condition = self.get_msg_sender_checks(function) table.add_row( [ diff --git a/slither/printers/functions/cfg.py b/slither/printers/functions/cfg.py index 03e010ff4..3c75f723f 100644 --- a/slither/printers/functions/cfg.py +++ b/slither/printers/functions/cfg.py @@ -1,4 +1,5 @@ from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output class CFG(AbstractPrinter): @@ -8,7 +9,7 @@ class CFG(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#cfg" - def output(self, filename): + def output(self, filename: str) -> Output: """ _filename is not used Args: @@ -17,10 +18,10 @@ class CFG(AbstractPrinter): info = "" all_files = [] - for contract in self.contracts: + for contract in self.contracts: # type: ignore if contract.is_top_level: continue - for function in contract.functions + contract.modifiers: + for function in contract.functions + list(contract.modifiers): if filename: new_filename = f"{filename}-{contract.name}-{function.full_name}.dot" else: diff --git a/slither/slithir/operations/call.py b/slither/slithir/operations/call.py index 37a2fe0b3..816c56e1d 100644 --- a/slither/slithir/operations/call.py +++ b/slither/slithir/operations/call.py @@ -8,14 +8,14 @@ from slither.slithir.operations.operation import Operation class Call(Operation): def __init__(self) -> None: super().__init__() - self._arguments = [] + self._arguments: List[Variable] = [] @property - def arguments(self): + def arguments(self) -> List[Variable]: return self._arguments @arguments.setter - def arguments(self, v): + def arguments(self, v: List[Variable]) -> None: self._arguments = v # pylint: disable=no-self-use diff --git a/slither/slithir/operations/codesize.py b/slither/slithir/operations/codesize.py index 6640f4fd8..13aa430eb 100644 --- a/slither/slithir/operations/codesize.py +++ b/slither/slithir/operations/codesize.py @@ -29,5 +29,5 @@ class CodeSize(OperationWithLValue): def value(self) -> LocalVariable: return self._value - def __str__(self): + def __str__(self) -> str: return f"{self.lvalue} -> CODESIZE {self.value}" diff --git a/slither/slithir/operations/condition.py b/slither/slithir/operations/condition.py index 41fb3d933..ccec033d9 100644 --- a/slither/slithir/operations/condition.py +++ b/slither/slithir/operations/condition.py @@ -1,13 +1,7 @@ -from typing import List, Union -from slither.slithir.operations.operation import Operation +from typing import List -from slither.slithir.utils.utils import is_valid_rvalue -from slither.core.variables.local_variable import LocalVariable -from slither.slithir.variables.constant import Constant -from slither.slithir.variables.local_variable import LocalIRVariable -from slither.slithir.variables.temporary import TemporaryVariable -from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA -from slither.core.variables.variable import Variable +from slither.slithir.operations.operation import Operation +from slither.slithir.utils.utils import is_valid_rvalue, RVALUE class Condition(Operation): @@ -18,9 +12,7 @@ class Condition(Operation): def __init__( self, - value: Union[ - LocalVariable, TemporaryVariableSSA, TemporaryVariable, Constant, LocalIRVariable - ], + value: RVALUE, ) -> None: assert is_valid_rvalue(value) super().__init__() @@ -29,14 +21,12 @@ class Condition(Operation): @property def read( self, - ) -> List[ - Union[LocalIRVariable, Constant, LocalVariable, TemporaryVariableSSA, TemporaryVariable] - ]: + ) -> List[RVALUE]: return [self.value] @property - def value(self) -> Variable: + def value(self) -> RVALUE: return self._value - def __str__(self): + def __str__(self) -> str: return f"CONDITION {self.value}" diff --git a/slither/slithir/variables/constant.py b/slither/slithir/variables/constant.py index ddfc9e054..5321e5250 100644 --- a/slither/slithir/variables/constant.py +++ b/slither/slithir/variables/constant.py @@ -28,7 +28,7 @@ class Constant(SlithIRVariable): assert isinstance(constant_type, ElementaryType) self._type = constant_type if constant_type.type in Int + Uint + ["address"]: - self._val = convert_string_to_int(val) + self._val: Union[bool, int, str] = convert_string_to_int(val) elif constant_type.type == "bool": self._val = (val == "true") | (val == "True") else: @@ -41,6 +41,8 @@ class Constant(SlithIRVariable): self._type = ElementaryType("string") self._val = val + self._name = str(self._val) + @property def value(self) -> Union[bool, int, str]: """ @@ -63,20 +65,18 @@ class Constant(SlithIRVariable): def __str__(self) -> str: return str(self.value) - @property - def name(self) -> str: - return str(self) - - def __eq__(self, other: Union["Constant", str]) -> bool: + def __eq__(self, other: object) -> bool: return self.value == other - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return self.value != other - def __lt__(self, other): - return self.value < other + def __lt__(self, other: object) -> bool: + if not isinstance(other, (Constant, str)): + raise NotImplementedError + return self.value < other # type: ignore - def __repr__(self): + def __repr__(self) -> str: return f"{str(self.value)}" def __hash__(self) -> int: diff --git a/slither/tools/mutator/utils/command_line.py b/slither/tools/mutator/utils/command_line.py index 840976ccf..feb479c5c 100644 --- a/slither/tools/mutator/utils/command_line.py +++ b/slither/tools/mutator/utils/command_line.py @@ -18,6 +18,6 @@ def output_mutators(mutators_classes: List[Type[AbstractMutator]]) -> None: mutators_list = sorted(mutators_list, key=lambda element: (element[2], element[3], element[0])) idx = 1 for (argument, help_info, fault_class, fault_nature) in mutators_list: - table.add_row([idx, argument, help_info, fault_class, fault_nature]) + table.add_row([str(idx), argument, help_info, fault_class, fault_nature]) idx = idx + 1 print(table) diff --git a/slither/tools/similarity/cache.py b/slither/tools/similarity/cache.py index 53fc7f5f0..ccd64b84b 100644 --- a/slither/tools/similarity/cache.py +++ b/slither/tools/similarity/cache.py @@ -1,4 +1,5 @@ import sys +from typing import Dict, Optional try: import numpy as np @@ -8,7 +9,7 @@ except ImportError: sys.exit(-1) -def load_cache(infile, nsamples=None): +def load_cache(infile: str, nsamples: Optional[int] = None) -> Dict: cache = {} with np.load(infile, allow_pickle=True) as data: array = data["arr_0"][0] @@ -20,5 +21,5 @@ def load_cache(infile, nsamples=None): return cache -def save_cache(cache, outfile): +def save_cache(cache: Dict, outfile: str) -> None: np.savez(outfile, [np.array(cache)]) diff --git a/slither/tools/upgradeability/checks/abstract_checks.py b/slither/tools/upgradeability/checks/abstract_checks.py index 016be2647..a3ab137a3 100644 --- a/slither/tools/upgradeability/checks/abstract_checks.py +++ b/slither/tools/upgradeability/checks/abstract_checks.py @@ -34,6 +34,8 @@ classification_txt = { CheckClassification.HIGH: "High", } +CHECK_INFO = List[Union[str, SupportedOutput]] + class AbstractCheck(metaclass=abc.ABCMeta): ARGUMENT = "" @@ -140,7 +142,7 @@ class AbstractCheck(metaclass=abc.ABCMeta): def generate_result( self, - info: Union[str, List[Union[str, SupportedOutput]]], + info: CHECK_INFO, additional_fields: Optional[Dict] = None, ) -> Output: output = Output( diff --git a/slither/tools/upgradeability/checks/constant.py b/slither/tools/upgradeability/checks/constant.py index a5a80bf5a..bd9814649 100644 --- a/slither/tools/upgradeability/checks/constant.py +++ b/slither/tools/upgradeability/checks/constant.py @@ -1,7 +1,11 @@ +from typing import List + from slither.tools.upgradeability.checks.abstract_checks import ( AbstractCheck, CheckClassification, + CHECK_INFO, ) +from slither.utils.output import Output class WereConstant(AbstractCheck): @@ -47,10 +51,12 @@ Do not remove `constant` from a state variables during an update. REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True - def _check(self): + def _check(self) -> List[Output]: contract_v1 = self.contract contract_v2 = self.contract_v2 + if contract_v2 is None: + raise Exception("were-constant requires a V2 contract") state_variables_v1 = contract_v1.state_variables state_variables_v2 = contract_v2.state_variables @@ -81,7 +87,7 @@ Do not remove `constant` from a state variables during an update. v2_additional_variables -= 1 idx_v2 += 1 continue - info = [state_v1, " was constant, but ", state_v2, "is not.\n"] + info: CHECK_INFO = [state_v1, " was constant, but ", state_v2, "is not.\n"] json = self.generate_result(info) results.append(json) @@ -134,10 +140,13 @@ Do not make an existing state variable `constant`. REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True - def _check(self): + def _check(self) -> List[Output]: contract_v1 = self.contract contract_v2 = self.contract_v2 + if contract_v2 is None: + raise Exception("became-constant requires a V2 contract") + state_variables_v1 = contract_v1.state_variables state_variables_v2 = contract_v2.state_variables @@ -169,7 +178,7 @@ Do not make an existing state variable `constant`. idx_v2 += 1 continue elif state_v2.is_constant: - info = [state_v1, " was not constant but ", state_v2, " is.\n"] + info: CHECK_INFO = [state_v1, " was not constant but ", state_v2, " is.\n"] json = self.generate_result(info) results.append(json) diff --git a/slither/tools/upgradeability/utils/command_line.py b/slither/tools/upgradeability/utils/command_line.py index 88b61ceed..c5767a522 100644 --- a/slither/tools/upgradeability/utils/command_line.py +++ b/slither/tools/upgradeability/utils/command_line.py @@ -63,7 +63,7 @@ def output_detectors(detector_classes: List[Type[AbstractCheck]]) -> None: def output_to_markdown(detector_classes: List[Type[AbstractCheck]], _filter_wiki: str) -> None: - def extract_help(cls: AbstractCheck) -> str: + def extract_help(cls: Type[AbstractCheck]) -> str: if cls.WIKI == "": return cls.HELP return f"[{cls.HELP}]({cls.WIKI})" diff --git a/slither/utils/code_complexity.py b/slither/utils/code_complexity.py index a389663b3..aa7838499 100644 --- a/slither/utils/code_complexity.py +++ b/slither/utils/code_complexity.py @@ -35,7 +35,7 @@ def compute_strongly_connected_components(function: "Function") -> List[List["No components = [] l = [] - def visit(node): + def visit(node: "Node") -> None: if not visited[node]: visited[node] = True for son in node.sons: @@ -45,7 +45,7 @@ def compute_strongly_connected_components(function: "Function") -> List[List["No for n in function.nodes: visit(n) - def assign(node: "Node", root: List["Node"]): + def assign(node: "Node", root: List["Node"]) -> None: if not assigned[node]: assigned[node] = True root.append(node) diff --git a/slither/utils/colors.py b/slither/utils/colors.py index 5d688489b..1a2ff1da3 100644 --- a/slither/utils/colors.py +++ b/slither/utils/colors.py @@ -28,7 +28,7 @@ def enable_windows_virtual_terminal_sequences() -> bool: try: # pylint: disable=import-outside-toplevel - from ctypes import windll, byref + from ctypes import windll, byref # type: ignore from ctypes.wintypes import DWORD, HANDLE kernel32 = windll.kernel32 @@ -65,7 +65,7 @@ def enable_windows_virtual_terminal_sequences() -> bool: return True -def set_colorization_enabled(enabled: bool): +def set_colorization_enabled(enabled: bool) -> None: """ Sets the enabled state of output colorization. :param enabled: Boolean indicating whether output should be colorized. diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index 464ea1285..0bdd123a3 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Any +from typing import Any from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation @@ -29,7 +29,7 @@ class ExpressionVisitor: self._result: Any = None self._visit_expression(self.expression) - def result(self) -> Optional[bool]: + def result(self) -> Any: return self._result @property @@ -146,7 +146,7 @@ class ExpressionVisitor: def _visit_new_contract(self, expression: NewContract) -> None: pass - def _visit_new_elementary_type(self, expression): + def _visit_new_elementary_type(self, expression: Expression) -> None: pass def _visit_tuple_expression(self, expression: TupleExpression) -> None: @@ -162,7 +162,7 @@ class ExpressionVisitor: # pre visit - def _pre_visit(self, expression) -> None: # pylint: disable=too-many-branches + def _pre_visit(self, expression: Expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._pre_assignement_operation(expression) @@ -251,7 +251,7 @@ class ExpressionVisitor: def _pre_new_contract(self, expression: NewContract) -> None: pass - def _pre_new_elementary_type(self, expression): + def _pre_new_elementary_type(self, expression: NewElementaryType) -> None: pass def _pre_tuple_expression(self, expression: TupleExpression) -> None: @@ -265,7 +265,7 @@ class ExpressionVisitor: # post visit - def _post_visit(self, expression) -> None: # pylint: disable=too-many-branches + def _post_visit(self, expression: Expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._post_assignement_operation(expression) @@ -328,7 +328,7 @@ class ExpressionVisitor: def _post_call_expression(self, expression: CallExpression) -> None: pass - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: pass def _post_elementary_type_name_expression( @@ -354,7 +354,7 @@ class ExpressionVisitor: def _post_new_contract(self, expression: NewContract) -> None: pass - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: pass def _post_tuple_expression(self, expression: TupleExpression) -> None: From 561148408ea525f25cfc5961a75b5dadb8f35412 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 16:03:28 +0100 Subject: [PATCH 09/34] Improve expression visitors --- slither/detectors/statements/unary.py | 20 +++++-- .../visitors/expression/constants_folding.py | 54 ++++++++++++------- slither/visitors/expression/export_values.py | 37 +++++++++---- slither/visitors/expression/expression.py | 21 ++++---- .../visitors/expression/expression_printer.py | 50 ++++++++++------- slither/visitors/expression/find_calls.py | 13 +++-- .../visitors/expression/has_conditional.py | 10 ++-- slither/visitors/expression/read_var.py | 16 ++++-- slither/visitors/expression/write_var.py | 10 +++- 9 files changed, 150 insertions(+), 81 deletions(-) diff --git a/slither/detectors/statements/unary.py b/slither/detectors/statements/unary.py index 5bb8d9c3c..2a8d78a34 100644 --- a/slither/detectors/statements/unary.py +++ b/slither/detectors/statements/unary.py @@ -4,29 +4,39 @@ Module detecting the incorrect use of unary expressions from typing import List from slither.core.expressions.assignment_operation import AssignmentOperation +from slither.core.expressions.expression import Expression from slither.core.expressions.unary_operation import UnaryOperationType, UnaryOperation from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.utils.output import Output from slither.visitors.expression.expression import ExpressionVisitor - +# pylint: disable=too-few-public-methods class InvalidUnaryExpressionDetector(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self.result: bool = False + super().__init__(expression) + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: if isinstance(expression.expression_right, UnaryOperation): if expression.expression_right.type == UnaryOperationType.PLUS_PRE: # This is defined in ExpressionVisitor but pylint # Seems to think its not # pylint: disable=attribute-defined-outside-init - self._result = True + self.result = True +# pylint: disable=too-few-public-methods class InvalidUnaryStateVariableDetector(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self.result: bool = False + super().__init__(expression) + def _post_unary_operation(self, expression: UnaryOperation) -> None: if expression.type == UnaryOperationType.PLUS_PRE: # This is defined in ExpressionVisitor but pylint # Seems to think its not # pylint: disable=attribute-defined-outside-init - self._result = True + self.result = True class IncorrectUnaryExpressionDetection(AbstractDetector): @@ -72,7 +82,7 @@ contract Bug{ for variable in c.state_variables: if ( variable.expression - and InvalidUnaryStateVariableDetector(variable.expression).result() + and InvalidUnaryStateVariableDetector(variable.expression).result ): info = [variable, f" uses an dangerous unary operator: {variable.expression}\n"] json = self.generate_result(info) @@ -80,7 +90,7 @@ contract Bug{ for f in c.functions_and_modifiers_declared: for node in f.nodes: - if node.expression and InvalidUnaryExpressionDetector(node.expression).result(): + if node.expression and InvalidUnaryExpressionDetector(node.expression).result: info = [node.function, " uses an dangerous unary operator: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 5f419ef99..158306f39 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -1,6 +1,7 @@ from fractions import Fraction -from typing import Union, TYPE_CHECKING +from typing import Union +from slither.core import expressions from slither.core.expressions import ( BinaryOperationType, Literal, @@ -14,9 +15,7 @@ from slither.core.expressions import ( from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor - -if TYPE_CHECKING: - from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.solidity_types.elementary_type import ElementaryType class NotConstant(Exception): @@ -45,11 +44,19 @@ class ConstantFolding(ExpressionVisitor): def __init__( self, expression: CONSTANT_TYPES_OPERATIONS, custom_type: Union[str, "ElementaryType"] ) -> None: - self._type = custom_type + if isinstance(custom_type, str): + custom_type = ElementaryType(custom_type) + self._type: ElementaryType = custom_type super().__init__(expression) + @property + def expression(self) -> CONSTANT_TYPES_OPERATIONS: + # We make the assumption that the expression is always a CONSTANT_TYPES_OPERATIONS + # Other expression are not supported for constant unfolding + return self._expression # type: ignore + def result(self) -> "Literal": - value = get_val(self._expression) + value = get_val(self.expression) if isinstance(value, Fraction): value = int(value) # emulate 256-bit wrapping @@ -62,9 +69,13 @@ class ConstantFolding(ExpressionVisitor): raise NotConstant expr = expression.value.expression # assumption that we won't have infinite loop - if not isinstance(expr, Literal): + # Everything outside of literal + if isinstance( + expr, (BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion) + ): cf = ConstantFolding(expr, self._type) expr = cf.result() + assert isinstance(expr, Literal) set_val(expression, convert_string_to_int(expr.converted_value)) # pylint: disable=too-many-branches @@ -118,7 +129,10 @@ class ConstantFolding(ExpressionVisitor): # Case of uint a = -7; uint[-a] arr; if expression.type == UnaryOperationType.MINUS_PRE: expr = expression.expression - if not isinstance(expr, Literal): + # Everything outside of literal + if isinstance( + expr, (BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion) + ): cf = ConstantFolding(expr, self._type) expr = cf.result() assert isinstance(expr, Literal) @@ -135,34 +149,36 @@ class ConstantFolding(ExpressionVisitor): except ValueError as e: raise NotConstant from e - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: expressions.AssignmentOperation) -> None: raise NotConstant - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: expressions.CallExpression) -> None: raise NotConstant - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: expressions.ConditionalExpression) -> None: raise NotConstant - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: expressions.ElementaryTypeNameExpression + ) -> None: raise NotConstant - def _post_index_access(self, expression): + def _post_index_access(self, expression: expressions.IndexAccess) -> None: raise NotConstant - def _post_member_access(self, expression): + def _post_member_access(self, expression: expressions.MemberAccess) -> None: raise NotConstant - def _post_new_array(self, expression): + def _post_new_array(self, expression: expressions.NewArray) -> None: raise NotConstant - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: expressions.NewContract) -> None: raise NotConstant - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: expressions.NewElementaryType) -> None: raise NotConstant - def _post_tuple_expression(self, expression): + def _post_tuple_expression(self, expression: expressions.TupleExpression) -> None: if expression.expressions: if len(expression.expressions) == 1: cf = ConstantFolding(expression.expressions[0], self._type) @@ -172,7 +188,7 @@ class ConstantFolding(ExpressionVisitor): return raise NotConstant - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: cf = ConstantFolding(expression.expression, self._type) expr = cf.result() assert isinstance(expr, Literal) diff --git a/slither/visitors/expression/export_values.py b/slither/visitors/expression/export_values.py index f5ca39a96..0c51e7831 100644 --- a/slither/visitors/expression/export_values.py +++ b/slither/visitors/expression/export_values.py @@ -1,4 +1,15 @@ -from typing import Any, List +from typing import Any, List, Optional + +from slither.core.expressions import ( + AssignmentOperation, + ConditionalExpression, + ElementaryTypeNameExpression, + IndexAccess, + NewArray, + NewContract, + UnaryOperation, + NewElementaryType, +) from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.identifier import Identifier @@ -25,12 +36,16 @@ def set_val(expression: Expression, val: List[Any]) -> None: class ExportValues(ExpressionVisitor): - def result(self) -> List[Any]: + def __init__(self, expression: Expression) -> None: + self._result: Optional[List[Expression]] = None + super().__init__(expression) + + def result(self) -> List[Expression]: if self._result is None: self._result = list(set(get(self.expression))) return self._result - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -49,20 +64,22 @@ class ExportValues(ExpressionVisitor): val = called + args set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) val = if_expr + else_expr + then_expr set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: ElementaryTypeNameExpression + ) -> None: set_val(expression, []) def _post_identifier(self, expression: Identifier) -> None: set_val(expression, [expression.value]) - def _post_index_access(self, expression): + def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = left + right @@ -76,13 +93,13 @@ class ExportValues(ExpressionVisitor): val = expr set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: NewArray) -> None: set_val(expression, []) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: set_val(expression, []) def _post_tuple_expression(self, expression: TupleExpression) -> None: @@ -95,7 +112,7 @@ class ExportValues(ExpressionVisitor): val = expr set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: UnaryOperation) -> None: expr = get(expression.expression) val = expr set_val(expression, val) diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index 464ea1285..41886a102 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -1,5 +1,4 @@ import logging -from typing import Optional, Any from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation @@ -22,16 +21,14 @@ from slither.exceptions import SlitherError logger = logging.getLogger("ExpressionVisitor") +# pylint: disable=too-few-public-methods class ExpressionVisitor: def __init__(self, expression: Expression) -> None: - # Inherited class must declared their variables prior calling super().__init__ + super().__init__() + # Inherited class must declare their variables prior calling super().__init__ self._expression = expression - self._result: Any = None self._visit_expression(self.expression) - def result(self) -> Optional[bool]: - return self._result - @property def expression(self) -> Expression: return self._expression @@ -146,7 +143,7 @@ class ExpressionVisitor: def _visit_new_contract(self, expression: NewContract) -> None: pass - def _visit_new_elementary_type(self, expression): + def _visit_new_elementary_type(self, expression: NewElementaryType) -> None: pass def _visit_tuple_expression(self, expression: TupleExpression) -> None: @@ -162,7 +159,7 @@ class ExpressionVisitor: # pre visit - def _pre_visit(self, expression) -> None: # pylint: disable=too-many-branches + def _pre_visit(self, expression: Expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._pre_assignement_operation(expression) @@ -251,7 +248,7 @@ class ExpressionVisitor: def _pre_new_contract(self, expression: NewContract) -> None: pass - def _pre_new_elementary_type(self, expression): + def _pre_new_elementary_type(self, expression: NewElementaryType) -> None: pass def _pre_tuple_expression(self, expression: TupleExpression) -> None: @@ -265,7 +262,7 @@ class ExpressionVisitor: # post visit - def _post_visit(self, expression) -> None: # pylint: disable=too-many-branches + def _post_visit(self, expression: Expression) -> None: # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._post_assignement_operation(expression) @@ -328,7 +325,7 @@ class ExpressionVisitor: def _post_call_expression(self, expression: CallExpression) -> None: pass - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: pass def _post_elementary_type_name_expression( @@ -354,7 +351,7 @@ class ExpressionVisitor: def _post_new_contract(self, expression: NewContract) -> None: pass - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: pass def _post_tuple_expression(self, expression: TupleExpression) -> None: diff --git a/slither/visitors/expression/expression_printer.py b/slither/visitors/expression/expression_printer.py index 317e1ace6..601627c02 100644 --- a/slither/visitors/expression/expression_printer.py +++ b/slither/visitors/expression/expression_printer.py @@ -1,97 +1,107 @@ +from typing import Optional + +from slither.core import expressions +from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor -def get(expression): +def get(expression: Expression) -> str: val = expression.context["ExpressionPrinter"] # we delete the item to reduce memory use del expression.context["ExpressionPrinter"] return val -def set_val(expression, val): +def set_val(expression: Expression, val: str) -> None: expression.context["ExpressionPrinter"] = val class ExpressionPrinter(ExpressionVisitor): - def result(self): + def __init__(self, expression: Expression) -> None: + self._result: Optional[str] = None + super().__init__(expression) + + def result(self) -> str: if not self._result: self._result = get(self.expression) return self._result - def _post_assignement_operation(self, expression): + def _post_assignement_operation(self, expression: expressions.AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = f"{left} {expression.type} {right}" set_val(expression, val) - def _post_binary_operation(self, expression): + def _post_binary_operation(self, expression: expressions.BinaryOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = f"{left} {expression.type} {right}" set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression: expressions.CallExpression) -> None: called = get(expression.called) arguments = ",".join([get(x) for x in expression.arguments if x]) val = f"{called}({arguments})" set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: expressions.ConditionalExpression) -> None: if_expr = get(expression.if_expression) else_expr = get(expression.else_expression) then_expr = get(expression.then_expression) val = f"if {if_expr} then {else_expr} else {then_expr}" set_val(expression, val) - def _post_elementary_type_name_expression(self, expression): + def _post_elementary_type_name_expression( + self, expression: expressions.ElementaryTypeNameExpression + ) -> None: set_val(expression, str(expression.type)) - def _post_identifier(self, expression): + def _post_identifier(self, expression: expressions.Identifier) -> None: set_val(expression, str(expression.value)) - def _post_index_access(self, expression): + def _post_index_access(self, expression: expressions.IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) val = f"{left}[{right}]" set_val(expression, val) - def _post_literal(self, expression): + def _post_literal(self, expression: expressions.Literal) -> None: set_val(expression, str(expression.value)) - def _post_member_access(self, expression): + def _post_member_access(self, expression: expressions.MemberAccess) -> None: expr = get(expression.expression) member_name = str(expression.member_name) val = f"{expr}.{member_name}" set_val(expression, val) - def _post_new_array(self, expression): + def _post_new_array(self, expression: expressions.NewArray) -> None: array = str(expression.array_type) depth = expression.depth val = f"new {array}{'[]' * depth}" set_val(expression, val) - def _post_new_contract(self, expression): + def _post_new_contract(self, expression: expressions.NewContract) -> None: contract = str(expression.contract_name) val = f"new {contract}" set_val(expression, val) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: expressions.NewElementaryType) -> None: t = str(expression.type) val = f"new {t}" set_val(expression, val) - def _post_tuple_expression(self, expression): - expressions = [get(e) for e in expression.expressions if e] - val = f"({','.join(expressions)})" + def _post_tuple_expression(self, expression: expressions.TupleExpression) -> None: + underlying_expressions = [get(e) for e in expression.expressions if e] + val = f"({','.join(underlying_expressions)})" set_val(expression, val) - def _post_type_conversion(self, expression): + def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: t = str(expression.type) expr = get(expression.expression) val = f"{t}({expr})" set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression: expressions.UnaryOperation) -> None: t = str(expression.type) expr = get(expression.expression) if expression.is_prefix: diff --git a/slither/visitors/expression/find_calls.py b/slither/visitors/expression/find_calls.py index 6653a9759..ce00533ed 100644 --- a/slither/visitors/expression/find_calls.py +++ b/slither/visitors/expression/find_calls.py @@ -1,5 +1,6 @@ -from typing import Any, Union, List +from typing import Any, Union, List, Optional +from slither.core.expressions import NewElementaryType from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import AssignmentOperation @@ -32,6 +33,10 @@ def set_val(expression: Expression, val: List[Union[Any, CallExpression]]) -> No class FindCalls(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self._result: Optional[List[Expression]] = None + super().__init__(expression) + def result(self) -> List[Expression]: if self._result is None: self._result = list(set(get(self.expression))) @@ -51,8 +56,8 @@ class FindCalls(ExpressionVisitor): def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] + argss = [get(a) for a in expression.arguments if a] + args = [item for sublist in argss for item in sublist] val = called + args val += [expression] set_val(expression, val) @@ -93,7 +98,7 @@ class FindCalls(ExpressionVisitor): def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: set_val(expression, []) def _post_tuple_expression(self, expression: TupleExpression) -> None: diff --git a/slither/visitors/expression/has_conditional.py b/slither/visitors/expression/has_conditional.py index b866a696b..613138533 100644 --- a/slither/visitors/expression/has_conditional.py +++ b/slither/visitors/expression/has_conditional.py @@ -1,13 +1,15 @@ +from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.conditional_expression import ConditionalExpression class HasConditional(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self._result: bool = False + super().__init__(expression) + def result(self) -> bool: - # == True, to convert None to false - return self._result is True + return self._result def _post_conditional_expression(self, expression: ConditionalExpression) -> None: - # if self._result is True: - # raise('Slither does not support nested ternary operator') self._result = True diff --git a/slither/visitors/expression/read_var.py b/slither/visitors/expression/read_var.py index e8f5aae67..a0efdde61 100644 --- a/slither/visitors/expression/read_var.py +++ b/slither/visitors/expression/read_var.py @@ -1,5 +1,6 @@ -from typing import Any, List, Union +from typing import Any, List, Union, Optional +from slither.core.expressions import NewElementaryType from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import ( @@ -40,7 +41,11 @@ def set_val(expression: Expression, val: List[Union[Identifier, IndexAccess, Any class ReadVar(ExpressionVisitor): - def result(self) -> List[Union[Identifier, IndexAccess, Any]]: + def __init__(self, expression: Expression) -> None: + self._result: Optional[List[Expression]] = None + super().__init__(expression) + + def result(self) -> List[Expression]: if self._result is None: self._result = list(set(get(self.expression))) return self._result @@ -69,8 +74,8 @@ class ReadVar(ExpressionVisitor): def _post_call_expression(self, expression: CallExpression) -> None: called = get(expression.called) - args = [get(a) for a in expression.arguments if a] - args = [item for sublist in args for item in sublist] + argss = [get(a) for a in expression.arguments if a] + args = [item for sublist in argss for item in sublist] val = called + args set_val(expression, val) @@ -91,6 +96,7 @@ class ReadVar(ExpressionVisitor): if isinstance(expression.value, Variable): set_val(expression, [expression]) elif isinstance(expression.value, SolidityVariable): + # TODO: investigate if this branch can be reached, and if Identifier.value has the correct type set_val(expression, [expression]) else: set_val(expression, []) @@ -115,7 +121,7 @@ class ReadVar(ExpressionVisitor): def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: set_val(expression, []) def _post_tuple_expression(self, expression: TupleExpression) -> None: diff --git a/slither/visitors/expression/write_var.py b/slither/visitors/expression/write_var.py index 97d3858e7..1c0b6108f 100644 --- a/slither/visitors/expression/write_var.py +++ b/slither/visitors/expression/write_var.py @@ -1,4 +1,6 @@ -from typing import Any, List +from typing import Any, List, Optional + +from slither.core.expressions import NewElementaryType from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation @@ -32,6 +34,10 @@ def set_val(expression: Expression, val: List[Any]) -> None: class WriteVar(ExpressionVisitor): + def __init__(self, expression: Expression) -> None: + self._result: Optional[List[Expression]] = None + super().__init__(expression) + def result(self) -> List[Any]: if self._result is None: self._result = list(set(get(self.expression))) @@ -123,7 +129,7 @@ class WriteVar(ExpressionVisitor): def _post_new_contract(self, expression: NewContract) -> None: set_val(expression, []) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: set_val(expression, []) def _post_tuple_expression(self, expression: TupleExpression) -> None: From 060f550b0d55357c6b0e747b82115981a0ac9713 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 17:34:46 +0100 Subject: [PATCH 10/34] Fix more types --- slither/core/expressions/identifier.py | 62 ++++++++++- slither/core/expressions/index_access.py | 15 +-- slither/core/expressions/literal.py | 4 +- slither/core/solidity_types/array_type.py | 8 +- slither/slithir/convert.py | 4 +- slither/slithir/operations/binary.py | 6 +- slither/slithir/operations/index.py | 14 +-- slither/slithir/operations/type_conversion.py | 25 ++--- slither/slithir/utils/ssa.py | 3 +- .../expressions/expression_parsing.py | 9 +- .../visitors/expression/constants_folding.py | 85 ++++++++++++--- .../visitors/slithir/expression_to_slithir.py | 101 +++++++++++------- 12 files changed, 223 insertions(+), 113 deletions(-) diff --git a/slither/core/expressions/identifier.py b/slither/core/expressions/identifier.py index 58a1174af..8ffabad89 100644 --- a/slither/core/expressions/identifier.py +++ b/slither/core/expressions/identifier.py @@ -1,16 +1,58 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union +from slither.core.declarations.contract_level import ContractLevel +from slither.core.declarations.top_level import TopLevel from slither.core.expressions.expression import Expression +from slither.core.variables.variable import Variable + if TYPE_CHECKING: - from slither.core.variables.variable import Variable from slither.core.solidity_types.type import Type + from slither.core.declarations import Contract, SolidityVariable, SolidityFunction + from slither.solc_parsing.yul.evm_functions import YulBuiltin class Identifier(Expression): - def __init__(self, value) -> None: + def __init__( + self, + value: Union[ + Variable, + "TopLevel", + "ContractLevel", + "Contract", + "SolidityVariable", + "SolidityFunction", + "YulBuiltin", + ], + ) -> None: super().__init__() - self._value: "Variable" = value + + # pylint: disable=import-outside-toplevel + from slither.core.declarations import Contract, SolidityVariable, SolidityFunction + from slither.solc_parsing.yul.evm_functions import YulBuiltin + + assert isinstance( + value, + ( + Variable, + TopLevel, + ContractLevel, + Contract, + SolidityVariable, + SolidityFunction, + YulBuiltin, + ), + ) + + self._value: Union[ + Variable, + "TopLevel", + "ContractLevel", + "Contract", + "SolidityVariable", + "SolidityFunction", + "YulBuiltin", + ] = value self._type: Optional["Type"] = None @property @@ -22,7 +64,17 @@ class Identifier(Expression): self._type = new_type @property - def value(self) -> "Variable": + def value( + self, + ) -> Union[ + Variable, + "TopLevel", + "ContractLevel", + "Contract", + "SolidityVariable", + "SolidityFunction", + "YulBuiltin", + ]: return self._value def __str__(self) -> str: diff --git a/slither/core/expressions/index_access.py b/slither/core/expressions/index_access.py index f8e630a6e..22f014242 100644 --- a/slither/core/expressions/index_access.py +++ b/slither/core/expressions/index_access.py @@ -1,11 +1,8 @@ -from typing import Union, List, TYPE_CHECKING +from typing import Union, List +from slither.core.expressions.expression import Expression from slither.core.expressions.identifier import Identifier from slither.core.expressions.literal import Literal -from slither.core.expressions.expression import Expression - -if TYPE_CHECKING: - from slither.core.solidity_types.type import Type class IndexAccess(Expression): @@ -13,13 +10,9 @@ class IndexAccess(Expression): self, left_expression: Union["IndexAccess", Identifier], right_expression: Union[Literal, Identifier], - index_type: str, ) -> None: super().__init__() self._expressions = [left_expression, right_expression] - # TODO type of undexAccess is not always a Type - # assert isinstance(index_type, Type) - self._type: "Type" = index_type @property def expressions(self) -> List["Expression"]: @@ -33,9 +26,5 @@ class IndexAccess(Expression): def expression_right(self) -> "Expression": return self._expressions[1] - @property - def type(self) -> "Type": - return self._type - def __str__(self) -> str: return str(self.expression_left) + "[" + str(self.expression_right) + "]" diff --git a/slither/core/expressions/literal.py b/slither/core/expressions/literal.py index 5dace3c41..8848ce966 100644 --- a/slither/core/expressions/literal.py +++ b/slither/core/expressions/literal.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, TYPE_CHECKING +from typing import Optional, Union, TYPE_CHECKING, Any from slither.core.expressions.expression import Expression from slither.core.solidity_types.elementary_type import Fixed, Int, Ufixed, Uint @@ -47,7 +47,7 @@ class Literal(Expression): # be sure to handle any character return str(self._value) - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, Literal): return False return (self.value, self.subdenomination) == (other.value, other.subdenomination) diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index 85c2d6ba7..9dfd3cf17 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -1,22 +1,20 @@ from typing import Union, Optional, Tuple, Any, TYPE_CHECKING from slither.core.expressions.expression import Expression -from slither.core.solidity_types.type import Type -from slither.visitors.expression.constants_folding import ConstantFolding from slither.core.expressions.literal import Literal from slither.core.solidity_types.elementary_type import ElementaryType +from slither.core.solidity_types.type import Type +from slither.visitors.expression.constants_folding import ConstantFolding if TYPE_CHECKING: from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.identifier import Identifier - from slither.core.solidity_types.function_type import FunctionType - from slither.core.solidity_types.type_alias import TypeAliasTopLevel class ArrayType(Type): def __init__( self, - t: Union["TypeAliasTopLevel", "ArrayType", "FunctionType", "ElementaryType"], + t: Type, length: Optional[Union["Identifier", Literal, "BinaryOperation", int]], ) -> None: assert isinstance(t, Type) diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index aa8dfb4ec..cc47ea913 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1299,7 +1299,7 @@ def convert_to_push_set_val( element_to_add = ReferenceVariable(node) element_to_add.set_type(new_type) - ir_assign_element_to_add = Index(element_to_add, arr, length_val, ElementaryType("uint256")) + ir_assign_element_to_add = Index(element_to_add, arr, length_val) ir_assign_element_to_add.set_expression(ir.expression) ir_assign_element_to_add.set_node(ir.node) ret.append(ir_assign_element_to_add) @@ -1383,7 +1383,7 @@ def convert_to_pop(ir, node): ret.append(ir_sub_1) element_to_delete = ReferenceVariable(node) - ir_assign_element_to_delete = Index(element_to_delete, arr, val, ElementaryType("uint256")) + ir_assign_element_to_delete = Index(element_to_delete, arr, val) ir_length.lvalue.points_to = arr element_to_delete.set_type(ElementaryType("uint256")) ir_assign_element_to_delete.set_expression(ir.expression) diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index 42f05011d..d1355a965 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -105,7 +105,7 @@ class Binary(OperationWithLValue): def __init__( self, result: Variable, - left_variable: Union[LVALUE, Function], + left_variable: Union[RVALUE, Function], right_variable: Union[RVALUE, Function], operation_type: BinaryType, ) -> None: @@ -127,11 +127,11 @@ class Binary(OperationWithLValue): return [self.variable_left, self.variable_right] @property - def get_variable(self) -> List[Union[RVALUE, LVALUE, Function]]: + def get_variable(self) -> List[Union[RVALUE, Function]]: return self._variables @property - def variable_left(self) -> Union[LVALUE, Function]: + def variable_left(self) -> Union[RVALUE, Function]: return self._variables[0] # type: ignore @property diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py index 77daa9462..f38a25927 100644 --- a/slither/slithir/operations/index.py +++ b/slither/slithir/operations/index.py @@ -1,7 +1,6 @@ from typing import List, Union from slither.core.declarations import SolidityVariableComposed -from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.variable import Variable from slither.slithir.operations.lvalue import OperationWithLValue @@ -11,11 +10,7 @@ from slither.slithir.variables.reference import ReferenceVariable class Index(OperationWithLValue): def __init__( - self, - result: ReferenceVariable, - left_variable: Variable, - right_variable: RVALUE, - index_type: Union[ElementaryType, str], + self, result: ReferenceVariable, left_variable: Variable, right_variable: RVALUE ) -> None: super().__init__() assert is_valid_lvalue(left_variable) or left_variable == SolidityVariableComposed( @@ -24,7 +19,6 @@ class Index(OperationWithLValue): assert is_valid_rvalue(right_variable) assert isinstance(result, ReferenceVariable) self._variables = [left_variable, right_variable] - self._type = index_type self._lvalue: ReferenceVariable = result @property @@ -43,9 +37,5 @@ class Index(OperationWithLValue): def variable_right(self) -> RVALUE: return self._variables[1] # type: ignore - @property - def index_type(self) -> Union[ElementaryType, str]: - return self._type - - def __str__(self): + def __str__(self) -> str: return f"{self.lvalue}({self.lvalue.type}) -> {self.variable_left}[{self.variable_right}]" diff --git a/slither/slithir/operations/type_conversion.py b/slither/slithir/operations/type_conversion.py index f351f1fdd..e9998bc65 100644 --- a/slither/slithir/operations/type_conversion.py +++ b/slither/slithir/operations/type_conversion.py @@ -1,13 +1,12 @@ from typing import List, Union + from slither.core.declarations import Contract -from slither.core.solidity_types.type import Type -from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue -import slither.core.declarations.contract from slither.core.solidity_types.elementary_type import ElementaryType -from slither.core.solidity_types.type_alias import TypeAliasContract, TypeAliasTopLevel +from slither.core.solidity_types.type_alias import TypeAlias from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.source_mapping.source_mapping import SourceMapping +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA @@ -17,15 +16,15 @@ class TypeConversion(OperationWithLValue): self, result: Union[TemporaryVariableSSA, TemporaryVariable], variable: SourceMapping, - variable_type: Union[TypeAliasContract, UserDefinedType, ElementaryType, TypeAliasTopLevel], + variable_type: Union[TypeAlias, UserDefinedType, ElementaryType], ) -> None: super().__init__() assert is_valid_rvalue(variable) or isinstance(variable, Contract) assert is_valid_lvalue(result) - assert isinstance(variable_type, Type) + assert isinstance(variable_type, (TypeAlias, UserDefinedType, ElementaryType)) self._variable = variable - self._type = variable_type + self._type: Union[TypeAlias, UserDefinedType, ElementaryType] = variable_type self._lvalue = result @property @@ -35,18 +34,12 @@ class TypeConversion(OperationWithLValue): @property def type( self, - ) -> Union[ - TypeAliasContract, - TypeAliasTopLevel, - slither.core.declarations.contract.Contract, - UserDefinedType, - ElementaryType, - ]: + ) -> Union[TypeAlias, UserDefinedType, ElementaryType,]: return self._type @property def read(self) -> List[SourceMapping]: return [self.variable] - def __str__(self): + def __str__(self) -> str: return str(self.lvalue) + f" = CONVERT {self.variable} to {self.type}" diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 156914b61..8b16cd516 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -751,8 +751,7 @@ def copy_ir(ir: Operation, *instances) -> Operation: lvalue = get_variable(ir, lambda x: x.lvalue, *instances) variable_left = get_variable(ir, lambda x: x.variable_left, *instances) variable_right = get_variable(ir, lambda x: x.variable_right, *instances) - index_type = ir.index_type - return Index(lvalue, variable_left, variable_right, index_type) + return Index(lvalue, variable_left, variable_right) if isinstance(ir, InitArray): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) init_values = get_rec_values(ir, lambda x: x.init_values, *instances) diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index ea433a921..d0dc4c7e0 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -481,11 +481,14 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression) if name == "IndexAccess": if is_compact_ast: - index_type = expression["typeDescriptions"]["typeString"] + # We dont use the index type here, as we recover it later + # We could change the paradigm with the current AST parsing + # And do the type parsing in advanced for most of the operation + # index_type = expression["typeDescriptions"]["typeString"] left = expression["baseExpression"] right = expression.get("indexExpression", None) else: - index_type = expression["attributes"]["type"] + # index_type = expression["attributes"]["type"] children = expression["children"] left = children[0] right = children[1] if len(children) > 1 else None @@ -502,7 +505,7 @@ def parse_expression(expression: Dict, caller_context: CallerContextExpression) left_expression = parse_expression(left, caller_context) right_expression = parse_expression(right, caller_context) - index = IndexAccess(left_expression, right_expression, index_type) + index = IndexAccess(left_expression, right_expression) index.set_offset(src, caller_context.compilation_unit) return index diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 158306f39..7b1a8f8ee 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -12,6 +12,7 @@ from slither.core.expressions import ( TupleExpression, TypeConversion, ) +from slither.core.variables import Variable from slither.utils.integer_conversion import convert_string_to_fraction, convert_string_to_int from slither.visitors.expression.expression import ExpressionVisitor @@ -65,6 +66,8 @@ class ConstantFolding(ExpressionVisitor): return Literal(value, self._type) def _post_identifier(self, expression: Identifier) -> None: + if not isinstance(expression.value, Variable): + return if not expression.value.is_constant: raise NotConstant expr = expression.value.expression @@ -80,19 +83,58 @@ class ConstantFolding(ExpressionVisitor): # pylint: disable=too-many-branches def _post_binary_operation(self, expression: BinaryOperation) -> None: - left = get_val(expression.expression_left) - right = get_val(expression.expression_right) - if expression.type == BinaryOperationType.POWER: - set_val(expression, left**right) - elif expression.type == BinaryOperationType.MULTIPLICATION: + expression_left = expression.expression_left + expression_right = expression.expression_right + if not isinstance( + expression_left, + (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ): + raise NotConstant + if not isinstance( + expression_right, + (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ): + raise NotConstant + + left = get_val(expression_left) + right = get_val(expression_right) + + if ( + expression.type == BinaryOperationType.POWER + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): + set_val(expression, left**right) #type: ignore + elif ( + expression.type == BinaryOperationType.MULTIPLICATION + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): set_val(expression, left * right) - elif expression.type == BinaryOperationType.DIVISION: - set_val(expression, left / right) - elif expression.type == BinaryOperationType.MODULO: + elif ( + expression.type == BinaryOperationType.DIVISION + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): + # TODO: maybe check for right + left to be int to use // ? + set_val(expression, left // right if isinstance(right, int) else left / right) + elif ( + expression.type == BinaryOperationType.MODULO + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): set_val(expression, left % right) - elif expression.type == BinaryOperationType.ADDITION: + elif ( + expression.type == BinaryOperationType.ADDITION + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): set_val(expression, left + right) - elif expression.type == BinaryOperationType.SUBTRACTION: + elif ( + expression.type == BinaryOperationType.SUBTRACTION + and isinstance(left, (int, Fraction)) + and isinstance(right, (int, Fraction)) + ): set_val(expression, left - right) # Convert to int for operations not supported by Fraction elif expression.type == BinaryOperationType.LEFT_SHIFT: @@ -181,7 +223,20 @@ class ConstantFolding(ExpressionVisitor): def _post_tuple_expression(self, expression: expressions.TupleExpression) -> None: if expression.expressions: if len(expression.expressions) == 1: - cf = ConstantFolding(expression.expressions[0], self._type) + first_expr = expression.expressions[0] + if not isinstance( + first_expr, + ( + Literal, + BinaryOperation, + UnaryOperation, + Identifier, + TupleExpression, + TypeConversion, + ), + ): + raise NotConstant + cf = ConstantFolding(first_expr, self._type) expr = cf.result() assert isinstance(expr, Literal) set_val(expression, convert_string_to_fraction(expr.converted_value)) @@ -189,7 +244,13 @@ class ConstantFolding(ExpressionVisitor): raise NotConstant def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: - cf = ConstantFolding(expression.expression, self._type) + expr = expression.expression + if not isinstance( + expr, + (Literal, BinaryOperation, UnaryOperation, Identifier, TupleExpression, TypeConversion), + ): + raise NotConstant + cf = ConstantFolding(expr, self._type) expr = cf.result() assert isinstance(expr, Literal) set_val(expression, convert_string_to_fraction(expr.converted_value)) diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index c150ee20b..8263a6e33 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -1,14 +1,16 @@ import logging -from typing import Union, List, TYPE_CHECKING +from typing import Union, List, TYPE_CHECKING, Any +from slither.core import expressions from slither.core.declarations import ( Function, SolidityVariable, SolidityVariableComposed, SolidityFunction, Contract, + EnumContract, + EnumTopLevel, ) -from slither.core.declarations.enum import Enum from slither.core.expressions import ( AssignmentOperation, AssignmentOperationType, @@ -18,6 +20,8 @@ from slither.core.expressions import ( CallExpression, Identifier, MemberAccess, + ConditionalExpression, + NewElementaryType, ) from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.expression import Expression @@ -27,7 +31,7 @@ from slither.core.expressions.new_array import NewArray from slither.core.expressions.new_contract import NewContract from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.unary_operation import UnaryOperation -from slither.core.solidity_types import ArrayType, ElementaryType, TypeAlias +from slither.core.solidity_types import ArrayType, ElementaryType, TypeAlias, UserDefinedType from slither.core.solidity_types.type import Type from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple @@ -71,18 +75,14 @@ logger = logging.getLogger("VISTIOR:ExpressionToSlithIR") key = "expressionToSlithIR" -def get(expression: Union[Expression, Operation]): +def get(expression: Expression) -> Any: val = expression.context[key] # we delete the item to reduce memory use del expression.context[key] return val -def get_without_removing(expression): - return expression.context[key] - - -def set_val(expression: Union[Expression, Operation], val) -> None: +def set_val(expression: Expression, val: Any) -> None: expression.context[key] = val @@ -121,7 +121,7 @@ def convert_assignment( left: Union[LocalVariable, StateVariable, ReferenceVariable], right: Union[LocalVariable, StateVariable, ReferenceVariable], t: AssignmentOperationType, - return_type, + return_type: Type, ) -> Union[Binary, Assignment]: if t == AssignmentOperationType.ASSIGN: return Assignment(left, right, return_type) @@ -150,6 +150,7 @@ def convert_assignment( class ExpressionToSlithIR(ExpressionVisitor): + # pylint: disable=super-init-not-called def __init__(self, expression: Expression, node: "Node") -> None: from slither.core.cfg.node import NodeType # pylint: disable=import-outside-toplevel @@ -171,11 +172,16 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_assignement_operation(self, expression: AssignmentOperation) -> None: left = get(expression.expression_left) right = get(expression.expression_right) + operation: Operation if isinstance(left, list): # tuple expression: if isinstance(right, list): # unbox assigment assert len(left) == len(right) for idx, _ in enumerate(left): - if not left[idx] is None: + if ( + not left[idx] is None + and expression.type + and expression.expression_return_type + ): operation = convert_assignment( left[idx], right[idx], @@ -220,7 +226,7 @@ class ExpressionToSlithIR(ExpressionVisitor): operation.set_expression(expression) self._result.append(operation) set_val(expression, left) - else: + elif expression.type and expression.expression_return_type: operation = convert_assignment( left, right, expression.type, expression.expression_return_type ) @@ -276,6 +282,8 @@ class ExpressionToSlithIR(ExpressionVisitor): called = get(expression_called) args = [get(a) for a in expression.arguments if a] + val: Union[TupleVariable, TemporaryVariable] + var: Operation for arg in args: arg_ = Argument(arg) arg_.set_expression(expression) @@ -284,6 +292,7 @@ class ExpressionToSlithIR(ExpressionVisitor): # internal call # If tuple + if expression.type_call.startswith("tuple(") and expression.type_call != "tuple()": val = TupleVariable(self._node) else: @@ -302,7 +311,7 @@ class ExpressionToSlithIR(ExpressionVisitor): ): # wrap: underlying_type -> alias # unwrap: alias -> underlying_type - dest_type = ( + dest_type: Union[TypeAlias, ElementaryType] = ( called if expression_called.member_name == "wrap" else called.underlying_type ) val = TemporaryVariable(self._node) @@ -315,19 +324,19 @@ class ExpressionToSlithIR(ExpressionVisitor): # yul things elif called.name == "caller()": val = TemporaryVariable(self._node) - var = Assignment(val, SolidityVariableComposed("msg.sender"), "uint256") + var = Assignment(val, SolidityVariableComposed("msg.sender"), ElementaryType("uint256")) self._result.append(var) set_val(expression, val) elif called.name == "origin()": val = TemporaryVariable(self._node) - var = Assignment(val, SolidityVariableComposed("tx.origin"), "uint256") + var = Assignment(val, SolidityVariableComposed("tx.origin"), ElementaryType("uint256")) self._result.append(var) set_val(expression, val) elif called.name == "extcodesize(uint256)": - val = ReferenceVariable(self._node) - var = Member(args[0], Constant("codesize"), val) + val_ref = ReferenceVariable(self._node) + var = Member(args[0], Constant("codesize"), val_ref) self._result.append(var) - set_val(expression, val) + set_val(expression, val_ref) elif called.name == "selfbalance()": val = TemporaryVariable(self._node) var = TypeConversion(val, SolidityVariable("this"), ElementaryType("address")) @@ -346,7 +355,7 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) elif called.name == "callvalue()": val = TemporaryVariable(self._node) - var = Assignment(val, SolidityVariableComposed("msg.value"), "uint256") + var = Assignment(val, SolidityVariableComposed("msg.value"), ElementaryType("uint256")) self._result.append(var) set_val(expression, val) @@ -373,7 +382,7 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(message_call) set_val(expression, val) - def _post_conditional_expression(self, expression): + def _post_conditional_expression(self, expression: ConditionalExpression) -> None: raise Exception(f"Ternary operator are not convertible to SlithIR {expression}") def _post_elementary_type_name_expression( @@ -388,12 +397,13 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_index_access(self, expression: IndexAccess) -> None: left = get(expression.expression_left) right = get(expression.expression_right) + operation: Operation # Left can be a type for abi.decode(var, uint[2]) if isinstance(left, Type): # Nested type are not yet supported by abi.decode, so the assumption # Is that the right variable must be a constant assert isinstance(right, Constant) - t = ArrayType(left, right.value) + t = ArrayType(left, int(right.value)) set_val(expression, t) return val = ReferenceVariable(self._node) @@ -406,13 +416,15 @@ class ExpressionToSlithIR(ExpressionVisitor): operation = InitArray(init_array_right, init_array_val) operation.set_expression(expression) self._result.append(operation) - operation = Index(val, left, right, expression.type) + operation = Index(val, left, right) operation.set_expression(expression) self._result.append(operation) set_val(expression, val) def _post_literal(self, expression: Literal) -> None: - cst = Constant(expression.value, expression.type, expression.subdenomination) + expression_type = expression.type + assert isinstance(expression_type, ElementaryType) + cst = Constant(expression.value, expression_type, expression.subdenomination) set_val(expression, cst) def _post_member_access(self, expression: MemberAccess) -> None: @@ -430,25 +442,33 @@ class ExpressionToSlithIR(ExpressionVisitor): assert len(expression.expression.arguments) == 1 val = TemporaryVariable(self._node) type_expression_found = expression.expression.arguments[0] + type_found: Union[ElementaryType, UserDefinedType] if isinstance(type_expression_found, ElementaryTypeNameExpression): - type_found = type_expression_found.type + type_expression_found_type = type_expression_found.type + assert isinstance(type_expression_found_type, ElementaryType) + type_found = type_expression_found_type + min_value = type_found.min + max_value = type_found.max constant_type = type_found else: # type(enum).max/min assert isinstance(type_expression_found, Identifier) - type_found = type_expression_found.value - assert isinstance(type_found, Enum) + type_found_in_expression = type_expression_found.value + assert isinstance(type_found_in_expression, (EnumContract, EnumTopLevel)) + type_found = UserDefinedType(type_found_in_expression) constant_type = None + min_value = type_found_in_expression.min + max_value = type_found_in_expression.max if expression.member_name == "min": op = Assignment( val, - Constant(str(type_found.min), constant_type), + Constant(str(min_value), constant_type), type_found, ) else: op = Assignment( val, - Constant(str(type_found.max), constant_type), + Constant(str(max_value), constant_type), type_found, ) self._result.append(op) @@ -494,11 +514,11 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, expr.custom_errors_as_dict[expression.member_name]) return - val = ReferenceVariable(self._node) - member = Member(expr, Constant(expression.member_name), val) + val_ref = ReferenceVariable(self._node) + member = Member(expr, Constant(expression.member_name), val_ref) member.set_expression(expression) self._result.append(member) - set_val(expression, val) + set_val(expression, val_ref) def _post_new_array(self, expression: NewArray) -> None: val = TemporaryVariable(self._node) @@ -521,7 +541,7 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, val) - def _post_new_elementary_type(self, expression): + def _post_new_elementary_type(self, expression: NewElementaryType) -> None: # TODO unclear if this is ever used? val = TemporaryVariable(self._node) operation = TmpNewElementaryType(expression.type, val) @@ -530,17 +550,20 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) def _post_tuple_expression(self, expression: TupleExpression) -> None: - expressions = [get(e) if e else None for e in expression.expressions] - if len(expressions) == 1: - val = expressions[0] + all_expressions = [get(e) if e else None for e in expression.expressions] + if len(all_expressions) == 1: + val = all_expressions[0] else: - val = expressions + val = all_expressions set_val(expression, val) - def _post_type_conversion(self, expression: TypeConversion) -> None: + def _post_type_conversion(self, expression: expressions.TypeConversion) -> None: + assert expression.expression expr = get(expression.expression) val = TemporaryVariable(self._node) - operation = TypeConversion(val, expr, expression.type) + expression_type = expression.type + assert isinstance(expression_type, (TypeAlias, UserDefinedType, ElementaryType)) + operation = TypeConversion(val, expr, expression_type) val.set_type(expression.type) operation.set_expression(expression) self._result.append(operation) @@ -549,6 +572,7 @@ class ExpressionToSlithIR(ExpressionVisitor): # pylint: disable=too-many-statements def _post_unary_operation(self, expression: UnaryOperation) -> None: value = get(expression.expression) + operation: Operation if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: lvalue = TemporaryVariable(self._node) operation = Unary(lvalue, value, expression.type) @@ -592,6 +616,7 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, value) elif expression.type in [UnaryOperationType.MINUS_PRE]: lvalue = TemporaryVariable(self._node) + assert isinstance(value.type, ElementaryType) operation = Binary(lvalue, Constant("0", value.type), value, BinaryType.SUBTRACTION) operation.set_expression(expression) self._result.append(operation) From bd2d572f48a751aeb4bfbbf3a9fed355861ccd7f Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Mon, 20 Feb 2023 19:05:56 +0100 Subject: [PATCH 11/34] More types --- examples/scripts/data_dependency.py | 6 ++ examples/scripts/variable_in_condition.py | 1 + slither/__main__.py | 4 +- .../data_dependency/data_dependency.py | 28 +++--- slither/core/declarations/contract.py | 27 +++--- slither/core/declarations/custom_error.py | 6 +- .../core/declarations/using_for_top_level.py | 3 +- slither/core/variables/variable.py | 5 +- .../statements/costly_operations_in_loop.py | 2 +- .../detectors/statements/write_after_write.py | 2 + slither/printers/call/call_graph.py | 3 +- slither/printers/summary/constructor_calls.py | 3 +- slither/printers/summary/contract.py | 19 ++-- slither/printers/summary/variable_order.py | 3 +- slither/slithir/utils/ssa.py | 2 +- slither/slithir/utils/utils.py | 2 +- slither/slithir/variables/local_variable.py | 4 +- slither/slithir/variables/state_variable.py | 4 +- slither/slithir/variables/variable.py | 5 +- slither/solc_parsing/declarations/contract.py | 78 ++++++++------- slither/solc_parsing/declarations/function.py | 10 +- .../variables/variable_declaration.py | 16 ++-- slither/tools/doctor/checks/versions.py | 5 +- slither/tools/read_storage/utils/utils.py | 5 +- .../checks/variable_initialization.py | 8 +- .../upgradeability/checks/variables_order.py | 38 +++++--- .../visitors/expression/constants_folding.py | 2 +- tests/test_ssa_generation.py | 95 +++++++++++-------- 28 files changed, 223 insertions(+), 163 deletions(-) diff --git a/examples/scripts/data_dependency.py b/examples/scripts/data_dependency.py index 478394766..23c82cae1 100644 --- a/examples/scripts/data_dependency.py +++ b/examples/scripts/data_dependency.py @@ -18,6 +18,8 @@ assert len(contracts) == 1 contract = contracts[0] destination = contract.get_state_variable_from_name("destination") source = contract.get_state_variable_from_name("source") +assert source +assert destination print(f"{source} is dependent of {destination}: {is_dependent(source, destination, contract)}") assert not is_dependent(source, destination, contract) @@ -47,9 +49,11 @@ print(f"{destination} is tainted {is_tainted(destination, contract)}") assert is_tainted(destination, contract) destination_indirect_1 = contract.get_state_variable_from_name("destination_indirect_1") +assert destination_indirect_1 print(f"{destination_indirect_1} is tainted {is_tainted(destination_indirect_1, contract)}") assert is_tainted(destination_indirect_1, contract) destination_indirect_2 = contract.get_state_variable_from_name("destination_indirect_2") +assert destination_indirect_2 print(f"{destination_indirect_2} is tainted {is_tainted(destination_indirect_2, contract)}") assert is_tainted(destination_indirect_2, contract) @@ -88,6 +92,8 @@ contract = contracts[0] contract_derived = slither.get_contract_from_name("Derived")[0] destination = contract.get_state_variable_from_name("destination") source = contract.get_state_variable_from_name("source") +assert destination +assert source print(f"{destination} is dependent of {source}: {is_dependent(destination, source, contract)}") assert not is_dependent(destination, source, contract) diff --git a/examples/scripts/variable_in_condition.py b/examples/scripts/variable_in_condition.py index 43dcf41e7..bde41424d 100644 --- a/examples/scripts/variable_in_condition.py +++ b/examples/scripts/variable_in_condition.py @@ -14,6 +14,7 @@ assert len(contracts) == 1 contract = contracts[0] # Get the variable var_a = contract.get_state_variable_from_name("a") +assert var_a # Get the functions reading the variable functions_reading_a = contract.get_functions_reading_from_variable(var_a) diff --git a/slither/__main__.py b/slither/__main__.py index a5d51dcd6..d6c3ea717 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -615,7 +615,9 @@ def parse_args( class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods - def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs + def __call__( + self, parser: Any, *args: Any, **kwargs: Any + ) -> None: # pylint: disable=signature-differs detectors, _ = get_detectors_and_printers() output_detectors(detectors) parser.exit() diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index 2b66f2bb3..d133cd2dc 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -2,7 +2,7 @@ Compute the data depenency between all the SSA variables """ from collections import defaultdict -from typing import Union, Set, Dict, TYPE_CHECKING +from typing import Union, Set, Dict, TYPE_CHECKING, List from slither.core.cfg.node import Node from slither.core.declarations import ( @@ -20,6 +20,7 @@ from slither.core.solidity_types.type import Type from slither.core.variables.top_level_variable import TopLevelVariable from slither.core.variables.variable import Variable from slither.slithir.operations import Index, OperationWithLValue, InternalCall, Operation +from slither.slithir.utils.utils import LVALUE from slither.slithir.variables import ( Constant, LocalIRVariable, @@ -29,6 +30,7 @@ from slither.slithir.variables import ( TemporaryVariableSSA, TupleVariableSSA, ) +from slither.slithir.variables.variable import SlithIRVariable if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit @@ -393,13 +395,9 @@ def transitive_close_dependencies( while changed: changed = False to_add = defaultdict(set) - [ # pylint: disable=expression-not-assigned - [ + for key, items in context.context[context_key].items(): + for item in items & keys: to_add[key].update(context.context[context_key][item] - {key} - items) - for item in items & keys - ] - for key, items in context.context[context_key].items() - ] for k, v in to_add.items(): # Because we dont have any check on the update operation # We might update an empty set with an empty set @@ -418,20 +416,20 @@ def add_dependency(lvalue: Variable, function: Function, ir: Operation, is_prote function.context[KEY_SSA][lvalue] = set() if not is_protected: function.context[KEY_SSA_UNPROTECTED][lvalue] = set() + read: Union[List[Union[LVALUE, SolidityVariableComposed]], List[SlithIRVariable]] if isinstance(ir, Index): read = [ir.variable_left] - elif isinstance(ir, InternalCall): + elif isinstance(ir, InternalCall) and ir.function: read = ir.function.return_values_ssa else: read = ir.read - # pylint: disable=expression-not-assigned - [function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)] + for v in read: + if not isinstance(v, Constant): + function.context[KEY_SSA][lvalue].add(v) if not is_protected: - [ - function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) - for v in read - if not isinstance(v, Constant) - ] + for v in read: + if not isinstance(v, Constant): + function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) def compute_dependency_function(function: Function) -> None: diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 38b4221d9..2c82f9b58 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -49,6 +49,9 @@ if TYPE_CHECKING: LOGGER = logging.getLogger("Contract") +USING_FOR_KEY = Union[str, Type] +USING_FOR_ITEM = List[Union[Type, Function]] + class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ @@ -80,8 +83,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods self._custom_errors: Dict[str, "CustomErrorContract"] = {} # The only str is "*" - self._using_for: Dict[Union[str, Type], List[Type]] = {} - self._using_for_complete: Optional[Dict[Union[str, Type], List[Type]]] = None + self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {} + self._using_for_complete: Optional[Dict[USING_FOR_KEY, USING_FOR_ITEM]] = None self._kind: Optional[str] = None self._is_interface: bool = False self._is_library: bool = False @@ -123,7 +126,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._name @name.setter - def name(self, name: str): + def name(self, name: str) -> None: self._name = name @property @@ -133,7 +136,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._id @id.setter - def id(self, new_id): + def id(self, new_id: int) -> None: """Unique id.""" self._id = new_id @@ -146,7 +149,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._kind @contract_kind.setter - def contract_kind(self, kind): + def contract_kind(self, kind: str) -> None: self._kind = kind @property @@ -154,7 +157,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_interface @is_interface.setter - def is_interface(self, is_interface: bool): + def is_interface(self, is_interface: bool) -> None: self._is_interface = is_interface @property @@ -162,7 +165,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_library @is_library.setter - def is_library(self, is_library: bool): + def is_library(self, is_library: bool) -> None: self._is_library = is_library # endregion @@ -266,16 +269,18 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### @property - def using_for(self) -> Dict[Union[str, Type], List[Type]]: + def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: return self._using_for @property - def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]: + def using_for_complete(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: """ Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive """ - def _merge_using_for(uf1: Dict, uf2: Dict) -> Dict: + def _merge_using_for( + uf1: Dict[USING_FOR_KEY, USING_FOR_ITEM], uf2: Dict[USING_FOR_KEY, USING_FOR_ITEM] + ) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: result = {**uf1, **uf2} for key, value in result.items(): if key in uf1 and key in uf2: @@ -1452,7 +1457,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods result = func.get_last_ssa_state_variables_instances() for variable_name, instances in result.items(): # TODO: investigate the next operation - last_state_variables_instances[variable_name] += instances + last_state_variables_instances[variable_name] += list(instances) for func in self.functions + list(self.modifiers): func.fix_phi(last_state_variables_instances, initial_state_variables_instances) diff --git a/slither/core/declarations/custom_error.py b/slither/core/declarations/custom_error.py index c566fccec..7e78748c6 100644 --- a/slither/core/declarations/custom_error.py +++ b/slither/core/declarations/custom_error.py @@ -1,4 +1,4 @@ -from typing import List, TYPE_CHECKING, Optional, Type, Union +from typing import List, TYPE_CHECKING, Optional, Type from slither.core.solidity_types import UserDefinedType from slither.core.source_mapping.source_mapping import SourceMapping @@ -42,7 +42,7 @@ class CustomError(SourceMapping): ################################################################################### @staticmethod - def _convert_type_for_solidity_signature(t: Optional[Union[Type, List[Type]]]) -> str: + def _convert_type_for_solidity_signature(t: Optional[Type]) -> str: # pylint: disable=import-outside-toplevel from slither.core.declarations import Contract @@ -72,7 +72,7 @@ class CustomError(SourceMapping): Returns: """ - parameters = [x.type for x in self.parameters] + parameters = [x.type for x in self.parameters if x.type] self._full_name = self.name + "(" + ",".join(map(str, parameters)) + ")" solidity_parameters = map(self._convert_type_for_solidity_signature, parameters) self._solidity_signature = self.name + "(" + ",".join(solidity_parameters) + ")" diff --git a/slither/core/declarations/using_for_top_level.py b/slither/core/declarations/using_for_top_level.py index 27d1f90e4..edf846a5b 100644 --- a/slither/core/declarations/using_for_top_level.py +++ b/slither/core/declarations/using_for_top_level.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, List, Dict, Union +from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM from slither.core.solidity_types.type import Type from slither.core.declarations.top_level import TopLevel @@ -14,5 +15,5 @@ class UsingForTopLevel(TopLevel): self.file_scope: "FileScope" = scope @property - def using_for(self) -> Dict[Union[str, Type], List[Type]]: + def using_for(self) -> Dict[USING_FOR_KEY, USING_FOR_ITEM]: return self._using_for diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 0d610c928..2b777e672 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -160,8 +160,8 @@ class Variable(SourceMapping): return ( self.name, - [str(x) for x in export_nested_types_from_variable(self)], - [str(x) for x in export_return_type_from_variable(self)], + [str(x) for x in export_nested_types_from_variable(self)], # type: ignore + [str(x) for x in export_return_type_from_variable(self)], # type: ignore ) @property @@ -179,4 +179,5 @@ class Variable(SourceMapping): return f'{name}({",".join(parameters)})' def __str__(self) -> str: + assert self._name return self._name diff --git a/slither/detectors/statements/costly_operations_in_loop.py b/slither/detectors/statements/costly_operations_in_loop.py index 6af04329c..53fa12647 100644 --- a/slither/detectors/statements/costly_operations_in_loop.py +++ b/slither/detectors/statements/costly_operations_in_loop.py @@ -43,7 +43,7 @@ def costly_operations_in_loop( if isinstance(ir, OperationWithLValue) and isinstance(ir.lvalue, StateVariable): ret.append(ir.node) break - if isinstance(ir, (InternalCall)): + if isinstance(ir, (InternalCall)) and ir.function: costly_operations_in_loop(ir.function.entry_point, in_loop_counter, visited, ret) for son in node.sons: diff --git a/slither/detectors/statements/write_after_write.py b/slither/detectors/statements/write_after_write.py index 40a82d3ff..1f11921cb 100644 --- a/slither/detectors/statements/write_after_write.py +++ b/slither/detectors/statements/write_after_write.py @@ -37,6 +37,8 @@ def _handle_ir( _remove_states(written) if isinstance(ir, InternalCall): + if not ir.function: + return if ir.function.all_high_level_calls() or ir.function.all_library_calls(): _remove_states(written) diff --git a/slither/printers/call/call_graph.py b/slither/printers/call/call_graph.py index 0a4df0c65..38225e6d7 100644 --- a/slither/printers/call/call_graph.py +++ b/slither/printers/call/call_graph.py @@ -13,6 +13,7 @@ from slither.core.declarations.function import Function from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.variables.variable import Variable from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output def _contract_subgraph(contract: Contract) -> str: @@ -222,7 +223,7 @@ class PrinterCallGraph(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph" - def output(self, filename): + def output(self, filename: str) -> Output: """ Output the graph in filename Args: diff --git a/slither/printers/summary/constructor_calls.py b/slither/printers/summary/constructor_calls.py index 665c76546..789811c36 100644 --- a/slither/printers/summary/constructor_calls.py +++ b/slither/printers/summary/constructor_calls.py @@ -5,6 +5,7 @@ from slither.core.declarations import Function from slither.core.source_mapping.source_mapping import Source from slither.printers.abstract_printer import AbstractPrinter from slither.utils import output +from slither.utils.output import Output def _get_source_code(cst: Function) -> str: @@ -17,7 +18,7 @@ class ConstructorPrinter(AbstractPrinter): ARGUMENT = "constructor-calls" HELP = "Print the constructors executed" - def output(self, _filename): + def output(self, _filename: str) -> Output: info = "" for contract in self.slither.contracts_derived: stack_name = [] diff --git a/slither/printers/summary/contract.py b/slither/printers/summary/contract.py index 5af953e20..5fee94416 100644 --- a/slither/printers/summary/contract.py +++ b/slither/printers/summary/contract.py @@ -2,9 +2,13 @@ Module printing summary of the contract """ import collections +from typing import Dict, List + +from slither.core.declarations import FunctionContract from slither.printers.abstract_printer import AbstractPrinter from slither.utils import output from slither.utils.colors import blue, green, magenta +from slither.utils.output import Output class ContractSummary(AbstractPrinter): @@ -13,7 +17,7 @@ class ContractSummary(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#contract-summary" - def output(self, _filename): # pylint: disable=too-many-locals + def output(self, _filename: str) -> Output: # pylint: disable=too-many-locals """ _filename is not used Args: @@ -53,17 +57,16 @@ class ContractSummary(AbstractPrinter): # Order the function with # contract_declarer -> list_functions - public = [ + public_function = [ (f.contract_declarer.name, f) for f in c.functions if (not f.is_shadowed and not f.is_constructor_variables) ] - collect = collections.defaultdict(list) - for a, b in public: + collect: Dict[str, List[FunctionContract]] = collections.defaultdict(list) + for a, b in public_function: collect[a].append(b) - public = list(collect.items()) - for contract, functions in public: + for contract, functions in collect.items(): txt += blue(f" - From {contract}\n") functions = sorted(functions, key=lambda f: f.full_name) @@ -90,7 +93,7 @@ class ContractSummary(AbstractPrinter): self.info(txt) res = self.generate_output(txt) - for contract, additional_fields in all_contracts: - res.add(contract, additional_fields=additional_fields) + for current_contract, current_additional_fields in all_contracts: + res.add(current_contract, additional_fields=current_additional_fields) return res diff --git a/slither/printers/summary/variable_order.py b/slither/printers/summary/variable_order.py index 9dc9e77c2..3325b7a01 100644 --- a/slither/printers/summary/variable_order.py +++ b/slither/printers/summary/variable_order.py @@ -4,6 +4,7 @@ from slither.printers.abstract_printer import AbstractPrinter from slither.utils.myprettytable import MyPrettyTable +from slither.utils.output import Output class VariableOrder(AbstractPrinter): @@ -13,7 +14,7 @@ class VariableOrder(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#variable-order" - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 8b16cd516..9a180d14f 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -366,7 +366,7 @@ def last_name( def is_used_later( initial_node: Node, - variable: Union[StateIRVariable, LocalVariable], + variable: Union[StateIRVariable, LocalVariable, TemporaryVariableSSA], ) -> bool: # TODO: does not handle the case where its read and written in the declaration node # It can be problematic if this happens in a loop/if structure diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index a0ca0bd6f..4619c08bc 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -46,7 +46,7 @@ def is_valid_rvalue(v: SourceMapping) -> bool: ) -def is_valid_lvalue(v) -> bool: +def is_valid_lvalue(v: SourceMapping) -> bool: return isinstance( v, ( diff --git a/slither/slithir/variables/local_variable.py b/slither/slithir/variables/local_variable.py index eb32d4024..35b624a01 100644 --- a/slither/slithir/variables/local_variable.py +++ b/slither/slithir/variables/local_variable.py @@ -41,11 +41,11 @@ class LocalIRVariable( self._non_ssa_version = local_variable @property - def index(self): + def index(self) -> int: return self._index @index.setter - def index(self, idx): + def index(self, idx: int) -> None: self._index = idx @property diff --git a/slither/slithir/variables/state_variable.py b/slither/slithir/variables/state_variable.py index 7bb3a4077..f7fb8ab8a 100644 --- a/slither/slithir/variables/state_variable.py +++ b/slither/slithir/variables/state_variable.py @@ -30,11 +30,11 @@ class StateIRVariable( self._non_ssa_version = state_variable @property - def index(self): + def index(self) -> int: return self._index @index.setter - def index(self, idx): + def index(self, idx: int) -> None: self._index = idx @property diff --git a/slither/slithir/variables/variable.py b/slither/slithir/variables/variable.py index a1a1a6df9..20d203ea4 100644 --- a/slither/slithir/variables/variable.py +++ b/slither/slithir/variables/variable.py @@ -7,8 +7,9 @@ class SlithIRVariable(Variable): self._index = 0 @property - def ssa_name(self): + def ssa_name(self) -> str: + assert self.name return self.name - def __str__(self): + def __str__(self) -> str: return self.ssa_name diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 47ee7ec10..b9dbe9a9f 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -1,6 +1,6 @@ import logging import re -from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set +from typing import Any, List, Dict, Callable, TYPE_CHECKING, Union, Set, Sequence from slither.core.declarations import ( Modifier, @@ -9,10 +9,10 @@ from slither.core.declarations import ( StructureContract, Function, ) -from slither.core.declarations.contract import Contract +from slither.core.declarations.contract import Contract, USING_FOR_KEY from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.function_contract import FunctionContract -from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type +from slither.core.solidity_types import ElementaryType, TypeAliasContract from slither.core.variables.state_variable import StateVariable from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.custom_error import CustomErrorSolc @@ -302,7 +302,7 @@ class ContractSolc(CallerContextExpression): st.set_contract(self._contract) st.set_offset(struct["src"], self._contract.compilation_unit) - st_parser = StructureContractSolc(st, struct, self) + st_parser = StructureContractSolc(st, struct, self) # type: ignore self._contract.structures_as_dict[st.name] = st self._structures_parser.append(st_parser) @@ -312,7 +312,7 @@ class ContractSolc(CallerContextExpression): for struct in self._structuresNotParsed: self._parse_struct(struct) - self._structuresNotParsed = None + self._structuresNotParsed = [] def _parse_custom_error(self, custom_error: Dict) -> None: ce = CustomErrorContract(self.compilation_unit) @@ -329,7 +329,7 @@ class ContractSolc(CallerContextExpression): for custom_error in self._customErrorParsed: self._parse_custom_error(custom_error) - self._customErrorParsed = None + self._customErrorParsed = [] def parse_state_variables(self) -> None: for father in self._contract.inheritance_reverse: @@ -356,6 +356,7 @@ class ContractSolc(CallerContextExpression): var_parser = StateVariableSolc(var, varNotParsed) self._variables_parser.append(var_parser) + assert var.name self._contract.variables_as_dict[var.name] = var self._contract.add_variables_ordered([var]) @@ -365,7 +366,7 @@ class ContractSolc(CallerContextExpression): modif.set_contract(self._contract) modif.set_contract_declarer(self._contract) - modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) + modif_parser = ModifierSolc(modif, modifier_data, self, self.slither_parser) # type: ignore self._contract.compilation_unit.add_modifier(modif) self._modifiers_no_params.append(modif_parser) self._modifiers_parser.append(modif_parser) @@ -375,7 +376,7 @@ class ContractSolc(CallerContextExpression): def parse_modifiers(self) -> None: for modifier in self._modifiersNotParsed: self._parse_modifier(modifier) - self._modifiersNotParsed = None + self._modifiersNotParsed = [] def _parse_function(self, function_data: Dict) -> None: func = FunctionContract(self._contract.compilation_unit) @@ -383,7 +384,7 @@ class ContractSolc(CallerContextExpression): func.set_contract(self._contract) func.set_contract_declarer(self._contract) - func_parser = FunctionSolc(func, function_data, self, self._slither_parser) + func_parser = FunctionSolc(func, function_data, self, self._slither_parser) # type: ignore self._contract.compilation_unit.add_function(func) self._functions_no_params.append(func_parser) self._functions_parser.append(func_parser) @@ -395,7 +396,7 @@ class ContractSolc(CallerContextExpression): for function in self._functionsNotParsed: self._parse_function(function) - self._functionsNotParsed = None + self._functionsNotParsed = [] # endregion ################################################################################### @@ -439,7 +440,8 @@ class ContractSolc(CallerContextExpression): Cls_parser, self._modifiers_parser, ) - self._contract.set_modifiers(modifiers) + # modifiers will be using Modifier so we can ignore the next type check + self._contract.set_modifiers(modifiers) # type: ignore except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing params {e}") self._modifiers_no_params = [] @@ -459,7 +461,8 @@ class ContractSolc(CallerContextExpression): Cls_parser, self._functions_parser, ) - self._contract.set_functions(functions) + # function will be using FunctionContract so we can ignore the next type check + self._contract.set_functions(functions) # type: ignore except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing params {e}") self._functions_no_params = [] @@ -470,7 +473,7 @@ class ContractSolc(CallerContextExpression): Cls_parser: Callable, element_parser: FunctionSolc, explored_reference_id: Set[str], - parser: List[FunctionSolc], + parser: Union[List[FunctionSolc], List[ModifierSolc]], all_elements: Dict[str, Function], ) -> None: elem = Cls(self._contract.compilation_unit) @@ -508,13 +511,13 @@ class ContractSolc(CallerContextExpression): def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-locals self, - elements_no_params: List[FunctionSolc], + elements_no_params: Sequence[FunctionSolc], getter: Callable[["ContractSolc"], List[FunctionSolc]], getter_available: Callable[[Contract], List[FunctionContract]], Cls: Callable, Cls_parser: Callable, - parser: List[FunctionSolc], - ) -> Dict[str, Union[FunctionContract, Modifier]]: + parser: Union[List[FunctionSolc], List[ModifierSolc]], + ) -> Dict[str, Function]: """ Analyze the parameters of the given elements (Function or Modifier). The function iterates over the inheritance to create an instance or inherited elements (Function or Modifier) @@ -526,13 +529,13 @@ class ContractSolc(CallerContextExpression): :param Cls: Class to create for collision :return: """ - all_elements = {} + all_elements: Dict[str, Function] = {} - explored_reference_id = set() + explored_reference_id: Set[str] = set() try: for father in self._contract.inheritance: father_parser = self._slither_parser.underlying_contract_to_parser[father] - for element_parser in getter(father_parser): + for element_parser in getter(father_parser): # type: ignore self._analyze_params_element( Cls, Cls_parser, element_parser, explored_reference_id, parser, all_elements ) @@ -597,7 +600,7 @@ class ContractSolc(CallerContextExpression): if self.is_compact_ast: for using_for in self._usingForNotParsed: if "typeName" in using_for and using_for["typeName"]: - type_name = parse_type(using_for["typeName"], self) + type_name: USING_FOR_KEY = parse_type(using_for["typeName"], self) else: type_name = "*" if type_name not in self._contract.using_for: @@ -616,7 +619,7 @@ class ContractSolc(CallerContextExpression): assert children and len(children) <= 2 if len(children) == 2: new = parse_type(children[0], self) - old = parse_type(children[1], self) + old: USING_FOR_KEY = parse_type(children[1], self) else: new = parse_type(children[0], self) old = "*" @@ -627,7 +630,7 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing using for {e}") - def _analyze_function_list(self, function_list: List, type_name: Type) -> None: + def _analyze_function_list(self, function_list: List, type_name: USING_FOR_KEY) -> None: for f in function_list: full_name_split = f["function"]["name"].split(".") if len(full_name_split) == 1: @@ -646,7 +649,9 @@ class ContractSolc(CallerContextExpression): function_name = full_name_split[2] self._analyze_library_function(library_name, function_name, type_name) - def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type) -> None: + def _check_aliased_import( + self, first_part: str, function_name: str, type_name: USING_FOR_KEY + ) -> None: # We check if the first part appear as alias for an import # if it is then function_name must be a top level function # otherwise it's a library function @@ -656,13 +661,13 @@ class ContractSolc(CallerContextExpression): return self._analyze_library_function(first_part, function_name, type_name) - def _analyze_top_level_function(self, function_name: str, type_name: Type) -> None: + def _analyze_top_level_function(self, function_name: str, type_name: USING_FOR_KEY) -> None: for tl_function in self.compilation_unit.functions_top_level: if tl_function.name == function_name: self._contract.using_for[type_name].append(tl_function) def _analyze_library_function( - self, library_name: str, function_name: str, type_name: Type + self, library_name: str, function_name: str, type_name: USING_FOR_KEY ) -> None: # Get the library function found = False @@ -689,22 +694,13 @@ class ContractSolc(CallerContextExpression): # for enum, we can parse and analyze it # at the same time self._analyze_enum(enum) - self._enumsNotParsed = None + self._enumsNotParsed = [] except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing enum {e}") def _analyze_enum( self, - enum: Dict[ - str, - Union[ - str, - int, - List[Dict[str, Union[int, str]]], - Dict[str, str], - List[Dict[str, Union[Dict[str, str], int, str]]], - ], - ], + enum: Dict, ) -> None: # Enum can be parsed in one pass if self.is_compact_ast: @@ -753,13 +749,13 @@ class ContractSolc(CallerContextExpression): event.set_contract(self._contract) event.set_offset(event_to_parse["src"], self._contract.compilation_unit) - event_parser = EventSolc(event, event_to_parse, self) - event_parser.analyze(self) + event_parser = EventSolc(event, event_to_parse, self) # type: ignore + event_parser.analyze(self) # type: ignore self._contract.events_as_dict[event.full_name] = event except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing event {e}") - self._eventsNotParsed = None + self._eventsNotParsed = [] # endregion ################################################################################### @@ -768,7 +764,7 @@ class ContractSolc(CallerContextExpression): ################################################################################### ################################################################################### - def delete_content(self): + def delete_content(self) -> None: """ Remove everything not parsed from the contract This is used only if something went wrong with the inheritance parsing @@ -810,7 +806,7 @@ class ContractSolc(CallerContextExpression): ################################################################################### ################################################################################### - def __hash__(self): + def __hash__(self) -> int: return self._contract.id # endregion diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 9671d9bbe..ba2f225f0 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -242,7 +242,7 @@ class FunctionSolc(CallerContextExpression): if "payable" in attributes: self._function.payable = attributes["payable"] - def analyze_params(self): + def analyze_params(self) -> None: # Can be re-analyzed due to inheritance if self._params_was_analyzed: return @@ -272,7 +272,7 @@ class FunctionSolc(CallerContextExpression): if returns: self._parse_returns(returns) - def analyze_content(self): + def analyze_content(self) -> None: if self._content_was_analyzed: return @@ -308,8 +308,8 @@ class FunctionSolc(CallerContextExpression): for node_parser in self._node_to_nodesolc.values(): node_parser.analyze_expressions(self) - for node_parser in self._node_to_yulobject.values(): - node_parser.analyze_expressions() + for yul_parser in self._node_to_yulobject.values(): + yul_parser.analyze_expressions() self._rewrite_ternary_as_if_else() @@ -1297,7 +1297,7 @@ class FunctionSolc(CallerContextExpression): son.remove_father(node) node.set_sons(new_sons) - def _remove_alone_endif(self): + def _remove_alone_endif(self) -> None: """ Can occur on: if(..){ diff --git a/slither/solc_parsing/variables/variable_declaration.py b/slither/solc_parsing/variables/variable_declaration.py index d21d89875..69b72a521 100644 --- a/slither/solc_parsing/variables/variable_declaration.py +++ b/slither/solc_parsing/variables/variable_declaration.py @@ -1,6 +1,6 @@ import logging import re -from typing import Dict, Optional +from typing import Dict, Optional, Union from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.expressions.expression_parsing import parse_expression @@ -42,12 +42,12 @@ class VariableDeclarationSolc: self._variable = variable self._was_analyzed = False - self._elem_to_parse = None - self._initializedNotParsed = None + self._elem_to_parse: Optional[Union[Dict, UnknownType]] = None + self._initializedNotParsed: Optional[Dict] = None self._is_compact_ast = False - self._reference_id = None + self._reference_id: Optional[int] = None if "nodeType" in variable_data: self._is_compact_ast = True @@ -87,7 +87,7 @@ class VariableDeclarationSolc: declaration = variable_data["children"][0] self._init_from_declaration(declaration, init) elif nodeType == "VariableDeclaration": - self._init_from_declaration(variable_data, False) + self._init_from_declaration(variable_data, None) else: raise ParsingError(f"Incorrect variable declaration type {nodeType}") @@ -101,6 +101,7 @@ class VariableDeclarationSolc: Return the solc id. It can be compared with the referencedDeclaration attr Returns None if it was not parsed (legacy AST) """ + assert self._reference_id return self._reference_id def _handle_comment(self, attributes: Dict) -> None: @@ -127,7 +128,7 @@ class VariableDeclarationSolc: self._variable.visibility = "internal" def _init_from_declaration( - self, var: Dict, init: Optional[bool] + self, var: Dict, init: Optional[Dict] ) -> None: # pylint: disable=too-many-branches if self._is_compact_ast: attributes = var @@ -195,7 +196,7 @@ class VariableDeclarationSolc: self._initializedNotParsed = init elif len(var["children"]) in [0, 1]: self._variable.initialized = False - self._initializedNotParsed = [] + self._initializedNotParsed = None else: assert len(var["children"]) == 2 self._variable.initialized = True @@ -212,5 +213,6 @@ class VariableDeclarationSolc: self._elem_to_parse = None if self._variable.initialized: + assert self._initializedNotParsed self._variable.expression = parse_expression(self._initializedNotParsed, caller_context) self._initializedNotParsed = None diff --git a/slither/tools/doctor/checks/versions.py b/slither/tools/doctor/checks/versions.py index ec7ef1d1f..00662b3e9 100644 --- a/slither/tools/doctor/checks/versions.py +++ b/slither/tools/doctor/checks/versions.py @@ -1,6 +1,6 @@ from importlib import metadata import json -from typing import Optional +from typing import Optional, Any import urllib from packaging.version import parse, Version @@ -17,6 +17,7 @@ def get_installed_version(name: str) -> Optional[Version]: def get_github_version(name: str) -> Optional[Version]: try: + # type: ignore with urllib.request.urlopen( f"https://api.github.com/repos/crytic/{name}/releases/latest" ) as response: @@ -27,7 +28,7 @@ def get_github_version(name: str) -> Optional[Version]: return None -def show_versions(**_kwargs) -> None: +def show_versions(**_kwargs: Any) -> None: versions = { "Slither": (get_installed_version("slither-analyzer"), get_github_version("slither")), "crytic-compile": ( diff --git a/slither/tools/read_storage/utils/utils.py b/slither/tools/read_storage/utils/utils.py index 3e51e2181..4a04a5b6d 100644 --- a/slither/tools/read_storage/utils/utils.py +++ b/slither/tools/read_storage/utils/utils.py @@ -2,6 +2,7 @@ from typing import Union from eth_typing.evm import ChecksumAddress from eth_utils import to_int, to_text, to_checksum_address +from web3 import Web3 def get_offset_value(hex_bytes: bytes, offset: int, size: int) -> bytes: @@ -48,7 +49,7 @@ def coerce_type( if "address" in solidity_type: if not isinstance(value, (str, bytes)): raise TypeError - return to_checksum_address(value) + return to_checksum_address(value) # type: ignore if not isinstance(value, bytes): raise TypeError @@ -56,7 +57,7 @@ def coerce_type( def get_storage_data( - web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str] + web3: Web3, checksum_address: ChecksumAddress, slot: bytes, block: Union[int, str] ) -> bytes: """ Retrieves the storage data from the blockchain at target address and slot. diff --git a/slither/tools/upgradeability/checks/variable_initialization.py b/slither/tools/upgradeability/checks/variable_initialization.py index e8ae9b26c..b4535ddfe 100644 --- a/slither/tools/upgradeability/checks/variable_initialization.py +++ b/slither/tools/upgradeability/checks/variable_initialization.py @@ -1,7 +1,11 @@ +from typing import List + from slither.tools.upgradeability.checks.abstract_checks import ( CheckClassification, AbstractCheck, + CHECK_INFO, ) +from slither.utils.output import Output class VariableWithInit(AbstractCheck): @@ -37,11 +41,11 @@ Using initialize functions to write initial values in state variables. REQUIRE_CONTRACT = True - def _check(self): + def _check(self) -> List[Output]: results = [] for s in self.contract.state_variables_ordered: if s.initialized and not (s.is_constant or s.is_immutable): - info = [s, " is a state variable with an initial value.\n"] + info: CHECK_INFO = [s, " is a state variable with an initial value.\n"] json = self.generate_result(info) results.append(json) return results diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index 030fb0f65..fc83c44c6 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -1,7 +1,12 @@ +from typing import List + +from slither.core.declarations import Contract from slither.tools.upgradeability.checks.abstract_checks import ( CheckClassification, AbstractCheck, + CHECK_INFO, ) +from slither.utils.output import Output class MissingVariable(AbstractCheck): @@ -45,9 +50,12 @@ Do not change the order of the state variables in the updated contract. REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True - def _check(self): + def _check(self) -> List[Output]: contract1 = self.contract contract2 = self.contract_v2 + + assert contract2 + order1 = [ variable for variable in contract1.state_variables_ordered @@ -63,7 +71,7 @@ Do not change the order of the state variables in the updated contract. for idx, _ in enumerate(order1): variable1 = order1[idx] if len(order2) <= idx: - info = ["Variable missing in ", contract2, ": ", variable1, "\n"] + info: CHECK_INFO = ["Variable missing in ", contract2, ": ", variable1, "\n"] json = self.generate_result(info) results.append(json) @@ -108,13 +116,14 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s REQUIRE_CONTRACT = True REQUIRE_PROXY = True - def _contract1(self): + def _contract1(self) -> Contract: return self.contract - def _contract2(self): + def _contract2(self) -> Contract: + assert self.proxy return self.proxy - def _check(self): + def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() order1 = [ @@ -128,7 +137,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s if not (variable.is_constant or variable.is_immutable) ] - results = [] + results: List[Output] = [] for idx, _ in enumerate(order1): if len(order2) <= idx: # Handle by MissingVariable @@ -137,7 +146,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s variable1 = order1[idx] variable2 = order2[idx] if (variable1.name != variable2.name) or (variable1.type != variable2.type): - info = [ + info: CHECK_INFO = [ "Different variables between ", contract1, " and ", @@ -190,7 +199,8 @@ Respect the variable order of the original contract in the updated contract. REQUIRE_PROXY = False REQUIRE_CONTRACT_V2 = True - def _contract2(self): + def _contract2(self) -> Contract: + assert self.contract_v2 return self.contract_v2 @@ -235,13 +245,14 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s REQUIRE_CONTRACT = True REQUIRE_PROXY = True - def _contract1(self): + def _contract1(self) -> Contract: return self.contract - def _contract2(self): + def _contract2(self) -> Contract: + assert self.proxy return self.proxy - def _check(self): + def _check(self) -> List[Output]: contract1 = self._contract1() contract2 = self._contract2() order1 = [ @@ -264,7 +275,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s while idx < len(order2): variable2 = order2[idx] - info = ["Extra variables in ", contract2, ": ", variable2, "\n"] + info: CHECK_INFO = ["Extra variables in ", contract2, ": ", variable2, "\n"] json = self.generate_result(info) results.append(json) idx = idx + 1 @@ -299,5 +310,6 @@ Ensure that all the new variables are expected. REQUIRE_PROXY = False REQUIRE_CONTRACT_V2 = True - def _contract2(self): + def _contract2(self) -> Contract: + assert self.contract_v2 return self.contract_v2 diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 7b1a8f8ee..12eb6be9d 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -104,7 +104,7 @@ class ConstantFolding(ExpressionVisitor): and isinstance(left, (int, Fraction)) and isinstance(right, (int, Fraction)) ): - set_val(expression, left**right) #type: ignore + set_val(expression, left**right) # type: ignore elif ( expression.type == BinaryOperationType.MULTIPLICATION and isinstance(left, (int, Fraction)) diff --git a/tests/test_ssa_generation.py b/tests/test_ssa_generation.py index 9bb008fdf..c7bc8d5cc 100644 --- a/tests/test_ssa_generation.py +++ b/tests/test_ssa_generation.py @@ -6,7 +6,7 @@ from collections import defaultdict from contextlib import contextmanager from inspect import getsourcefile from tempfile import NamedTemporaryFile -from typing import Union, List, Optional +from typing import Union, List, Optional, Dict, Callable import pytest from solc_select import solc_select @@ -15,6 +15,7 @@ from solc_select.solc_select import valid_version as solc_valid_version from slither import Slither from slither.core.cfg.node import Node, NodeType from slither.core.declarations import Function, Contract +from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.slithir.operations import ( OperationWithLValue, @@ -34,10 +35,11 @@ from slither.slithir.variables import ( ReferenceVariable, LocalIRVariable, StateIRVariable, + TemporaryVariableSSA, ) # Directory of currently executing script. Will be used as basis for temporary file names. -SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent +SCRIPT_DIR = pathlib.Path(getsourcefile(lambda: 0)).parent # type:ignore def valid_version(ver: str) -> bool: @@ -53,15 +55,15 @@ def valid_version(ver: str) -> bool: return False -def have_ssa_if_ir(function: Function): +def have_ssa_if_ir(function: Function) -> None: """Verifies that all nodes in a function that have IR also have SSA IR""" for n in function.nodes: if n.irs: assert n.irs_ssa -# pylint: disable=too-many-branches -def ssa_basic_properties(function: Function): +# pylint: disable=too-many-branches, too-many-locals +def ssa_basic_properties(function: Function) -> None: """Verifies that basic properties of ssa holds 1. Every name is defined only once @@ -75,12 +77,14 @@ def ssa_basic_properties(function: Function): """ ssa_lvalues = set() ssa_rvalues = set() - lvalue_assignments = {} + lvalue_assignments: Dict[str, int] = {} for n in function.nodes: for ir in n.irs: - if isinstance(ir, OperationWithLValue): + if isinstance(ir, OperationWithLValue) and ir.lvalue: name = ir.lvalue.name + if name is None: + continue if name in lvalue_assignments: lvalue_assignments[name] += 1 else: @@ -93,8 +97,9 @@ def ssa_basic_properties(function: Function): ssa_lvalues.add(ssa.lvalue) # 2 (if Local/State Var) - if isinstance(ssa.lvalue, (StateIRVariable, LocalIRVariable)): - assert ssa.lvalue.index > 0 + ssa_lvalue = ssa.lvalue + if isinstance(ssa_lvalue, (StateIRVariable, LocalIRVariable)): + assert ssa_lvalue.index > 0 for rvalue in filter( lambda x: not isinstance(x, (StateIRVariable, Constant)), ssa.read @@ -111,15 +116,18 @@ def ssa_basic_properties(function: Function): undef_vars.add(rvalue.non_ssa_version) # 4 - ssa_defs = defaultdict(int) + ssa_defs: Dict[str, int] = defaultdict(int) for v in ssa_lvalues: - ssa_defs[v.name] += 1 + if v and v.name: + ssa_defs[v.name] += 1 - for (k, n) in lvalue_assignments.items(): - assert ssa_defs[k] >= n + for (k, count) in lvalue_assignments.items(): + assert ssa_defs[k] >= count # Helper 5/6 - def check_property_5_and_6(variables, ssavars): + def check_property_5_and_6( + variables: List[LocalVariable], ssavars: List[LocalIRVariable] + ) -> None: for var in filter(lambda x: x.name, variables): ssa_vars = [x for x in ssavars if x.non_ssa_version == var] assert len(ssa_vars) == 1 @@ -136,7 +144,7 @@ def ssa_basic_properties(function: Function): check_property_5_and_6(function.returns, function.returns_ssa) -def ssa_phi_node_properties(f: Function): +def ssa_phi_node_properties(f: Function) -> None: """Every phi-function should have as many args as predecessors This does not apply if the phi-node refers to state variables, @@ -152,7 +160,7 @@ def ssa_phi_node_properties(f: Function): # TODO (hbrodin): This should probably go into another file, not specific to SSA -def dominance_properties(f: Function): +def dominance_properties(f: Function) -> None: """Verifies properties related to dominators holds 1. Every node have an immediate dominator except entry_node which have none @@ -180,14 +188,16 @@ def dominance_properties(f: Function): assert find_path(node.immediate_dominator, node) -def phi_values_inserted(f: Function): +def phi_values_inserted(f: Function) -> None: """Verifies that phi-values are inserted at the right places For every node that has a dominance frontier, any def (including phi) should be a phi function in its dominance frontier """ - def have_phi_for_var(node: Node, var): + def have_phi_for_var( + node: Node, var: Union[StateIRVariable, LocalIRVariable, TemporaryVariableSSA] + ) -> bool: """Checks if a node has a phi-instruction for var The ssa version would ideally be checked, but then @@ -198,7 +208,14 @@ def phi_values_inserted(f: Function): non_ssa = var.non_ssa_version for ssa in node.irs_ssa: if isinstance(ssa, Phi): - if non_ssa in map(lambda ssa_var: ssa_var.non_ssa_version, ssa.read): + if non_ssa in map( + lambda ssa_var: ssa_var.non_ssa_version, + [ + r + for r in ssa.read + if isinstance(r, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA)) + ], + ): return True return False @@ -206,12 +223,15 @@ def phi_values_inserted(f: Function): for df in node.dominance_frontier: for ssa in node.irs_ssa: if isinstance(ssa, OperationWithLValue): - if is_used_later(node, ssa.lvalue): - assert have_phi_for_var(df, ssa.lvalue) + ssa_lvalue = ssa.lvalue + if isinstance( + ssa_lvalue, (StateIRVariable, LocalIRVariable, TemporaryVariableSSA) + ) and is_used_later(node, ssa_lvalue): + assert have_phi_for_var(df, ssa_lvalue) @contextmanager -def select_solc_version(version: Optional[str]): +def select_solc_version(version: Optional[str]) -> None: """Selects solc version to use for running tests. If no version is provided, latest is used.""" @@ -256,17 +276,17 @@ def slither_from_source(source_code: str, solc_version: Optional[str] = None): pathlib.Path(fname).unlink() -def verify_properties_hold(source_code_or_slither: Union[str, Slither]): +def verify_properties_hold(source_code_or_slither: Union[str, Slither]) -> None: """Ensures that basic properties of SSA hold true""" - def verify_func(func: Function): + def verify_func(func: Function) -> None: have_ssa_if_ir(func) phi_values_inserted(func) ssa_basic_properties(func) ssa_phi_node_properties(func) dominance_properties(func) - def verify(slither): + def verify(slither: Slither) -> None: for cu in slither.compilation_units: for func in cu.functions_and_modifiers: _dump_function(func) @@ -280,11 +300,12 @@ def verify_properties_hold(source_code_or_slither: Union[str, Slither]): if isinstance(source_code_or_slither, Slither): verify(source_code_or_slither) else: + slither: Slither with slither_from_source(source_code_or_slither) as slither: verify(slither) -def _dump_function(f: Function): +def _dump_function(f: Function) -> None: """Helper function to print nodes/ssa ir for a function or modifier""" print(f"---- {f.name} ----") for n in f.nodes: @@ -294,13 +315,13 @@ def _dump_function(f: Function): print("") -def _dump_functions(c: Contract): +def _dump_functions(c: Contract) -> None: """Helper function to print functions and modifiers of a contract""" for f in c.functions_and_modifiers: _dump_function(f) -def get_filtered_ssa(f: Union[Function, Node], flt) -> List[Operation]: +def get_filtered_ssa(f: Union[Function, Node], flt: Callable) -> List[Operation]: """Returns a list of all ssanodes filtered by filter for all nodes in function f""" if isinstance(f, Function): return [ssanode for node in f.nodes for ssanode in node.irs_ssa if flt(ssanode)] @@ -314,7 +335,7 @@ def get_ssa_of_type(f: Union[Function, Node], ssatype) -> List[Operation]: return get_filtered_ssa(f, lambda ssanode: isinstance(ssanode, ssatype)) -def test_multi_write(): +def test_multi_write() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -327,7 +348,7 @@ def test_multi_write(): verify_properties_hold(contract) -def test_single_branch_phi(): +def test_single_branch_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -342,7 +363,7 @@ def test_single_branch_phi(): verify_properties_hold(contract) -def test_basic_phi(): +def test_basic_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -359,7 +380,7 @@ def test_basic_phi(): verify_properties_hold(contract) -def test_basic_loop_phi(): +def test_basic_loop_phi() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -375,7 +396,7 @@ def test_basic_loop_phi(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_phi_propagation_loop(): +def test_phi_propagation_loop() -> None: contract = """ pragma solidity ^0.8.11; contract Test { @@ -396,7 +417,7 @@ def test_phi_propagation_loop(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_free_function_properties(): +def test_free_function_properties() -> None: contract = """ pragma solidity ^0.8.11; @@ -417,7 +438,7 @@ def test_free_function_properties(): verify_properties_hold(contract) -def test_ssa_inter_transactional(): +def test_ssa_inter_transactional() -> None: source = """ pragma solidity ^0.8.11; contract A { @@ -460,7 +481,7 @@ def test_ssa_inter_transactional(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_ssa_phi_callbacks(): +def test_ssa_phi_callbacks() -> None: source = """ pragma solidity ^0.8.11; contract A { @@ -519,7 +540,7 @@ def test_ssa_phi_callbacks(): @pytest.mark.skip(reason="Fails in current slither version. Fix in #1102.") -def test_storage_refers_to(): +def test_storage_refers_to() -> None: """Test the storage aspects of the SSA IR When declaring a var as being storage, start tracking what storage it refers_to. From 77bf3e576e9ba80b900f98fbcb0eb4188e6c9e82 Mon Sep 17 00:00:00 2001 From: Feist Josselin Date: Fri, 24 Feb 2023 20:20:49 +0100 Subject: [PATCH 12/34] More fixes --- .../statements/delegatecall_in_loop.py | 2 +- slither/printers/functions/dominator.py | 3 +- slither/printers/guidance/echidna.py | 6 +- slither/printers/summary/data_depenency.py | 12 +- slither/printers/summary/declaration.py | 3 +- slither/slithir/convert.py | 122 +++++++----------- slither/slithir/operations/delete.py | 2 +- .../slithir/operations/return_operation.py | 25 ++-- slither/slithir/utils/utils.py | 6 +- slither/utils/myprettytable.py | 4 +- 10 files changed, 85 insertions(+), 100 deletions(-) diff --git a/slither/detectors/statements/delegatecall_in_loop.py b/slither/detectors/statements/delegatecall_in_loop.py index d97466edf..bdcf5dcae 100644 --- a/slither/detectors/statements/delegatecall_in_loop.py +++ b/slither/detectors/statements/delegatecall_in_loop.py @@ -42,7 +42,7 @@ def delegatecall_in_loop( and ir.function_name == "delegatecall" ): results.append(ir.node) - if isinstance(ir, (InternalCall)): + if isinstance(ir, (InternalCall)) and ir.function: delegatecall_in_loop(ir.function.entry_point, in_loop_counter, visited, results) for son in node.sons: diff --git a/slither/printers/functions/dominator.py b/slither/printers/functions/dominator.py index f618fd5db..1b32498f9 100644 --- a/slither/printers/functions/dominator.py +++ b/slither/printers/functions/dominator.py @@ -1,4 +1,5 @@ from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output class Dominator(AbstractPrinter): @@ -8,7 +9,7 @@ class Dominator(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#dominator" - def output(self, filename): + def output(self, filename: str) -> Output: """ _filename is not used Args: diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 166fa48f5..3a555562f 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -79,7 +79,7 @@ def _is_constant(f: Function) -> bool: # pylint: disable=too-many-branches :return: """ if f.view or f.pure: - if not f.contract.compilation_unit.solc_version.startswith("0.4"): + if not f.compilation_unit.solc_version.startswith("0.4"): return True if f.payable: return False @@ -102,11 +102,11 @@ def _is_constant(f: Function) -> bool: # pylint: disable=too-many-branches if isinstance(ir, HighLevelCall): if isinstance(ir.function, Variable) or ir.function.view or ir.function.pure: # External call to constant functions are ensured to be constant only for solidity >= 0.5 - if f.contract.compilation_unit.solc_version.startswith("0.4"): + if f.compilation_unit.solc_version.startswith("0.4"): return False else: return False - if isinstance(ir, InternalCall): + if isinstance(ir, InternalCall) and ir.function: # Storage write are not properly handled by all_state_variables_written if any(parameter.is_storage for parameter in ir.function.parameters): return False diff --git a/slither/printers/summary/data_depenency.py b/slither/printers/summary/data_depenency.py index 41659a299..f1c0dc8d5 100644 --- a/slither/printers/summary/data_depenency.py +++ b/slither/printers/summary/data_depenency.py @@ -1,19 +1,22 @@ """ Module printing summary of the contract """ +from typing import List +from slither.core.declarations import Contract from slither.printers.abstract_printer import AbstractPrinter -from slither.analyses.data_dependency.data_dependency import get_dependencies +from slither.analyses.data_dependency.data_dependency import get_dependencies, SUPPORTED_TYPES from slither.slithir.variables import TemporaryVariable, ReferenceVariable from slither.utils.myprettytable import MyPrettyTable +from slither.utils.output import Output -def _get(v, c): +def _get(v: SUPPORTED_TYPES, c: Contract) -> List[str]: return list( { d.name for d in get_dependencies(v, c) - if not isinstance(d, (TemporaryVariable, ReferenceVariable)) + if not isinstance(d, (TemporaryVariable, ReferenceVariable)) and d.name } ) @@ -25,7 +28,7 @@ class DataDependency(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#data-dependencies" - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: @@ -42,6 +45,7 @@ class DataDependency(AbstractPrinter): txt += f"\nContract {c.name}\n" table = MyPrettyTable(["Variable", "Dependencies"]) for v in c.state_variables: + assert v.name table.add_row([v.name, sorted(_get(v, c))]) txt += str(table) diff --git a/slither/printers/summary/declaration.py b/slither/printers/summary/declaration.py index 5888a1f00..9266d5580 100644 --- a/slither/printers/summary/declaration.py +++ b/slither/printers/summary/declaration.py @@ -1,4 +1,5 @@ from slither.printers.abstract_printer import AbstractPrinter +from slither.utils.output import Output from slither.utils.source_mapping import get_definition, get_implementation, get_references @@ -8,7 +9,7 @@ class Declaration(AbstractPrinter): WIKI = "TODO" - def output(self, _filename): + def output(self, _filename: str) -> Output: """ _filename is not used Args: diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index cc47ea913..e05526d86 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, List, TYPE_CHECKING, Union, Optional +from typing import Any, List, TYPE_CHECKING, Union, Optional, Dict # pylint: disable= too-many-lines,import-outside-toplevel,too-many-branches,too-many-statements,too-many-nested-blocks from slither.core.declarations import ( @@ -13,12 +13,14 @@ from slither.core.declarations import ( SolidityVariableComposed, Structure, ) +from slither.core.declarations.contract import USING_FOR_KEY, USING_FOR_ITEM from slither.core.declarations.custom_error import CustomError from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.declarations.solidity_variables import SolidityCustomRevert from slither.core.expressions import Identifier, Literal +from slither.core.expressions.expression import Expression from slither.core.solidity_types import ( ArrayType, ElementaryType, @@ -83,28 +85,6 @@ from slither.slithir.variables import TupleVariable from slither.utils.function import get_function_id from slither.utils.type import export_nested_types_from_variable from slither.visitors.slithir.expression_to_slithir import ExpressionToSlithIR -import slither.core.declarations.contract -import slither.core.declarations.function -import slither.core.solidity_types.elementary_type -import slither.core.solidity_types.function_type -import slither.core.solidity_types.user_defined_type -import slither.slithir.operations.assignment -import slither.slithir.operations.binary -import slither.slithir.operations.call -import slither.slithir.operations.high_level_call -import slither.slithir.operations.index -import slither.slithir.operations.init_array -import slither.slithir.operations.internal_call -import slither.slithir.operations.length -import slither.slithir.operations.library_call -import slither.slithir.operations.low_level_call -import slither.slithir.operations.member -import slither.slithir.operations.operation -import slither.slithir.operations.send -import slither.slithir.operations.solidity_call -import slither.slithir.operations.transfer -import slither.slithir.variables.temporary -from slither.core.expressions.expression import Expression if TYPE_CHECKING: from slither.core.cfg.node import Node @@ -112,7 +92,7 @@ if TYPE_CHECKING: logger = logging.getLogger("ConvertToIR") -def convert_expression(expression: Expression, node: "Node") -> List[Any]: +def convert_expression(expression: Expression, node: "Node") -> List[Operation]: # handle standlone expression # such as return true; from slither.core.cfg.node import NodeType @@ -122,8 +102,7 @@ def convert_expression(expression: Expression, node: "Node") -> List[Any]: cond = Condition(cst) cond.set_expression(expression) cond.set_node(node) - result = [cond] - return result + return [cond] if isinstance(expression, Identifier) and node.type in [ NodeType.IF, NodeType.IFLOOP, @@ -131,8 +110,7 @@ def convert_expression(expression: Expression, node: "Node") -> List[Any]: cond = Condition(expression.value) cond.set_expression(expression) cond.set_node(node) - result = [cond] - return result + return [cond] visitor = ExpressionToSlithIR(expression, node) result = visitor.result() @@ -141,15 +119,17 @@ def convert_expression(expression: Expression, node: "Node") -> List[Any]: if result: if node.type in [NodeType.IF, NodeType.IFLOOP]: - assert isinstance(result[-1], (OperationWithLValue)) - cond = Condition(result[-1].lvalue) + prev = result[-1] + assert isinstance(prev, (OperationWithLValue)) and prev.lvalue + cond = Condition(prev.lvalue) cond.set_expression(expression) cond.set_node(node) result.append(cond) elif node.type == NodeType.RETURN: # May return None - if isinstance(result[-1], (OperationWithLValue)): - r = Return(result[-1].lvalue) + prev = result[-1] + if isinstance(prev, (OperationWithLValue)): + r = Return(prev.lvalue) r.set_expression(expression) r.set_node(node) result.append(r) @@ -273,7 +253,7 @@ def _find_function_from_parameter( type_args += ["string"] not_found = True - candidates_kept = [] + candidates_kept: List[Function] = [] for type_arg in type_args: if not not_found: break @@ -336,7 +316,7 @@ def integrate_value_gas(result: List[Operation]) -> List[Operation]: # Find all the assignments assigments = {} for i in result: - if isinstance(i, OperationWithLValue): + if isinstance(i, OperationWithLValue) and i.lvalue: assigments[i.lvalue.name] = i if isinstance(i, TmpCall): if isinstance(i.called, Variable) and i.called.name in assigments: @@ -350,20 +330,25 @@ def integrate_value_gas(result: List[Operation]) -> List[Operation]: for idx, ins in enumerate(result): # value can be shadowed, so we check that the prev ins # is an Argument - if is_value(ins) and isinstance(result[idx - 1], Argument): + if idx == 0: + continue + prev_ins = result[idx - 1] + if is_value(ins) and isinstance(prev_ins, Argument): was_changed = True - result[idx - 1].set_type(ArgumentType.VALUE) - result[idx - 1].call_id = ins.ori.variable_left.name - calls.append(ins.ori.variable_left) + prev_ins.set_type(ArgumentType.VALUE) + # Types checked by is_value + prev_ins.call_id = ins.ori.variable_left.name # type: ignore + calls.append(ins.ori.variable_left) # type: ignore to_remove.append(ins) - variable_to_replace[ins.lvalue.name] = ins.ori.variable_left - elif is_gas(ins) and isinstance(result[idx - 1], Argument): + variable_to_replace[ins.lvalue.name] = ins.ori.variable_left # type: ignore + elif is_gas(ins) and isinstance(prev_ins, Argument): was_changed = True - result[idx - 1].set_type(ArgumentType.GAS) - result[idx - 1].call_id = ins.ori.variable_left.name - calls.append(ins.ori.variable_left) + prev_ins.set_type(ArgumentType.GAS) + # Types checked by is_gas + prev_ins.call_id = ins.ori.variable_left.name # type: ignore + calls.append(ins.ori.variable_left) # type: ignore to_remove.append(ins) - variable_to_replace[ins.lvalue.name] = ins.ori.variable_left + variable_to_replace[ins.lvalue.name] = ins.ori.variable_left # type: ignore # Remove the call to value/gas instruction result = [i for i in result if not i in to_remove] @@ -446,7 +431,7 @@ def propagate_type_and_convert_call(result: List[Operation], node: "Node") -> Li if isinstance(ins, (HighLevelCall, NewContract, InternalDynamicCall)): if ins.call_id in calls_value: ins.call_value = calls_value[ins.call_id] - if ins.call_id in calls_gas: + if ins.call_id in calls_gas and isinstance(ins, (HighLevelCall, InternalDynamicCall)): ins.call_gas = calls_gas[ins.call_id] if isinstance(ins, (Call, NewContract, NewStructure)): @@ -525,17 +510,15 @@ def _convert_type_contract(ir: Member) -> Assignment: raise SlithIRError(f"type({contract.name}).{ir.variable_right} is unknown") -def propagate_types( - ir: slither.slithir.operations.operation.Operation, node: "Node" -): # pylint: disable=too-many-locals +def propagate_types(ir: Operation, node: "Node"): # pylint: disable=too-many-locals # propagate the type node_function = node.function - using_for = ( + using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = ( node_function.contract.using_for_complete if isinstance(node_function, FunctionContract) else {} ) - if isinstance(ir, OperationWithLValue): + if isinstance(ir, OperationWithLValue) and ir.lvalue: # Force assignment in case of missing previous correct type if not ir.lvalue.type: if isinstance(ir, Assignment): @@ -646,11 +629,12 @@ def propagate_types( and not isinstance(ir.variable_left, Contract) and isinstance(ir.variable_left.type, (ElementaryType, ArrayType)) ): - length = Length(ir.variable_left, ir.lvalue) - length.set_expression(ir.expression) - length.lvalue.points_to = ir.variable_left - length.set_node(ir.node) - return length + new_length = Length(ir.variable_left, ir.lvalue) + assert ir.expression + new_length.set_expression(ir.expression) + new_length.lvalue.points_to = ir.variable_left + new_length.set_node(ir.node) + return new_length # This only happen for .balance/code/codehash access on a variable for which we dont know at # early parsing time the type # Like @@ -794,6 +778,7 @@ def propagate_types( ir.lvalue.set_type(ir.array_type) elif isinstance(ir, NewContract): contract = node.file_scope.get_contract_from_name(ir.contract_name) + assert contract ir.lvalue.set_type(UserDefinedType(contract)) elif isinstance(ir, NewElementaryType): ir.lvalue.set_type(ir.type) @@ -837,9 +822,7 @@ def propagate_types( # pylint: disable=too-many-locals -def extract_tmp_call( - ins: TmpCall, contract: Optional[Contract] -) -> slither.slithir.operations.call.Call: +def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]) -> Union[Call, Nop]: assert isinstance(ins, TmpCall) if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): # If the call is made to a variable member, where the member is this @@ -1328,16 +1311,8 @@ def convert_to_push_set_val( def convert_to_push( - ir: slither.slithir.operations.high_level_call.HighLevelCall, node: "Node" -) -> List[ - Union[ - slither.slithir.operations.length.Length, - slither.slithir.operations.assignment.Assignment, - slither.slithir.operations.binary.Binary, - slither.slithir.operations.index.Index, - slither.slithir.operations.init_array.InitArray, - ] -]: + ir: HighLevelCall, node: "Node" +) -> List[Union[Length, Assignment, Binary, Index, InitArray,]]: """ Convert a call to a series of operations to push a new value onto the array @@ -1357,22 +1332,23 @@ def convert_to_push( return ret -def convert_to_pop(ir, node): +def convert_to_pop(ir: HighLevelCall, node: "Node") -> List[Operation]: """ Convert pop operators Return a list of 6 operations """ - ret = [] + ret: List[Operation] = [] arr = ir.destination length = ReferenceVariable(node) length.set_type(ElementaryType("uint256")) ir_length = Length(arr, length) + assert ir.expression ir_length.set_expression(ir.expression) ir_length.set_node(ir.node) - ir_length.lvalue.points_to = arr + length.points_to = arr ret.append(ir_length) val = TemporaryVariable(node) @@ -1384,6 +1360,8 @@ def convert_to_pop(ir, node): element_to_delete = ReferenceVariable(node) ir_assign_element_to_delete = Index(element_to_delete, arr, val) + # TODO the following is equivalent to length.points_to = arr + # Should it be removed? ir_length.lvalue.points_to = arr element_to_delete.set_type(ElementaryType("uint256")) ir_assign_element_to_delete.set_expression(ir.expression) @@ -1399,7 +1377,7 @@ def convert_to_pop(ir, node): length_to_assign.set_type(ElementaryType("uint256")) ir_length = Length(arr, length_to_assign) ir_length.set_expression(ir.expression) - ir_length.lvalue.points_to = arr + length_to_assign.points_to = arr ir_length.set_node(ir.node) ret.append(ir_length) diff --git a/slither/slithir/operations/delete.py b/slither/slithir/operations/delete.py index 496d170ad..d241033c5 100644 --- a/slither/slithir/operations/delete.py +++ b/slither/slithir/operations/delete.py @@ -36,5 +36,5 @@ class Delete(OperationWithLValue): ) -> Union[StateIRVariable, StateVariable, ReferenceVariable, ReferenceVariableSSA]: return self._variable - def __str__(self): + def __str__(self) -> str: return f"{self.lvalue} = delete {self.variable} " diff --git a/slither/slithir/operations/return_operation.py b/slither/slithir/operations/return_operation.py index c21579763..290572ebf 100644 --- a/slither/slithir/operations/return_operation.py +++ b/slither/slithir/operations/return_operation.py @@ -1,11 +1,10 @@ -from typing import List +from typing import List, Optional, Union, Any from slither.core.declarations import Function +from slither.core.variables.variable import Variable from slither.slithir.operations.operation import Operation - +from slither.slithir.utils.utils import is_valid_rvalue, RVALUE from slither.slithir.variables.tuple import TupleVariable -from slither.slithir.utils.utils import is_valid_rvalue -from slither.core.variables.variable import Variable class Return(Operation): @@ -14,10 +13,13 @@ class Return(Operation): Only present as last operation in RETURN node """ - def __init__(self, values) -> None: + def __init__( + self, values: Optional[Union[RVALUE, TupleVariable, Function, List[RVALUE]]] + ) -> None: # Note: Can return None # ex: return call() # where call() dont return + self._values: List[Union[RVALUE, TupleVariable, Function]] if not isinstance(values, list): assert ( is_valid_rvalue(values) @@ -25,20 +27,19 @@ class Return(Operation): or values is None ) if values is None: - values = [] + self._values = [] else: - values = [values] + self._values = [values] else: # Remove None # Prior Solidity 0.5 # return (0,) # was valid for returns(uint) - values = [v for v in values if not v is None] - self._valid_value(values) + self._values = [v for v in values if not v is None] + self._valid_value(self._values) super().__init__() - self._values = values - def _valid_value(self, value) -> bool: + def _valid_value(self, value: Any) -> bool: if isinstance(value, list): assert all(self._valid_value(v) for v in value) else: @@ -53,5 +54,5 @@ class Return(Operation): def values(self) -> List[Variable]: return self._unroll(self._values) - def __str__(self): + def __str__(self) -> str: return f"RETURN {','.join([f'{x}' for x in self.values])}" diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index 4619c08bc..49b1a879c 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable @@ -31,7 +31,7 @@ LVALUE = Union[ ] -def is_valid_rvalue(v: SourceMapping) -> bool: +def is_valid_rvalue(v: Optional[SourceMapping]) -> bool: return isinstance( v, ( @@ -46,7 +46,7 @@ def is_valid_rvalue(v: SourceMapping) -> bool: ) -def is_valid_lvalue(v: SourceMapping) -> bool: +def is_valid_lvalue(v: Optional[SourceMapping]) -> bool: return isinstance( v, ( diff --git a/slither/utils/myprettytable.py b/slither/utils/myprettytable.py index a1dfd7ac0..af10a6ff2 100644 --- a/slither/utils/myprettytable.py +++ b/slither/utils/myprettytable.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import List, Dict, Union from prettytable import PrettyTable @@ -8,7 +8,7 @@ class MyPrettyTable: self._field_names = field_names self._rows: List = [] - def add_row(self, row: List[str]) -> None: + def add_row(self, row: List[Union[str, List[str]]]) -> None: self._rows.append(row) def to_pretty_table(self) -> PrettyTable: From 197ccb104865d12755a319acdbbf75d3a6a4862e Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 8 Mar 2023 16:05:09 -0600 Subject: [PATCH 13/34] Add `Function.interface_signature_str` --- slither/core/declarations/function.py | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index a4624feec..cbc0b64ec 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -210,6 +210,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu self._signature: Optional[Tuple[str, List[str], List[str]]] = None self._solidity_signature: Optional[str] = None self._signature_str: Optional[str] = None + self._interface_signature_str: Optional[str] = None self._canonical_name: Optional[str] = None self._is_protected: Optional[bool] = None @@ -1002,6 +1003,33 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ) return self._signature_str + @property + def interface_signature_str(self) -> Optional[str]: + """ + str: func_name(type1,type2) external {payable/view/pure} returns (type3) + Return the function interface as a str (contains the return values) + Returns None if the function is private or internal, or is a constructor/fallback/receive + """ + if self._interface_signature_str is None: + name, parameters, returnVars = self.signature + visibility = self.visibility + if ( + visibility in ["private", "internal"] + or self.is_constructor + or self.is_fallback + or self.is_receive + ): + return None + view = " view" if self.view else "" + pure = " pure" if self.pure else "" + payable = " payable" if self.payable else "" + self._interface_signature_str = ( + name + "(" + ",".join(parameters) + ") external" + payable + pure + view + ) + if len(returnVars) > 0: + self._interface_signature_str += " returns (" + ",".join(returnVars) + ")" + return self._interface_signature_str + # endregion ################################################################################### ################################################################################### From 4da0fb51101bebed3d05a5c6d195f814f1de0ecc Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 8 Mar 2023 16:05:25 -0600 Subject: [PATCH 14/34] Add `Structure.interface_def_str` --- slither/core/declarations/structure.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/slither/core/declarations/structure.py b/slither/core/declarations/structure.py index 8f6d8c50a..2800f809c 100644 --- a/slither/core/declarations/structure.py +++ b/slither/core/declarations/structure.py @@ -49,5 +49,11 @@ class Structure(SourceMapping): ret.append(self._elems[e]) return ret + def interface_def_str(self) -> str: + definition = f" struct {self.name} {{\n" + for elem in self.elems_ordered: + definition += f" {elem.type} {elem.name};\n" + definition += " }\n" + def __str__(self) -> str: return self.name From 6a2250f60d63674884756972bda8cbb40f6304ef Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 8 Mar 2023 16:05:45 -0600 Subject: [PATCH 15/34] Add `Contract.generate_interface` --- slither/core/declarations/contract.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index eb2ac9a2e..ffd8a0862 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -953,6 +953,20 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return all((not f.is_implemented) for f in self.functions) + def generate_interface(self) -> str: + interface = f"interface I{self.name} {{\n" + for struct in self.structures: + interface += struct.interface_def_str() + for var in self.state_variables_entry_points: + interface += ( + f" function {var.signature_str.replace('returns', 'external returns ')};\n" + ) + for func in self.functions_entry_points: + if func.is_constructor or func.is_fallback or func.is_receive: + continue + interface += f" function {func.interface_signature_str};\n" + return interface + # endregion ################################################################################### ################################################################################### From 772710cdede90ffd3fbd8c186e385219d0c467aa Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 8 Mar 2023 16:22:13 -0600 Subject: [PATCH 16/34] Fix typo --- slither/core/declarations/contract.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index ffd8a0862..1ca38a1a2 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -956,7 +956,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def generate_interface(self) -> str: interface = f"interface I{self.name} {{\n" for struct in self.structures: - interface += struct.interface_def_str() + if isinstance(struct.interface_def_str(), str): + interface += struct.interface_def_str() for var in self.state_variables_entry_points: interface += ( f" function {var.signature_str.replace('returns', 'external returns ')};\n" @@ -965,6 +966,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods if func.is_constructor or func.is_fallback or func.is_receive: continue interface += f" function {func.interface_signature_str};\n" + interface += "}\n" return interface # endregion From 9e70c3ff884de2fd1be575cfaab44e3da4e7d1ce Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 8 Mar 2023 16:36:51 -0600 Subject: [PATCH 17/34] Fix typo --- slither/core/declarations/contract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 1ca38a1a2..35f31ab07 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -966,7 +966,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods if func.is_constructor or func.is_fallback or func.is_receive: continue interface += f" function {func.interface_signature_str};\n" - interface += "}\n" + interface += "}\n\n" return interface # endregion From 9238efc1a46befa6d7911f0d4cae68749dd0c68a Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 8 Mar 2023 16:37:33 -0600 Subject: [PATCH 18/34] Change contract type to address in `interface_signature_str` --- slither/core/declarations/function.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index cbc0b64ec..477b6dc51 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -21,6 +21,7 @@ from slither.core.expressions import ( UnaryOperation, ) from slither.core.solidity_types.type import Type +from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable @@ -1023,11 +1024,13 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu view = " view" if self.view else "" pure = " pure" if self.pure else "" payable = " payable" if self.payable else "" + returns = ["address" if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Contract) + else str(ret.type) for ret in self.returns] self._interface_signature_str = ( name + "(" + ",".join(parameters) + ") external" + payable + pure + view ) if len(returnVars) > 0: - self._interface_signature_str += " returns (" + ",".join(returnVars) + ")" + self._interface_signature_str += " returns (" + ",".join(returns) + ")" return self._interface_signature_str # endregion From 289bd49c3ec4516409f6589ffaee18100ba24825 Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 8 Mar 2023 16:58:19 -0600 Subject: [PATCH 19/34] Black format --- slither/core/declarations/function.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 477b6dc51..65e19e013 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -1024,8 +1024,12 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu view = " view" if self.view else "" pure = " pure" if self.pure else "" payable = " payable" if self.payable else "" - returns = ["address" if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Contract) - else str(ret.type) for ret in self.returns] + returns = [ + "address" + if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Contract) + else str(ret.type) + for ret in self.returns + ] self._interface_signature_str = ( name + "(" + ",".join(parameters) + ") external" + payable + pure + view ) From 5a25c81a52e108f90ab28ac56445aa40069dfb8a Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 14 Mar 2023 10:30:20 -0500 Subject: [PATCH 20/34] Locally import Contract to resolve pylint in `Function.interface_signature_str` --- slither/core/declarations/function.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 65e19e013..4b4bbf82c 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -1011,8 +1011,10 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu Return the function interface as a str (contains the return values) Returns None if the function is private or internal, or is a constructor/fallback/receive """ + from slither.core.declarations.contract import Contract + if self._interface_signature_str is None: - name, parameters, returnVars = self.signature + name, parameters, return_vars = self.signature visibility = self.visibility if ( visibility in ["private", "internal"] @@ -1033,7 +1035,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu self._interface_signature_str = ( name + "(" + ",".join(parameters) + ") external" + payable + pure + view ) - if len(returnVars) > 0: + if len(return_vars) > 0: self._interface_signature_str += " returns (" + ",".join(returns) + ")" return self._interface_signature_str From 0df2c23eeed659c0a2aa9bd057ac5a93f90086ed Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 14 Mar 2023 10:43:07 -0500 Subject: [PATCH 21/34] Include events in `Contract.generate_interface` --- slither/core/declarations/contract.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 35f31ab07..fc6fa0824 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -955,6 +955,9 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def generate_interface(self) -> str: interface = f"interface I{self.name} {{\n" + for event in self.events: + name, args = event.signature + interface += f" event {name}({','.join(args)});\n" for struct in self.structures: if isinstance(struct.interface_def_str(), str): interface += struct.interface_def_str() From c507b3b99f8adbe5c504fa09f26d152911dc918c Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 14 Mar 2023 16:14:39 -0500 Subject: [PATCH 22/34] Move code generation to a new util Rather than on the core objects. --- slither/core/declarations/contract.py | 19 ------ slither/core/declarations/function.py | 35 ---------- slither/core/declarations/structure.py | 6 -- slither/utils/code_generation.py | 94 ++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 60 deletions(-) create mode 100644 slither/utils/code_generation.py diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index fc6fa0824..eb2ac9a2e 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -953,25 +953,6 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return all((not f.is_implemented) for f in self.functions) - def generate_interface(self) -> str: - interface = f"interface I{self.name} {{\n" - for event in self.events: - name, args = event.signature - interface += f" event {name}({','.join(args)});\n" - for struct in self.structures: - if isinstance(struct.interface_def_str(), str): - interface += struct.interface_def_str() - for var in self.state_variables_entry_points: - interface += ( - f" function {var.signature_str.replace('returns', 'external returns ')};\n" - ) - for func in self.functions_entry_points: - if func.is_constructor or func.is_fallback or func.is_receive: - continue - interface += f" function {func.interface_signature_str};\n" - interface += "}\n\n" - return interface - # endregion ################################################################################### ################################################################################### diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 4b4bbf82c..ad765bfc0 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -1004,41 +1004,6 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ) return self._signature_str - @property - def interface_signature_str(self) -> Optional[str]: - """ - str: func_name(type1,type2) external {payable/view/pure} returns (type3) - Return the function interface as a str (contains the return values) - Returns None if the function is private or internal, or is a constructor/fallback/receive - """ - from slither.core.declarations.contract import Contract - - if self._interface_signature_str is None: - name, parameters, return_vars = self.signature - visibility = self.visibility - if ( - visibility in ["private", "internal"] - or self.is_constructor - or self.is_fallback - or self.is_receive - ): - return None - view = " view" if self.view else "" - pure = " pure" if self.pure else "" - payable = " payable" if self.payable else "" - returns = [ - "address" - if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Contract) - else str(ret.type) - for ret in self.returns - ] - self._interface_signature_str = ( - name + "(" + ",".join(parameters) + ") external" + payable + pure + view - ) - if len(return_vars) > 0: - self._interface_signature_str += " returns (" + ",".join(returns) + ")" - return self._interface_signature_str - # endregion ################################################################################### ################################################################################### diff --git a/slither/core/declarations/structure.py b/slither/core/declarations/structure.py index 2800f809c..8f6d8c50a 100644 --- a/slither/core/declarations/structure.py +++ b/slither/core/declarations/structure.py @@ -49,11 +49,5 @@ class Structure(SourceMapping): ret.append(self._elems[e]) return ret - def interface_def_str(self) -> str: - definition = f" struct {self.name} {{\n" - for elem in self.elems_ordered: - definition += f" {elem.type} {elem.name};\n" - definition += " }\n" - def __str__(self) -> str: return self.name diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py new file mode 100644 index 000000000..7f5c88b2f --- /dev/null +++ b/slither/utils/code_generation.py @@ -0,0 +1,94 @@ +# Functions for generating Solidity code +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from slither.core.declarations import Function, Contract, Structure + + +def generate_interface(contract: "Contract") -> str: + """ + Generates code for a Solidity interface to the contract. + Args: + contract: A Contract object + + Returns: + A string with the code for an interface, with function stubs for all public or external functions and + state variables, as well as any events or structs declared in the contract. + """ + interface = f"interface I{contract.name} {{\n" + for event in contract.events: + name, args = event.signature + interface += f" event {name}({','.join(args)});\n" + for struct in contract.structures: + interface += generate_struct_interface_str(struct) + for var in contract.state_variables_entry_points: + interface += ( + f" function {var.signature_str.replace('returns', 'external returns ')};\n" + ) + for func in contract.functions_entry_points: + if func.is_constructor or func.is_fallback or func.is_receive: + continue + interface += f" function {generate_interface_function_signature(func)};\n" + interface += "}\n\n" + return interface + + +def generate_interface_function_signature(func: "Function") -> Optional[str]: + """ + Generates a string of the form: + func_name(type1,type2) external {payable/view/pure} returns (type3) + + Args: + func: A Function object + + Returns: + The function interface as a str (contains the return values). + Returns None if the function is private or internal, or is a constructor/fallback/receive. + """ + from slither.core.declarations.contract import Contract + + name, parameters, return_vars = func.signature + visibility = func.visibility + if ( + visibility in ["private", "internal"] + or func.is_constructor + or func.is_fallback + or func.is_receive + ): + return None + view = " view" if func.view else "" + pure = " pure" if func.pure else "" + payable = " payable" if func.payable else "" + returns = [ + "address" + if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Contract) + else str(ret.type) + for ret in func.returns + ] + _interface_signature_str = ( + name + "(" + ",".join(parameters) + ") external" + payable + pure + view + ) + if len(return_vars) > 0: + _interface_signature_str += " returns (" + ",".join(returns) + ")" + return _interface_signature_str + + +def generate_struct_interface_str(struct: "Structure") -> str: + """ + Generates code for a structure declaration in an interface of the form: + struct struct_name { + elem1_type elem1_name; + elem2_type elem2_name; + ... ... + } + Args: + struct: A Structure object + + Returns: + The structure declaration code as a string. + """ + definition = f" struct {struct.name} {{\n" + for elem in struct.elems_ordered: + definition += f" {elem.type} {elem.name};\n" + definition += " }\n" + return definition From 67ece635f0f1248f88c89219308db4b72f855754 Mon Sep 17 00:00:00 2001 From: webthethird Date: Tue, 14 Mar 2023 16:16:06 -0500 Subject: [PATCH 23/34] Revert unnecessary imports --- slither/core/declarations/function.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index ad765bfc0..a4624feec 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -21,7 +21,6 @@ from slither.core.expressions import ( UnaryOperation, ) from slither.core.solidity_types.type import Type -from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable @@ -211,7 +210,6 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu self._signature: Optional[Tuple[str, List[str], List[str]]] = None self._solidity_signature: Optional[str] = None self._signature_str: Optional[str] = None - self._interface_signature_str: Optional[str] = None self._canonical_name: Optional[str] = None self._is_protected: Optional[bool] = None From 6284e0ac44e6740acab7c9e92caf3e70b426e196 Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 15 Mar 2023 09:52:44 -0500 Subject: [PATCH 24/34] Pylint and black --- slither/utils/code_generation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index 7f5c88b2f..fd0d5c161 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -1,8 +1,11 @@ # Functions for generating Solidity code -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Optional + +from slither.core.declarations.contract import Contract +from slither.core.solidity_types.user_defined_type import UserDefinedType if TYPE_CHECKING: - from slither.core.declarations import Function, Contract, Structure + from slither.core.declarations import Function, Structure def generate_interface(contract: "Contract") -> str: @@ -22,9 +25,7 @@ def generate_interface(contract: "Contract") -> str: for struct in contract.structures: interface += generate_struct_interface_str(struct) for var in contract.state_variables_entry_points: - interface += ( - f" function {var.signature_str.replace('returns', 'external returns ')};\n" - ) + interface += f" function {var.signature_str.replace('returns', 'external returns ')};\n" for func in contract.functions_entry_points: if func.is_constructor or func.is_fallback or func.is_receive: continue @@ -45,7 +46,6 @@ def generate_interface_function_signature(func: "Function") -> Optional[str]: The function interface as a str (contains the return values). Returns None if the function is private or internal, or is a constructor/fallback/receive. """ - from slither.core.declarations.contract import Contract name, parameters, return_vars = func.signature visibility = func.visibility From 631f124d3232c1f2e6efb5be8ceec4184223433d Mon Sep 17 00:00:00 2001 From: webthethird Date: Wed, 15 Mar 2023 15:59:35 -0500 Subject: [PATCH 25/34] Include custom errors in `generate_interface` --- slither/utils/code_generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index fd0d5c161..c0ece478c 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -22,6 +22,8 @@ def generate_interface(contract: "Contract") -> str: for event in contract.events: name, args = event.signature interface += f" event {name}({','.join(args)});\n" + for error in contract.custom_errors: + interface += f" error {error.solidity_signature};\n" for struct in contract.structures: interface += generate_struct_interface_str(struct) for var in contract.state_variables_entry_points: From 4d8181ee0e8e5b767c84b93d9b24a510169b6e97 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 14:29:01 -0500 Subject: [PATCH 26/34] Check `contract.functions_entry_points` instead of `function.visibility` --- slither/utils/code_generation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index c0ece478c..8b3a267b1 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -5,7 +5,7 @@ from slither.core.declarations.contract import Contract from slither.core.solidity_types.user_defined_type import UserDefinedType if TYPE_CHECKING: - from slither.core.declarations import Function, Structure + from slither.core.declarations import FunctionContract, Structure def generate_interface(contract: "Contract") -> str: @@ -36,13 +36,13 @@ def generate_interface(contract: "Contract") -> str: return interface -def generate_interface_function_signature(func: "Function") -> Optional[str]: +def generate_interface_function_signature(func: "FunctionContract") -> Optional[str]: """ Generates a string of the form: func_name(type1,type2) external {payable/view/pure} returns (type3) Args: - func: A Function object + func: A FunctionContract object Returns: The function interface as a str (contains the return values). @@ -50,9 +50,8 @@ def generate_interface_function_signature(func: "Function") -> Optional[str]: """ name, parameters, return_vars = func.signature - visibility = func.visibility if ( - visibility in ["private", "internal"] + func not in func.contract.functions_entry_points or func.is_constructor or func.is_fallback or func.is_receive From 4dfe45742a1fceee9c584b604576385d57186669 Mon Sep 17 00:00:00 2001 From: webthethird Date: Fri, 17 Mar 2023 14:33:45 -0500 Subject: [PATCH 27/34] Update docstring --- slither/utils/code_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index 8b3a267b1..8d7b1ec71 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -16,7 +16,7 @@ def generate_interface(contract: "Contract") -> str: Returns: A string with the code for an interface, with function stubs for all public or external functions and - state variables, as well as any events or structs declared in the contract. + state variables, as well as any events, custom errors and/or structs declared in the contract. """ interface = f"interface I{contract.name} {{\n" for event in contract.events: From 1d35316d9b59683956fd20d98fbf7acc40749c78 Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 20 Mar 2023 10:06:39 -0500 Subject: [PATCH 28/34] Use `convert_type_for_solidity_signature_to_string` and include Enums --- slither/utils/code_generation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index 8d7b1ec71..0cb9af9c8 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Optional from slither.core.declarations.contract import Contract from slither.core.solidity_types.user_defined_type import UserDefinedType +from slither.utils.type import convert_type_for_solidity_signature_to_string if TYPE_CHECKING: from slither.core.declarations import FunctionContract, Structure @@ -23,7 +24,11 @@ def generate_interface(contract: "Contract") -> str: name, args = event.signature interface += f" event {name}({','.join(args)});\n" for error in contract.custom_errors: - interface += f" error {error.solidity_signature};\n" + args = [convert_type_for_solidity_signature_to_string(arg.type).replace("(", "").replace(")", "") + for arg in error.parameters] + interface += f" error {error.name}({', '.join(args)});\n" + for enum in contract.enums: + interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" for struct in contract.structures: interface += generate_struct_interface_str(struct) for var in contract.state_variables_entry_points: @@ -61,9 +66,7 @@ def generate_interface_function_signature(func: "FunctionContract") -> Optional[ pure = " pure" if func.pure else "" payable = " payable" if func.payable else "" returns = [ - "address" - if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Contract) - else str(ret.type) + convert_type_for_solidity_signature_to_string(ret.type) for ret in func.returns ] _interface_signature_str = ( From 2c30d4619218b9e60c86c8788541ccce4ad2fb2e Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 20 Mar 2023 10:08:30 -0500 Subject: [PATCH 29/34] Minor formatting --- slither/utils/code_generation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index 0cb9af9c8..080c8bd3e 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -1,12 +1,11 @@ # Functions for generating Solidity code from typing import TYPE_CHECKING, Optional -from slither.core.declarations.contract import Contract from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.utils.type import convert_type_for_solidity_signature_to_string if TYPE_CHECKING: - from slither.core.declarations import FunctionContract, Structure + from slither.core.declarations import FunctionContract, Structure, Contract def generate_interface(contract: "Contract") -> str: @@ -22,7 +21,7 @@ def generate_interface(contract: "Contract") -> str: interface = f"interface I{contract.name} {{\n" for event in contract.events: name, args = event.signature - interface += f" event {name}({','.join(args)});\n" + interface += f" event {name}({', '.join(args)});\n" for error in contract.custom_errors: args = [convert_type_for_solidity_signature_to_string(arg.type).replace("(", "").replace(")", "") for arg in error.parameters] @@ -69,11 +68,14 @@ def generate_interface_function_signature(func: "FunctionContract") -> Optional[ convert_type_for_solidity_signature_to_string(ret.type) for ret in func.returns ] + parameters = [ + param.replace(f"{func.contract.name}.", "") for param in parameters + ] _interface_signature_str = ( name + "(" + ",".join(parameters) + ") external" + payable + pure + view ) if len(return_vars) > 0: - _interface_signature_str += " returns (" + ",".join(returns) + ")" + _interface_signature_str += " returns (" + ",".join(returns).replace("(", "").replace(")", "") + ")" return _interface_signature_str From dd1b2a84b929c0ef567ac25cc2201a4035a75c96 Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 20 Mar 2023 10:09:11 -0500 Subject: [PATCH 30/34] Add test for interface code generation --- tests/code_generation/CodeGeneration.sol | 56 ++++++++++++++++++++++++ tests/test_code_generation.py | 25 +++++++++++ 2 files changed, 81 insertions(+) create mode 100644 tests/code_generation/CodeGeneration.sol create mode 100644 tests/test_code_generation.py diff --git a/tests/code_generation/CodeGeneration.sol b/tests/code_generation/CodeGeneration.sol new file mode 100644 index 000000000..c15017abd --- /dev/null +++ b/tests/code_generation/CodeGeneration.sol @@ -0,0 +1,56 @@ +pragma solidity ^0.8.4; +interface I { + enum SomeEnum { ONE, TWO, THREE } + error ErrorWithEnum(SomeEnum e); +} + +contract TestContract is I { + uint public stateA; + uint private stateB; + address public immutable owner = msg.sender; + mapping(address => mapping(uint => St)) public structs; + + event NoParams(); + event Anonymous() anonymous; + event OneParam(address addr); + event OneParamIndexed(address indexed addr); + + error ErrorSimple(); + error ErrorWithArgs(uint, uint); + error ErrorWithStruct(St s); + + struct St{ + uint v; + } + + function err0() public { + revert ErrorSimple(); + } + function err1() public { + St memory s; + revert ErrorWithStruct(s); + } + function err2(uint a, uint b) public { + revert ErrorWithArgs(a, b); + revert ErrorWithArgs(uint(SomeEnum.ONE), uint(SomeEnum.ONE)); + } + function err3() internal { + revert('test'); + } + function err4() private { + revert ErrorWithEnum(SomeEnum.ONE); + } + + function newSt(uint x) public returns (St memory) { + St memory st; + st.v = x; + structs[msg.sender][x] = st; + return st; + } + function getSt(uint x) public view returns (St memory) { + return structs[msg.sender][x]; + } + function removeSt(St memory st) public { + delete structs[msg.sender][st.v]; + } +} \ No newline at end of file diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py new file mode 100644 index 000000000..9733f7c90 --- /dev/null +++ b/tests/test_code_generation.py @@ -0,0 +1,25 @@ +import os + +from solc_select import solc_select + +from slither import Slither +from slither.core.expressions import Literal +from slither.utils.code_generation import ( + generate_interface, + generate_interface_function_signature, + generate_struct_interface_str, +) + +SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +CODE_TEST_ROOT = os.path.join(SLITHER_ROOT, "tests", "code_generation") + + +def test_interface_generation() -> None: + solc_select.switch_global_version("0.8.4", always_install=True) + + sl = Slither(os.path.join(CODE_TEST_ROOT, "CodeGeneration.sol")) + + with open("actual_generated_code.sol", "w", encoding="utf-8") as file: + file.write(generate_interface(sl.get_contract_from_name("TestContract")[0])) + + From b0dcc57f03b08c8273873c2b989283caf21dfcfd Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 20 Mar 2023 10:19:13 -0500 Subject: [PATCH 31/34] Add test for interface code generation --- tests/code_generation/TEST_generated_code.sol | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 tests/code_generation/TEST_generated_code.sol diff --git a/tests/code_generation/TEST_generated_code.sol b/tests/code_generation/TEST_generated_code.sol new file mode 100644 index 000000000..62e08bd74 --- /dev/null +++ b/tests/code_generation/TEST_generated_code.sol @@ -0,0 +1,24 @@ +interface ITestContract { + event NoParams(); + event Anonymous(); + event OneParam(address); + event OneParamIndexed(address); + error ErrorWithEnum(uint8); + error ErrorSimple(); + error ErrorWithArgs(uint256, uint256); + error ErrorWithStruct(uint256); + enum SomeEnum { ONE, TWO, THREE } + struct St { + uint256 v; + } + function stateA() external returns (uint256); + function owner() external returns (address); + function structs(address,uint256) external returns (uint256); + function err0() external; + function err1() external; + function err2(uint256,uint256) external; + function newSt(uint256) external returns (uint256); + function getSt(uint256) external view returns (uint256); + function removeSt(uint256) external; +} + From b78fa7f1a20ad2a69ed068fe8819eb6ec6ca6cfc Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 20 Mar 2023 10:19:33 -0500 Subject: [PATCH 32/34] Unwrap parameters as well --- slither/utils/code_generation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index 080c8bd3e..fa1793b4f 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -65,17 +65,18 @@ def generate_interface_function_signature(func: "FunctionContract") -> Optional[ pure = " pure" if func.pure else "" payable = " payable" if func.payable else "" returns = [ - convert_type_for_solidity_signature_to_string(ret.type) + convert_type_for_solidity_signature_to_string(ret.type).replace("(", "").replace(")", "") for ret in func.returns ] parameters = [ - param.replace(f"{func.contract.name}.", "") for param in parameters + convert_type_for_solidity_signature_to_string(param.type).replace("(", "").replace(")", "") + for param in func.parameters ] _interface_signature_str = ( name + "(" + ",".join(parameters) + ") external" + payable + pure + view ) if len(return_vars) > 0: - _interface_signature_str += " returns (" + ",".join(returns).replace("(", "").replace(")", "") + ")" + _interface_signature_str += " returns (" + ",".join(returns) + ")" return _interface_signature_str From 7532baae57543cdc4627998728fe6ae0be7a3056 Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 20 Mar 2023 10:24:37 -0500 Subject: [PATCH 33/34] Finish testing interface generation --- tests/test_code_generation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py index 9733f7c90..35db30a44 100644 --- a/tests/test_code_generation.py +++ b/tests/test_code_generation.py @@ -19,7 +19,12 @@ def test_interface_generation() -> None: sl = Slither(os.path.join(CODE_TEST_ROOT, "CodeGeneration.sol")) - with open("actual_generated_code.sol", "w", encoding="utf-8") as file: - file.write(generate_interface(sl.get_contract_from_name("TestContract")[0])) + actual = generate_interface(sl.get_contract_from_name("TestContract")[0]) + expected_path = os.path.join(CODE_TEST_ROOT, "TEST_generated_code.sol") + + with open(expected_path, "r", encoding="utf-8") as file: + expected = file.read() + + assert actual == expected From 6d685711dd1d6b14605df2e801e579e2647f5c9a Mon Sep 17 00:00:00 2001 From: webthethird Date: Mon, 20 Mar 2023 10:53:29 -0500 Subject: [PATCH 34/34] Pylint and black --- slither/utils/code_generation.py | 9 ++++++--- tests/test_code_generation.py | 5 ----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/slither/utils/code_generation.py b/slither/utils/code_generation.py index fa1793b4f..951bf4702 100644 --- a/slither/utils/code_generation.py +++ b/slither/utils/code_generation.py @@ -1,7 +1,6 @@ # Functions for generating Solidity code from typing import TYPE_CHECKING, Optional -from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.utils.type import convert_type_for_solidity_signature_to_string if TYPE_CHECKING: @@ -23,8 +22,12 @@ def generate_interface(contract: "Contract") -> str: name, args = event.signature interface += f" event {name}({', '.join(args)});\n" for error in contract.custom_errors: - args = [convert_type_for_solidity_signature_to_string(arg.type).replace("(", "").replace(")", "") - for arg in error.parameters] + args = [ + convert_type_for_solidity_signature_to_string(arg.type) + .replace("(", "") + .replace(")", "") + for arg in error.parameters + ] interface += f" error {error.name}({', '.join(args)});\n" for enum in contract.enums: interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n" diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py index 35db30a44..13d1c8fb0 100644 --- a/tests/test_code_generation.py +++ b/tests/test_code_generation.py @@ -3,11 +3,8 @@ import os from solc_select import solc_select from slither import Slither -from slither.core.expressions import Literal from slither.utils.code_generation import ( generate_interface, - generate_interface_function_signature, - generate_struct_interface_str, ) SLITHER_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -26,5 +23,3 @@ def test_interface_generation() -> None: expected = file.read() assert actual == expected - -