Improve slither-flat (#1125)

* Improve slither-flat
- Better support of multiple compilation unit
- Better support if there are multiple contracts with the same name
- Better support of top level object (partially fix #955). The support
for the local import strategy is missing
pull/1130/head
Feist Josselin 3 years ago committed by GitHub
parent 5fd5d4f6ce
commit dc0cec2d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      slither/core/declarations/contract.py
  2. 2
      slither/solc_parsing/solidity_types/type_parsing.py
  3. 45
      slither/tools/flattening/__main__.py
  4. 152
      slither/tools/flattening/flattening.py

@ -75,7 +75,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._custom_errors: Dict[str, "CustomErrorContract"] = {}
# The only str is "*"
self._using_for: Dict[Union[str, Type], List[str]] = {}
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._kind: Optional[str] = None
self._is_interface: bool = False
@ -245,7 +245,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
###################################################################################
@property
def using_for(self) -> Dict[Union[str, Type], List[str]]:
def using_for(self) -> Dict[Union[str, Type], List[Type]]:
return self._using_for
# endregion

@ -198,7 +198,7 @@ def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,t
def parse_type(
t: Union[Dict, UnknownType],
caller_context: Union[CallerContextExpression, "SlitherCompilationUnitSolc"],
):
) -> Type:
"""
caller_context can be a SlitherCompilationUnitSolc because we recursively call the function
and go up in the context's scope. If we are really lost we just go over the SlitherCompilationUnitSolc

@ -104,28 +104,31 @@ def main():
args = parse_args()
slither = Slither(args.filename, **vars(args))
flat = Flattening(
slither,
external_to_public=args.convert_external,
remove_assert=args.remove_assert,
private_to_internal=args.convert_private,
export_path=args.dir,
pragma_solidity=args.pragma_solidity,
)
try:
strategy = Strategy[args.strategy]
except KeyError:
to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)"
logger.error(to_log)
return
flat.export(
strategy=strategy,
target=args.contract,
json=args.json,
zip=args.zip,
zip_type=args.zip_type,
)
for compilation_unit in slither.compilation_units:
flat = Flattening(
compilation_unit,
external_to_public=args.convert_external,
remove_assert=args.remove_assert,
private_to_internal=args.convert_private,
export_path=args.dir,
pragma_solidity=args.pragma_solidity,
)
try:
strategy = Strategy[args.strategy]
except KeyError:
to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)"
logger.error(to_log)
return
flat.export(
strategy=strategy,
target=args.contract,
json=args.json,
zip=args.zip,
zip_type=args.zip_type,
)
if __name__ == "__main__":

@ -1,17 +1,21 @@
import logging
import re
import uuid
from collections import namedtuple
from enum import Enum as PythonEnum
from pathlib import Path
from typing import List, Set, Dict, Optional
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import SolidityFunction, EnumContract, StructureContract
from slither.core.declarations.contract import Contract
from slither.core.slither_core import SlitherCore
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.top_level import TopLevel
from slither.core.solidity_types import MappingType, ArrayType
from slither.core.solidity_types.type import Type
from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.exceptions import SlitherException
from slither.slithir.operations import NewContract, TypeConversion, SolidityCall
from slither.slithir.operations import NewContract, TypeConversion, SolidityCall, InternalCall
from slither.tools.flattening.export.export import (
Export,
export_as_json,
@ -44,7 +48,7 @@ class Flattening:
# pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods
def __init__(
self,
slither: SlitherCore,
compilation_unit: SlitherCompilationUnit,
external_to_public=False,
remove_assert=False,
private_to_internal=False,
@ -52,7 +56,8 @@ class Flattening:
pragma_solidity: Optional[str] = None,
):
self._source_codes: Dict[Contract, str] = {}
self._slither: SlitherCore = slither
self._source_codes_top_level: Dict[TopLevel, str] = {}
self._compilation_unit: SlitherCompilationUnit = compilation_unit
self._external_to_public = external_to_public
self._remove_assert = remove_assert
self._use_abi_encoder_v2 = False
@ -63,20 +68,32 @@ class Flattening:
self._check_abi_encoder_v2()
for contract in slither.contracts:
for contract in compilation_unit.contracts:
self._get_source_code(contract)
self._get_source_code_top_level(compilation_unit.structures_top_level)
self._get_source_code_top_level(compilation_unit.enums_top_level)
self._get_source_code_top_level(compilation_unit.variables_top_level)
self._get_source_code_top_level(compilation_unit.functions_top_level)
def _get_source_code_top_level(self, elems: List[TopLevel]) -> None:
for elem in elems:
src_mapping = elem.source_mapping
content = self._compilation_unit.core.source_code[src_mapping["filename_absolute"]]
start = src_mapping["start"]
end = src_mapping["start"] + src_mapping["length"]
self._source_codes_top_level[elem] = content[start:end]
def _check_abi_encoder_v2(self):
"""
Check if ABIEncoderV2 is required
Set _use_abi_encorder_v2
:return:
"""
for compilation_unit in self._slither.compilation_units:
for p in compilation_unit.pragma_directives:
if "ABIEncoderV2" in str(p.directive):
self._use_abi_encoder_v2 = True
return
for p in self._compilation_unit.pragma_directives:
if "ABIEncoderV2" in str(p.directive):
self._use_abi_encoder_v2 = True
return
def _get_source_code(
self, contract: Contract
@ -88,7 +105,7 @@ class Flattening:
:return:
"""
src_mapping = contract.source_mapping
content = self._slither.source_code[src_mapping["filename_absolute"]]
content = self._compilation_unit.core.source_code[src_mapping["filename_absolute"]]
start = src_mapping["start"]
end = src_mapping["start"] + src_mapping["length"]
@ -132,11 +149,9 @@ class Flattening:
if self._private_to_internal:
for variable in contract.state_variables_declared:
if variable.visibility == "private":
print(variable.source_mapping)
attributes_start = variable.source_mapping["start"]
attributes_end = attributes_start + variable.source_mapping["length"]
attributes = content[attributes_start:attributes_end]
print(attributes)
regex = re.search(r" private ", attributes)
if regex:
to_patch.append(
@ -191,35 +206,54 @@ class Flattening:
ret += f"pragma solidity {self._pragma_solidity};\n"
else:
# TODO support multiple compiler version
ret += f"pragma solidity {list(self._slither.crytic_compile.compilation_units.values())[0].compiler_version.version};\n"
ret += f"pragma solidity {list(self._compilation_unit.crytic_compile.compilation_units.values())[0].compiler_version.version};\n"
if self._use_abi_encoder_v2:
ret += "pragma experimental ABIEncoderV2;\n"
return ret
def _export_from_type(self, t, contract, exported, list_contract):
def _export_from_type(
self,
t: Type,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
):
if isinstance(t, UserDefinedType):
if isinstance(t.type, (EnumContract, StructureContract)):
if t.type.contract != contract and t.type.contract not in exported:
self._export_list_used_contracts(t.type.contract, exported, list_contract)
t_type = t.type
if isinstance(t_type, (EnumContract, StructureContract)):
if t_type.contract != contract and t_type.contract not in exported:
self._export_list_used_contracts(
t_type.contract, exported, list_contract, list_top_level
)
else:
assert isinstance(t.type, Contract)
if t.type != contract and t.type not in exported:
self._export_list_used_contracts(t.type, exported, list_contract)
self._export_list_used_contracts(
t.type, exported, list_contract, list_top_level
)
elif isinstance(t, MappingType):
self._export_from_type(t.type_from, contract, exported, list_contract)
self._export_from_type(t.type_to, contract, exported, list_contract)
self._export_from_type(t.type_from, contract, exported, list_contract, list_top_level)
self._export_from_type(t.type_to, contract, exported, list_contract, list_top_level)
elif isinstance(t, ArrayType):
self._export_from_type(t.type, contract, exported, list_contract)
self._export_from_type(t.type, contract, exported, list_contract, list_top_level)
def _export_list_used_contracts( # pylint: disable=too-many-branches
self, contract: Contract, exported: Set[str], list_contract: List[Contract]
self,
contract: Contract,
exported: Set[str],
list_contract: List[Contract],
list_top_level: List[TopLevel],
):
# TODO: investigate why this happen
if not isinstance(contract, Contract):
return
if contract.name in exported:
return
exported.add(contract.name)
for inherited in contract.inheritance:
self._export_list_used_contracts(inherited, exported, list_contract)
self._export_list_used_contracts(inherited, exported, list_contract, list_top_level)
# Find all the external contracts called
externals = contract.all_library_calls + contract.all_high_level_calls
@ -228,7 +262,16 @@ class Flattening:
externals = list({e[0] for e in externals if e[0] != contract})
for inherited in externals:
self._export_list_used_contracts(inherited, exported, list_contract)
self._export_list_used_contracts(inherited, exported, list_contract, list_top_level)
for list_libs in contract.using_for.values():
for lib_candidate_type in list_libs:
if isinstance(lib_candidate_type, UserDefinedType):
lib_candidate = lib_candidate_type.type
if isinstance(lib_candidate, Contract):
self._export_list_used_contracts(
lib_candidate, exported, list_contract, list_top_level
)
# Find all the external contracts use as a base type
local_vars = []
@ -236,11 +279,11 @@ class Flattening:
local_vars += f.variables
for v in contract.variables + local_vars:
self._export_from_type(v.type, contract, exported, list_contract)
self._export_from_type(v.type, contract, exported, list_contract, list_top_level)
for s in contract.structures:
for elem in s.elems.values():
self._export_from_type(elem.type, contract, exported, list_contract)
self._export_from_type(elem.type, contract, exported, list_contract, list_top_level)
# Find all convert and "new" operation that can lead to use an external contract
for f in contract.functions_declared:
@ -248,21 +291,38 @@ class Flattening:
if isinstance(ir, NewContract):
if ir.contract_created != contract and not ir.contract_created in exported:
self._export_list_used_contracts(
ir.contract_created, exported, list_contract
ir.contract_created, exported, list_contract, list_top_level
)
if isinstance(ir, TypeConversion):
self._export_from_type(ir.type, contract, exported, list_contract)
self._export_from_type(
ir.type, contract, exported, list_contract, list_top_level
)
for read in ir.read:
if isinstance(read, TopLevel):
if read not in list_top_level:
list_top_level.append(read)
if isinstance(ir, InternalCall):
function_called = ir.function
if isinstance(function_called, FunctionTopLevel):
list_top_level.append(function_called)
if contract not in list_contract:
list_contract.append(contract)
def _export_contract_with_inheritance(self, contract) -> Export:
list_contracts: List[Contract] = [] # will contain contract itself
self._export_list_used_contracts(contract, set(), list_contracts)
path = Path(self._export_path, f"{contract.name}.sol")
list_top_level: List[TopLevel] = []
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
path = Path(self._export_path, f"{contract.name}_{uuid.uuid4()}.sol")
content = ""
content += self._pragmas()
for listed_top_level in list_top_level:
content += self._source_codes_top_level[listed_top_level]
content += "\n"
for listed_contract in list_contracts:
content += self._source_codes[listed_contract]
content += "\n"
@ -271,7 +331,7 @@ class Flattening:
def _export_most_derived(self) -> List[Export]:
ret: List[Export] = []
for contract in self._slither.contracts_derived:
for contract in self._compilation_unit.contracts_derived:
ret.append(self._export_contract_with_inheritance(contract))
return ret
@ -281,8 +341,13 @@ class Flattening:
content = ""
content += self._pragmas()
for top_level_content in self._source_codes_top_level.values():
content += "\n"
content += top_level_content
content += "\n"
contract_seen = set()
contract_to_explore = list(self._slither.contracts)
contract_to_explore = list(self._compilation_unit.contracts)
# We only need the inheritance order here, as solc can compile
# a contract that use another contract type (ex: state variable) that he has not seen yet
@ -303,9 +368,17 @@ class Flattening:
def _export_with_import(self) -> List[Export]:
exports: List[Export] = []
for contract in self._slither.contracts:
for contract in self._compilation_unit.contracts:
list_contracts: List[Contract] = [] # will contain contract itself
self._export_list_used_contracts(contract, set(), list_contracts)
list_top_level: List[TopLevel] = []
self._export_list_used_contracts(contract, set(), list_contracts, list_top_level)
if list_top_level:
logger.info(
"Top level objects are not yet supported with the local import flattening"
)
for elem in list_top_level:
logger.info(f"Missing {elem} for {contract.name}")
path = Path(self._export_path, f"{contract.name}.sol")
@ -341,12 +414,13 @@ class Flattening:
elif strategy == Strategy.LocalImport:
exports = self._export_with_import()
else:
contracts = self._slither.get_contract_from_name(target)
if len(contracts) != 1:
contracts = self._compilation_unit.get_contract_from_name(target)
if len(contracts) == 0:
logger.error(f"{target} not found")
return
contract = contracts[0]
exports = [self._export_contract_with_inheritance(contract)]
exports = []
for contract in contracts:
exports.append(self._export_contract_with_inheritance(contract))
if json:
export_as_json(exports, json)

Loading…
Cancel
Save