From ed34d85617c3aeba8b6da9f926586867757354bc Mon Sep 17 00:00:00 2001 From: Josselin Date: Sat, 6 Jun 2020 17:44:36 +0200 Subject: [PATCH] Several improvements to slither-flat: - Allow to specify a strategy (as defined in #495) - Add json and zip export - Clean up architecture - Clean up logging and --help --- slither/tools/flattening/__main__.py | 112 +++++-- slither/tools/flattening/export/__init__.y.py | 0 slither/tools/flattening/export/export.py | 55 ++++ slither/tools/flattening/flattening.py | 280 +++++++++++++----- 4 files changed, 347 insertions(+), 100 deletions(-) create mode 100644 slither/tools/flattening/export/__init__.y.py create mode 100644 slither/tools/flattening/export/export.py diff --git a/slither/tools/flattening/__main__.py b/slither/tools/flattening/__main__.py index 30ce746d9..3486150f4 100644 --- a/slither/tools/flattening/__main__.py +++ b/slither/tools/flattening/__main__.py @@ -1,44 +1,88 @@ import argparse import logging -from slither import Slither +import sys + from crytic_compile import cryticparser -from .flattening import Flattening +from crytic_compile.utils.zip import ZIP_TYPES_ACCEPTED + +from slither import Slither +from slither.tools.flattening.flattening import ( + Flattening, + Strategy, + STRATEGIES_NAMES, + DEFAULT_EXPORT_PATH, +) logging.basicConfig() -logging.getLogger("Slither").setLevel(logging.INFO) -logger = logging.getLogger("Slither-flattening") +logger = logging.getLogger("Slither") logger.setLevel(logging.INFO) + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='Contracts flattening', - usage='slither-flat filename') + parser = argparse.ArgumentParser( + description="Contracts flattening. See https://github.com/crytic/slither/wiki/Contract-Flattening", + usage="slither-flat filename", + ) + + parser.add_argument("filename", help="The filename of the contract or project to analyze.") + + parser.add_argument("--contract", help="Flatten one contract.", default=None) - parser.add_argument('filename', - help='The filename of the contract or project to analyze.') + parser.add_argument( + "--strategy", + help=f"Flatenning strategy: {STRATEGIES_NAMES} (default: MostDerived).", + default=Strategy.MostDerived.name, + ) - parser.add_argument('--convert-external', - help='Convert external to public.', - action='store_true') + group_export = parser.add_argument_group("Export options") - parser.add_argument('--convert-private', - help='Convert private variables to internal.', - action='store_true') + group_export.add_argument( + "--dir", help=f"Export directory (default: {DEFAULT_EXPORT_PATH}).", default=None + ) - parser.add_argument('--remove-assert', - help='Remove call to assert().', - action='store_true') + group_export.add_argument( + "--json", + help='Export the results as a JSON file ("--json -" to export to stdout)', + action="store", + default=None, + ) - parser.add_argument('--contract', - help='Flatten a specific contract (default: all most derived contracts).', - default=None) + parser.add_argument( + "--zip", help="Export all the files to a zip file", action="store", default=None, + ) + + parser.add_argument( + "--zip-type", + help=f"Zip compression type. One of {','.join(ZIP_TYPES_ACCEPTED.keys())}. Default lzma", + action="store", + default=None, + ) + + group_patching = parser.add_argument_group("Patching options") + + group_patching.add_argument( + "--convert-external", help="Convert external to public.", action="store_true" + ) + + group_patching.add_argument( + "--convert-private", help="Convert private variables to internal.", action="store_true" + ) + + group_patching.add_argument( + "--remove-assert", help="Remove call to assert().", action="store_true" + ) # Add default arguments from crytic-compile cryticparser.init(parser) + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + return parser.parse_args() @@ -46,13 +90,29 @@ def main(): args = parse_args() slither = Slither(args.filename, **vars(args)) - flat = Flattening(slither, - external_to_public=args.convert_external, - remove_assert=args.remove_assert, - private_to_internal=args.convert_private) + flat = Flattening( + slither, + external_to_public=args.convert_external, + remove_assert=args.remove_assert, + private_to_internal=args.convert_private, + export_path=args.dir, + ) - flat.export(target=args.contract) + try: + strategy = Strategy[args.strategy] + except KeyError: + logger.error( + f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)" + ) + return + flat.export( + strategy=strategy, + target=args.contract, + json=args.json, + zip=args.zip, + zip_type=args.zip_type, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/slither/tools/flattening/export/__init__.y.py b/slither/tools/flattening/export/__init__.y.py new file mode 100644 index 000000000..e69de29bb diff --git a/slither/tools/flattening/export/export.py b/slither/tools/flattening/export/export.py new file mode 100644 index 000000000..94e13d78a --- /dev/null +++ b/slither/tools/flattening/export/export.py @@ -0,0 +1,55 @@ +import json +import logging + +# https://docs.python.org/3/library/zipfile.html#zipfile-objects +import zipfile +from collections import namedtuple +from typing import List + +ZIP_TYPES_ACCEPTED = { + "lzma": zipfile.ZIP_LZMA, + "stored": zipfile.ZIP_STORED, + "deflated": zipfile.ZIP_DEFLATED, + "bzip2": zipfile.ZIP_BZIP2, +} + +Export = namedtuple("Export", ["filename", "content"]) + +logger = logging.getLogger("Slither") + + +def save_to_zip(files: List[Export], zip_filename: str, zip_type: str = "lzma"): + """ + Save projects to a zip + """ + logger.info(f"Export {zip_filename}") + with zipfile.ZipFile( + zip_filename, "w", compression=ZIP_TYPES_ACCEPTED.get(zip_type, zipfile.ZIP_LZMA) + ) as file_desc: + for f in files: + file_desc.writestr(str(f.filename), f.content) + + +def save_to_disk(files: List[Export]): + """ + Save projects to a zip + """ + for file in files: + with open(file.filename, "w") as f: + logger.info(f"Export {file.filename}") + f.write(file.content) + + +def export_as_json(files: List[Export], filename: str): + """ + Save projects to a zip + """ + + files_as_dict = {str(f.filename): f.content for f in files} + + if filename == "-": + print(json.dumps(files_as_dict)) + else: + logger.info(f"Export {filename}") + with open(filename, "w") as f: + json.dump(files_as_dict, f) diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index e4bb27569..21638b05c 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -1,16 +1,18 @@ -from pathlib import Path -import re import logging +import re from collections import namedtuple +from enum import Enum as PythonEnum +from pathlib import Path +from typing import List, Set, Dict, Optional -from slither.core.declarations import SolidityFunction +from slither.core.declarations import SolidityFunction, Enum +from slither.core.declarations.contract import Contract +from slither.core.declarations.structure import Structure from slither.core.solidity_types import MappingType, ArrayType -from slither.exceptions import SlitherException from slither.core.solidity_types.user_defined_type import UserDefinedType -from slither.core.declarations.structure import Structure -from slither.core.declarations.enum import Enum -from slither.core.declarations.contract import Contract +from slither.exceptions import SlitherException from slither.slithir.operations import NewContract, TypeConversion, SolidityCall +from slither.tools.flattening.export.export import Export, export_as_json, save_to_zip, save_to_disk logger = logging.getLogger("Slither-flattening") @@ -19,36 +21,65 @@ logger = logging.getLogger("Slither-flattening") # - public_to_external: public to external (external-to-public) # - calldata_to_memory: calldata to memory (external-to-public) # - line_removal: remove the line (remove-assert) -Patch = namedtuple('PatchExternal', ['index', 'patch_type']) +Patch = namedtuple("PatchExternal", ["index", "patch_type"]) -class Flattening: - DEFAULT_EXPORT_PATH = Path('crytic-export/flattening') +class Strategy(PythonEnum): + MostDerived = 0 + OneFile = 1 + LocalImport = 2 + + +STRATEGIES_NAMES = ",".join([i.name for i in Strategy]) + +DEFAULT_EXPORT_PATH = Path("crytic-export/flattening") + - def __init__(self, slither, external_to_public=False, remove_assert=False, private_to_internal=False): - self._source_codes = {} +class Flattening: + def __init__( + self, + slither, + external_to_public=False, + remove_assert=False, + private_to_internal=False, + export_path: Optional[str] = None, + ): + self._source_codes: Dict[Contract, str] = {} self._slither = slither self._external_to_public = external_to_public self._remove_assert = remove_assert self._use_abi_encoder_v2 = False self._private_to_internal = private_to_internal + self._export_path: Path = DEFAULT_EXPORT_PATH if export_path is None else Path(export_path) + self._check_abi_encoder_v2() for contract in slither.contracts: self._get_source_code(contract) def _check_abi_encoder_v2(self): + """ + Check if ABIEncoderV2 is required + Set _use_abi_encorder_v2 + :return: + """ for p in self._slither.pragma_directives: - if 'ABIEncoderV2' in str(p.directive): + if "ABIEncoderV2" in str(p.directive): self._use_abi_encoder_v2 = True return - def _get_source_code(self, contract): + def _get_source_code(self, contract: Contract): + """ + Save the source code of the contract in self._source_codes + Patch the source code + :param contract: + :return: + """ src_mapping = contract.source_mapping - content = self._slither.source_code[src_mapping['filename_absolute']].encode('utf8') - start = src_mapping['start'] - end = src_mapping['start'] + src_mapping['length'] + content = self._slither.source_code[src_mapping["filename_absolute"]].encode("utf8") + start = src_mapping["start"] + end = src_mapping["start"] + src_mapping["length"] to_patch = [] # interface must use external @@ -57,45 +88,57 @@ class Flattening: # fallback must be external if f.is_fallback or f.is_constructor_variables: continue - if f.visibility == 'external': - attributes_start = (f.parameters_src.source_mapping['start'] + - f.parameters_src.source_mapping['length']) - attributes_end = f.returns_src.source_mapping['start'] + if f.visibility == "external": + attributes_start = ( + f.parameters_src.source_mapping["start"] + + f.parameters_src.source_mapping["length"] + ) + attributes_end = f.returns_src.source_mapping["start"] attributes = content[attributes_start:attributes_end] - regex = re.search(r'((\sexternal)\s+)|(\sexternal)$|(\)external)$', attributes) + regex = re.search(r"((\sexternal)\s+)|(\sexternal)$|(\)external)$", attributes) if regex: - to_patch.append(Patch(attributes_start + regex.span()[0] + 1, 'public_to_external')) + to_patch.append( + Patch(attributes_start + regex.span()[0] + 1, "public_to_external") + ) else: - raise SlitherException(f'External keyword not found {f.name} {attributes}') + raise SlitherException(f"External keyword not found {f.name} {attributes}") for var in f.parameters: if var.location == "calldata": - calldata_start = var.source_mapping['start'] - calldata_end = calldata_start + var.source_mapping['length'] - calldata_idx = content[calldata_start:calldata_end].find(' calldata ') - to_patch.append(Patch(calldata_start + calldata_idx + 1, 'calldata_to_memory')) + calldata_start = var.source_mapping["start"] + calldata_end = calldata_start + var.source_mapping["length"] + calldata_idx = content[calldata_start:calldata_end].find(" calldata ") + to_patch.append( + Patch(calldata_start + calldata_idx + 1, "calldata_to_memory") + ) if self._private_to_internal: for variable in contract.state_variables_declared: - if variable.visibility == 'private': + if variable.visibility == "private": print(variable.source_mapping) - attributes_start = variable.source_mapping['start'] - attributes_end = attributes_start + variable.source_mapping['length'] + attributes_start = variable.source_mapping["start"] + attributes_end = attributes_start + variable.source_mapping["length"] attributes = content[attributes_start:attributes_end] print(attributes) - regex = re.search(r' private ', attributes) + regex = re.search(r" private ", attributes) if regex: - to_patch.append(Patch(attributes_start + regex.span()[0] + 1, 'private_to_internal')) + to_patch.append( + Patch(attributes_start + regex.span()[0] + 1, "private_to_internal") + ) else: - raise SlitherException(f'private keyword not found {v.name} {attributes}') + raise SlitherException(f"private keyword not found {v.name} {attributes}") if self._remove_assert: for function in contract.functions_and_modifiers_declared: for node in function.nodes: for ir in node.irs: - if isinstance(ir, SolidityCall) and ir.function == SolidityFunction('assert(bool)'): - to_patch.append(Patch(node.source_mapping['start'], 'line_removal')) - logger.info(f'Code commented: {node.expression} ({node.source_mapping_str})') + if isinstance(ir, SolidityCall) and ir.function == SolidityFunction( + "assert(bool)" + ): + to_patch.append(Patch(node.source_mapping["start"], "line_removal")) + logger.info( + f"Code commented: {node.expression} ({node.source_mapping_str})" + ) to_patch.sort(key=lambda x: x.index, reverse=True) @@ -104,39 +147,53 @@ class Flattening: patch_type = patch.patch_type index = patch.index index = index - start - if patch_type == 'public_to_external': - content = content[:index] + 'public' + content[index + len('external'):] - if patch_type == 'private_to_internal': - content = content[:index] + 'internal' + content[index + len('private'):] - elif patch_type == 'calldata_to_memory': - content = content[:index] + 'memory' + content[index + len('calldata'):] + if patch_type == "public_to_external": + content = content[:index] + "public" + content[index + len("external") :] + if patch_type == "private_to_internal": + content = content[:index] + "internal" + content[index + len("private") :] + elif patch_type == "calldata_to_memory": + content = content[:index] + "memory" + content[index + len("calldata") :] else: - assert patch_type == 'line_removal' - content = content[:index] + ' // ' + content[index:] + assert patch_type == "line_removal" + content = content[:index] + " // " + content[index:] - self._source_codes[contract] = content.decode('utf8') + self._source_codes[contract] = content.decode("utf8") + + def _pragmas(self) -> str: + """ + Return the required pragmas + :return: + """ + ret = "" + if self._slither.solc_version: + ret += f"pragma solidity {self._slither.solc_version};\n" + if self._use_abi_encoder_v2: + ret += "pragma experimental ABIEncoderV2;\n" + return ret def _export_from_type(self, t, contract, exported, list_contract): if isinstance(t, UserDefinedType): if isinstance(t.type, (Enum, Structure)): if t.type.contract != contract and t.type.contract not in exported: - self._export_contract(t.type.contract, exported, list_contract) + self._export_list_used_contracts(t.type.contract, exported, list_contract) else: assert isinstance(t.type, Contract) if t.type != contract and t.type not in exported: - self._export_contract(t.type, exported, list_contract) + self._export_list_used_contracts(t.type, exported, list_contract) elif isinstance(t, MappingType): self._export_from_type(t.type_from, contract, exported, list_contract) self._export_from_type(t.type_to, contract, exported, list_contract) elif isinstance(t, ArrayType): self._export_from_type(t.type, contract, exported, list_contract) - def _export_contract(self, contract, exported, list_contract): + def _export_list_used_contracts( + self, contract: Contract, exported: Set[str], list_contract: List[Contract] + ): if contract.name in exported: return exported.add(contract.name) for inherited in contract.inheritance: - self._export_contract(inherited, exported, list_contract) + self._export_list_used_contracts(inherited, exported, list_contract) # Find all the external contracts called externals = contract.all_library_calls + contract.all_high_level_calls @@ -145,7 +202,7 @@ class Flattening: externals = list(set([e[0] for e in externals if e[0] != contract])) for inherited in externals: - self._export_contract(inherited, exported, list_contract) + self._export_list_used_contracts(inherited, exported, list_contract) # Find all the external contracts use as a base type local_vars = [] @@ -164,36 +221,111 @@ class Flattening: for ir in f.slithir_operations: if isinstance(ir, NewContract): if ir.contract_created != contract and not ir.contract_created in exported: - self._export_contract(ir.contract_created, exported, list_contract) + self._export_list_used_contracts( + ir.contract_created, exported, list_contract + ) if isinstance(ir, TypeConversion): self._export_from_type(ir.type, contract, exported, list_contract) - list_contract.append(self._source_codes[contract]) + if contract not in list_contract: + list_contract.append(contract) + + def _export_contract_with_inheritance(self, contract) -> Export: + list_contracts: List[Contract] = [] # will contain contract itself + self._export_list_used_contracts(contract, set(), list_contracts) + path = Path(self._export_path, f"{contract.name}.sol") - def _export(self, contract, ret): - self._export_contract(contract, set(), ret) - path = Path(self.DEFAULT_EXPORT_PATH, f'{contract.name}.sol') - logger.info(f'Export {path}') - with open(path, 'w') as f: - if self._slither.solc_version: - f.write(f'pragma solidity {self._slither.solc_version};\n') - if self._use_abi_encoder_v2: - f.write('pragma experimental ABIEncoderV2;\n') - f.write('\n'.join(ret)) - f.write('\n') + content = "" + content += self._pragmas() - def export(self, target=None): + for contract in list_contracts: + content += self._source_codes[contract] + content += "\n" - if not self.DEFAULT_EXPORT_PATH.exists(): - self.DEFAULT_EXPORT_PATH.mkdir(parents=True) + return Export(filename=path, content=content) + def _export_most_derived(self) -> List[Export]: + ret: List[Export] = [] + for contract in self._slither.contracts_derived: + ret.append(self._export_contract_with_inheritance(contract)) + return ret + + def _export_all(self) -> List[Export]: + path = Path(self._export_path, f"export.sol") + + content = "" + content += self._pragmas() + + contract_seen = set() + contract_to_explore = list(self._slither.contracts) + + # We only need the inheritance order here, as solc can compile + # a contract that use another contract type (ex: state variable) that he has not seen yet + while contract_to_explore: + next = contract_to_explore.pop(0) + + if not next.inheritance or all( + (father in contract_seen for father in next.inheritance) + ): + content += "\n" + content += self._source_codes[next] + content += "\n" + contract_seen.add(next) + else: + contract_to_explore.append(next) + + return [Export(filename=path, content=content)] + + def _export_with_import(self) -> List[Export]: + exports: List[Export] = [] + for contract in self._slither.contracts: + list_contracts: List[Contract] = [] # will contain contract itself + self._export_list_used_contracts(contract, set(), list_contracts) + + path = Path(self._export_path, f"{contract.name}.sol") + + content = "" + content += self._pragmas() + for used_contract in list_contracts: + if used_contract != contract: + content += f"import './{used_contract.name}.sol';\n" + content += "\n" + content += self._source_codes[contract] + content += "\n" + exports.append(Export(filename=path, content=content)) + return exports + + def export( + self, + strategy: Strategy, + target: Optional[str] = None, + json: Optional[str] = None, + zip: Optional[str] = None, + zip_type: Optional[str] = None, + ): + + if not self._export_path.exists(): + self._export_path.mkdir(parents=True) + + exports: List[Export] = [] if target is None: - for contract in self._slither.contracts_derived: - ret = [] - self._export(contract, ret) + if strategy == Strategy.MostDerived: + exports = self._export_most_derived() + elif strategy == Strategy.OneFile: + exports = self._export_all() + elif strategy == Strategy.LocalImport: + exports = self._export_with_import() else: contract = self._slither.get_contract_from_name(target) if contract is None: - logger.error(f'{target} not found') - else: - ret = [] - self._export(contract, ret) + logger.error(f"{target} not found") + return + exports = [self._export_contract_with_inheritance(contract)] + + if json: + export_as_json(exports, json) + + elif zip: + save_to_zip(exports, zip, zip_type) + + else: + save_to_disk(exports)