diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index f8d8c8158..dbc8c3921 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -127,10 +127,10 @@ class Function(metaclass=ABCMeta): # pylint: disable=too-many-public-methods self._slithir_variables: Set["SlithIRVariable"] = set() self._parameters: List["LocalVariable"] = [] self._parameters_ssa: List["LocalIRVariable"] = [] - self._parameters_src: Optional[SourceMapping] = None + self._parameters_src: SourceMapping = SourceMapping() self._returns: List["LocalVariable"] = [] self._returns_ssa: List["LocalIRVariable"] = [] - self._returns_src: Optional[SourceMapping] = None + self._returns_src: SourceMapping = SourceMapping() self._return_values: Optional[List["SlithIRVariable"]] = None self._return_values_ssa: Optional[List["SlithIRVariable"]] = None self._vars_read: List["Variable"] = [] @@ -570,6 +570,9 @@ class Function(metaclass=ABCMeta): # pylint: disable=too-many-public-methods def add_parameter_ssa(self, var: "LocalIRVariable"): self._parameters_ssa.append(var) + def parameters_src(self) -> SourceMapping: + return self._parameters_src + # endregion ################################################################################### ################################################################################### @@ -588,6 +591,9 @@ class Function(metaclass=ABCMeta): # pylint: disable=too-many-public-methods return [r.type for r in returns] return None + def returns_src(self) -> SourceMapping: + return self._returns_src + @property def type(self) -> Optional[List[Type]]: """ diff --git a/slither/formatters/attributes/const_functions.py b/slither/formatters/attributes/const_functions.py index 282e9077c..95fc6d8c8 100644 --- a/slither/formatters/attributes/const_functions.py +++ b/slither/formatters/attributes/const_functions.py @@ -22,10 +22,10 @@ def custom_format(slither, result): result, element["source_mapping"]["filename_absolute"], int( - function.parameters_src.source_mapping["start"] - + function.parameters_src.source_mapping["length"] + function.parameters_src().source_mapping["start"] + + function.parameters_src().source_mapping["length"] ), - int(function.returns_src.source_mapping["start"]), + int(function.returns_src().source_mapping["start"]), ) diff --git a/slither/formatters/functions/external_function.py b/slither/formatters/functions/external_function.py index da557e41e..dd2eaf362 100644 --- a/slither/formatters/functions/external_function.py +++ b/slither/formatters/functions/external_function.py @@ -17,8 +17,8 @@ def custom_format(slither, result): slither, result, element["source_mapping"]["filename_absolute"], - int(function.parameters_src.source_mapping["start"]), - int(function.returns_src.source_mapping["start"]), + int(function.parameters_src().source_mapping["start"]), + int(function.returns_src().source_mapping["start"]), ) diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index e437bdda3..4c21e72f1 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -26,7 +26,6 @@ from slither.utils.expression_manipulations import SplitTernaryExpression from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.has_conditional import HasConditional from slither.solc_parsing.exceptions import ParsingError -from slither.core.source_mapping.source_mapping import SourceMapping if TYPE_CHECKING: from slither.core.expressions.expression import Expression @@ -85,9 +84,6 @@ class FunctionSolc: self._analyze_type() - self.parameters_src = SourceMapping() - self.returns_src = SourceMapping() - self._node_to_nodesolc: Dict[Node, NodeSolc] = dict() self._node_to_yulobject: Dict[Node, YulBlock] = dict() @@ -1089,7 +1085,7 @@ class FunctionSolc: def _parse_params(self, params: Dict): assert params[self.get_key()] == "ParameterList" - self.parameters_src.set_offset(params["src"], self._function.slither) + self._function.parameters_src().set_offset(params["src"], self._function.slither) if self.is_compact_ast: params = params["parameters"] @@ -1105,7 +1101,7 @@ class FunctionSolc: assert returns[self.get_key()] == "ParameterList" - self.returns_src.set_offset(returns["src"], self._function.slither) + self._function.returns_src().set_offset(returns["src"], self._function.slither) if self.is_compact_ast: returns = returns["parameters"] diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index 7790f9af1..2e3d6c650 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -5,9 +5,8 @@ from enum import Enum as PythonEnum from pathlib import Path from typing import List, Set, Dict, Optional -from slither.core.declarations import SolidityFunction, Enum +from slither.core.declarations import SolidityFunction, EnumContract, StructureContract from slither.core.declarations.contract import Contract -from slither.core.declarations.structure import Structure from slither.core.solidity_types import MappingType, ArrayType from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.exceptions import SlitherException @@ -100,10 +99,10 @@ class Flattening: continue if f.visibility == "external": attributes_start = ( - f.parameters_src.source_mapping["start"] - + f.parameters_src.source_mapping["length"] + f.parameters_src().source_mapping["start"] + + f.parameters_src().source_mapping["length"] ) - attributes_end = f.returns_src.source_mapping["start"] + attributes_end = f.returns_src().source_mapping["start"] attributes = content[attributes_start:attributes_end] regex = re.search(r"((\sexternal)\s+)|(\sexternal)$|(\)external)$", attributes) if regex: @@ -197,7 +196,7 @@ class Flattening: def _export_from_type(self, t, contract, exported, list_contract): if isinstance(t, UserDefinedType): - if isinstance(t.type, (Enum, Structure)): + if isinstance(t.type, (EnumContract, StructureContract)): if t.type.contract != contract and t.type.contract not in exported: self._export_list_used_contracts(t.type.contract, exported, list_contract) else: