diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index ec0a20c8a..2db1cd964 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -75,7 +75,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods self._custom_errors: Dict[str, "CustomErrorContract"] = {} # The only str is "*" - self._using_for: Dict[Union[str, Type], List[str]] = {} + self._using_for: Dict[Union[str, Type], List[Type]] = {} self._kind: Optional[str] = None self._is_interface: bool = False @@ -245,7 +245,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods ################################################################################### @property - def using_for(self) -> Dict[Union[str, Type], List[str]]: + def using_for(self) -> Dict[Union[str, Type], List[Type]]: return self._using_for # endregion diff --git a/slither/solc_parsing/solidity_types/type_parsing.py b/slither/solc_parsing/solidity_types/type_parsing.py index 118b97ce3..4c4282f5d 100644 --- a/slither/solc_parsing/solidity_types/type_parsing.py +++ b/slither/solc_parsing/solidity_types/type_parsing.py @@ -198,7 +198,7 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t def parse_type( t: Union[Dict, UnknownType], caller_context: Union[CallerContextExpression, "SlitherCompilationUnitSolc"], -): +) -> Type: """ caller_context can be a SlitherCompilationUnitSolc because we recursively call the function and go up in the context's scope. If we are really lost we just go over the SlitherCompilationUnitSolc diff --git a/slither/tools/flattening/__main__.py b/slither/tools/flattening/__main__.py index 308dd8106..e0888f222 100644 --- a/slither/tools/flattening/__main__.py +++ b/slither/tools/flattening/__main__.py @@ -104,28 +104,31 @@ def main(): args = parse_args() slither = Slither(args.filename, **vars(args)) - flat = Flattening( - slither, - external_to_public=args.convert_external, - remove_assert=args.remove_assert, - private_to_internal=args.convert_private, - export_path=args.dir, - pragma_solidity=args.pragma_solidity, - ) - try: - strategy = Strategy[args.strategy] - except KeyError: - to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)" - logger.error(to_log) - return - flat.export( - strategy=strategy, - target=args.contract, - json=args.json, - zip=args.zip, - zip_type=args.zip_type, - ) + for compilation_unit in slither.compilation_units: + + flat = Flattening( + compilation_unit, + external_to_public=args.convert_external, + remove_assert=args.remove_assert, + private_to_internal=args.convert_private, + export_path=args.dir, + pragma_solidity=args.pragma_solidity, + ) + + try: + strategy = Strategy[args.strategy] + except KeyError: + to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)" + logger.error(to_log) + return + flat.export( + strategy=strategy, + target=args.contract, + json=args.json, + zip=args.zip, + zip_type=args.zip_type, + ) if __name__ == "__main__": diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index a1dbc0a9d..f067bde6d 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -1,17 +1,21 @@ import logging import re +import uuid from collections import namedtuple from enum import Enum as PythonEnum from pathlib import Path from typing import List, Set, Dict, Optional +from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.declarations import SolidityFunction, EnumContract, StructureContract from slither.core.declarations.contract import Contract -from slither.core.slither_core import SlitherCore +from slither.core.declarations.function_top_level import FunctionTopLevel +from slither.core.declarations.top_level import TopLevel from slither.core.solidity_types import MappingType, ArrayType +from slither.core.solidity_types.type import Type from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.exceptions import SlitherException -from slither.slithir.operations import NewContract, TypeConversion, SolidityCall +from slither.slithir.operations import NewContract, TypeConversion, SolidityCall, InternalCall from slither.tools.flattening.export.export import ( Export, export_as_json, @@ -44,7 +48,7 @@ class Flattening: # pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods def __init__( self, - slither: SlitherCore, + compilation_unit: SlitherCompilationUnit, external_to_public=False, remove_assert=False, private_to_internal=False, @@ -52,7 +56,8 @@ class Flattening: pragma_solidity: Optional[str] = None, ): self._source_codes: Dict[Contract, str] = {} - self._slither: SlitherCore = slither + self._source_codes_top_level: Dict[TopLevel, str] = {} + self._compilation_unit: SlitherCompilationUnit = compilation_unit self._external_to_public = external_to_public self._remove_assert = remove_assert self._use_abi_encoder_v2 = False @@ -63,20 +68,32 @@ class Flattening: self._check_abi_encoder_v2() - for contract in slither.contracts: + for contract in compilation_unit.contracts: self._get_source_code(contract) + self._get_source_code_top_level(compilation_unit.structures_top_level) + self._get_source_code_top_level(compilation_unit.enums_top_level) + self._get_source_code_top_level(compilation_unit.variables_top_level) + self._get_source_code_top_level(compilation_unit.functions_top_level) + + def _get_source_code_top_level(self, elems: List[TopLevel]) -> None: + for elem in elems: + src_mapping = elem.source_mapping + content = self._compilation_unit.core.source_code[src_mapping["filename_absolute"]] + start = src_mapping["start"] + end = src_mapping["start"] + src_mapping["length"] + self._source_codes_top_level[elem] = content[start:end] + def _check_abi_encoder_v2(self): """ Check if ABIEncoderV2 is required Set _use_abi_encorder_v2 :return: """ - for compilation_unit in self._slither.compilation_units: - for p in compilation_unit.pragma_directives: - if "ABIEncoderV2" in str(p.directive): - self._use_abi_encoder_v2 = True - return + for p in self._compilation_unit.pragma_directives: + if "ABIEncoderV2" in str(p.directive): + self._use_abi_encoder_v2 = True + return def _get_source_code( self, contract: Contract @@ -88,7 +105,7 @@ class Flattening: :return: """ src_mapping = contract.source_mapping - content = self._slither.source_code[src_mapping["filename_absolute"]] + content = self._compilation_unit.core.source_code[src_mapping["filename_absolute"]] start = src_mapping["start"] end = src_mapping["start"] + src_mapping["length"] @@ -132,11 +149,9 @@ class Flattening: if self._private_to_internal: for variable in contract.state_variables_declared: if variable.visibility == "private": - print(variable.source_mapping) attributes_start = variable.source_mapping["start"] attributes_end = attributes_start + variable.source_mapping["length"] attributes = content[attributes_start:attributes_end] - print(attributes) regex = re.search(r" private ", attributes) if regex: to_patch.append( @@ -191,35 +206,54 @@ class Flattening: ret += f"pragma solidity {self._pragma_solidity};\n" else: # TODO support multiple compiler version - ret += f"pragma solidity {list(self._slither.crytic_compile.compilation_units.values())[0].compiler_version.version};\n" + ret += f"pragma solidity {list(self._compilation_unit.crytic_compile.compilation_units.values())[0].compiler_version.version};\n" if self._use_abi_encoder_v2: ret += "pragma experimental ABIEncoderV2;\n" return ret - def _export_from_type(self, t, contract, exported, list_contract): + def _export_from_type( + self, + t: Type, + contract: Contract, + exported: Set[str], + list_contract: List[Contract], + list_top_level: List[TopLevel], + ): if isinstance(t, UserDefinedType): - if isinstance(t.type, (EnumContract, StructureContract)): - if t.type.contract != contract and t.type.contract not in exported: - self._export_list_used_contracts(t.type.contract, exported, list_contract) + t_type = t.type + if isinstance(t_type, (EnumContract, StructureContract)): + if t_type.contract != contract and t_type.contract not in exported: + self._export_list_used_contracts( + t_type.contract, exported, list_contract, list_top_level + ) else: assert isinstance(t.type, Contract) if t.type != contract and t.type not in exported: - self._export_list_used_contracts(t.type, exported, list_contract) + self._export_list_used_contracts( + t.type, exported, list_contract, list_top_level + ) elif isinstance(t, MappingType): - self._export_from_type(t.type_from, contract, exported, list_contract) - self._export_from_type(t.type_to, contract, exported, list_contract) + self._export_from_type(t.type_from, contract, exported, list_contract, list_top_level) + self._export_from_type(t.type_to, contract, exported, list_contract, list_top_level) elif isinstance(t, ArrayType): - self._export_from_type(t.type, contract, exported, list_contract) + self._export_from_type(t.type, contract, exported, list_contract, list_top_level) def _export_list_used_contracts( # pylint: disable=too-many-branches - self, contract: Contract, exported: Set[str], list_contract: List[Contract] + self, + contract: Contract, + exported: Set[str], + list_contract: List[Contract], + list_top_level: List[TopLevel], ): + # TODO: investigate why this happen + if not isinstance(contract, Contract): + return if contract.name in exported: return exported.add(contract.name) for inherited in contract.inheritance: - self._export_list_used_contracts(inherited, exported, list_contract) + self._export_list_used_contracts(inherited, exported, list_contract, list_top_level) # Find all the external contracts called externals = contract.all_library_calls + contract.all_high_level_calls @@ -228,7 +262,16 @@ class Flattening: externals = list({e[0] for e in externals if e[0] != contract}) for inherited in externals: - self._export_list_used_contracts(inherited, exported, list_contract) + self._export_list_used_contracts(inherited, exported, list_contract, list_top_level) + + for list_libs in contract.using_for.values(): + for lib_candidate_type in list_libs: + if isinstance(lib_candidate_type, UserDefinedType): + lib_candidate = lib_candidate_type.type + if isinstance(lib_candidate, Contract): + self._export_list_used_contracts( + lib_candidate, exported, list_contract, list_top_level + ) # Find all the external contracts use as a base type local_vars = [] @@ -236,11 +279,11 @@ class Flattening: local_vars += f.variables for v in contract.variables + local_vars: - self._export_from_type(v.type, contract, exported, list_contract) + self._export_from_type(v.type, contract, exported, list_contract, list_top_level) for s in contract.structures: for elem in s.elems.values(): - self._export_from_type(elem.type, contract, exported, list_contract) + self._export_from_type(elem.type, contract, exported, list_contract, list_top_level) # Find all convert and "new" operation that can lead to use an external contract for f in contract.functions_declared: @@ -248,21 +291,38 @@ class Flattening: if isinstance(ir, NewContract): if ir.contract_created != contract and not ir.contract_created in exported: self._export_list_used_contracts( - ir.contract_created, exported, list_contract + ir.contract_created, exported, list_contract, list_top_level ) if isinstance(ir, TypeConversion): - self._export_from_type(ir.type, contract, exported, list_contract) + self._export_from_type( + ir.type, contract, exported, list_contract, list_top_level + ) + + for read in ir.read: + if isinstance(read, TopLevel): + if read not in list_top_level: + list_top_level.append(read) + if isinstance(ir, InternalCall): + function_called = ir.function + if isinstance(function_called, FunctionTopLevel): + list_top_level.append(function_called) + if contract not in list_contract: list_contract.append(contract) def _export_contract_with_inheritance(self, contract) -> Export: list_contracts: List[Contract] = [] # will contain contract itself - self._export_list_used_contracts(contract, set(), list_contracts) - path = Path(self._export_path, f"{contract.name}.sol") + list_top_level: List[TopLevel] = [] + self._export_list_used_contracts(contract, set(), list_contracts, list_top_level) + path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol") content = "" content += self._pragmas() + for listed_top_level in list_top_level: + content += self._source_codes_top_level[listed_top_level] + content += "\n" + for listed_contract in list_contracts: content += self._source_codes[listed_contract] content += "\n" @@ -271,7 +331,7 @@ class Flattening: def _export_most_derived(self) -> List[Export]: ret: List[Export] = [] - for contract in self._slither.contracts_derived: + for contract in self._compilation_unit.contracts_derived: ret.append(self._export_contract_with_inheritance(contract)) return ret @@ -281,8 +341,13 @@ class Flattening: content = "" content += self._pragmas() + for top_level_content in self._source_codes_top_level.values(): + content += "\n" + content += top_level_content + content += "\n" + contract_seen = set() - contract_to_explore = list(self._slither.contracts) + contract_to_explore = list(self._compilation_unit.contracts) # We only need the inheritance order here, as solc can compile # a contract that use another contract type (ex: state variable) that he has not seen yet @@ -303,9 +368,17 @@ class Flattening: def _export_with_import(self) -> List[Export]: exports: List[Export] = [] - for contract in self._slither.contracts: + for contract in self._compilation_unit.contracts: list_contracts: List[Contract] = [] # will contain contract itself - self._export_list_used_contracts(contract, set(), list_contracts) + list_top_level: List[TopLevel] = [] + self._export_list_used_contracts(contract, set(), list_contracts, list_top_level) + + if list_top_level: + logger.info( + "Top level objects are not yet supported with the local import flattening" + ) + for elem in list_top_level: + logger.info(f"Missing {elem} for {contract.name}") path = Path(self._export_path, f"{contract.name}.sol") @@ -341,12 +414,13 @@ class Flattening: elif strategy == Strategy.LocalImport: exports = self._export_with_import() else: - contracts = self._slither.get_contract_from_name(target) - if len(contracts) != 1: + contracts = self._compilation_unit.get_contract_from_name(target) + if len(contracts) == 0: logger.error(f"{target} not found") return - contract = contracts[0] - exports = [self._export_contract_with_inheritance(contract)] + exports = [] + for contract in contracts: + exports.append(self._export_contract_with_inheritance(contract)) if json: export_as_json(exports, json)