Improve handling of top level usages

pull/1852/head
Simone 2 years ago
parent 637b8e2a9d
commit 718e51160b
  1. 43
      slither/tools/flattening/flattening.py

@ -11,6 +11,7 @@ from slither.core.declarations import SolidityFunction, EnumContract, StructureC
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.top_level import TopLevel from slither.core.declarations.top_level import TopLevel
from slither.core.declarations.solidity_variables import SolidityCustomRevert
from slither.core.solidity_types import MappingType, ArrayType from slither.core.solidity_types import MappingType, ArrayType
from slither.core.solidity_types.type import Type from slither.core.solidity_types.type import Type
from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.user_defined_type import UserDefinedType
@ -24,6 +25,7 @@ from slither.tools.flattening.export.export import (
) )
logger = logging.getLogger("Slither-flattening") logger = logging.getLogger("Slither-flattening")
logger.setLevel(logging.INFO)
# index: where to start # index: where to start
# patch_type: # patch_type:
@ -75,6 +77,7 @@ class Flattening:
self._get_source_code_top_level(compilation_unit.structures_top_level) self._get_source_code_top_level(compilation_unit.structures_top_level)
self._get_source_code_top_level(compilation_unit.enums_top_level) self._get_source_code_top_level(compilation_unit.enums_top_level)
self._get_source_code_top_level(compilation_unit.custom_errors)
self._get_source_code_top_level(compilation_unit.variables_top_level) self._get_source_code_top_level(compilation_unit.variables_top_level)
self._get_source_code_top_level(compilation_unit.functions_top_level) self._get_source_code_top_level(compilation_unit.functions_top_level)
@ -249,12 +252,14 @@ class Flattening:
t: Type, t: Type,
contract: Contract, contract: Contract,
exported: Set[str], exported: Set[str],
list_contract: List[Contract], list_contract: Set[Contract],
list_top_level: List[TopLevel], list_top_level: Set[TopLevel],
): ):
if isinstance(t, UserDefinedType): if isinstance(t, UserDefinedType):
t_type = t.type t_type = t.type
if isinstance(t_type, (EnumContract, StructureContract)): if isinstance(t_type, TopLevel):
list_top_level.add(t_type)
elif isinstance(t_type, (EnumContract, StructureContract)):
if t_type.contract != contract and t_type.contract not in exported: if t_type.contract != contract and t_type.contract not in exported:
self._export_list_used_contracts( self._export_list_used_contracts(
t_type.contract, exported, list_contract, list_top_level t_type.contract, exported, list_contract, list_top_level
@ -275,8 +280,8 @@ class Flattening:
self, self,
contract: Contract, contract: Contract,
exported: Set[str], exported: Set[str],
list_contract: List[Contract], list_contract: Set[Contract],
list_top_level: List[TopLevel], list_top_level: Set[TopLevel],
): ):
# TODO: investigate why this happen # TODO: investigate why this happen
if not isinstance(contract, Contract): if not isinstance(contract, Contract):
@ -332,19 +337,21 @@ class Flattening:
for read in ir.read: for read in ir.read:
if isinstance(read, TopLevel): if isinstance(read, TopLevel):
if read not in list_top_level: list_top_level.add(read)
list_top_level.append(read) if isinstance(ir, InternalCall) and isinstance(ir.function, FunctionTopLevel):
if isinstance(ir, InternalCall): list_top_level.add(ir.function)
function_called = ir.function if (
if isinstance(function_called, FunctionTopLevel): isinstance(ir, SolidityCall)
list_top_level.append(function_called) and isinstance(ir.function, SolidityCustomRevert)
and isinstance(ir.function.custom_error, TopLevel)
if contract not in list_contract: ):
list_contract.append(contract) list_top_level.add(ir.function.custom_error)
list_contract.add(contract)
def _export_contract_with_inheritance(self, contract) -> Export: def _export_contract_with_inheritance(self, contract) -> Export:
list_contracts: List[Contract] = [] # will contain contract itself list_contracts: Set[Contract] = set() # will contain contract itself
list_top_level: List[TopLevel] = [] list_top_level: Set[TopLevel] = set()
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level) self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol") path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol")
@ -401,8 +408,8 @@ class Flattening:
def _export_with_import(self) -> List[Export]: def _export_with_import(self) -> List[Export]:
exports: List[Export] = [] exports: List[Export] = []
for contract in self._compilation_unit.contracts: for contract in self._compilation_unit.contracts:
list_contracts: List[Contract] = [] # will contain contract itself list_contracts: Set[Contract] = set() # will contain contract itself
list_top_level: List[TopLevel] = [] list_top_level: Set[TopLevel] = set()
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level) self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
if list_top_level: if list_top_level:

Loading…
Cancel
Save