diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index 67b3c00a3..7603f5e93 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -11,6 +11,7 @@ from slither.core.declarations import SolidityFunction, EnumContract, StructureC from slither.core.declarations.contract import Contract from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.top_level import TopLevel +from slither.core.declarations.solidity_variables import SolidityCustomRevert from slither.core.solidity_types import MappingType, ArrayType from slither.core.solidity_types.type import Type from slither.core.solidity_types.user_defined_type import UserDefinedType @@ -24,6 +25,7 @@ from slither.tools.flattening.export.export import ( ) logger = logging.getLogger("Slither-flattening") +logger.setLevel(logging.INFO) # index: where to start # patch_type: @@ -75,6 +77,7 @@ class Flattening: self._get_source_code_top_level(compilation_unit.structures_top_level) self._get_source_code_top_level(compilation_unit.enums_top_level) + self._get_source_code_top_level(compilation_unit.custom_errors) self._get_source_code_top_level(compilation_unit.variables_top_level) self._get_source_code_top_level(compilation_unit.functions_top_level) @@ -249,12 +252,14 @@ class Flattening: t: Type, contract: Contract, exported: Set[str], - list_contract: List[Contract], - list_top_level: List[TopLevel], + list_contract: Set[Contract], + list_top_level: Set[TopLevel], ): if isinstance(t, UserDefinedType): t_type = t.type - if isinstance(t_type, (EnumContract, StructureContract)): + if isinstance(t_type, TopLevel): + list_top_level.add(t_type) + elif isinstance(t_type, (EnumContract, StructureContract)): if t_type.contract != contract and t_type.contract not in exported: self._export_list_used_contracts( t_type.contract, exported, list_contract, list_top_level @@ -275,8 +280,8 @@ class Flattening: self, contract: Contract, exported: Set[str], - list_contract: List[Contract], - list_top_level: List[TopLevel], + list_contract: Set[Contract], + list_top_level: Set[TopLevel], ): # TODO: investigate why this happen if not isinstance(contract, Contract): @@ -332,19 +337,21 @@ class Flattening: for read in ir.read: if isinstance(read, TopLevel): - if read not in list_top_level: - list_top_level.append(read) - if isinstance(ir, InternalCall): - function_called = ir.function - if isinstance(function_called, FunctionTopLevel): - list_top_level.append(function_called) - - if contract not in list_contract: - list_contract.append(contract) + list_top_level.add(read) + if isinstance(ir, InternalCall) and isinstance(ir.function, FunctionTopLevel): + list_top_level.add(ir.function) + if ( + isinstance(ir, SolidityCall) + and isinstance(ir.function, SolidityCustomRevert) + and isinstance(ir.function.custom_error, TopLevel) + ): + list_top_level.add(ir.function.custom_error) + + list_contract.add(contract) def _export_contract_with_inheritance(self, contract) -> Export: - list_contracts: List[Contract] = [] # will contain contract itself - list_top_level: List[TopLevel] = [] + list_contracts: Set[Contract] = set() # will contain contract itself + list_top_level: Set[TopLevel] = set() self._export_list_used_contracts(contract, set(), list_contracts, list_top_level) path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol") @@ -401,8 +408,8 @@ class Flattening: def _export_with_import(self) -> List[Export]: exports: List[Export] = [] for contract in self._compilation_unit.contracts: - list_contracts: List[Contract] = [] # will contain contract itself - list_top_level: List[TopLevel] = [] + list_contracts: Set[Contract] = set() # will contain contract itself + list_top_level: Set[TopLevel] = set() self._export_list_used_contracts(contract, set(), list_contracts, list_top_level) if list_top_level: