Several improvements to slither-flat:

- Allow to specify a strategy (as defined in #495)
- Add json and zip export
- Clean up architecture
- Clean up logging and --help
pull/496/head
Josselin 5 years ago
parent 4648b110a9
commit ed34d85617
  1. 112
      slither/tools/flattening/__main__.py
  2. 0
      slither/tools/flattening/export/__init__.y.py
  3. 55
      slither/tools/flattening/export/export.py
  4. 280
      slither/tools/flattening/flattening.py

@ -1,44 +1,88 @@
import argparse
import logging
from slither import Slither
import sys
from crytic_compile import cryticparser
from .flattening import Flattening
from crytic_compile.utils.zip import ZIP_TYPES_ACCEPTED
from slither import Slither
from slither.tools.flattening.flattening import (
Flattening,
Strategy,
STRATEGIES_NAMES,
DEFAULT_EXPORT_PATH,
)
logging.basicConfig()
logging.getLogger("Slither").setLevel(logging.INFO)
logger = logging.getLogger("Slither-flattening")
logger = logging.getLogger("Slither")
logger.setLevel(logging.INFO)
def parse_args():
"""
Parse the underlying arguments for the program.
:return: Returns the arguments for the program.
"""
parser = argparse.ArgumentParser(description='Contracts flattening',
usage='slither-flat filename')
parser = argparse.ArgumentParser(
description="Contracts flattening. See https://github.com/crytic/slither/wiki/Contract-Flattening",
usage="slither-flat filename",
)
parser.add_argument("filename", help="The filename of the contract or project to analyze.")
parser.add_argument("--contract", help="Flatten one contract.", default=None)
parser.add_argument('filename',
help='The filename of the contract or project to analyze.')
parser.add_argument(
"--strategy",
help=f"Flatenning strategy: {STRATEGIES_NAMES} (default: MostDerived).",
default=Strategy.MostDerived.name,
)
parser.add_argument('--convert-external',
help='Convert external to public.',
action='store_true')
group_export = parser.add_argument_group("Export options")
parser.add_argument('--convert-private',
help='Convert private variables to internal.',
action='store_true')
group_export.add_argument(
"--dir", help=f"Export directory (default: {DEFAULT_EXPORT_PATH}).", default=None
)
parser.add_argument('--remove-assert',
help='Remove call to assert().',
action='store_true')
group_export.add_argument(
"--json",
help='Export the results as a JSON file ("--json -" to export to stdout)',
action="store",
default=None,
)
parser.add_argument('--contract',
help='Flatten a specific contract (default: all most derived contracts).',
default=None)
parser.add_argument(
"--zip", help="Export all the files to a zip file", action="store", default=None,
)
parser.add_argument(
"--zip-type",
help=f"Zip compression type. One of {','.join(ZIP_TYPES_ACCEPTED.keys())}. Default lzma",
action="store",
default=None,
)
group_patching = parser.add_argument_group("Patching options")
group_patching.add_argument(
"--convert-external", help="Convert external to public.", action="store_true"
)
group_patching.add_argument(
"--convert-private", help="Convert private variables to internal.", action="store_true"
)
group_patching.add_argument(
"--remove-assert", help="Remove call to assert().", action="store_true"
)
# Add default arguments from crytic-compile
cryticparser.init(parser)
if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1)
return parser.parse_args()
@ -46,13 +90,29 @@ def main():
args = parse_args()
slither = Slither(args.filename, **vars(args))
flat = Flattening(slither,
external_to_public=args.convert_external,
remove_assert=args.remove_assert,
private_to_internal=args.convert_private)
flat = Flattening(
slither,
external_to_public=args.convert_external,
remove_assert=args.remove_assert,
private_to_internal=args.convert_private,
export_path=args.dir,
)
flat.export(target=args.contract)
try:
strategy = Strategy[args.strategy]
except KeyError:
logger.error(
f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)"
)
return
flat.export(
strategy=strategy,
target=args.contract,
json=args.json,
zip=args.zip,
zip_type=args.zip_type,
)
if __name__ == '__main__':
if __name__ == "__main__":
main()

@ -0,0 +1,55 @@
import json
import logging
# https://docs.python.org/3/library/zipfile.html#zipfile-objects
import zipfile
from collections import namedtuple
from typing import List
ZIP_TYPES_ACCEPTED = {
"lzma": zipfile.ZIP_LZMA,
"stored": zipfile.ZIP_STORED,
"deflated": zipfile.ZIP_DEFLATED,
"bzip2": zipfile.ZIP_BZIP2,
}
Export = namedtuple("Export", ["filename", "content"])
logger = logging.getLogger("Slither")
def save_to_zip(files: List[Export], zip_filename: str, zip_type: str = "lzma"):
"""
Save projects to a zip
"""
logger.info(f"Export {zip_filename}")
with zipfile.ZipFile(
zip_filename, "w", compression=ZIP_TYPES_ACCEPTED.get(zip_type, zipfile.ZIP_LZMA)
) as file_desc:
for f in files:
file_desc.writestr(str(f.filename), f.content)
def save_to_disk(files: List[Export]):
"""
Save projects to a zip
"""
for file in files:
with open(file.filename, "w") as f:
logger.info(f"Export {file.filename}")
f.write(file.content)
def export_as_json(files: List[Export], filename: str):
"""
Save projects to a zip
"""
files_as_dict = {str(f.filename): f.content for f in files}
if filename == "-":
print(json.dumps(files_as_dict))
else:
logger.info(f"Export {filename}")
with open(filename, "w") as f:
json.dump(files_as_dict, f)

@ -1,16 +1,18 @@
from pathlib import Path
import re
import logging
import re
from collections import namedtuple
from enum import Enum as PythonEnum
from pathlib import Path
from typing import List, Set, Dict, Optional
from slither.core.declarations import SolidityFunction
from slither.core.declarations import SolidityFunction, Enum
from slither.core.declarations.contract import Contract
from slither.core.declarations.structure import Structure
from slither.core.solidity_types import MappingType, ArrayType
from slither.exceptions import SlitherException
from slither.core.solidity_types.user_defined_type import UserDefinedType
from slither.core.declarations.structure import Structure
from slither.core.declarations.enum import Enum
from slither.core.declarations.contract import Contract
from slither.exceptions import SlitherException
from slither.slithir.operations import NewContract, TypeConversion, SolidityCall
from slither.tools.flattening.export.export import Export, export_as_json, save_to_zip, save_to_disk
logger = logging.getLogger("Slither-flattening")
@ -19,36 +21,65 @@ logger = logging.getLogger("Slither-flattening")
# - public_to_external: public to external (external-to-public)
# - calldata_to_memory: calldata to memory (external-to-public)
# - line_removal: remove the line (remove-assert)
Patch = namedtuple('PatchExternal', ['index', 'patch_type'])
Patch = namedtuple("PatchExternal", ["index", "patch_type"])
class Flattening:
DEFAULT_EXPORT_PATH = Path('crytic-export/flattening')
class Strategy(PythonEnum):
MostDerived = 0
OneFile = 1
LocalImport = 2
STRATEGIES_NAMES = ",".join([i.name for i in Strategy])
DEFAULT_EXPORT_PATH = Path("crytic-export/flattening")
def __init__(self, slither, external_to_public=False, remove_assert=False, private_to_internal=False):
self._source_codes = {}
class Flattening:
def __init__(
self,
slither,
external_to_public=False,
remove_assert=False,
private_to_internal=False,
export_path: Optional[str] = None,
):
self._source_codes: Dict[Contract, str] = {}
self._slither = slither
self._external_to_public = external_to_public
self._remove_assert = remove_assert
self._use_abi_encoder_v2 = False
self._private_to_internal = private_to_internal
self._export_path: Path = DEFAULT_EXPORT_PATH if export_path is None else Path(export_path)
self._check_abi_encoder_v2()
for contract in slither.contracts:
self._get_source_code(contract)
def _check_abi_encoder_v2(self):
"""
Check if ABIEncoderV2 is required
Set _use_abi_encorder_v2
:return:
"""
for p in self._slither.pragma_directives:
if 'ABIEncoderV2' in str(p.directive):
if "ABIEncoderV2" in str(p.directive):
self._use_abi_encoder_v2 = True
return
def _get_source_code(self, contract):
def _get_source_code(self, contract: Contract):
"""
Save the source code of the contract in self._source_codes
Patch the source code
:param contract:
:return:
"""
src_mapping = contract.source_mapping
content = self._slither.source_code[src_mapping['filename_absolute']].encode('utf8')
start = src_mapping['start']
end = src_mapping['start'] + src_mapping['length']
content = self._slither.source_code[src_mapping["filename_absolute"]].encode("utf8")
start = src_mapping["start"]
end = src_mapping["start"] + src_mapping["length"]
to_patch = []
# interface must use external
@ -57,45 +88,57 @@ class Flattening:
# fallback must be external
if f.is_fallback or f.is_constructor_variables:
continue
if f.visibility == 'external':
attributes_start = (f.parameters_src.source_mapping['start'] +
f.parameters_src.source_mapping['length'])
attributes_end = f.returns_src.source_mapping['start']
if f.visibility == "external":
attributes_start = (
f.parameters_src.source_mapping["start"]
+ f.parameters_src.source_mapping["length"]
)
attributes_end = f.returns_src.source_mapping["start"]
attributes = content[attributes_start:attributes_end]
regex = re.search(r'((\sexternal)\s+)|(\sexternal)$|(\)external)$', attributes)
regex = re.search(r"((\sexternal)\s+)|(\sexternal)$|(\)external)$", attributes)
if regex:
to_patch.append(Patch(attributes_start + regex.span()[0] + 1, 'public_to_external'))
to_patch.append(
Patch(attributes_start + regex.span()[0] + 1, "public_to_external")
)
else:
raise SlitherException(f'External keyword not found {f.name} {attributes}')
raise SlitherException(f"External keyword not found {f.name} {attributes}")
for var in f.parameters:
if var.location == "calldata":
calldata_start = var.source_mapping['start']
calldata_end = calldata_start + var.source_mapping['length']
calldata_idx = content[calldata_start:calldata_end].find(' calldata ')
to_patch.append(Patch(calldata_start + calldata_idx + 1, 'calldata_to_memory'))
calldata_start = var.source_mapping["start"]
calldata_end = calldata_start + var.source_mapping["length"]
calldata_idx = content[calldata_start:calldata_end].find(" calldata ")
to_patch.append(
Patch(calldata_start + calldata_idx + 1, "calldata_to_memory")
)
if self._private_to_internal:
for variable in contract.state_variables_declared:
if variable.visibility == 'private':
if variable.visibility == "private":
print(variable.source_mapping)
attributes_start = variable.source_mapping['start']
attributes_end = attributes_start + variable.source_mapping['length']
attributes_start = variable.source_mapping["start"]
attributes_end = attributes_start + variable.source_mapping["length"]
attributes = content[attributes_start:attributes_end]
print(attributes)
regex = re.search(r' private ', attributes)
regex = re.search(r" private ", attributes)
if regex:
to_patch.append(Patch(attributes_start + regex.span()[0] + 1, 'private_to_internal'))
to_patch.append(
Patch(attributes_start + regex.span()[0] + 1, "private_to_internal")
)
else:
raise SlitherException(f'private keyword not found {v.name} {attributes}')
raise SlitherException(f"private keyword not found {v.name} {attributes}")
if self._remove_assert:
for function in contract.functions_and_modifiers_declared:
for node in function.nodes:
for ir in node.irs:
if isinstance(ir, SolidityCall) and ir.function == SolidityFunction('assert(bool)'):
to_patch.append(Patch(node.source_mapping['start'], 'line_removal'))
logger.info(f'Code commented: {node.expression} ({node.source_mapping_str})')
if isinstance(ir, SolidityCall) and ir.function == SolidityFunction(
"assert(bool)"
):
to_patch.append(Patch(node.source_mapping["start"], "line_removal"))
logger.info(
f"Code commented: {node.expression} ({node.source_mapping_str})"
)
to_patch.sort(key=lambda x: x.index, reverse=True)
@ -104,39 +147,53 @@ class Flattening:
patch_type = patch.patch_type
index = patch.index
index = index - start
if patch_type == 'public_to_external':
content = content[:index] + 'public' + content[index + len('external'):]
if patch_type == 'private_to_internal':
content = content[:index] + 'internal' + content[index + len('private'):]
elif patch_type == 'calldata_to_memory':
content = content[:index] + 'memory' + content[index + len('calldata'):]
if patch_type == "public_to_external":
content = content[:index] + "public" + content[index + len("external") :]
if patch_type == "private_to_internal":
content = content[:index] + "internal" + content[index + len("private") :]
elif patch_type == "calldata_to_memory":
content = content[:index] + "memory" + content[index + len("calldata") :]
else:
assert patch_type == 'line_removal'
content = content[:index] + ' // ' + content[index:]
assert patch_type == "line_removal"
content = content[:index] + " // " + content[index:]
self._source_codes[contract] = content.decode('utf8')
self._source_codes[contract] = content.decode("utf8")
def _pragmas(self) -> str:
"""
Return the required pragmas
:return:
"""
ret = ""
if self._slither.solc_version:
ret += f"pragma solidity {self._slither.solc_version};\n"
if self._use_abi_encoder_v2:
ret += "pragma experimental ABIEncoderV2;\n"
return ret
def _export_from_type(self, t, contract, exported, list_contract):
if isinstance(t, UserDefinedType):
if isinstance(t.type, (Enum, Structure)):
if t.type.contract != contract and t.type.contract not in exported:
self._export_contract(t.type.contract, exported, list_contract)
self._export_list_used_contracts(t.type.contract, exported, list_contract)
else:
assert isinstance(t.type, Contract)
if t.type != contract and t.type not in exported:
self._export_contract(t.type, exported, list_contract)
self._export_list_used_contracts(t.type, exported, list_contract)
elif isinstance(t, MappingType):
self._export_from_type(t.type_from, contract, exported, list_contract)
self._export_from_type(t.type_to, contract, exported, list_contract)
elif isinstance(t, ArrayType):
self._export_from_type(t.type, contract, exported, list_contract)
def _export_contract(self, contract, exported, list_contract):
def _export_list_used_contracts(
self, contract: Contract, exported: Set[str], list_contract: List[Contract]
):
if contract.name in exported:
return
exported.add(contract.name)
for inherited in contract.inheritance:
self._export_contract(inherited, exported, list_contract)
self._export_list_used_contracts(inherited, exported, list_contract)
# Find all the external contracts called
externals = contract.all_library_calls + contract.all_high_level_calls
@ -145,7 +202,7 @@ class Flattening:
externals = list(set([e[0] for e in externals if e[0] != contract]))
for inherited in externals:
self._export_contract(inherited, exported, list_contract)
self._export_list_used_contracts(inherited, exported, list_contract)
# Find all the external contracts use as a base type
local_vars = []
@ -164,36 +221,111 @@ class Flattening:
for ir in f.slithir_operations:
if isinstance(ir, NewContract):
if ir.contract_created != contract and not ir.contract_created in exported:
self._export_contract(ir.contract_created, exported, list_contract)
self._export_list_used_contracts(
ir.contract_created, exported, list_contract
)
if isinstance(ir, TypeConversion):
self._export_from_type(ir.type, contract, exported, list_contract)
list_contract.append(self._source_codes[contract])
if contract not in list_contract:
list_contract.append(contract)
def _export_contract_with_inheritance(self, contract) -> Export:
list_contracts: List[Contract] = [] # will contain contract itself
self._export_list_used_contracts(contract, set(), list_contracts)
path = Path(self._export_path, f"{contract.name}.sol")
def _export(self, contract, ret):
self._export_contract(contract, set(), ret)
path = Path(self.DEFAULT_EXPORT_PATH, f'{contract.name}.sol')
logger.info(f'Export {path}')
with open(path, 'w') as f:
if self._slither.solc_version:
f.write(f'pragma solidity {self._slither.solc_version};\n')
if self._use_abi_encoder_v2:
f.write('pragma experimental ABIEncoderV2;\n')
f.write('\n'.join(ret))
f.write('\n')
content = ""
content += self._pragmas()
def export(self, target=None):
for contract in list_contracts:
content += self._source_codes[contract]
content += "\n"
if not self.DEFAULT_EXPORT_PATH.exists():
self.DEFAULT_EXPORT_PATH.mkdir(parents=True)
return Export(filename=path, content=content)
def _export_most_derived(self) -> List[Export]:
ret: List[Export] = []
for contract in self._slither.contracts_derived:
ret.append(self._export_contract_with_inheritance(contract))
return ret
def _export_all(self) -> List[Export]:
path = Path(self._export_path, f"export.sol")
content = ""
content += self._pragmas()
contract_seen = set()
contract_to_explore = list(self._slither.contracts)
# We only need the inheritance order here, as solc can compile
# a contract that use another contract type (ex: state variable) that he has not seen yet
while contract_to_explore:
next = contract_to_explore.pop(0)
if not next.inheritance or all(
(father in contract_seen for father in next.inheritance)
):
content += "\n"
content += self._source_codes[next]
content += "\n"
contract_seen.add(next)
else:
contract_to_explore.append(next)
return [Export(filename=path, content=content)]
def _export_with_import(self) -> List[Export]:
exports: List[Export] = []
for contract in self._slither.contracts:
list_contracts: List[Contract] = [] # will contain contract itself
self._export_list_used_contracts(contract, set(), list_contracts)
path = Path(self._export_path, f"{contract.name}.sol")
content = ""
content += self._pragmas()
for used_contract in list_contracts:
if used_contract != contract:
content += f"import './{used_contract.name}.sol';\n"
content += "\n"
content += self._source_codes[contract]
content += "\n"
exports.append(Export(filename=path, content=content))
return exports
def export(
self,
strategy: Strategy,
target: Optional[str] = None,
json: Optional[str] = None,
zip: Optional[str] = None,
zip_type: Optional[str] = None,
):
if not self._export_path.exists():
self._export_path.mkdir(parents=True)
exports: List[Export] = []
if target is None:
for contract in self._slither.contracts_derived:
ret = []
self._export(contract, ret)
if strategy == Strategy.MostDerived:
exports = self._export_most_derived()
elif strategy == Strategy.OneFile:
exports = self._export_all()
elif strategy == Strategy.LocalImport:
exports = self._export_with_import()
else:
contract = self._slither.get_contract_from_name(target)
if contract is None:
logger.error(f'{target} not found')
else:
ret = []
self._export(contract, ret)
logger.error(f"{target} not found")
return
exports = [self._export_contract_with_inheritance(contract)]
if json:
export_as_json(exports, json)
elif zip:
save_to_zip(exports, zip, zip_type)
else:
save_to_disk(exports)

Loading…
Cancel
Save