Improve handling of top level usages

pull/1852/head
Simone 2 years ago
parent 637b8e2a9d
commit 718e51160b
  1. 41
      slither/tools/flattening/flattening.py

@ -11,6 +11,7 @@ from slither.core.declarations import SolidityFunction, EnumContract, StructureC
from slither.core.declarations.contract import Contract
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.top_level import TopLevel
from slither.core.declarations.solidity_variables import SolidityCustomRevert
from slither.core.solidity_types import MappingType, ArrayType
from slither.core.solidity_types.type import Type
from slither.core.solidity_types.user_defined_type import UserDefinedType
@ -24,6 +25,7 @@ from slither.tools.flattening.export.export import (
)
logger = logging.getLogger("Slither-flattening")
logger.setLevel(logging.INFO)
# index: where to start
# patch_type:
@ -75,6 +77,7 @@ class Flattening:
self._get_source_code_top_level(compilation_unit.structures_top_level)
self._get_source_code_top_level(compilation_unit.enums_top_level)
self._get_source_code_top_level(compilation_unit.custom_errors)
self._get_source_code_top_level(compilation_unit.variables_top_level)
self._get_source_code_top_level(compilation_unit.functions_top_level)
@ -249,12 +252,14 @@ class Flattening:
t: Type,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
list_contract: Set[Contract],
list_top_level: Set[TopLevel],
):
if isinstance(t, UserDefinedType):
t_type = t.type
if isinstance(t_type, (EnumContract, StructureContract)):
if isinstance(t_type, TopLevel):
list_top_level.add(t_type)
elif isinstance(t_type, (EnumContract, StructureContract)):
if t_type.contract != contract and t_type.contract not in exported:
self._export_list_used_contracts(
t_type.contract, exported, list_contract, list_top_level
@ -275,8 +280,8 @@ class Flattening:
self,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
list_contract: Set[Contract],
list_top_level: Set[TopLevel],
):
# TODO: investigate why this happen
if not isinstance(contract, Contract):
@ -332,19 +337,21 @@ class Flattening:
for read in ir.read:
if isinstance(read, TopLevel):
if read not in list_top_level:
list_top_level.append(read)
if isinstance(ir, InternalCall):
function_called = ir.function
if isinstance(function_called, FunctionTopLevel):
list_top_level.append(function_called)
list_top_level.add(read)
if isinstance(ir, InternalCall) and isinstance(ir.function, FunctionTopLevel):
list_top_level.add(ir.function)
if (
isinstance(ir, SolidityCall)
and isinstance(ir.function, SolidityCustomRevert)
and isinstance(ir.function.custom_error, TopLevel)
):
list_top_level.add(ir.function.custom_error)
if contract not in list_contract:
list_contract.append(contract)
list_contract.add(contract)
def _export_contract_with_inheritance(self, contract) -> Export:
list_contracts: List[Contract] = [] # will contain contract itself
list_top_level: List[TopLevel] = []
list_contracts: Set[Contract] = set() # will contain contract itself
list_top_level: Set[TopLevel] = set()
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol")
@ -401,8 +408,8 @@ class Flattening:
def _export_with_import(self) -> List[Export]:
exports: List[Export] = []
for contract in self._compilation_unit.contracts:
list_contracts: List[Contract] = [] # will contain contract itself
list_top_level: List[TopLevel] = []
list_contracts: Set[Contract] = set() # will contain contract itself
list_top_level: Set[TopLevel] = set()
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
if list_top_level:

Loading…
Cancel
Save