Improve slither-flat (#1125)

* Improve slither-flat
- Better support of multiple compilation unit
- Better support if there are multiple contracts with the same name
- Better support of top level object (partially fix #955). The support
for the local import strategy is missing
pull/1130/head
Feist Josselin 3 years ago committed by GitHub
parent 5fd5d4f6ce
commit dc0cec2d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      slither/core/declarations/contract.py
  2. 2
      slither/solc_parsing/solidity_types/type_parsing.py
  3. 45
      slither/tools/flattening/__main__.py
  4. 152
      slither/tools/flattening/flattening.py

@ -75,7 +75,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._custom_errors: Dict[str, "CustomErrorContract"] = {} self._custom_errors: Dict[str, "CustomErrorContract"] = {}
# The only str is "*" # The only str is "*"
self._using_for: Dict[Union[str, Type], List[str]] = {} self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._kind: Optional[str] = None self._kind: Optional[str] = None
self._is_interface: bool = False self._is_interface: bool = False
@ -245,7 +245,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
################################################################################### ###################################################################################
@property @property
def using_for(self) -> Dict[Union[str, Type], List[str]]: def using_for(self) -> Dict[Union[str, Type], List[Type]]:
return self._using_for return self._using_for
# endregion # endregion

@ -198,7 +198,7 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t
def parse_type( def parse_type(
t: Union[Dict, UnknownType], t: Union[Dict, UnknownType],
caller_context: Union[CallerContextExpression, "SlitherCompilationUnitSolc"], caller_context: Union[CallerContextExpression, "SlitherCompilationUnitSolc"],
): ) -> Type:
""" """
caller_context can be a SlitherCompilationUnitSolc because we recursively call the function caller_context can be a SlitherCompilationUnitSolc because we recursively call the function
and go up in the context's scope. If we are really lost we just go over the SlitherCompilationUnitSolc and go up in the context's scope. If we are really lost we just go over the SlitherCompilationUnitSolc

@ -104,28 +104,31 @@ def main():
args = parse_args() args = parse_args()
slither = Slither(args.filename, **vars(args)) slither = Slither(args.filename, **vars(args))
flat = Flattening(
slither,
external_to_public=args.convert_external,
remove_assert=args.remove_assert,
private_to_internal=args.convert_private,
export_path=args.dir,
pragma_solidity=args.pragma_solidity,
)
try: for compilation_unit in slither.compilation_units:
strategy = Strategy[args.strategy]
except KeyError: flat = Flattening(
to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)" compilation_unit,
logger.error(to_log) external_to_public=args.convert_external,
return remove_assert=args.remove_assert,
flat.export( private_to_internal=args.convert_private,
strategy=strategy, export_path=args.dir,
target=args.contract, pragma_solidity=args.pragma_solidity,
json=args.json, )
zip=args.zip,
zip_type=args.zip_type, try:
) strategy = Strategy[args.strategy]
except KeyError:
to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)"
logger.error(to_log)
return
flat.export(
strategy=strategy,
target=args.contract,
json=args.json,
zip=args.zip,
zip_type=args.zip_type,
)
if __name__ == "__main__": if __name__ == "__main__":

@ -1,17 +1,21 @@
import logging import logging
import re import re
import uuid
from collections import namedtuple from collections import namedtuple
from enum import Enum as PythonEnum from enum import Enum as PythonEnum
from pathlib import Path from pathlib import Path
from typing import List, Set, Dict, Optional from typing import List, Set, Dict, Optional
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import SolidityFunction, EnumContract, StructureContract from slither.core.declarations import SolidityFunction, EnumContract, StructureContract
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.slither_core import SlitherCore from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.top_level import TopLevel
from slither.core.solidity_types import MappingType, ArrayType from slither.core.solidity_types import MappingType, ArrayType
from slither.core.solidity_types.type import Type
from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.exceptions import SlitherException from slither.exceptions import SlitherException
from slither.slithir.operations import NewContract, TypeConversion, SolidityCall from slither.slithir.operations import NewContract, TypeConversion, SolidityCall, InternalCall
from slither.tools.flattening.export.export import ( from slither.tools.flattening.export.export import (
Export, Export,
export_as_json, export_as_json,
@ -44,7 +48,7 @@ class Flattening:
# pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods # pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods
def __init__( def __init__(
self, self,
slither: SlitherCore, compilation_unit: SlitherCompilationUnit,
external_to_public=False, external_to_public=False,
remove_assert=False, remove_assert=False,
private_to_internal=False, private_to_internal=False,
@ -52,7 +56,8 @@ class Flattening:
pragma_solidity: Optional[str] = None, pragma_solidity: Optional[str] = None,
): ):
self._source_codes: Dict[Contract, str] = {} self._source_codes: Dict[Contract, str] = {}
self._slither: SlitherCore = slither self._source_codes_top_level: Dict[TopLevel, str] = {}
self._compilation_unit: SlitherCompilationUnit = compilation_unit
self._external_to_public = external_to_public self._external_to_public = external_to_public
self._remove_assert = remove_assert self._remove_assert = remove_assert
self._use_abi_encoder_v2 = False self._use_abi_encoder_v2 = False
@ -63,20 +68,32 @@ class Flattening:
self._check_abi_encoder_v2() self._check_abi_encoder_v2()
for contract in slither.contracts: for contract in compilation_unit.contracts:
self._get_source_code(contract) self._get_source_code(contract)
self._get_source_code_top_level(compilation_unit.structures_top_level)
self._get_source_code_top_level(compilation_unit.enums_top_level)
self._get_source_code_top_level(compilation_unit.variables_top_level)
self._get_source_code_top_level(compilation_unit.functions_top_level)
def _get_source_code_top_level(self, elems: List[TopLevel]) -> None:
for elem in elems:
src_mapping = elem.source_mapping
content = self._compilation_unit.core.source_code[src_mapping["filename_absolute"]]
start = src_mapping["start"]
end = src_mapping["start"] + src_mapping["length"]
self._source_codes_top_level[elem] = content[start:end]
def _check_abi_encoder_v2(self): def _check_abi_encoder_v2(self):
""" """
Check if ABIEncoderV2 is required Check if ABIEncoderV2 is required
Set _use_abi_encorder_v2 Set _use_abi_encorder_v2
:return: :return:
""" """
for compilation_unit in self._slither.compilation_units: for p in self._compilation_unit.pragma_directives:
for p in compilation_unit.pragma_directives: if "ABIEncoderV2" in str(p.directive):
if "ABIEncoderV2" in str(p.directive): self._use_abi_encoder_v2 = True
self._use_abi_encoder_v2 = True return
return
def _get_source_code( def _get_source_code(
self, contract: Contract self, contract: Contract
@ -88,7 +105,7 @@ class Flattening:
:return: :return:
""" """
src_mapping = contract.source_mapping src_mapping = contract.source_mapping
content = self._slither.source_code[src_mapping["filename_absolute"]] content = self._compilation_unit.core.source_code[src_mapping["filename_absolute"]]
start = src_mapping["start"] start = src_mapping["start"]
end = src_mapping["start"] + src_mapping["length"] end = src_mapping["start"] + src_mapping["length"]
@ -132,11 +149,9 @@ class Flattening:
if self._private_to_internal: if self._private_to_internal:
for variable in contract.state_variables_declared: for variable in contract.state_variables_declared:
if variable.visibility == "private": if variable.visibility == "private":
print(variable.source_mapping)
attributes_start = variable.source_mapping["start"] attributes_start = variable.source_mapping["start"]
attributes_end = attributes_start + variable.source_mapping["length"] attributes_end = attributes_start + variable.source_mapping["length"]
attributes = content[attributes_start:attributes_end] attributes = content[attributes_start:attributes_end]
print(attributes)
regex = re.search(r" private ", attributes) regex = re.search(r" private ", attributes)
if regex: if regex:
to_patch.append( to_patch.append(
@ -191,35 +206,54 @@ class Flattening:
ret += f"pragma solidity {self._pragma_solidity};\n" ret += f"pragma solidity {self._pragma_solidity};\n"
else: else:
# TODO support multiple compiler version # TODO support multiple compiler version
ret += f"pragma solidity {list(self._slither.crytic_compile.compilation_units.values())[0].compiler_version.version};\n" ret += f"pragma solidity {list(self._compilation_unit.crytic_compile.compilation_units.values())[0].compiler_version.version};\n"
if self._use_abi_encoder_v2: if self._use_abi_encoder_v2:
ret += "pragma experimental ABIEncoderV2;\n" ret += "pragma experimental ABIEncoderV2;\n"
return ret return ret
def _export_from_type(self, t, contract, exported, list_contract): def _export_from_type(
self,
t: Type,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
):
if isinstance(t, UserDefinedType): if isinstance(t, UserDefinedType):
if isinstance(t.type, (EnumContract, StructureContract)): t_type = t.type
if t.type.contract != contract and t.type.contract not in exported: if isinstance(t_type, (EnumContract, StructureContract)):
self._export_list_used_contracts(t.type.contract, exported, list_contract) if t_type.contract != contract and t_type.contract not in exported:
self._export_list_used_contracts(
t_type.contract, exported, list_contract, list_top_level
)
else: else:
assert isinstance(t.type, Contract) assert isinstance(t.type, Contract)
if t.type != contract and t.type not in exported: if t.type != contract and t.type not in exported:
self._export_list_used_contracts(t.type, exported, list_contract) self._export_list_used_contracts(
t.type, exported, list_contract, list_top_level
)
elif isinstance(t, MappingType): elif isinstance(t, MappingType):
self._export_from_type(t.type_from, contract, exported, list_contract) self._export_from_type(t.type_from, contract, exported, list_contract, list_top_level)
self._export_from_type(t.type_to, contract, exported, list_contract) self._export_from_type(t.type_to, contract, exported, list_contract, list_top_level)
elif isinstance(t, ArrayType): elif isinstance(t, ArrayType):
self._export_from_type(t.type, contract, exported, list_contract) self._export_from_type(t.type, contract, exported, list_contract, list_top_level)
def _export_list_used_contracts( # pylint: disable=too-many-branches def _export_list_used_contracts( # pylint: disable=too-many-branches
self, contract: Contract, exported: Set[str], list_contract: List[Contract] self,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
): ):
# TODO: investigate why this happen
if not isinstance(contract, Contract):
return
if contract.name in exported: if contract.name in exported:
return return
exported.add(contract.name) exported.add(contract.name)
for inherited in contract.inheritance: for inherited in contract.inheritance:
self._export_list_used_contracts(inherited, exported, list_contract) self._export_list_used_contracts(inherited, exported, list_contract, list_top_level)
# Find all the external contracts called # Find all the external contracts called
externals = contract.all_library_calls + contract.all_high_level_calls externals = contract.all_library_calls + contract.all_high_level_calls
@ -228,7 +262,16 @@ class Flattening:
externals = list({e[0] for e in externals if e[0] != contract}) externals = list({e[0] for e in externals if e[0] != contract})
for inherited in externals: for inherited in externals:
self._export_list_used_contracts(inherited, exported, list_contract) self._export_list_used_contracts(inherited, exported, list_contract, list_top_level)
for list_libs in contract.using_for.values():
for lib_candidate_type in list_libs:
if isinstance(lib_candidate_type, UserDefinedType):
lib_candidate = lib_candidate_type.type
if isinstance(lib_candidate, Contract):
self._export_list_used_contracts(
lib_candidate, exported, list_contract, list_top_level
)
# Find all the external contracts use as a base type # Find all the external contracts use as a base type
local_vars = [] local_vars = []
@ -236,11 +279,11 @@ class Flattening:
local_vars += f.variables local_vars += f.variables
for v in contract.variables + local_vars: for v in contract.variables + local_vars:
self._export_from_type(v.type, contract, exported, list_contract) self._export_from_type(v.type, contract, exported, list_contract, list_top_level)
for s in contract.structures: for s in contract.structures:
for elem in s.elems.values(): for elem in s.elems.values():
self._export_from_type(elem.type, contract, exported, list_contract) self._export_from_type(elem.type, contract, exported, list_contract, list_top_level)
# Find all convert and "new" operation that can lead to use an external contract # Find all convert and "new" operation that can lead to use an external contract
for f in contract.functions_declared: for f in contract.functions_declared:
@ -248,21 +291,38 @@ class Flattening:
if isinstance(ir, NewContract): if isinstance(ir, NewContract):
if ir.contract_created != contract and not ir.contract_created in exported: if ir.contract_created != contract and not ir.contract_created in exported:
self._export_list_used_contracts( self._export_list_used_contracts(
ir.contract_created, exported, list_contract ir.contract_created, exported, list_contract, list_top_level
) )
if isinstance(ir, TypeConversion): if isinstance(ir, TypeConversion):
self._export_from_type(ir.type, contract, exported, list_contract) self._export_from_type(
ir.type, contract, exported, list_contract, list_top_level
)
for read in ir.read:
if isinstance(read, TopLevel):
if read not in list_top_level:
list_top_level.append(read)
if isinstance(ir, InternalCall):
function_called = ir.function
if isinstance(function_called, FunctionTopLevel):
list_top_level.append(function_called)
if contract not in list_contract: if contract not in list_contract:
list_contract.append(contract) list_contract.append(contract)
def _export_contract_with_inheritance(self, contract) -> Export: def _export_contract_with_inheritance(self, contract) -> Export:
list_contracts: List[Contract] = [] # will contain contract itself list_contracts: List[Contract] = [] # will contain contract itself
self._export_list_used_contracts(contract, set(), list_contracts) list_top_level: List[TopLevel] = []
path = Path(self._export_path, f"{contract.name}.sol") self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol")
content = "" content = ""
content += self._pragmas() content += self._pragmas()
for listed_top_level in list_top_level:
content += self._source_codes_top_level[listed_top_level]
content += "\n"
for listed_contract in list_contracts: for listed_contract in list_contracts:
content += self._source_codes[listed_contract] content += self._source_codes[listed_contract]
content += "\n" content += "\n"
@ -271,7 +331,7 @@ class Flattening:
def _export_most_derived(self) -> List[Export]: def _export_most_derived(self) -> List[Export]:
ret: List[Export] = [] ret: List[Export] = []
for contract in self._slither.contracts_derived: for contract in self._compilation_unit.contracts_derived:
ret.append(self._export_contract_with_inheritance(contract)) ret.append(self._export_contract_with_inheritance(contract))
return ret return ret
@ -281,8 +341,13 @@ class Flattening:
content = "" content = ""
content += self._pragmas() content += self._pragmas()
for top_level_content in self._source_codes_top_level.values():
content += "\n"
content += top_level_content
content += "\n"
contract_seen = set() contract_seen = set()
contract_to_explore = list(self._slither.contracts) contract_to_explore = list(self._compilation_unit.contracts)
# We only need the inheritance order here, as solc can compile # We only need the inheritance order here, as solc can compile
# a contract that use another contract type (ex: state variable) that he has not seen yet # a contract that use another contract type (ex: state variable) that he has not seen yet
@ -303,9 +368,17 @@ class Flattening:
def _export_with_import(self) -> List[Export]: def _export_with_import(self) -> List[Export]:
exports: List[Export] = [] exports: List[Export] = []
for contract in self._slither.contracts: for contract in self._compilation_unit.contracts:
list_contracts: List[Contract] = [] # will contain contract itself list_contracts: List[Contract] = [] # will contain contract itself
self._export_list_used_contracts(contract, set(), list_contracts) list_top_level: List[TopLevel] = []
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
if list_top_level:
logger.info(
"Top level objects are not yet supported with the local import flattening"
)
for elem in list_top_level:
logger.info(f"Missing {elem} for {contract.name}")
path = Path(self._export_path, f"{contract.name}.sol") path = Path(self._export_path, f"{contract.name}.sol")
@ -341,12 +414,13 @@ class Flattening:
elif strategy == Strategy.LocalImport: elif strategy == Strategy.LocalImport:
exports = self._export_with_import() exports = self._export_with_import()
else: else:
contracts = self._slither.get_contract_from_name(target) contracts = self._compilation_unit.get_contract_from_name(target)
if len(contracts) != 1: if len(contracts) == 0:
logger.error(f"{target} not found") logger.error(f"{target} not found")
return return
contract = contracts[0] exports = []
exports = [self._export_contract_with_inheritance(contract)] for contract in contracts:
exports.append(self._export_contract_with_inheritance(contract))
if json: if json:
export_as_json(exports, json) export_as_json(exports, json)

Loading…
Cancel
Save