diff --git a/slither/__main__.py b/slither/__main__.py index 528a93e8f..a5d51dcd6 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -66,7 +66,7 @@ def process_single( args: argparse.Namespace, detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[Slither, List[Dict], List[Dict], int]: +) -> Tuple[Slither, List[Dict], List[Output], int]: """ The core high-level code for running Slither static analysis. @@ -89,7 +89,7 @@ def process_all( args: argparse.Namespace, detector_classes: List[Type[AbstractDetector]], printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[List[Slither], List[Dict], List[Dict], int]: +) -> Tuple[List[Slither], List[Dict], List[Output], int]: compilations = compile_all(target, **vars(args)) slither_instances = [] results_detectors = [] @@ -144,23 +144,6 @@ def _process( return slither, results_detectors, results_printers, analyzed_contracts_count -# TODO: delete me? -def process_from_asts( - filenames: List[str], - args: argparse.Namespace, - detector_classes: List[Type[AbstractDetector]], - printer_classes: List[Type[AbstractPrinter]], -) -> Tuple[Slither, List[Dict], List[Dict], int]: - all_contracts: List[str] = [] - - for filename in filenames: - with open(filename, encoding="utf8") as file_open: - contract_loaded = json.load(file_open) - all_contracts.append(contract_loaded["ast"]) - - return process_single(all_contracts, args, detector_classes, printer_classes) - - # endregion ################################################################################### ################################################################################### @@ -608,9 +591,6 @@ def parse_args( default=False, ) - # if the json is splitted in different files - parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False) - # Disable the throw/catch on partial analyses parser.add_argument( "--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False @@ -626,7 +606,7 @@ def parse_args( args.filter_paths = parse_filter_paths(args) # Verify our json-type output is valid - args.json_types = set(args.json_types.split(",")) + args.json_types = set(args.json_types.split(",")) # type:ignore for json_type in args.json_types: if json_type not in JSON_OUTPUT_TYPES: raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.') @@ -697,14 +677,14 @@ class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods class FormatterCryticCompile(logging.Formatter): - def format(self, record): + def format(self, record: logging.LogRecord) -> str: # for i, msg in enumerate(record.msg): if record.msg.startswith("Compilation warnings/errors on "): - txt = record.args[1] - txt = txt.split("\n") + txt = record.args[1] # type:ignore + txt = txt.split("\n") # type:ignore txt = [red(x) if "Error" in x else x for x in txt] txt = "\n".join(txt) - record.args = (record.args[0], txt) + record.args = (record.args[0], txt) # type:ignore return super().format(record) @@ -747,7 +727,7 @@ def main_impl( set_colorization_enabled(False if args.disable_color else sys.stdout.isatty()) # Define some variables for potential JSON output - json_results = {} + json_results: Dict[str, Any] = {} output_error = None outputting_json = args.json is not None outputting_json_stdout = args.json == "-" @@ -796,7 +776,7 @@ def main_impl( crytic_compile_error.setLevel(logging.INFO) results_detectors: List[Dict] = [] - results_printers: List[Dict] = [] + results_printers: List[Output] = [] try: filename = args.filename @@ -809,26 +789,17 @@ def main_impl( number_contracts = 0 slither_instances = [] - if args.splitted: + for filename in filenames: ( slither_instance, - results_detectors, - results_printers, - number_contracts, - ) = process_from_asts(filenames, args, detector_classes, printer_classes) + results_detectors_tmp, + results_printers_tmp, + number_contracts_tmp, + ) = process_single(filename, args, detector_classes, printer_classes) + number_contracts += number_contracts_tmp + results_detectors += results_detectors_tmp + results_printers += results_printers_tmp slither_instances.append(slither_instance) - else: - for filename in filenames: - ( - slither_instance, - results_detectors_tmp, - results_printers_tmp, - number_contracts_tmp, - ) = process_single(filename, args, detector_classes, printer_classes) - number_contracts += number_contracts_tmp - results_detectors += results_detectors_tmp - results_printers += results_printers_tmp - slither_instances.append(slither_instance) # Rely on CryticCompile to discern the underlying type of compilations. else: diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index b2a154672..448ee393a 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -4,6 +4,7 @@ from collections import defaultdict from typing import Union, Set, Dict, TYPE_CHECKING +from slither.core.cfg.node import Node from slither.core.declarations import ( Contract, Enum, @@ -12,6 +13,7 @@ from slither.core.declarations import ( SolidityVariable, SolidityVariableComposed, Structure, + FunctionContract, ) from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.variables.top_level_variable import TopLevelVariable @@ -40,25 +42,37 @@ if TYPE_CHECKING: Variable_types = Union[Variable, SolidityVariable] +# TODO refactor the data deps to be better suited for top level function object +# Right now we allow to pass a node to ease the API, but we need something +# better +# The deps propagation for top level elements is also not working as expected +Context_types_API = Union[Contract, Function, Node] Context_types = Union[Contract, Function] def is_dependent( variable: Variable_types, source: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable (Variable) source (Variable) - context (Contract|Function) + context (Contract|Function|Node). only_unprotected (bool): True only unprotected function are considered Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func + if isinstance(variable, Constant): return False if variable == source: @@ -76,10 +90,13 @@ def is_dependent( def is_dependent_ssa( variable: Variable_types, source: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable (Variable) taint (Variable) @@ -88,7 +105,10 @@ def is_dependent_ssa( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func context_dict = context.context if isinstance(variable, Constant): return False @@ -112,11 +132,14 @@ GENERIC_TAINT = { def is_tainted( variable: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ignore_generic_taint: bool = False, ) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable context (Contract|Function) @@ -124,7 +147,10 @@ def is_tainted( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if isinstance(variable, Constant): return False @@ -139,11 +165,14 @@ def is_tainted( def is_tainted_ssa( variable: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ignore_generic_taint: bool = False, -): +) -> bool: """ + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not + Args: variable context (Contract|Function) @@ -151,7 +180,10 @@ def is_tainted_ssa( Returns: bool """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if isinstance(variable, Constant): return False @@ -166,18 +198,23 @@ def is_tainted_ssa( def get_dependencies( variable: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ) -> Set[Variable]: """ Return the variables for which `variable` depends on. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param variable: The target :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: set(Variable) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set()) @@ -185,16 +222,21 @@ def get_dependencies( def get_all_dependencies( - context: Context_types, only_unprotected: bool = False + context: Context_types_API, only_unprotected: bool = False ) -> Dict[Variable, Set[Variable]]: """ Return the dictionary of dependencies. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: Dict(Variable, set(Variable)) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_NON_SSA_UNPROTECTED] @@ -203,18 +245,23 @@ def get_all_dependencies( def get_dependencies_ssa( variable: Variable_types, - context: Context_types, + context: Context_types_API, only_unprotected: bool = False, ) -> Set[Variable]: """ Return the variables for which `variable` depends on (SSA version). + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param variable: The target (must be SSA variable) :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: set(Variable) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_SSA_UNPROTECTED].get(variable, set()) @@ -222,16 +269,21 @@ def get_dependencies_ssa( def get_all_dependencies_ssa( - context: Context_types, only_unprotected: bool = False + context: Context_types_API, only_unprotected: bool = False ) -> Dict[Variable, Set[Variable]]: """ Return the dictionary of dependencies. + If Node is provided as context, the context will be the broader context, either the contract or the function, + depending on if the node is in a top level function or not :param context: Either a function (interprocedural) or a contract (inter transactional) :param only_unprotected: True if consider only protected functions :return: Dict(Variable, set(Variable)) """ - assert isinstance(context, (Contract, Function)) + assert isinstance(context, (Contract, Function, Node)) + if isinstance(context, Node): + func = context.function + context = func.contract if isinstance(func, FunctionContract) else func assert isinstance(only_unprotected, bool) if only_unprotected: return context.context[KEY_SSA_UNPROTECTED] diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 7643b19b7..82502f889 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -6,7 +6,7 @@ from typing import Optional, List, Set, Dict, Tuple, Union, TYPE_CHECKING from slither.all_exceptions import SlitherException from slither.core.children.child_function import ChildFunction -from slither.core.declarations import Contract, Function +from slither.core.declarations import Contract, Function, FunctionContract from slither.core.declarations.solidity_variables import ( SolidityVariable, SolidityFunction, @@ -33,6 +33,7 @@ from slither.slithir.operations import ( Return, Operation, ) +from slither.slithir.utils.utils import RVALUE from slither.slithir.variables import ( Constant, LocalIRVariable, @@ -146,12 +147,12 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._node_id: int = node_id self._vars_written: List[Variable] = [] - self._vars_read: List[Variable] = [] + self._vars_read: List[Union[Variable, SolidityVariable]] = [] self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = [] - self._internal_calls: List["Function"] = [] + self._internal_calls: List[Union["Function", "SolidityFunction"]] = [] self._solidity_calls: List[SolidityFunction] = [] self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls self._library_calls: List["LibraryCallType"] = [] @@ -172,7 +173,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._local_vars_read: List[LocalVariable] = [] self._local_vars_written: List[LocalVariable] = [] - self._slithir_vars: Set["SlithIRVariable"] = set() # non SSA + self._slithir_vars: Set[ + Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable] + ] = set() # non SSA self._ssa_local_vars_read: List[LocalIRVariable] = [] self._ssa_local_vars_written: List[LocalIRVariable] = [] @@ -213,7 +216,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._node_type @type.setter - def type(self, new_type: NodeType): + def type(self, new_type: NodeType) -> None: self._node_type = new_type @property @@ -232,7 +235,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met ################################################################################### @property - def variables_read(self) -> List[Variable]: + def variables_read(self) -> List[Union[Variable, SolidityVariable]]: """ list(Variable): Variables read (local/state/solidity) """ @@ -285,11 +288,13 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._expression_vars_read @variables_read_as_expression.setter - def variables_read_as_expression(self, exprs: List[Expression]): + def variables_read_as_expression(self, exprs: List[Expression]) -> None: self._expression_vars_read = exprs @property - def slithir_variables(self) -> List["SlithIRVariable"]: + def slithir_variables( + self, + ) -> List[Union["SlithIRVariable", ReferenceVariable, TemporaryVariable, TupleVariable]]: return list(self._slithir_vars) @property @@ -339,7 +344,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._expression_vars_written @variables_written_as_expression.setter - def variables_written_as_expression(self, exprs: List[Expression]): + def variables_written_as_expression(self, exprs: List[Expression]) -> None: self._expression_vars_written = exprs # endregion @@ -399,7 +404,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._external_calls_as_expressions @external_calls_as_expressions.setter - def external_calls_as_expressions(self, exprs: List[Expression]): + def external_calls_as_expressions(self, exprs: List[Expression]) -> None: self._external_calls_as_expressions = exprs @property @@ -410,7 +415,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._internal_calls_as_expressions @internal_calls_as_expressions.setter - def internal_calls_as_expressions(self, exprs: List[Expression]): + def internal_calls_as_expressions(self, exprs: List[Expression]) -> None: self._internal_calls_as_expressions = exprs @property @@ -418,10 +423,10 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return list(self._expression_calls) @calls_as_expression.setter - def calls_as_expression(self, exprs: List[Expression]): + def calls_as_expression(self, exprs: List[Expression]) -> None: self._expression_calls = exprs - def can_reenter(self, callstack=None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Check if the node can re-enter Do not consider CREATE as potential re-enter, but check if the @@ -567,7 +572,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met """ self._fathers.append(father) - def set_fathers(self, fathers: List["Node"]): + def set_fathers(self, fathers: List["Node"]) -> None: """Set the father nodes Args: @@ -663,20 +668,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._irs_ssa @irs_ssa.setter - def irs_ssa(self, irs): + def irs_ssa(self, irs: List[Operation]) -> None: self._irs_ssa = irs def add_ssa_ir(self, ir: Operation) -> None: """ Use to place phi operation """ - ir.set_node(self) + ir.set_node(self) # type: ignore self._irs_ssa.append(ir) def slithir_generation(self) -> None: if self.expression: expression = self.expression - self._irs = convert_expression(expression, self) + self._irs = convert_expression(expression, self) # type:ignore self._find_read_write_call() @@ -713,7 +718,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._dominators @dominators.setter - def dominators(self, dom: Set["Node"]): + def dominators(self, dom: Set["Node"]) -> None: self._dominators = dom @property @@ -725,7 +730,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._immediate_dominator @immediate_dominator.setter - def immediate_dominator(self, idom: "Node"): + def immediate_dominator(self, idom: "Node") -> None: self._immediate_dominator = idom @property @@ -737,7 +742,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met return self._dominance_frontier @dominance_frontier.setter - def dominance_frontier(self, doms: Set["Node"]): + def dominance_frontier(self, doms: Set["Node"]) -> None: """ Returns: set(Node) @@ -789,6 +794,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met def add_phi_origin_local_variable(self, variable: LocalVariable, node: "Node") -> None: if variable.name not in self._phi_origins_local_variables: + assert variable.name self._phi_origins_local_variables[variable.name] = (variable, set()) (v, nodes) = self._phi_origins_local_variables[variable.name] assert v == variable @@ -827,7 +833,8 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met if isinstance(ir, OperationWithLValue): var = ir.lvalue if var and self._is_valid_slithir_var(var): - self._slithir_vars.add(var) + # The type is checked by is_valid_slithir_var + self._slithir_vars.add(var) # type: ignore if not isinstance(ir, (Phi, Index, Member)): self._vars_read += [v for v in ir.read if self._is_non_slithir_var(v)] @@ -835,8 +842,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met if isinstance(var, ReferenceVariable): self._vars_read.append(var.points_to_origin) elif isinstance(ir, (Member, Index)): + # TODO investigate types for member variable left var = ir.variable_left if isinstance(ir, Member) else ir.variable_right - if self._is_non_slithir_var(var): + if var and self._is_non_slithir_var(var): self._vars_read.append(var) if isinstance(var, ReferenceVariable): origin = var.points_to_origin @@ -860,14 +868,21 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._internal_calls.append(ir.function) if isinstance(ir, LowLevelCall): assert isinstance(ir.destination, (Variable, SolidityVariable)) - self._low_level_calls.append((ir.destination, ir.function_name.value)) + self._low_level_calls.append((ir.destination, str(ir.function_name.value))) elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall): + # Todo investigate this if condition + # It does seem right to compare against a contract + # This might need a refactoring if isinstance(ir.destination.type, Contract): self._high_level_calls.append((ir.destination.type, ir.function)) elif ir.destination == SolidityVariable("this"): - self._high_level_calls.append((self.function.contract, ir.function)) + func = self.function + # Can't use this in a top level function + assert isinstance(func, FunctionContract) + self._high_level_calls.append((func.contract, ir.function)) else: try: + # Todo this part needs more tests and documentation self._high_level_calls.append((ir.destination.type.type, ir.function)) except AttributeError as error: # pylint: disable=raise-missing-from @@ -883,7 +898,9 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._vars_read = list(set(self._vars_read)) self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] - self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)] + self._solidity_vars_read = [ + v_ for v_ in self._vars_read if isinstance(v_, SolidityVariable) + ] self._vars_written = list(set(self._vars_written)) self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] @@ -895,12 +912,15 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met @staticmethod def _convert_ssa(v: Variable) -> Optional[Union[StateVariable, LocalVariable]]: + non_ssa_var: Optional[Union[StateVariable, LocalVariable]] if isinstance(v, StateIRVariable): contract = v.contract + assert v.name non_ssa_var = contract.get_state_variable_from_name(v.name) return non_ssa_var assert isinstance(v, LocalIRVariable) function = v.function + assert v.name non_ssa_var = function.get_local_variable_from_name(v.name) return non_ssa_var @@ -921,10 +941,11 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._ssa_vars_read.append(origin) elif isinstance(ir, (Member, Index)): - if isinstance(ir.variable_right, (StateIRVariable, LocalIRVariable)): - self._ssa_vars_read.append(ir.variable_right) - if isinstance(ir.variable_right, ReferenceVariable): - origin = ir.variable_right.points_to_origin + variable_right: RVALUE = ir.variable_right + if isinstance(variable_right, (StateIRVariable, LocalIRVariable)): + self._ssa_vars_read.append(variable_right) + if isinstance(variable_right, ReferenceVariable): + origin = variable_right.points_to_origin if isinstance(origin, (StateIRVariable, LocalIRVariable)): self._ssa_vars_read.append(origin) @@ -944,20 +965,20 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] self._ssa_vars_written = list(set(self._ssa_vars_written)) self._ssa_state_vars_written = [ - v for v in self._ssa_vars_written if isinstance(v, StateVariable) + v for v in self._ssa_vars_written if v and isinstance(v, StateIRVariable) ] self._ssa_local_vars_written = [ - v for v in self._ssa_vars_written if isinstance(v, LocalVariable) + v for v in self._ssa_vars_written if v and isinstance(v, LocalIRVariable) ] vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] - self._vars_read += [v for v in vars_read if v not in self._vars_read] + self._vars_read += [v_ for v_ in vars_read if v_ and v_ not in self._vars_read] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] - self._vars_written += [v for v in vars_written if v not in self._vars_written] + self._vars_written += [v_ for v_ in vars_written if v_ and v_ not in self._vars_written] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] @@ -974,7 +995,7 @@ class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-met additional_info += " " + str(self.expression) elif self.variable_declaration: additional_info += " " + str(self.variable_declaration) - txt = self._node_type.value + additional_info + txt = str(self._node_type.value) + additional_info return txt diff --git a/slither/core/children/child_contract.py b/slither/core/children/child_contract.py index 86f9dea53..2c93d9a51 100644 --- a/slither/core/children/child_contract.py +++ b/slither/core/children/child_contract.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from slither.core.source_mapping.source_mapping import SourceMapping @@ -9,11 +9,14 @@ if TYPE_CHECKING: class ChildContract(SourceMapping): def __init__(self) -> None: super().__init__() - self._contract = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._contract: Optional["Contract"] = None def set_contract(self, contract: "Contract") -> None: self._contract = contract @property def contract(self) -> "Contract": - return self._contract + return self._contract # type: ignore diff --git a/slither/core/children/child_event.py b/slither/core/children/child_event.py index df91596e3..e9a2177c5 100644 --- a/slither/core/children/child_event.py +++ b/slither/core/children/child_event.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.declarations import Event @@ -7,11 +7,14 @@ if TYPE_CHECKING: class ChildEvent: def __init__(self) -> None: super().__init__() - self._event = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._event: Optional["Event"] = None - def set_event(self, event: "Event"): + def set_event(self, event: "Event") -> None: self._event = event @property def event(self) -> "Event": - return self._event + return self._event # type: ignore diff --git a/slither/core/children/child_expression.py b/slither/core/children/child_expression.py index 0064658c0..2294cf384 100644 --- a/slither/core/children/child_expression.py +++ b/slither/core/children/child_expression.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union, Optional if TYPE_CHECKING: from slither.core.expressions.expression import Expression @@ -8,11 +8,16 @@ if TYPE_CHECKING: class ChildExpression: def __init__(self) -> None: super().__init__() - self._expression = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._expression: Optional[Union["Expression", "Operation"]] = None def set_expression(self, expression: Union["Expression", "Operation"]) -> None: + # TODO investigate when this can be an operation? + # It was auto generated during an AST or detectors tests self._expression = expression @property def expression(self) -> Union["Expression", "Operation"]: - return self._expression + return self._expression # type: ignore diff --git a/slither/core/children/child_function.py b/slither/core/children/child_function.py index 5367320ca..d79d12c10 100644 --- a/slither/core/children/child_function.py +++ b/slither/core/children/child_function.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.declarations import Function @@ -7,11 +7,14 @@ if TYPE_CHECKING: class ChildFunction: def __init__(self) -> None: super().__init__() - self._function = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._function: Optional["Function"] = None def set_function(self, function: "Function") -> None: self._function = function @property def function(self) -> "Function": - return self._function + return self._function # type: ignore diff --git a/slither/core/children/child_inheritance.py b/slither/core/children/child_inheritance.py index 30b32f6c1..1ff1a4967 100644 --- a/slither/core/children/child_inheritance.py +++ b/slither/core/children/child_inheritance.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.declarations import Contract @@ -7,11 +7,14 @@ if TYPE_CHECKING: class ChildInheritance: def __init__(self) -> None: super().__init__() - self._contract_declarer = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._contract_declarer: Optional["Contract"] = None def set_contract_declarer(self, contract: "Contract") -> None: self._contract_declarer = contract @property def contract_declarer(self) -> "Contract": - return self._contract_declarer + return self._contract_declarer # type: ignore diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py index 8e6e1f0b5..998ec5ea4 100644 --- a/slither/core/children/child_node.py +++ b/slither/core/children/child_node.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit @@ -9,14 +9,17 @@ if TYPE_CHECKING: class ChildNode: def __init__(self) -> None: super().__init__() - self._node = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._node: Optional["Node"] = None def set_node(self, node: "Node") -> None: self._node = node @property def node(self) -> "Node": - return self._node + return self._node # type:ignore @property def function(self) -> "Function": @@ -24,7 +27,7 @@ class ChildNode: @property def contract(self) -> "Contract": - return self.node.function.contract + return self.node.function.contract # type: ignore @property def compilation_unit(self) -> "SlitherCompilationUnit": diff --git a/slither/core/children/child_structure.py b/slither/core/children/child_structure.py index abcb041c2..413d7f4df 100644 --- a/slither/core/children/child_structure.py +++ b/slither/core/children/child_structure.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from slither.core.declarations import Structure @@ -7,11 +7,14 @@ if TYPE_CHECKING: class ChildStructure: def __init__(self) -> None: super().__init__() - self._structure = None + # TODO remove all the setters for the child objects + # And make it a constructor arguement + # This will remove the optional + self._structure: Optional["Structure"] = None def set_structure(self, structure: "Structure") -> None: self._structure = structure @property def structure(self) -> "Structure": - return self._structure + return self._structure # type: ignore diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 2d2d10b04..d3c016b4b 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -81,7 +81,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods # The only str is "*" self._using_for: Dict[Union[str, Type], List[Type]] = {} - self._using_for_complete: Dict[Union[str, Type], List[Type]] = None + self._using_for_complete: Optional[Dict[Union[str, Type], List[Type]]] = None self._kind: Optional[str] = None self._is_interface: bool = False self._is_library: bool = False @@ -275,7 +275,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive """ - def _merge_using_for(uf1, uf2): + def _merge_using_for(uf1: Dict, uf2: Dict) -> Dict: result = {**uf1, **uf2} for key, value in result.items(): if key in uf1 and key in uf2: @@ -524,14 +524,14 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return list(self._functions.values()) - def available_functions_as_dict(self) -> Dict[str, "FunctionContract"]: + def available_functions_as_dict(self) -> Dict[str, "Function"]: if self._available_functions_as_dict is None: self._available_functions_as_dict = { f.full_name: f for f in self._functions.values() if not f.is_shadowed } return self._available_functions_as_dict - def add_function(self, func: "FunctionContract"): + def add_function(self, func: "FunctionContract") -> None: self._functions[func.canonical_name] = func def set_functions(self, functions: Dict[str, "FunctionContract"]) -> None: @@ -699,7 +699,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods list(Contract): Return the list of contracts derived from self """ candidates = self.compilation_unit.contracts - return [c for c in candidates if self in c.inheritance] + return [c for c in candidates if self in c.inheritance] # type: ignore # endregion ################################################################################### @@ -855,7 +855,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods """ return next((e for e in self.enums if e.name == enum_name), None) - def get_enum_from_canonical_name(self, enum_name) -> Optional["Enum"]: + def get_enum_from_canonical_name(self, enum_name: str) -> Optional["Enum"]: """ Return an enum from a canonical name Args: @@ -956,7 +956,9 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def get_summary(self, include_shadowed=True) -> Tuple[str, List[str], List[str], List, List]: + def get_summary( + self, include_shadowed: bool = True + ) -> Tuple[str, List[str], List[str], List, List]: """Return the function summary :param include_shadowed: boolean to indicate if shadowed functions should be included (default True) @@ -1209,7 +1211,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods @property def is_test(self) -> bool: - return is_test_contract(self) or self.is_truffle_migration + return is_test_contract(self) or self.is_truffle_migration # type: ignore # endregion ################################################################################### @@ -1219,7 +1221,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### def update_read_write_using_ssa(self) -> None: - for function in self.functions + self.modifiers: + for function in self.functions + list(self.modifiers): function.update_read_write_using_ssa() # endregion @@ -1254,7 +1256,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_upgradeable @is_upgradeable.setter - def is_upgradeable(self, upgradeable: bool): + def is_upgradeable(self, upgradeable: bool) -> None: self._is_upgradeable = upgradeable @property @@ -1283,7 +1285,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_upgradeable_proxy @is_upgradeable_proxy.setter - def is_upgradeable_proxy(self, upgradeable_proxy: bool): + def is_upgradeable_proxy(self, upgradeable_proxy: bool) -> None: self._is_upgradeable_proxy = upgradeable_proxy @property @@ -1291,7 +1293,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._upgradeable_version @upgradeable_version.setter - def upgradeable_version(self, version_name: str): + def upgradeable_version(self, version_name: str) -> None: self._upgradeable_version = version_name # endregion @@ -1310,7 +1312,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self._is_incorrectly_parsed @is_incorrectly_constructed.setter - def is_incorrectly_constructed(self, incorrect: bool): + def is_incorrectly_constructed(self, incorrect: bool) -> None: self._is_incorrectly_parsed = incorrect def add_constructor_variables(self) -> None: @@ -1322,8 +1324,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods constructor_variable = FunctionContract(self.compilation_unit) constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) - constructor_variable.set_contract(self) - constructor_variable.set_contract_declarer(self) + constructor_variable.set_contract(self) # type: ignore + constructor_variable.set_contract_declarer(self) # type: ignore constructor_variable.set_visibility("internal") # For now, source mapping of the constructor variable is the whole contract # Could be improved with a targeted source mapping @@ -1354,8 +1356,8 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods constructor_variable.set_function_type( FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES ) - constructor_variable.set_contract(self) - constructor_variable.set_contract_declarer(self) + constructor_variable.set_contract(self) # type: ignore + constructor_variable.set_contract_declarer(self) # type: ignore constructor_variable.set_visibility("internal") # For now, source mapping of the constructor variable is the whole contract # Could be improved with a targeted source mapping @@ -1436,22 +1438,23 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods all_ssa_state_variables_instances[v.canonical_name] = new_var self._initial_state_variables.append(new_var) - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): func.generate_slithir_ssa(all_ssa_state_variables_instances) def fix_phi(self) -> None: - last_state_variables_instances = {} - initial_state_variables_instances = {} + last_state_variables_instances: Dict[str, List["StateVariable"]] = {} + initial_state_variables_instances: Dict[str, "StateVariable"] = {} for v in self._initial_state_variables: last_state_variables_instances[v.canonical_name] = [] initial_state_variables_instances[v.canonical_name] = v - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): result = func.get_last_ssa_state_variables_instances() for variable_name, instances in result.items(): + # TODO: investigate the next operation last_state_variables_instances[variable_name] += instances - for func in self.functions + self.modifiers: + for func in self.functions + list(self.modifiers): func.fix_phi(last_state_variables_instances, initial_state_variables_instances) # endregion @@ -1461,7 +1464,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### ################################################################################### - def __eq__(self, other: SourceMapping) -> bool: + def __eq__(self, other: Any) -> bool: if isinstance(other, str): return other == self.name return NotImplemented @@ -1475,6 +1478,6 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods return self.name def __hash__(self) -> int: - return self._id + return self._id # type:ignore # endregion diff --git a/slither/core/declarations/custom_error.py b/slither/core/declarations/custom_error.py index 5e851c8da..c566fccec 100644 --- a/slither/core/declarations/custom_error.py +++ b/slither/core/declarations/custom_error.py @@ -51,7 +51,7 @@ class CustomError(SourceMapping): return str(t) @property - def solidity_signature(self) -> Optional[str]: + def solidity_signature(self) -> str: """ Return a signature following the Solidity Standard Contract and converted into address @@ -63,7 +63,7 @@ class CustomError(SourceMapping): # (set_solidity_sig was not called before find_variable) if self._solidity_signature is None: raise ValueError("Custom Error not yet built") - return self._solidity_signature + return self._solidity_signature # type: ignore def set_solidity_sig(self) -> None: """ diff --git a/slither/core/declarations/custom_error_contract.py b/slither/core/declarations/custom_error_contract.py index a96f12057..d5b7b6a92 100644 --- a/slither/core/declarations/custom_error_contract.py +++ b/slither/core/declarations/custom_error_contract.py @@ -1,9 +1,14 @@ +from typing import TYPE_CHECKING + from slither.core.children.child_contract import ChildContract from slither.core.declarations.custom_error import CustomError +if TYPE_CHECKING: + from slither.core.declarations import Contract + class CustomErrorContract(CustomError, ChildContract): - def is_declared_by(self, contract): + def is_declared_by(self, contract: "Contract") -> bool: """ Check if the element is declared by the contract :param contract: diff --git a/slither/core/declarations/custom_error_top_level.py b/slither/core/declarations/custom_error_top_level.py index 29a9fd41a..64a6a8535 100644 --- a/slither/core/declarations/custom_error_top_level.py +++ b/slither/core/declarations/custom_error_top_level.py @@ -9,6 +9,6 @@ if TYPE_CHECKING: class CustomErrorTopLevel(CustomError, TopLevel): - def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope"): + def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope") -> None: super().__init__(compilation_unit) self.file_scope: "FileScope" = scope diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index c383fc99b..e77801961 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -47,7 +47,6 @@ if TYPE_CHECKING: from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.scope.scope import FileScope from slither.slithir.variables.state_variable import StateIRVariable - from slither.core.declarations.function_contract import FunctionContract LOGGER = logging.getLogger("Function") ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) @@ -298,7 +297,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def contains_assembly(self, c: bool): self._contains_assembly = c - def can_reenter(self, callstack: Optional[List["FunctionContract"]] = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union["Function", "Variable"]]] = None) -> bool: """ Check if the function can re-enter Follow internal calls. @@ -1720,8 +1719,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu def fix_phi( self, - last_state_variables_instances: Dict[str, List["StateIRVariable"]], - initial_state_variables_instances: Dict[str, "StateIRVariable"], + last_state_variables_instances: Dict[str, List["StateVariable"]], + initial_state_variables_instances: Dict[str, "StateVariable"], ) -> None: from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.variables import Constant, StateIRVariable diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index 9569cde93..f0e903d7b 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -82,7 +82,7 @@ SOLIDITY_FUNCTIONS: Dict[str, List[str]] = { } -def solidity_function_signature(name): +def solidity_function_signature(name: str) -> str: """ Return the function signature (containing the return value) It is useful if a solidity function is used as a pointer @@ -106,7 +106,7 @@ class SolidityVariable(SourceMapping): assert name in SOLIDITY_VARIABLES or name.endswith(("_slot", "_offset")) @property - def state_variable(self): + def state_variable(self) -> str: if self._name.endswith("_slot"): return self._name[:-5] if self._name.endswith("_offset"): @@ -125,7 +125,7 @@ class SolidityVariable(SourceMapping): def __str__(self) -> str: return self._name - def __eq__(self, other: SourceMapping) -> bool: + def __eq__(self, other: Any) -> bool: return self.__class__ == other.__class__ and self.name == other.name def __hash__(self) -> int: @@ -182,13 +182,13 @@ class SolidityFunction(SourceMapping): return self._return_type @return_type.setter - def return_type(self, r: List[Union[TypeInformation, ElementaryType]]): + def return_type(self, r: List[Union[TypeInformation, ElementaryType]]) -> None: self._return_type = r def __str__(self) -> str: return self._name - def __eq__(self, other: "SolidityFunction") -> bool: + def __eq__(self, other: Any) -> bool: return self.__class__ == other.__class__ and self.name == other.name def __hash__(self) -> int: @@ -201,7 +201,7 @@ class SolidityCustomRevert(SolidityFunction): self._custom_error = custom_error self._return_type: List[Union[TypeInformation, ElementaryType]] = [] - def __eq__(self, other: Union["SolidityCustomRevert", SolidityFunction]) -> bool: + def __eq__(self, other: Any) -> bool: return ( self.__class__ == other.__class__ and self.name == other.name diff --git a/slither/core/dominators/utils.py b/slither/core/dominators/utils.py index ca5c51282..4dd55749d 100644 --- a/slither/core/dominators/utils.py +++ b/slither/core/dominators/utils.py @@ -95,4 +95,5 @@ def compute_dominance_frontier(nodes: List["Node"]) -> None: runner.dominance_frontier = runner.dominance_frontier.union({node}) while runner != node.immediate_dominator: runner.dominance_frontier = runner.dominance_frontier.union({node}) + assert runner.immediate_dominator runner = runner.immediate_dominator diff --git a/slither/core/expressions/assignment_operation.py b/slither/core/expressions/assignment_operation.py index 22aba57fb..7b7bc62d6 100644 --- a/slither/core/expressions/assignment_operation.py +++ b/slither/core/expressions/assignment_operation.py @@ -91,7 +91,7 @@ class AssignmentOperation(ExpressionTyped): super().__init__() left_expression.set_lvalue() self._expressions = [left_expression, right_expression] - self._type: Optional["AssignmentOperationType"] = expression_type + self._type: AssignmentOperationType = expression_type self._expression_return_type: Optional["Type"] = expression_return_type @property diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index e5f4e830a..ba2ec802d 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -482,8 +482,8 @@ class SlitherCore(Context): ################################################################################### @property - def crytic_compile(self) -> Optional[CryticCompile]: - return self._crytic_compile + def crytic_compile(self) -> CryticCompile: + return self._crytic_compile # type: ignore # endregion ################################################################################### diff --git a/slither/core/solidity_types/array_type.py b/slither/core/solidity_types/array_type.py index 9a0b12c00..cdb8c10c7 100644 --- a/slither/core/solidity_types/array_type.py +++ b/slither/core/solidity_types/array_type.py @@ -4,11 +4,11 @@ from slither.core.expressions.expression import Expression from slither.core.solidity_types.type import Type from slither.visitors.expression.constants_folding import ConstantFolding from slither.core.expressions.literal import Literal +from slither.core.solidity_types.elementary_type import ElementaryType if TYPE_CHECKING: from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.identifier import Identifier - from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.solidity_types.function_type import FunctionType from slither.core.solidity_types.type_alias import TypeAliasTopLevel @@ -22,7 +22,7 @@ class ArrayType(Type): assert isinstance(t, Type) if length: if isinstance(length, int): - length = Literal(length, "uint256") + length = Literal(length, ElementaryType("uint256")) assert isinstance(length, Expression) super().__init__() self._type: Type = t diff --git a/slither/core/solidity_types/elementary_type.py b/slither/core/solidity_types/elementary_type.py index ec2b0ef04..a9f45c8d8 100644 --- a/slither/core/solidity_types/elementary_type.py +++ b/slither/core/solidity_types/elementary_type.py @@ -1,5 +1,5 @@ import itertools -from typing import Tuple +from typing import Tuple, Optional, Any from slither.core.solidity_types.type import Type @@ -176,7 +176,7 @@ class ElementaryType(Type): return self.type @property - def size(self) -> int: + def size(self) -> Optional[int]: """ Return the size in bits Return None if the size is not known @@ -219,7 +219,7 @@ class ElementaryType(Type): def __str__(self) -> str: return self._type - def __eq__(self, other) -> bool: + def __eq__(self, other: Any) -> bool: if not isinstance(other, ElementaryType): return False return self.type == other.type diff --git a/slither/core/solidity_types/mapping_type.py b/slither/core/solidity_types/mapping_type.py index a8acb4d9c..9741569ed 100644 --- a/slither/core/solidity_types/mapping_type.py +++ b/slither/core/solidity_types/mapping_type.py @@ -1,4 +1,4 @@ -from typing import Union, Tuple, TYPE_CHECKING +from typing import Union, Tuple, TYPE_CHECKING, Any from slither.core.solidity_types.type import Type @@ -38,7 +38,7 @@ class MappingType(Type): def __str__(self) -> str: return f"mapping({str(self._from)} => {str(self._to)})" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, MappingType): return False return self.type_from == other.type_from and self.type_to == other.type_to diff --git a/slither/core/solidity_types/type_alias.py b/slither/core/solidity_types/type_alias.py index 5b9ea0a37..1da2d4182 100644 --- a/slither/core/solidity_types/type_alias.py +++ b/slither/core/solidity_types/type_alias.py @@ -40,7 +40,7 @@ class TypeAlias(Type): class TypeAliasTopLevel(TypeAlias, TopLevel): - def __init__(self, underlying_type: Type, name: str, scope: "FileScope") -> None: + def __init__(self, underlying_type: ElementaryType, name: str, scope: "FileScope") -> None: super().__init__(underlying_type, name) self.file_scope: "FileScope" = scope @@ -49,7 +49,7 @@ class TypeAliasTopLevel(TypeAlias, TopLevel): class TypeAliasContract(TypeAlias, ChildContract): - def __init__(self, underlying_type: Type, name: str, contract: "Contract") -> None: + def __init__(self, underlying_type: ElementaryType, name: str, contract: "Contract") -> None: super().__init__(underlying_type, name) self._contract: "Contract" = contract diff --git a/slither/core/solidity_types/type_information.py b/slither/core/solidity_types/type_information.py index 2af0b097a..9cef9352c 100644 --- a/slither/core/solidity_types/type_information.py +++ b/slither/core/solidity_types/type_information.py @@ -1,4 +1,4 @@ -from typing import Union, TYPE_CHECKING, Tuple +from typing import Union, TYPE_CHECKING, Tuple, Any from slither.core.solidity_types import ElementaryType from slither.core.solidity_types.type import Type @@ -40,10 +40,10 @@ class TypeInformation(Type): def is_dynamic(self) -> bool: raise NotImplementedError - def __str__(self): + def __str__(self) -> str: return f"type({self.type.name})" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, TypeInformation): return False return self.type == other.type diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index a0fcf354a..4c8742b22 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -1,6 +1,6 @@ import re from abc import ABCMeta -from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional +from typing import Dict, Union, List, Tuple, TYPE_CHECKING, Optional, Any from Crypto.Hash import SHA1 from crytic_compile.utils.naming import Filename @@ -102,10 +102,10 @@ class Source: filename_short: str = self.filename.short if self.filename.short else "" return f"{filename_short}{lines}" - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, type(self)): return NotImplemented return ( diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index f3ad60d0b..e191433dc 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -16,5 +16,5 @@ class EventVariable(ChildEvent, Variable): return self._indexed @indexed.setter - def indexed(self, is_indexed: bool): + def indexed(self, is_indexed: bool) -> bool: self._indexed = is_indexed diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index 8607a8921..c775e7c98 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -55,7 +55,7 @@ class Variable(SourceMapping): return self._initialized @initialized.setter - def initialized(self, is_init: bool): + def initialized(self, is_init: bool) -> None: self._initialized = is_init @property @@ -73,7 +73,7 @@ class Variable(SourceMapping): return self._name @name.setter - def name(self, name): + def name(self, name: str) -> None: self._name = name @property @@ -89,7 +89,7 @@ class Variable(SourceMapping): return self._is_constant @is_constant.setter - def is_constant(self, is_cst: bool): + def is_constant(self, is_cst: bool) -> None: self._is_constant = is_cst @property diff --git a/slither/detectors/abstract_detector.py b/slither/detectors/abstract_detector.py index 8e2dd490d..59f8ca3a0 100644 --- a/slither/detectors/abstract_detector.py +++ b/slither/detectors/abstract_detector.py @@ -59,6 +59,8 @@ ALL_SOLC_VERSIONS_06 = make_solc_versions(6, 0, 12) ALL_SOLC_VERSIONS_07 = make_solc_versions(7, 0, 6) # No VERSIONS_08 as it is still in dev +DETECTOR_INFO = Union[str, List[Union[str, SupportedOutput]]] + class AbstractDetector(metaclass=abc.ABCMeta): ARGUMENT = "" # run the detector with slither.py --ARGUMENT @@ -251,7 +253,7 @@ class AbstractDetector(metaclass=abc.ABCMeta): def generate_result( self, - info: Union[str, List[Union[str, SupportedOutput]]], + info: DETECTOR_INFO, additional_fields: Optional[Dict] = None, ) -> Output: output = Output( diff --git a/slither/detectors/assembly/shift_parameter_mixup.py b/slither/detectors/assembly/shift_parameter_mixup.py index 31dad2371..a4169499a 100644 --- a/slither/detectors/assembly/shift_parameter_mixup.py +++ b/slither/detectors/assembly/shift_parameter_mixup.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.slithir.operations import Binary, BinaryType from slither.slithir.variables import Constant from slither.core.declarations.function_contract import FunctionContract @@ -49,7 +53,12 @@ The shift statement will right-shift the constant 8 by `a` bits""" BinaryType.RIGHT_SHIFT, ]: if isinstance(ir.variable_left, Constant): - info = [f, " contains an incorrect shift operation: ", node, "\n"] + info: DETECTOR_INFO = [ + f, + " contains an incorrect shift operation: ", + node, + "\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/attributes/const_functions_asm.py b/slither/detectors/attributes/const_functions_asm.py index e3a938361..6de5062d8 100644 --- a/slither/detectors/attributes/const_functions_asm.py +++ b/slither/detectors/attributes/const_functions_asm.py @@ -7,6 +7,7 @@ from slither.detectors.abstract_detector import ( AbstractDetector, DetectorClassification, ALL_SOLC_VERSIONS_04, + DETECTOR_INFO, ) from slither.formatters.attributes.const_functions import custom_format from slither.utils.output import Output @@ -73,7 +74,10 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" if f.contains_assembly: attr = "view" if f.view else "pure" - info = [f, f" is declared {attr} but contains assembly code\n"] + info: DETECTOR_INFO = [ + f, + f" is declared {attr} but contains assembly code\n", + ] res = self.generate_result(info, {"contains_assembly": True}) results.append(res) diff --git a/slither/detectors/compiler_bugs/array_by_reference.py b/slither/detectors/compiler_bugs/array_by_reference.py index 83ed69b9b..fffe93847 100644 --- a/slither/detectors/compiler_bugs/array_by_reference.py +++ b/slither/detectors/compiler_bugs/array_by_reference.py @@ -105,7 +105,12 @@ As a result, Bob's usage of the contract is incorrect.""" write to the array unsuccessfully. """ # Define our resulting array. - results = [] + results: List[ + Union[ + Tuple[Node, StateVariable, FunctionContract], + Tuple[Node, LocalVariable, FunctionContract], + ] + ] = [] # Verify we have functions in our list to check for. if not array_modifying_funcs: diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20.py b/slither/detectors/erc/erc20/arbitrary_send_erc20.py index 17b1fba30..f06005459 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20.py @@ -61,12 +61,12 @@ class ArbitrarySendErc20: is_dependent( ir.arguments[0], SolidityVariableComposed("msg.sender"), - node.function.contract, + node, ) or is_dependent( ir.arguments[0], SolidityVariable("this"), - node.function.contract, + node, ) ) ): @@ -79,12 +79,12 @@ class ArbitrarySendErc20: is_dependent( ir.arguments[1], SolidityVariableComposed("msg.sender"), - node.function.contract, + node, ) or is_dependent( ir.arguments[1], SolidityVariable("this"), - node.function.contract, + node, ) ) ): diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py b/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py index f43b6302e..351f1dcfa 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20_no_permit.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from .arbitrary_send_erc20 import ArbitrarySendErc20 @@ -38,7 +42,7 @@ Use `msg.sender` as `from` in transferFrom. arbitrary_sends.detect() for node in arbitrary_sends.no_permit_results: func = node.function - info = [func, " uses arbitrary from in transferFrom: ", node, "\n"] + info: DETECTOR_INFO = [func, " uses arbitrary from in transferFrom: ", node, "\n"] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py b/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py index 1d311c442..ca4c4a793 100644 --- a/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py +++ b/slither/detectors/erc/erc20/arbitrary_send_erc20_permit.py @@ -1,5 +1,9 @@ from typing import List -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output from .arbitrary_send_erc20 import ArbitrarySendErc20 @@ -41,7 +45,7 @@ Ensure that the underlying ERC20 token correctly implements a permit function. arbitrary_sends.detect() for node in arbitrary_sends.permit_results: func = node.function - info = [ + info: DETECTOR_INFO = [ func, " uses arbitrary from in transferFrom in combination with permit: ", node, diff --git a/slither/detectors/functions/arbitrary_send_eth.py b/slither/detectors/functions/arbitrary_send_eth.py index 390b1f2ab..e0112c6e1 100644 --- a/slither/detectors/functions/arbitrary_send_eth.py +++ b/slither/detectors/functions/arbitrary_send_eth.py @@ -39,6 +39,10 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: ret: List[Node] = [] for node in func.nodes: + func = node.function + deps_target: Union[Contract, Function] = ( + func.contract if isinstance(func, FunctionContract) else func + ) for ir in node.irs: if isinstance(ir, SolidityCall): if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"): @@ -49,7 +53,7 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: if is_dependent( ir.variable_right, SolidityVariableComposed("msg.sender"), - func.contract, + deps_target, ): return False if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): @@ -64,11 +68,11 @@ def arbitrary_send(func: Function) -> Union[bool, List[Node]]: if is_dependent( ir.call_value, SolidityVariableComposed("msg.value"), - func.contract, + node, ): continue - if is_tainted(ir.destination, func.contract): + if is_tainted(ir.destination, node): ret.append(node) return ret diff --git a/slither/detectors/statements/array_length_assignment.py b/slither/detectors/statements/array_length_assignment.py index 51302a2c9..f4ae01b88 100644 --- a/slither/detectors/statements/array_length_assignment.py +++ b/slither/detectors/statements/array_length_assignment.py @@ -7,6 +7,7 @@ from slither.detectors.abstract_detector import ( DetectorClassification, ALL_SOLC_VERSIONS_04, ALL_SOLC_VERSIONS_05, + DETECTOR_INFO, ) from slither.core.cfg.node import Node, NodeType from slither.slithir.operations import Assignment, Length @@ -120,7 +121,7 @@ Otherwise, thoroughly review the contract to ensure a user-controlled variable c for contract in self.contracts: array_length_assignments = detect_array_length_assignment(contract) if array_length_assignments: - contract_info = [ + contract_info: DETECTOR_INFO = [ contract, " contract sets array length with a user-controlled value:\n", ] diff --git a/slither/detectors/statements/assembly.py b/slither/detectors/statements/assembly.py index 2c0c49f09..25b5d8034 100644 --- a/slither/detectors/statements/assembly.py +++ b/slither/detectors/statements/assembly.py @@ -6,7 +6,11 @@ from typing import List, Tuple from slither.core.cfg.node import Node, NodeType from slither.core.declarations.contract import Contract from slither.core.declarations.function_contract import FunctionContract -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.detectors.abstract_detector import ( + AbstractDetector, + DetectorClassification, + DETECTOR_INFO, +) from slither.utils.output import Output @@ -52,7 +56,7 @@ class Assembly(AbstractDetector): for c in self.contracts: values = self.detect_assembly(c) for func, nodes in values: - info = [func, " uses assembly\n"] + info: DETECTOR_INFO = [func, " uses assembly\n"] # sort the nodes to get deterministic results nodes.sort(key=lambda x: x.node_id) diff --git a/slither/slithir/operations/call.py b/slither/slithir/operations/call.py index 07304fa99..37a2fe0b3 100644 --- a/slither/slithir/operations/call.py +++ b/slither/slithir/operations/call.py @@ -1,5 +1,7 @@ -from typing import Optional, List +from typing import Optional, List, Union +from slither.core.declarations import Function +from slither.core.variables import Variable from slither.slithir.operations.operation import Operation @@ -16,7 +18,8 @@ class Call(Operation): def arguments(self, v): self._arguments = v - def can_reenter(self, _callstack: Optional[List] = None) -> bool: # pylint: disable=no-self-use + # pylint: disable=no-self-use + def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index 93fb73bd4..d707e11b3 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union +from slither.core.declarations import Contract from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.variables.variable import Variable @@ -32,7 +33,8 @@ class HighLevelCall(Call, OperationWithLValue): assert is_valid_lvalue(result) or result is None self._check_destination(destination) super().__init__() - self._destination = destination + # Contract is only possible for library call, which inherits from highlevelcall + self._destination: Union[Variable, SolidityVariable, Contract] = destination # type: ignore self._function_name = function_name self._nbr_arguments = nbr_arguments self._type_call = type_call @@ -44,8 +46,9 @@ class HighLevelCall(Call, OperationWithLValue): self._call_gas = None # Development function, to be removed once the code is stable - # It is ovveride by LbraryCall - def _check_destination(self, destination: SourceMapping) -> None: # pylint: disable=no-self-use + # It is overridden by LibraryCall + # pylint: disable=no-self-use + def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None: assert isinstance(destination, (Variable, SolidityVariable)) @property @@ -79,7 +82,14 @@ class HighLevelCall(Call, OperationWithLValue): return [x for x in all_read if x] + [self.destination] @property - def destination(self) -> SourceMapping: + def destination(self) -> Union[Variable, SolidityVariable, Contract]: + """ + Return a variable or a solidityVariable + Contract is only possible for LibraryCall + + Returns: + + """ return self._destination @property @@ -116,7 +126,7 @@ class HighLevelCall(Call, OperationWithLValue): return True return False - def can_reenter(self, callstack: None = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass For Solidity > 0.5, filter access to public variables and constant/pure/view diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py index ade84fe1d..77daa9462 100644 --- a/slither/slithir/operations/index.py +++ b/slither/slithir/operations/index.py @@ -1,20 +1,20 @@ from typing import List, Union + from slither.core.declarations import SolidityVariableComposed -from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue -from slither.slithir.variables.reference import ReferenceVariable from slither.core.solidity_types.elementary_type import ElementaryType from slither.core.source_mapping.source_mapping import SourceMapping from slither.core.variables.variable import Variable -from slither.slithir.variables.reference_ssa import ReferenceVariableSSA +from slither.slithir.operations.lvalue import OperationWithLValue +from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue, RVALUE, LVALUE +from slither.slithir.variables.reference import ReferenceVariable class Index(OperationWithLValue): def __init__( self, - result: Union[ReferenceVariable, ReferenceVariableSSA], + result: ReferenceVariable, left_variable: Variable, - right_variable: SourceMapping, + right_variable: RVALUE, index_type: Union[ElementaryType, str], ) -> None: super().__init__() @@ -25,23 +25,23 @@ class Index(OperationWithLValue): assert isinstance(result, ReferenceVariable) self._variables = [left_variable, right_variable] self._type = index_type - self._lvalue = result + self._lvalue: ReferenceVariable = result @property def read(self) -> List[SourceMapping]: return list(self.variables) @property - def variables(self) -> List[SourceMapping]: - return self._variables + def variables(self) -> List[Union[LVALUE, RVALUE, SolidityVariableComposed]]: + return self._variables # type: ignore @property - def variable_left(self) -> Variable: - return self._variables[0] + def variable_left(self) -> Union[LVALUE, SolidityVariableComposed]: + return self._variables[0] # type: ignore @property - def variable_right(self) -> SourceMapping: - return self._variables[1] + def variable_right(self) -> RVALUE: + return self._variables[1] # type: ignore @property def index_type(self) -> Union[ElementaryType, str]: diff --git a/slither/slithir/operations/library_call.py b/slither/slithir/operations/library_call.py index ebe9bf5ef..1b7f4e8a6 100644 --- a/slither/slithir/operations/library_call.py +++ b/slither/slithir/operations/library_call.py @@ -1,4 +1,7 @@ -from slither.core.declarations import Function +from typing import Union, Optional, List + +from slither.core.declarations import Function, SolidityVariable +from slither.core.variables import Variable from slither.slithir.operations.high_level_call import HighLevelCall from slither.core.declarations.contract import Contract @@ -9,10 +12,10 @@ class LibraryCall(HighLevelCall): """ # Development function, to be removed once the code is stable - def _check_destination(self, destination: Contract) -> None: + def _check_destination(self, destination: Union[Variable, SolidityVariable, Contract]) -> None: assert isinstance(destination, Contract) - def can_reenter(self, callstack: None = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool @@ -20,11 +23,11 @@ class LibraryCall(HighLevelCall): if self.is_static_call(): return False # In case of recursion, return False - callstack = [] if callstack is None else callstack - if self.function in callstack: + callstack_local = [] if callstack is None else callstack + if self.function in callstack_local: return False - callstack = callstack + [self.function] - return self.function.can_reenter(callstack) + callstack_local = callstack_local + [self.function] + return self.function.can_reenter(callstack_local) def __str__(self): gas = "" diff --git a/slither/slithir/operations/low_level_call.py b/slither/slithir/operations/low_level_call.py index 7e8c278b8..eac779d27 100644 --- a/slither/slithir/operations/low_level_call.py +++ b/slither/slithir/operations/low_level_call.py @@ -1,4 +1,6 @@ -from typing import List, Union +from typing import List, Union, Optional + +from slither.core.declarations import Function from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.variables.variable import Variable @@ -74,7 +76,7 @@ class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-insta # remove None return self._unroll([x for x in all_read if x]) - def can_reenter(self, _callstack: None = None) -> bool: + def can_reenter(self, _callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/lvalue.py b/slither/slithir/operations/lvalue.py index d9b800c92..b983d1c5d 100644 --- a/slither/slithir/operations/lvalue.py +++ b/slither/slithir/operations/lvalue.py @@ -1,4 +1,6 @@ -from typing import Any, List +from typing import Any, List, Optional + +from slither.core.variables import Variable from slither.slithir.operations.operation import Operation @@ -10,16 +12,16 @@ class OperationWithLValue(Operation): def __init__(self) -> None: super().__init__() - self._lvalue = None + self._lvalue: Optional[Variable] = None @property - def lvalue(self): + def lvalue(self) -> Optional[Variable]: return self._lvalue - @property - def used(self) -> List[Any]: - return self.read + [self.lvalue] - @lvalue.setter - def lvalue(self, lvalue): + def lvalue(self, lvalue: Variable) -> None: self._lvalue = lvalue + + @property + def used(self) -> List[Optional[Any]]: + return self.read + [self.lvalue] diff --git a/slither/slithir/operations/member.py b/slither/slithir/operations/member.py index 9a561ea87..0942813cf 100644 --- a/slither/slithir/operations/member.py +++ b/slither/slithir/operations/member.py @@ -5,7 +5,7 @@ from slither.core.declarations.enum import Enum from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder from slither.core.solidity_types import ElementaryType from slither.slithir.operations.lvalue import OperationWithLValue -from slither.slithir.utils.utils import is_valid_rvalue +from slither.slithir.utils.utils import is_valid_rvalue, RVALUE from slither.slithir.variables.constant import Constant from slither.slithir.variables.reference import ReferenceVariable from slither.core.source_mapping.source_mapping import SourceMapping @@ -39,7 +39,9 @@ class Member(OperationWithLValue): assert isinstance(variable_right, Constant) assert isinstance(result, ReferenceVariable) super().__init__() - self._variable_left = variable_left + self._variable_left: Union[ + RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType + ] = variable_left self._variable_right = variable_right self._lvalue = result self._gas = None @@ -50,7 +52,11 @@ class Member(OperationWithLValue): return [self.variable_left, self.variable_right] @property - def variable_left(self) -> SourceMapping: + def variable_left( + self, + ) -> Union[ + RVALUE, Contract, Enum, Function, CustomError, SolidityImportPlaceHolder, ElementaryType + ]: return self._variable_left @property diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py index 879d12df6..8d3c949df 100644 --- a/slither/slithir/operations/new_contract.py +++ b/slither/slithir/operations/new_contract.py @@ -1,11 +1,13 @@ from typing import Optional, Any, List, Union + +from slither.core.declarations import Function +from slither.core.declarations.contract import Contract +from slither.core.variables import Variable from slither.slithir.operations import Call, OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant -from slither.core.declarations.contract import Contract from slither.slithir.variables.temporary import TemporaryVariable from slither.slithir.variables.temporary_ssa import TemporaryVariableSSA -from slither.core.declarations.function_contract import FunctionContract class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes @@ -58,6 +60,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan def contract_created(self) -> Contract: contract_name = self.contract_name contract_instance = self.node.file_scope.get_contract_from_name(contract_name) + assert contract_instance return contract_instance ################################################################################### @@ -66,7 +69,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan ################################################################################### ################################################################################### - def can_reenter(self, callstack: Optional[List[FunctionContract]] = None) -> bool: + def can_reenter(self, callstack: Optional[List[Union[Function, Variable]]] = None) -> bool: """ Must be called after slithIR analysis pass For Solidity > 0.5, filter access to public variables and constant/pure/view @@ -92,7 +95,7 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan # endregion - def __str__(self): + def __str__(self) -> str: options = "" if self.call_value: options = f"value:{self.call_value} " diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py index b059c55a6..88430f934 100644 --- a/slither/slithir/operations/solidity_call.py +++ b/slither/slithir/operations/solidity_call.py @@ -1,15 +1,16 @@ from typing import Any, List, Union -from slither.core.declarations.solidity_variables import SolidityCustomRevert, SolidityFunction -from slither.slithir.operations.call import Call -from slither.slithir.operations.lvalue import OperationWithLValue + from slither.core.children.child_node import ChildNode +from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.solidity_types.elementary_type import ElementaryType +from slither.slithir.operations.call import Call +from slither.slithir.operations.lvalue import OperationWithLValue class SolidityCall(Call, OperationWithLValue): def __init__( self, - function: Union[SolidityCustomRevert, SolidityFunction], + function: SolidityFunction, nbr_arguments: int, result: ChildNode, type_call: Union[str, List[ElementaryType]], @@ -26,7 +27,7 @@ class SolidityCall(Call, OperationWithLValue): return self._unroll(self.arguments) @property - def function(self) -> Union[SolidityCustomRevert, SolidityFunction]: + def function(self) -> SolidityFunction: return self._function @property diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index 0a50f8e50..a0ca0bd6f 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -1,3 +1,5 @@ +from typing import Union + from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable @@ -10,6 +12,24 @@ from slither.slithir.variables.reference import ReferenceVariable from slither.slithir.variables.tuple import TupleVariable from slither.core.source_mapping.source_mapping import SourceMapping +RVALUE = Union[ + StateVariable, + LocalVariable, + TopLevelVariable, + TemporaryVariable, + Constant, + SolidityVariable, + ReferenceVariable, +] + +LVALUE = Union[ + StateVariable, + LocalVariable, + TemporaryVariable, + ReferenceVariable, + TupleVariable, +] + def is_valid_rvalue(v: SourceMapping) -> bool: return isinstance( diff --git a/slither/tools/mutator/__main__.py b/slither/tools/mutator/__main__.py index 27e396d0b..84286ce66 100644 --- a/slither/tools/mutator/__main__.py +++ b/slither/tools/mutator/__main__.py @@ -79,9 +79,10 @@ def main() -> None: print(args.codebase) sl = Slither(args.codebase, **vars(args)) - for M in _get_mutators(): - m = M(sl) - m.mutate() + for compilation_unit in sl.compilation_units: + for M in _get_mutators(): + m = M(compilation_unit) + m.mutate() # endregion diff --git a/slither/tools/mutator/mutators/abstract_mutator.py b/slither/tools/mutator/mutators/abstract_mutator.py index 850c3c399..169d8725e 100644 --- a/slither/tools/mutator/mutators/abstract_mutator.py +++ b/slither/tools/mutator/mutators/abstract_mutator.py @@ -3,7 +3,7 @@ import logging from enum import Enum from typing import Optional, Dict -from slither import Slither +from slither.core.compilation_unit import SlitherCompilationUnit from slither.formatters.utils.patches import apply_patch, create_diff logger = logging.getLogger("Slither") @@ -34,8 +34,11 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public- FAULTCLASS = FaultClass.Undefined FAULTNATURE = FaultNature.Undefined - def __init__(self, slither: Slither, rate: int = 10, seed: Optional[int] = None): - self.slither = slither + def __init__( + self, compilation_unit: SlitherCompilationUnit, rate: int = 10, seed: Optional[int] = None + ): + self.compilation_unit = compilation_unit + self.slither = compilation_unit.core self.seed = seed self.rate = rate @@ -87,7 +90,7 @@ class AbstractMutator(metaclass=abc.ABCMeta): # pylint: disable=too-few-public- continue for patch in patches: patched_txt, offset = apply_patch(patched_txt, patch, offset) - diff = create_diff(self.slither, original_txt, patched_txt, file) + diff = create_diff(self.compilation_unit, original_txt, patched_txt, file) if not diff: logger.info(f"Impossible to generate patch; empty {patches}") print(diff)