diff --git a/slither/tools/flattening/__main__.py b/slither/tools/flattening/__main__.py index 735d924a2..5c6f2c24b 100644 --- a/slither/tools/flattening/__main__.py +++ b/slither/tools/flattening/__main__.py @@ -24,6 +24,10 @@ def parse_args(): help='Convert external to public.', action='store_true') + parser.add_argument('--remove-assert', + help='Remove call to assert().', + action='store_true') + parser.add_argument('--contract', help='Flatten a specific contract (default: all most derived contracts).', default=None) @@ -38,7 +42,7 @@ def main(): args = parse_args() slither = Slither(args.filename, **vars(args)) - flat = Flattening(slither, external_to_public=args.convert_external) + flat = Flattening(slither, external_to_public=args.convert_external, remove_assert=args.remove_assert) flat.export(target=args.contract) diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index df4686368..1fdf73630 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -1,23 +1,33 @@ from pathlib import Path import re import logging +from collections import namedtuple + +from slither.core.declarations import SolidityFunction from slither.exceptions import SlitherException from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.declarations.structure import Structure from slither.core.declarations.enum import Enum from slither.core.declarations.contract import Contract -from slither.slithir.operations import NewContract, TypeConversion +from slither.slithir.operations import NewContract, TypeConversion, SolidityCall logger = logging.getLogger("Slither-flattening") -class Flattening: +# index: where to start +# patch_type: +# - public_to_external: public to external (external-to-public) +# - calldata_to_memory: calldata to memory (external-to-public) +# - line_removal: remove the line (remove-assert) +Patch = namedtuple('PatchExternal', ['index', 'patch_type']) +class Flattening: DEFAULT_EXPORT_PATH = Path('crytic-export/flattening') - def __init__(self, slither, external_to_public=False): + def __init__(self, slither, external_to_public=False, remove_assert=False): self._source_codes = {} self._slither = slither self._external_to_public = external_to_public + self._remove_assert = remove_assert self._use_abi_encoder_v2 = False self._check_abi_encoder_v2() @@ -36,13 +46,11 @@ class Flattening: content = self._slither.source_code[src_mapping['filename_absolute']] start = src_mapping['start'] end = src_mapping['start'] + src_mapping['length'] + first_line = src_mapping['lines'][0] + to_patch = [] # interface must use external if self._external_to_public and contract.contract_kind != "interface": - # to_patch is a list of (index, bool). The bool indicates - # if the index is for external -> public (true) - # or a calldata -> memory (false) - to_patch = [] for f in contract.functions_declared: # fallback must be external if f.is_fallback or f.is_constructor_variables: @@ -54,7 +62,7 @@ class Flattening: attributes = content[attributes_start:attributes_end] regex = re.search(r'((\sexternal)\s+)|(\sexternal)$|(\)external)$', attributes) if regex: - to_patch.append((attributes_start + regex.span()[0] + 1, True)) + to_patch.append(Patch(attributes_start + regex.span()[0] + 1, 'public_to_external')) else: raise SlitherException(f'External keyword not found {f.name} {attributes}') @@ -63,23 +71,33 @@ class Flattening: calldata_start = var.source_mapping['start'] calldata_end = calldata_start + var.source_mapping['length'] calldata_idx = content[calldata_start:calldata_end].find(' calldata ') - to_patch.append((calldata_start + calldata_idx + 1, False)) - - to_patch.sort(key=lambda x:x[0], reverse=True) - - content = content[start:end] - for (index, is_external) in to_patch: - index = index - start - if is_external: - content = content[:index] + 'public' + content[index + len('external'):] - else: - content = content[:index] + 'memory' + content[index + len('calldata'):] - else: - content = content[start:end] + to_patch.append(Patch(calldata_start + calldata_idx + 1, 'calldata_to_memory')) + + if self._remove_assert: + for function in contract.functions_and_modifiers_declared: + for node in function.nodes: + for ir in node.irs: + if isinstance(ir, SolidityCall) and ir.function == SolidityFunction('assert(bool)'): + to_patch.append(Patch(node.source_mapping['start'], 'line_removal')) + logger.info(f'Code commented: {node.expression} ({node.source_mapping_str})') + + to_patch.sort(key=lambda x: x.index, reverse=True) + + content = content[start:end] + for patch in to_patch: + patch_type = patch.patch_type + index = patch.index + index = index - start + if patch_type == 'public_to_external': + content = content[:index] + 'public' + content[index + len('external'):] + elif patch_type == 'calldata_to_memory': + content = content[:index] + 'memory' + content[index + len('calldata'):] + else: + assert patch_type == 'line_removal' + content = content[:index] + ' // ' + content[index:] self._source_codes[contract] = content - def _export_from_type(self, t, contract, exported, list_contract): if isinstance(t, UserDefinedType): if isinstance(t.type, (Enum, Structure)): @@ -152,4 +170,3 @@ class Flattening: else: ret = [] self._export(contract, ret) - diff --git a/tests/flat/file1.sol b/tests/flat/file1.sol new file mode 100644 index 000000000..de3d0e0f8 --- /dev/null +++ b/tests/flat/file1.sol @@ -0,0 +1,7 @@ +contract A{ + + function test(bytes calldata b) external{ + + } + +} diff --git a/tests/flat/file2.sol b/tests/flat/file2.sol new file mode 100644 index 000000000..603b9ffa7 --- /dev/null +++ b/tests/flat/file2.sol @@ -0,0 +1,9 @@ +import "file1.sol"; + +contract B is A{ + + function test() public{ + assert(true); + } + +}