From 4319bb360542cab6a89d9bacd039d4c70b8b76e5 Mon Sep 17 00:00:00 2001 From: Josselin Date: Mon, 22 Jun 2020 11:46:04 +0200 Subject: [PATCH] Merge slither/tools slither/visitors/expression from dev-0.7 --- slither/tools/demo/__main__.py | 14 +- slither/tools/erc_conformance/__main__.py | 39 ++- slither/tools/erc_conformance/erc/erc20.py | 14 +- slither/tools/erc_conformance/erc/ercs.py | 99 +++--- slither/tools/kspec_coverage/__main__.py | 48 ++- slither/tools/kspec_coverage/analysis.py | 95 ++--- .../tools/kspec_coverage/kspec_coverage.py | 3 +- slither/tools/possible_paths/__main__.py | 21 +- .../tools/possible_paths/possible_paths.py | 39 ++- slither/tools/properties/__main__.py | 98 +++--- slither/tools/properties/addresses/address.py | 8 +- slither/tools/properties/platforms/echidna.py | 6 +- slither/tools/properties/platforms/truffle.py | 68 ++-- slither/tools/properties/properties/erc20.py | 130 +++---- .../properties/ercs/erc20/properties/burn.py | 48 +-- .../ercs/erc20/properties/initialization.py | 135 ++++---- .../properties/ercs/erc20/properties/mint.py | 27 +- .../ercs/erc20/properties/mint_and_burn.py | 30 +- .../ercs/erc20/properties/transfer.py | 327 +++++++++--------- .../ercs/erc20/unit_tests/truffle.py | 36 +- .../tools/properties/properties/properties.py | 2 +- .../solidity/generate_properties.py | 69 ++-- slither/tools/properties/utils.py | 20 +- slither/tools/similarity/__main__.py | 103 +++--- slither/tools/similarity/cache.py | 8 +- slither/tools/similarity/encode.py | 167 +++++---- slither/tools/similarity/info.py | 19 +- slither/tools/similarity/plot.py | 45 +-- slither/tools/similarity/similarity.py | 1 + slither/tools/similarity/test.py | 25 +- slither/tools/similarity/train.py | 39 ++- slither/tools/slither_format/__main__.py | 115 +++--- .../tools/slither_format/slither_format.py | 120 +++---- slither/tools/upgradeability/__main__.py | 158 +++++---- .../upgradeability/checks/abstract_checks.py | 94 +++-- .../tools/upgradeability/checks/all_checks.py | 22 +- .../tools/upgradeability/checks/constant.py | 55 +-- .../upgradeability/checks/functions_ids.py | 83 +++-- .../upgradeability/checks/initialization.py | 173 ++++----- .../checks/variable_initialization.py | 22 +- .../upgradeability/checks/variables_order.py | 108 +++--- .../upgradeability/utils/command_line.py | 104 +++--- .../visitors/expression/constants_folding.py | 28 +- slither/visitors/expression/export_values.py | 7 +- slither/visitors/expression/expression.py | 20 +- .../visitors/expression/expression_printer.py | 21 +- slither/visitors/expression/find_calls.py | 12 +- slither/visitors/expression/find_push.py | 8 +- .../visitors/expression/has_conditional.py | 7 +- slither/visitors/expression/left_value.py | 10 +- slither/visitors/expression/read_var.py | 8 +- slither/visitors/expression/right_value.py | 10 +- slither/visitors/expression/write_var.py | 36 +- 53 files changed, 1671 insertions(+), 1333 deletions(-) diff --git a/slither/tools/demo/__main__.py b/slither/tools/demo/__main__.py index 4bee3b449..03dc984b6 100644 --- a/slither/tools/demo/__main__.py +++ b/slither/tools/demo/__main__.py @@ -9,16 +9,17 @@ logging.getLogger("Slither").setLevel(logging.INFO) logger = logging.getLogger("Slither-demo") + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='Demo', - usage='slither-demo filename') + parser = argparse.ArgumentParser(description="Demo", usage="slither-demo filename") - parser.add_argument('filename', - help='The filename of the contract or truffle directory to analyze.') + parser.add_argument( + "filename", help="The filename of the contract or truffle directory to analyze." + ) # Add default arguments from crytic-compile cryticparser.init(parser) @@ -32,7 +33,8 @@ def main(): # Perform slither analysis on the given filename slither = Slither(args.filename, **vars(args)) - logger.info('Analysis done!') + logger.info("Analysis done!") + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/slither/tools/erc_conformance/__main__.py b/slither/tools/erc_conformance/__main__.py index fe0fbb845..449ae8669 100644 --- a/slither/tools/erc_conformance/__main__.py +++ b/slither/tools/erc_conformance/__main__.py @@ -17,28 +17,29 @@ logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) -formatter = logging.Formatter('%(message)s') +formatter = logging.Formatter("%(message)s") logger.addHandler(ch) logger.handlers[0].setFormatter(formatter) logger.propagate = False -ADDITIONAL_CHECKS = { - "ERC20": check_erc20 -} +ADDITIONAL_CHECKS = {"ERC20": check_erc20} + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='Check the ERC 20 conformance', - usage='slither-erc project contractName') + parser = argparse.ArgumentParser( + description="Check the ERC 20 conformance", usage="slither-erc project contractName" + ) - parser.add_argument('project', - help='The codebase to be tested.') + parser.add_argument("project", help="The codebase to be tested.") - parser.add_argument('contract_name', - help='The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.') + parser.add_argument( + "contract_name", + help="The name of the contract. Specify the first case contract that follow the standard. Derived contracts will be checked.", + ) parser.add_argument( "--erc", @@ -47,22 +48,26 @@ def parse_args(): default="erc20", ) - parser.add_argument('--json', - help='Export the results as a JSON file ("--json -" to export to stdout)', - action='store', - default=False) + parser.add_argument( + "--json", + help='Export the results as a JSON file ("--json -" to export to stdout)', + action="store", + default=False, + ) # Add default arguments from crytic-compile cryticparser.init(parser) return parser.parse_args() + def _log_error(err, args): if args.json: output_to_json(args.json, str(err), {"upgradeability-check": []}) logger.error(err) + def main(): args = parse_args() @@ -76,7 +81,7 @@ def main(): contract = slither.get_contract_from_name(args.contract_name) if not contract: - err = f'Contract not found: {args.contract_name}' + err = f"Contract not found: {args.contract_name}" _log_error(err, args) return # First elem is the function, second is the event @@ -87,7 +92,7 @@ def main(): ADDITIONAL_CHECKS[args.erc.upper()](contract, ret) else: - err = f'Incorrect ERC selected {args.erc}' + err = f"Incorrect ERC selected {args.erc}" _log_error(err, args) return @@ -95,5 +100,5 @@ def main(): output_to_json(args.json, None, {"upgradeability-check": ret}) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/slither/tools/erc_conformance/erc/erc20.py b/slither/tools/erc_conformance/erc/erc20.py index 25473bc84..720b08322 100644 --- a/slither/tools/erc_conformance/erc/erc20.py +++ b/slither/tools/erc_conformance/erc/erc20.py @@ -6,21 +6,25 @@ logger = logging.getLogger("Slither-conformance") def approval_race_condition(contract, ret): - increaseAllowance = contract.get_function_from_signature('increaseAllowance(address,uint256)') + increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)") if not increaseAllowance: - increaseAllowance = contract.get_function_from_signature('safeIncreaseAllowance(address,uint256)') + increaseAllowance = contract.get_function_from_signature( + "safeIncreaseAllowance(address,uint256)" + ) if increaseAllowance: - txt = f'\t[✓] {contract.name} has {increaseAllowance.full_name}' + txt = f"\t[✓] {contract.name} has {increaseAllowance.full_name}" logger.info(txt) else: - txt = f'\t[ ] {contract.name} is not protected for the ERC20 approval race condition' + txt = f"\t[ ] {contract.name} is not protected for the ERC20 approval race condition" logger.info(txt) lack_of_erc20_race_condition_protection = output.Output(txt) lack_of_erc20_race_condition_protection.add(contract) - ret["lack_of_erc20_race_condition_protection"].append(lack_of_erc20_race_condition_protection.data) + ret["lack_of_erc20_race_condition_protection"].append( + lack_of_erc20_race_condition_protection.data + ) def check_erc20(contract, ret, explored=None): diff --git a/slither/tools/erc_conformance/erc/ercs.py b/slither/tools/erc_conformance/erc/ercs.py index 5af000357..334b2a408 100644 --- a/slither/tools/erc_conformance/erc/ercs.py +++ b/slither/tools/erc_conformance/erc/ercs.py @@ -22,13 +22,15 @@ def _check_signature(erc_function, contract, ret): # The check on state variable is needed until we have a better API to handle state variable getters state_variable_as_function = contract.get_state_variable_from_name(name) - if not state_variable_as_function or not state_variable_as_function.visibility in ['public', 'external']: + if not state_variable_as_function or not state_variable_as_function.visibility in [ + "public", + "external", + ]: txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' logger.info(txt) - missing_func = output.Output(txt, additional_fields={ - "function": sig, - "required": required - }) + missing_func = output.Output( + txt, additional_fields={"function": sig, "required": required} + ) missing_func.add(contract) ret["missing_function"].append(missing_func.data) return @@ -38,10 +40,9 @@ def _check_signature(erc_function, contract, ret): if types != parameters: txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' logger.info(txt) - missing_func = output.Output(txt, additional_fields={ - "function": sig, - "required": required - }) + missing_func = output.Output( + txt, additional_fields={"function": sig, "required": required} + ) missing_func.add(contract) ret["missing_function"].append(missing_func.data) return @@ -53,45 +54,51 @@ def _check_signature(erc_function, contract, ret): function_return_type = function.return_type function_view = function.view - txt = f'[✓] {sig} is present' + txt = f"[✓] {sig} is present" logger.info(txt) if function_return_type: - function_return_type = ','.join([str(x) for x in function_return_type]) + function_return_type = ",".join([str(x) for x in function_return_type]) if function_return_type == return_type: - txt = f'\t[✓] {sig} -> () (correct return value)' + txt = f"\t[✓] {sig} -> () (correct return value)" logger.info(txt) else: - txt = f'\t[ ] {sig} -> () should return {return_type}' + txt = f"\t[ ] {sig} -> () should return {return_type}" logger.info(txt) - incorrect_return = output.Output(txt, additional_fields={ - "expected_return_type": return_type, - "actual_return_type": function_return_type - }) + incorrect_return = output.Output( + txt, + additional_fields={ + "expected_return_type": return_type, + "actual_return_type": function_return_type, + }, + ) incorrect_return.add(function) ret["incorrect_return_type"].append(incorrect_return.data) elif not return_type: - txt = f'\t[✓] {sig} -> () (correct return type)' + txt = f"\t[✓] {sig} -> () (correct return type)" logger.info(txt) else: - txt = f'\t[ ] {sig} -> () should return {return_type}' + txt = f"\t[ ] {sig} -> () should return {return_type}" logger.info(txt) - incorrect_return = output.Output(txt, additional_fields={ - "expected_return_type": return_type, - "actual_return_type": function_return_type - }) + incorrect_return = output.Output( + txt, + additional_fields={ + "expected_return_type": return_type, + "actual_return_type": function_return_type, + }, + ) incorrect_return.add(function) ret["incorrect_return_type"].append(incorrect_return.data) if view: if function_view: - txt = f'\t[✓] {sig} is view' + txt = f"\t[✓] {sig} is view" logger.info(txt) else: - txt = f'\t[ ] {sig} should be view' + txt = f"\t[ ] {sig} should be view" logger.info(txt) should_be_view = output.Output(txt) @@ -103,12 +110,12 @@ def _check_signature(erc_function, contract, ret): event_sig = f'{event.name}({",".join(event.parameters)})' if not function: - txt = f'\t[ ] Must emit be view {event_sig}' + txt = f"\t[ ] Must emit be view {event_sig}" logger.info(txt) - missing_event_emmited = output.Output(txt, additional_fields={ - "missing_event": event_sig - }) + missing_event_emmited = output.Output( + txt, additional_fields={"missing_event": event_sig} + ) missing_event_emmited.add(function) ret["missing_event_emmited"].append(missing_event_emmited.data) @@ -121,15 +128,15 @@ def _check_signature(erc_function, contract, ret): event_found = True break if event_found: - txt = f'\t[✓] {event_sig} is emitted' + txt = f"\t[✓] {event_sig} is emitted" logger.info(txt) else: - txt = f'\t[ ] Must emit be view {event_sig}' + txt = f"\t[ ] Must emit be view {event_sig}" logger.info(txt) - missing_event_emmited = output.Output(txt, additional_fields={ - "missing_event": event_sig - }) + missing_event_emmited = output.Output( + txt, additional_fields={"missing_event": event_sig} + ) missing_event_emmited.add(function) ret["missing_event_emmited"].append(missing_event_emmited.data) @@ -143,31 +150,27 @@ def _check_events(erc_event, contract, ret): event = contract.get_event_from_signature(sig) if not event: - txt = f'[ ] {sig} is missing' + txt = f"[ ] {sig} is missing" logger.info(txt) - missing_event = output.Output(txt, additional_fields={ - "event": sig - }) + missing_event = output.Output(txt, additional_fields={"event": sig}) missing_event.add(contract) ret["missing_event"].append(missing_event.data) return - txt = f'[✓] {sig} is present' + txt = f"[✓] {sig} is present" logger.info(txt) for i, index in enumerate(indexes): if index: if event.elems[i].indexed: - txt = f'\t[✓] parameter {i} is indexed' + txt = f"\t[✓] parameter {i} is indexed" logger.info(txt) else: - txt = f'\t[ ] parameter {i} should be indexed' + txt = f"\t[ ] parameter {i} should be indexed" logger.info(txt) - missing_event_index = output.Output(txt, additional_fields={ - "missing_index": i - }) + missing_event_index = output.Output(txt, additional_fields={"missing_index": i}) missing_event_index.add_event(event) ret["missing_event_index"].append(missing_event_index.data) @@ -179,16 +182,16 @@ def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None): explored.add(contract) - logger.info(f'# Check {contract.name}\n') + logger.info(f"# Check {contract.name}\n") - logger.info(f'## Check functions') + logger.info(f"## Check functions") for erc_function in erc_functions: _check_signature(erc_function, contract, ret) - logger.info(f'\n## Check events') + logger.info(f"\n## Check events") for erc_event in erc_events: _check_events(erc_event, contract, ret) - logger.info('\n') + logger.info("\n") for derived_contract in contract.derived_contracts: generic_erc_checks(derived_contract, erc_functions, erc_events, ret, explored) diff --git a/slither/tools/kspec_coverage/__main__.py b/slither/tools/kspec_coverage/__main__.py index 47dc5c9fa..33bd3a162 100644 --- a/slither/tools/kspec_coverage/__main__.py +++ b/slither/tools/kspec_coverage/__main__.py @@ -11,35 +11,44 @@ logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setLevel(logging.INFO) -formatter = logging.Formatter('%(message)s') +formatter = logging.Formatter("%(message)s") logger.addHandler(ch) logger.handlers[0].setFormatter(formatter) logger.propagate = False + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='slither-kspec-coverage', - usage='slither-kspec-coverage contract.sol kspec.md') - - parser.add_argument('contract', help='The filename of the contract or truffle directory to analyze.') - parser.add_argument('kspec', help='The filename of the Klab spec markdown for the analyzed contract(s)') - - parser.add_argument('--version', help='displays the current version', version='0.1.0',action='version') - parser.add_argument('--json', - help='Export the results as a JSON file ("--json -" to export to stdout)', - action='store', - default=False + parser = argparse.ArgumentParser( + description="slither-kspec-coverage", usage="slither-kspec-coverage contract.sol kspec.md" + ) + + parser.add_argument( + "contract", help="The filename of the contract or truffle directory to analyze." + ) + parser.add_argument( + "kspec", help="The filename of the Klab spec markdown for the analyzed contract(s)" ) - cryticparser.init(parser) - - if len(sys.argv) < 2: - parser.print_help(sys.stderr) + parser.add_argument( + "--version", help="displays the current version", version="0.1.0", action="version" + ) + parser.add_argument( + "--json", + help='Export the results as a JSON file ("--json -" to export to stdout)', + action="store", + default=False, + ) + + cryticparser.init(parser) + + if len(sys.argv) < 2: + parser.print_help(sys.stderr) sys.exit(1) - + return parser.parse_args() @@ -53,6 +62,7 @@ def main(): args = parse_args() kspec_coverage(args) - -if __name__ == '__main__': + + +if __name__ == "__main__": main() diff --git a/slither/tools/kspec_coverage/analysis.py b/slither/tools/kspec_coverage/analysis.py index d2daf03d0..08e42ad76 100755 --- a/slither/tools/kspec_coverage/analysis.py +++ b/slither/tools/kspec_coverage/analysis.py @@ -7,25 +7,22 @@ from slither.utils.colors import yellow, green, red from slither.utils import output logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger('Slither.kspec') +logger = logging.getLogger("Slither.kspec") def _refactor_type(type): - return { - 'uint': 'uint256', - 'int': 'int256' - }.get(type, type) + return {"uint": "uint256", "int": "int256"}.get(type, type) def _get_all_covered_kspec_functions(target): # Create a set of our discovered functions which are covered covered_functions = set() - BEHAVIOUR_PATTERN = re.compile('behaviour\s+(\S+)\s+of\s+(\S+)') - INTERFACE_PATTERN = re.compile('interface\s+([^\r\n]+)') + BEHAVIOUR_PATTERN = re.compile("behaviour\s+(\S+)\s+of\s+(\S+)") + INTERFACE_PATTERN = re.compile("interface\s+([^\r\n]+)") # Read the file contents - with open(target, 'r', encoding='utf8') as target_file: + with open(target, "r", encoding="utf8") as target_file: lines = target_file.readlines() # Loop for each line, if a line matches our behaviour regex, and the next one matches our interface regex, @@ -38,10 +35,12 @@ def _get_all_covered_kspec_functions(target): match = INTERFACE_PATTERN.match(lines[i + 1]) if match: function_full_name = match.groups()[0] - start, end = function_full_name.index('(') + 1, function_full_name.index(')') - function_arguments = function_full_name[start:end].split(',') - function_arguments = [_refactor_type(arg.strip().split(' ')[0]) for arg in function_arguments] - function_full_name = function_full_name[:start] + ','.join(function_arguments) + ')' + start, end = function_full_name.index("(") + 1, function_full_name.index(")") + function_arguments = function_full_name[start:end].split(",") + function_arguments = [ + _refactor_type(arg.strip().split(" ")[0]) for arg in function_arguments + ] + function_full_name = function_full_name[:start] + ",".join(function_arguments) + ")" covered_functions.add((contract_name, function_full_name)) i += 1 i += 1 @@ -50,14 +49,25 @@ def _get_all_covered_kspec_functions(target): def _get_slither_functions(slither): # Use contract == contract_declarer to avoid dupplicate - all_functions_declared = [f for f in slither.functions if (f.contract == f.contract_declarer - and f.is_implemented - and not f.is_constructor - and not f.is_constructor_variables)] + all_functions_declared = [ + f + for f in slither.functions + if ( + f.contract == f.contract_declarer + and f.is_implemented + and not f.is_constructor + and not f.is_constructor_variables + ) + ] # Use list(set()) because same state variable instances can be shared accross contracts # TODO: integrate state variables - all_functions_declared += list(set([s for s in slither.state_variables if s.visibility in ['public', 'external']])) - slither_functions = {(function.contract.name, function.full_name): function for function in all_functions_declared} + all_functions_declared += list( + set([s for s in slither.state_variables if s.visibility in ["public", "external"]]) + ) + slither_functions = { + (function.contract.name, function.full_name): function + for function in all_functions_declared + } return slither_functions @@ -110,35 +120,42 @@ def _run_coverage_analysis(args, slither, kspec_functions): else: kspec_missing.append(slither_func) - logger.info('## Check for functions coverage') + logger.info("## Check for functions coverage") json_kspec_present = _generate_output(kspec_present, "[✓]", green, args.json) - json_kspec_missing_functions = _generate_output([f for f in kspec_missing if isinstance(f, Function)], - "[ ] (Missing function)", - red, - args.json) - json_kspec_missing_variables = _generate_output([f for f in kspec_missing if isinstance(f, Variable)], - "[ ] (Missing variable)", - yellow, - args.json) - json_kspec_unresolved = _generate_output_unresolved(kspec_functions_unresolved, - "[ ] (Unresolved)", - yellow, - args.json) + json_kspec_missing_functions = _generate_output( + [f for f in kspec_missing if isinstance(f, Function)], + "[ ] (Missing function)", + red, + args.json, + ) + json_kspec_missing_variables = _generate_output( + [f for f in kspec_missing if isinstance(f, Variable)], + "[ ] (Missing variable)", + yellow, + args.json, + ) + json_kspec_unresolved = _generate_output_unresolved( + kspec_functions_unresolved, "[ ] (Unresolved)", yellow, args.json + ) # Handle unresolved kspecs if args.json: - output.output_to_json(args.json, None, { - "functions_present": json_kspec_present, - "functions_missing": json_kspec_missing_functions, - "variables_missing": json_kspec_missing_variables, - "functions_unresolved": json_kspec_unresolved - }) + output.output_to_json( + args.json, + None, + { + "functions_present": json_kspec_present, + "functions_missing": json_kspec_missing_functions, + "variables_missing": json_kspec_missing_variables, + "functions_unresolved": json_kspec_unresolved, + }, + ) def run_analysis(args, slither, kspec): # Get all of our kspec'd functions (tuple(contract_name, function_name)). - if ',' in kspec: - kspecs = kspec.split(',') + if "," in kspec: + kspecs = kspec.split(",") kspec_functions = set() for kspec in kspecs: kspec_functions |= _get_all_covered_kspec_functions(kspec) diff --git a/slither/tools/kspec_coverage/kspec_coverage.py b/slither/tools/kspec_coverage/kspec_coverage.py index 2ee25477f..86b59be53 100755 --- a/slither/tools/kspec_coverage/kspec_coverage.py +++ b/slither/tools/kspec_coverage/kspec_coverage.py @@ -1,6 +1,7 @@ from slither.tools.kspec_coverage.analysis import run_analysis from slither import Slither + def kspec_coverage(args): contract = args.contract @@ -10,5 +11,3 @@ def kspec_coverage(args): # Run the analysis on the Klab specs run_analysis(args, slither, kspec) - - diff --git a/slither/tools/possible_paths/__main__.py b/slither/tools/possible_paths/__main__.py index 70aea8003..c13a3f390 100644 --- a/slither/tools/possible_paths/__main__.py +++ b/slither/tools/possible_paths/__main__.py @@ -9,18 +9,21 @@ from crytic_compile import cryticparser logging.basicConfig() logging.getLogger("Slither").setLevel(logging.INFO) + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='PossiblePaths', - usage='possible_paths.py filename [contract.function targets]') + parser = argparse.ArgumentParser( + description="PossiblePaths", usage="possible_paths.py filename [contract.function targets]" + ) - parser.add_argument('filename', - help='The filename of the contract or truffle directory to analyze.') + parser.add_argument( + "filename", help="The filename of the contract or truffle directory to analyze." + ) - parser.add_argument('targets', nargs='+') + parser.add_argument("targets", nargs="+") cryticparser.init(parser) @@ -62,12 +65,16 @@ def main(): print("\n") # Format all function paths. - reaching_paths_str = [' -> '.join([f"{f.canonical_name}" for f in reaching_path]) for reaching_path in reaching_paths] + reaching_paths_str = [ + " -> ".join([f"{f.canonical_name}" for f in reaching_path]) + for reaching_path in reaching_paths + ] # Print a sorted list of all function paths which can reach the targets. print(f"The following paths reach the specified targets:") for reaching_path in sorted(reaching_paths_str): print(f"{reaching_path}\n") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/slither/tools/possible_paths/possible_paths.py b/slither/tools/possible_paths/possible_paths.py index e638b00ad..8137f3d1b 100644 --- a/slither/tools/possible_paths/possible_paths.py +++ b/slither/tools/possible_paths/possible_paths.py @@ -1,4 +1,5 @@ -class ResolveFunctionException(Exception): pass +class ResolveFunctionException(Exception): + pass def resolve_function(slither, contract_name, function_name): @@ -16,11 +17,15 @@ def resolve_function(slither, contract_name, function_name): raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}") # Obtain the target function - target_function = next((function for function in contract.functions if function.name == function_name), None) + target_function = next( + (function for function in contract.functions if function.name == function_name), None + ) # Verify we have resolved the function specified. if target_function is None: - raise ResolveFunctionException(f"Could not resolve target function: {contract_name}.{function_name}") + raise ResolveFunctionException( + f"Could not resolve target function: {contract_name}.{function_name}" + ) # Add the resolved function to the new list. return target_function @@ -44,17 +49,23 @@ def resolve_functions(slither, functions): for item in functions: if isinstance(item, str): # If the item is a single string, we assume it is of form 'ContractName.FunctionName'. - parts = item.split('.') + parts = item.split(".") if len(parts) < 2: - raise ResolveFunctionException("Provided string descriptor must be of form 'ContractName.FunctionName'") + raise ResolveFunctionException( + "Provided string descriptor must be of form 'ContractName.FunctionName'" + ) resolved.append(resolve_function(slither, parts[0], parts[1])) elif isinstance(item, tuple): # If the item is a tuple, it should be a 2-tuple providing contract and function names. if len(item) != 2: - raise ResolveFunctionException("Provided tuple descriptor must provide a contract and function name.") + raise ResolveFunctionException( + "Provided tuple descriptor must provide a contract and function name." + ) resolved.append(resolve_function(slither, item[0], item[1])) else: - raise ResolveFunctionException(f"Unexpected function descriptor type to resolve in list: {type(item)}") + raise ResolveFunctionException( + f"Unexpected function descriptor type to resolve in list: {type(item)}" + ) # Return the resolved list. return resolved @@ -66,9 +77,12 @@ def all_function_definitions(function): :param function: The function to obtain all definitions at and beneath. :return: Returns a list composed of the provided function definition and any base definitions. """ - return [function] + [f for c in function.contract.inheritance - for f in c.functions_and_modifiers_declared - if f.full_name == function.full_name] + return [function] + [ + f + for c in function.contract.inheritance + for f in c.functions_and_modifiers_declared + if f.full_name == function.full_name + ] def __find_target_paths(slither, target_function, current_path=[]): @@ -102,7 +116,7 @@ def __find_target_paths(slither, target_function, current_path=[]): results = results.union(path_results) # If this path is external accessible from this point, we add the current path to the list. - if target_function.visibility in ['public', 'external'] and len(current_path) > 1: + if target_function.visibility in ["public", "external"] and len(current_path) > 1: results.add(tuple(current_path)) return results @@ -122,6 +136,3 @@ def find_target_paths(slither, target_functions): results = results.union(__find_target_paths(slither, target_function)) return results - - - diff --git a/slither/tools/properties/__main__.py b/slither/tools/properties/__main__.py index 664d5c35a..25685fb6b 100644 --- a/slither/tools/properties/__main__.py +++ b/slither/tools/properties/__main__.py @@ -16,20 +16,21 @@ logging.getLogger("Slither").setLevel(logging.INFO) logger = logging.getLogger("Slither") ch = logging.StreamHandler() ch.setLevel(logging.INFO) -formatter = logging.Formatter('%(message)s') +formatter = logging.Formatter("%(message)s") logger.addHandler(ch) logger.handlers[0].setFormatter(formatter) logger.propagate = False def _all_scenarios(): - txt = '\n' - txt += '#################### ERC20 ####################\n' + txt = "\n" + txt += "#################### ERC20 ####################\n" for k, value in ERC20_PROPERTIES.items(): - txt += f'{k} - {value.description}\n' + txt += f"{k} - {value.description}\n" return txt + def _all_properties(): table = MyPrettyTable(["Num", "Description", "Scenario"]) idx = 0 @@ -39,6 +40,7 @@ def _all_properties(): idx = idx + 1 return table + class ListScenarios(argparse.Action): def __call__(self, parser, *args, **kwargs): logger.info(_all_scenarios()) @@ -56,43 +58,51 @@ def parse_args(): Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='Demo', - usage='slither-demo filename', - formatter_class=argparse.RawDescriptionHelpFormatter) - - parser.add_argument('filename', - help='The filename of the contract or truffle directory to analyze.') - - parser.add_argument('--contract', - help='The targeted contract.') - - parser.add_argument('--scenario', - help=f'Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable', - default='Transferable') - - parser.add_argument('--list-scenarios', - help='List available scenarios', - action=ListScenarios, - nargs=0, - default=False) - - parser.add_argument('--list-properties', - help='List available properties', - action=ListProperties, - nargs=0, - default=False) - - parser.add_argument('--address-owner', - help=f'Owner address. Default {OWNER_ADDRESS}', - default=None) - - parser.add_argument('--address-user', - help=f'Owner address. Default {USER_ADDRESS}', - default=None) - - parser.add_argument('--address-attacker', - help=f'Attacker address. Default {ATTACKER_ADDRESS}', - default=None) + parser = argparse.ArgumentParser( + description="Demo", + usage="slither-demo filename", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "filename", help="The filename of the contract or truffle directory to analyze." + ) + + parser.add_argument("--contract", help="The targeted contract.") + + parser.add_argument( + "--scenario", + help=f"Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable", + default="Transferable", + ) + + parser.add_argument( + "--list-scenarios", + help="List available scenarios", + action=ListScenarios, + nargs=0, + default=False, + ) + + parser.add_argument( + "--list-properties", + help="List available properties", + action=ListProperties, + nargs=0, + default=False, + ) + + parser.add_argument( + "--address-owner", help=f"Owner address. Default {OWNER_ADDRESS}", default=None + ) + + parser.add_argument( + "--address-user", help=f"Owner address. Default {USER_ADDRESS}", default=None + ) + + parser.add_argument( + "--address-attacker", help=f"Attacker address. Default {ATTACKER_ADDRESS}", default=None + ) # Add default arguments from crytic-compile cryticparser.init(parser) @@ -116,9 +126,9 @@ def main(): contract = slither.contracts[0] else: if args.contract is None: - logger.error(f'Specify the target: --contract ContractName') + logger.error(f"Specify the target: --contract ContractName") else: - logger.error(f'{args.contract} not found') + logger.error(f"{args.contract} not found") return addresses = Addresses(args.address_owner, args.address_user, args.address_attacker) @@ -126,5 +136,5 @@ def main(): generate_erc20(contract, args.scenario, addresses) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/slither/tools/properties/addresses/address.py b/slither/tools/properties/addresses/address.py index e183fd3f7..2068bca23 100644 --- a/slither/tools/properties/addresses/address.py +++ b/slither/tools/properties/addresses/address.py @@ -8,8 +8,12 @@ ATTACKER_ADDRESS = "0xC5fdf4076b8F3A5357c5E395ab970B5B54098Fef" class Addresses: - - def __init__(self, owner: Optional[str] = None, user: Optional[str] = None, attacker: Optional[str] = None): + def __init__( + self, + owner: Optional[str] = None, + user: Optional[str] = None, + attacker: Optional[str] = None, + ): self.owner = owner if owner else OWNER_ADDRESS self.user = user if user else USER_ADDRESS self.attacker = attacker if attacker else ATTACKER_ADDRESS diff --git a/slither/tools/properties/platforms/echidna.py b/slither/tools/properties/platforms/echidna.py index f02783429..6ab372cfe 100644 --- a/slither/tools/properties/platforms/echidna.py +++ b/slither/tools/properties/platforms/echidna.py @@ -11,11 +11,11 @@ def generate_echidna_config(output_dir: Path, addresses: Addresses) -> str: :param addresses: :return: """ - content = 'prefix: crytic_\n' + content = "prefix: crytic_\n" content += f'deployer: "{addresses.owner}"\n' content += f'sender: ["{addresses.user}", "{addresses.attacker}"]\n' content += f'psender: "{addresses.user}"\n' - content += 'coverage: true\n' - filename = 'echidna_config.yaml' + content += "coverage: true\n" + filename = "echidna_config.yaml" write_file(output_dir, filename, content) return filename diff --git a/slither/tools/properties/platforms/truffle.py b/slither/tools/properties/platforms/truffle.py index 0715ef7eb..48627fab4 100644 --- a/slither/tools/properties/platforms/truffle.py +++ b/slither/tools/properties/platforms/truffle.py @@ -7,21 +7,21 @@ from slither.tools.properties.addresses.address import Addresses from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller from slither.tools.properties.utils import write_file -PATTERN_TRUFFLE_MIGRATION = re.compile('^[0-9]*_') +PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_") logger = logging.getLogger("Slither") def _extract_caller(p: PropertyCaller): if p == PropertyCaller.OWNER: - return ['owner'] + return ["owner"] if p == PropertyCaller.SENDER: - return ['user'] + return ["user"] if p == PropertyCaller.ATTACKER: - return ['attacker'] + return ["attacker"] if p == PropertyCaller.ALL: - return ['owner', 'user', 'attacker'] + return ["owner", "user", "attacker"] assert p == PropertyCaller.ANY - return ['user'] + return ["user"] def _helpers(): @@ -31,7 +31,7 @@ def _helpers(): - catchRevertThrow: check if the call revert/throw :return: """ - return ''' + return """ async function catchRevertThrowReturnFalse(promise) { try { const ret = await promise; @@ -61,12 +61,17 @@ async function catchRevertThrow(promise) { } assert(false, "Expected revert/throw/or return false"); }; -''' +""" -def generate_unit_test(test_contract: str, filename: str, - unit_tests: List[Property], output_dir: Path, - addresses: Addresses, assert_message: str = ''): +def generate_unit_test( + test_contract: str, + filename: str, + unit_tests: List[Property], + output_dir: Path, + addresses: Addresses, + assert_message: str = "", +): """ Generate unit tests files :param test_contract: @@ -88,37 +93,37 @@ def generate_unit_test(test_contract: str, filename: str, content += f'\tlet attacker = "{addresses.attacker}";\n' for unit_test in unit_tests: content += f'\tit("{unit_test.description}", async () => {{\n' - content += f'\t\tlet instance = await {test_contract}.deployed();\n' + content += f"\t\tlet instance = await {test_contract}.deployed();\n" callers = _extract_caller(unit_test.caller) if unit_test.return_type == PropertyReturn.SUCCESS: for caller in callers: - content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n' + content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n" if assert_message: content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n' else: - content += f'\t\tassert.equal(test_{caller}, true);\n' + content += f"\t\tassert.equal(test_{caller}, true);\n" elif unit_test.return_type == PropertyReturn.FAIL: for caller in callers: - content += f'\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n' + content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n" if assert_message: content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n' else: - content += f'\t\tassert.equal(test_{caller}, false);\n' + content += f"\t\tassert.equal(test_{caller}, false);\n" elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW: for caller in callers: - content += f'\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n' + content += f"\t\tawait catchRevertThrowReturnFalse(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n" elif unit_test.return_type == PropertyReturn.THROW: callers = _extract_caller(unit_test.caller) for caller in callers: - content += f'\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n' - content += '\t});\n' + content += f"\t\tawait catchRevertThrow(instance.{unit_test.name[:-2]}.call({{from: {caller}}}));\n" + content += "\t});\n" - content += '});\n' + content += "});\n" - output_dir = Path(output_dir, 'test') + output_dir = Path(output_dir, "test") output_dir.mkdir(exist_ok=True) - output_dir = Path(output_dir, 'crytic') + output_dir = Path(output_dir, "crytic") output_dir.mkdir(exist_ok=True) write_file(output_dir, filename, content) @@ -133,28 +138,31 @@ def generate_migration(test_contract: str, output_dir: Path, owner_address: str) :param owner_address: :return: """ - content = f'''{test_contract} = artifacts.require("{test_contract}"); + content = f"""{test_contract} = artifacts.require("{test_contract}"); module.exports = function(deployer) {{ deployer.deploy({test_contract}, {{from: "{owner_address}"}}); }}; -''' +""" - output_dir = Path(output_dir, 'migrations') + output_dir = Path(output_dir, "migrations") output_dir.mkdir(exist_ok=True) - migration_files = [js_file for js_file in output_dir.iterdir() if js_file.suffix == '.js' - and PATTERN_TRUFFLE_MIGRATION.match(js_file.name)] + migration_files = [ + js_file + for js_file in output_dir.iterdir() + if js_file.suffix == ".js" and PATTERN_TRUFFLE_MIGRATION.match(js_file.name) + ] idx = len(migration_files) - filename = f'{idx + 1}_{test_contract}.js' - potential_previous_filename = f'{idx}_{test_contract}.js' + filename = f"{idx + 1}_{test_contract}.js" + potential_previous_filename = f"{idx}_{test_contract}.js" for m in migration_files: if m.name == potential_previous_filename: write_file(output_dir, potential_previous_filename, content) return if test_contract in m.name: - logger.error(f'Potential conflicts with {m.name}') + logger.error(f"Potential conflicts with {m.name}") write_file(output_dir, filename, content) diff --git a/slither/tools/properties/properties/erc20.py b/slither/tools/properties/properties/erc20.py index 5c83b1168..31f6ff8cc 100644 --- a/slither/tools/properties/properties/erc20.py +++ b/slither/tools/properties/properties/erc20.py @@ -12,27 +12,37 @@ from slither.tools.properties.platforms.echidna import generate_echidna_config from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable -from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ERC20_NotMintableNotBurnable -from slither.tools.properties.properties.ercs.erc20.properties.transfer import ERC20_Transferable, ERC20_Pausable +from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ( + ERC20_NotMintableNotBurnable, +) +from slither.tools.properties.properties.ercs.erc20.properties.transfer import ( + ERC20_Transferable, + ERC20_Pausable, +) from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test from slither.tools.properties.properties.properties import property_to_solidity, Property -from slither.tools.properties.solidity.generate_properties import generate_solidity_properties, generate_test_contract, \ - generate_solidity_interface +from slither.tools.properties.solidity.generate_properties import ( + generate_solidity_properties, + generate_test_contract, + generate_solidity_interface, +) from slither.utils.colors import red, green logger = logging.getLogger("Slither") -PropertyDescription = namedtuple('PropertyDescription', ['properties', 'description']) +PropertyDescription = namedtuple("PropertyDescription", ["properties", "description"]) ERC20_PROPERTIES = { - "Transferable": PropertyDescription(ERC20_Transferable, 'Test the correct tokens transfer'), - "Pausable": PropertyDescription(ERC20_Pausable, 'Test the pausable functionality'), - "NotMintable": PropertyDescription(ERC20_NotMintable, 'Test that no one can mint tokens'), - "NotMintableNotBurnable": PropertyDescription(ERC20_NotMintableNotBurnable, - 'Test that no one can mint or burn tokens'), - "NotBurnable": PropertyDescription(ERC20_NotBurnable, 'Test that no one can burn tokens'), - "Burnable": PropertyDescription(ERC20_NotBurnable, - 'Test the burn of tokens. Require the "burn(address) returns()" function') + "Transferable": PropertyDescription(ERC20_Transferable, "Test the correct tokens transfer"), + "Pausable": PropertyDescription(ERC20_Pausable, "Test the pausable functionality"), + "NotMintable": PropertyDescription(ERC20_NotMintable, "Test that no one can mint tokens"), + "NotMintableNotBurnable": PropertyDescription( + ERC20_NotMintableNotBurnable, "Test that no one can mint or burn tokens" + ), + "NotBurnable": PropertyDescription(ERC20_NotBurnable, "Test that no one can burn tokens"), + "Burnable": PropertyDescription( + ERC20_NotBurnable, 'Test the burn of tokens. Require the "burn(address) returns()" function' + ), } @@ -54,7 +64,7 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) :return: """ if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]: - logging.error(f'{contract.slither.crytic_compile.type} not yet supported by slither-prop') + logging.error(f"{contract.slither.crytic_compile.type} not yet supported by slither-prop") return # Check if the contract is an ERC20 contract and if the functions have the correct visibility @@ -65,7 +75,9 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) properties = ERC20_PROPERTIES.get(type_property, None) if properties is None: - logger.error(f'{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}') + logger.error( + f"{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}" + ) return properties = properties.properties @@ -78,51 +90,53 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) # Generate the contract containing the properties generate_solidity_interface(output_dir, addresses) - property_file = generate_solidity_properties(contract, type_property, solidity_properties, output_dir) + property_file = generate_solidity_properties( + contract, type_property, solidity_properties, output_dir + ) # Generate the Test contract initialization_recommendation = _initialization_recommendation(type_property) - contract_filename, contract_name = generate_test_contract(contract, - type_property, - output_dir, - property_file, - initialization_recommendation) + contract_filename, contract_name = generate_test_contract( + contract, type_property, output_dir, property_file, initialization_recommendation + ) # Generate Echidna config file - echidna_config_filename = generate_echidna_config(Path(contract.slither.crytic_compile.target).parent, addresses) + echidna_config_filename = generate_echidna_config( + Path(contract.slither.crytic_compile.target).parent, addresses + ) - unit_test_info = '' + unit_test_info = "" # If truffle, generate unit tests if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses) - logger.info('################################################') - logger.info(green(f'Update the constructor in {Path(output_dir, contract_filename)}')) + logger.info("################################################") + logger.info(green(f"Update the constructor in {Path(output_dir, contract_filename)}")) if unit_test_info: logger.info(green(unit_test_info)) - logger.info(green('To run Echidna:')) - txt = f'\t echidna-test {contract.slither.crytic_compile.target} ' - txt += f'--contract {contract_name} --config {echidna_config_filename}' + logger.info(green("To run Echidna:")) + txt = f"\t echidna-test {contract.slither.crytic_compile.target} " + txt += f"--contract {contract_name} --config {echidna_config_filename}" logger.info(green(txt)) def _initialization_recommendation(type_property: str) -> str: - content = '' - content += '\t\t// Add below a minimal configuration:\n' - content += '\t\t// - crytic_owner must have some tokens \n' - content += '\t\t// - crytic_user must have some tokens \n' - content += '\t\t// - crytic_attacker must have some tokens \n' - if type_property in ['Pausable']: - content += '\t\t// - The contract must be paused \n' - if type_property in ['NotMintable', 'NotMintableNotBurnable']: - content += '\t\t// - The contract must not be mintable \n' - if type_property in ['NotBurnable', 'NotMintableNotBurnable']: - content += '\t\t// - The contract must not be burnable \n' - content += '\n' - content += '\n' + content = "" + content += "\t\t// Add below a minimal configuration:\n" + content += "\t\t// - crytic_owner must have some tokens \n" + content += "\t\t// - crytic_user must have some tokens \n" + content += "\t\t// - crytic_attacker must have some tokens \n" + if type_property in ["Pausable"]: + content += "\t\t// - The contract must be paused \n" + if type_property in ["NotMintable", "NotMintableNotBurnable"]: + content += "\t\t// - The contract must not be mintable \n" + if type_property in ["NotBurnable", "NotMintableNotBurnable"]: + content += "\t\t// - The contract must not be burnable \n" + content += "\n" + content += "\n" return content @@ -130,44 +144,44 @@ def _initialization_recommendation(type_property: str) -> str: # TODO: move this to crytic-compile def _platform_to_output_dir(platform: AbstractPlatform) -> Path: if platform.TYPE == PlatformType.TRUFFLE: - return Path(platform.target, 'contracts', 'crytic') + return Path(platform.target, "contracts", "crytic") if platform.TYPE == PlatformType.SOLC: return Path(platform.target).parent def _check_compatibility(contract): - errors = '' + errors = "" if not contract.is_erc20(): - errors = f'{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc' + errors = f"{contract} is not ERC20 compliant. Consider checking the contract with slither-check-erc" return errors - transfer = contract.get_function_from_signature('transfer(address,uint256)') + transfer = contract.get_function_from_signature("transfer(address,uint256)") - if transfer.visibility != 'public': - errors = f'slither-prop requires {transfer.canonical_name} to be public. Please change the visibility' + if transfer.visibility != "public": + errors = f"slither-prop requires {transfer.canonical_name} to be public. Please change the visibility" - transfer_from = contract.get_function_from_signature('transferFrom(address,address,uint256)') - if transfer_from.visibility != 'public': + transfer_from = contract.get_function_from_signature("transferFrom(address,address,uint256)") + if transfer_from.visibility != "public": if errors: - errors += '\n' - errors += f'slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility' + errors += "\n" + errors += f"slither-prop requires {transfer_from.canonical_name} to be public. Please change the visibility" - approve = contract.get_function_from_signature('approve(address,uint256)') - if approve.visibility != 'public': + approve = contract.get_function_from_signature("approve(address,uint256)") + if approve.visibility != "public": if errors: - errors += '\n' - errors += f'slither-prop requires {approve.canonical_name} to be public. Please change the visibility' + errors += "\n" + errors += f"slither-prop requires {approve.canonical_name} to be public. Please change the visibility" return errors def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Property]]: - solidity_properties = '' + solidity_properties = "" if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: - solidity_properties += '\n'.join([property_to_solidity(p) for p in ERC20_CONFIG]) + solidity_properties += "\n".join([property_to_solidity(p) for p in ERC20_CONFIG]) - solidity_properties += '\n'.join([property_to_solidity(p) for p in properties]) + solidity_properties += "\n".join([property_to_solidity(p) for p in properties]) unit_tests = [p for p in properties if p.is_unit_test] return solidity_properties, unit_tests diff --git a/slither/tools/properties/properties/ercs/erc20/properties/burn.py b/slither/tools/properties/properties/ercs/erc20/properties/burn.py index 47e744995..d612abf87 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/burn.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/burn.py @@ -1,30 +1,38 @@ -from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller +from slither.tools.properties.properties.properties import ( + Property, + PropertyType, + PropertyReturn, + PropertyCaller, +) ERC20_NotBurnable = [ - Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()', - description='The total supply does not decrease.', - content=''' -\t\treturn initialTotalSupply == this.totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), + Property( + name="crytic_supply_constant_ERC20PropertiesNotBurnable()", + description="The total supply does not decrease.", + content=""" +\t\treturn initialTotalSupply == this.totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), ] # Require burn(address) returns() ERC20_Burnable = [ - Property(name='crytic_supply_constant_ERC20PropertiesNotBurnable()', - description='Cannot burn more than available balance', - content=''' + Property( + name="crytic_supply_constant_ERC20PropertiesNotBurnable()", + description="Cannot burn more than available balance", + content=""" \t\tuint balance = balanceOf(msg.sender); \t\tburn(balance + 1); -\t\treturn false;''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL) +\t\treturn false;""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ) ] - diff --git a/slither/tools/properties/properties/ercs/erc20/properties/initialization.py b/slither/tools/properties/properties/ercs/erc20/properties/initialization.py index c01a1d973..5f954b512 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/initialization.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/initialization.py @@ -1,65 +1,76 @@ -from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller +from slither.tools.properties.properties.properties import ( + Property, + PropertyType, + PropertyReturn, + PropertyCaller, +) ERC20_CONFIG = [ - - Property(name='init_total_supply()', - description='The total supply is correctly initialized.', - content=''' -\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), - - Property(name='init_owner_balance()', - description="Owner's balance is correctly initialized.", - content=''' -\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), - - Property(name='init_user_balance()', - description="User's balance is correctly initialized.", - content=''' -\t\treturn initialBalance_user == this.balanceOf(crytic_user);''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), - - Property(name='init_attacker_balance()', - description="Attacker's balance is correctly initialized.", - content=''' -\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), - - Property(name='init_caller_balance()', - description="All the users have a positive balance.", - content=''' -\t\treturn this.balanceOf(msg.sender) >0 ;''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ALL), - + Property( + name="init_total_supply()", + description="The total supply is correctly initialized.", + content=""" +\t\treturn this.totalSupply() >= 0 && this.totalSupply() == initialTotalSupply;""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), + Property( + name="init_owner_balance()", + description="Owner's balance is correctly initialized.", + content=""" +\t\treturn initialBalance_owner == this.balanceOf(crytic_owner);""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), + Property( + name="init_user_balance()", + description="User's balance is correctly initialized.", + content=""" +\t\treturn initialBalance_user == this.balanceOf(crytic_user);""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), + Property( + name="init_attacker_balance()", + description="Attacker's balance is correctly initialized.", + content=""" +\t\treturn initialBalance_attacker == this.balanceOf(crytic_attacker);""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), + Property( + name="init_caller_balance()", + description="All the users have a positive balance.", + content=""" +\t\treturn this.balanceOf(msg.sender) >0 ;""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ALL, + ), # Note: there is a potential overflow on the addition, but we dont consider it - Property(name='init_total_supply_is_balances()', - description="The total supply is the user and owner balance.", - content=''' -\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=False, - caller=PropertyCaller.ANY), -] \ No newline at end of file + Property( + name="init_total_supply_is_balances()", + description="The total supply is the user and owner balance.", + content=""" +\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) == this.totalSupply();""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=False, + caller=PropertyCaller.ANY, + ), +] diff --git a/slither/tools/properties/properties/ercs/erc20/properties/mint.py b/slither/tools/properties/properties/ercs/erc20/properties/mint.py index a1355a2ce..4aafa907e 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/mint.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/mint.py @@ -1,13 +1,20 @@ -from slither.tools.properties.properties.properties import PropertyType, PropertyReturn, Property, PropertyCaller +from slither.tools.properties.properties.properties import ( + PropertyType, + PropertyReturn, + Property, + PropertyCaller, +) ERC20_NotMintable = [ - Property(name='crytic_supply_constant_ERC20PropertiesNotMintable()', - description='The total supply does not increase.', - content=''' -\t\treturn initialTotalSupply >= totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), + Property( + name="crytic_supply_constant_ERC20PropertiesNotMintable()", + description="The total supply does not increase.", + content=""" +\t\treturn initialTotalSupply >= totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), ] diff --git a/slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py b/slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py index 58b9709a7..5f99838d0 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/mint_and_burn.py @@ -1,14 +1,20 @@ -from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller +from slither.tools.properties.properties.properties import ( + Property, + PropertyType, + PropertyReturn, + PropertyCaller, +) ERC20_NotMintableNotBurnable = [ - - Property(name='crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()', - description='The total supply does not change.', - content=''' -\t\treturn initialTotalSupply == this.totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), -] \ No newline at end of file + Property( + name="crytic_supply_constant_ERC20PropertiesNotMintableNotBurnable()", + description="The total supply does not change.", + content=""" +\t\treturn initialTotalSupply == this.totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), +] diff --git a/slither/tools/properties/properties/ercs/erc20/properties/transfer.py b/slither/tools/properties/properties/ercs/erc20/properties/transfer.py index 02860e4ab..bea208f02 100644 --- a/slither/tools/properties/properties/ercs/erc20/properties/transfer.py +++ b/slither/tools/properties/properties/ercs/erc20/properties/transfer.py @@ -1,96 +1,108 @@ -from slither.tools.properties.properties.properties import Property, PropertyType, PropertyReturn, PropertyCaller +from slither.tools.properties.properties.properties import ( + Property, + PropertyType, + PropertyReturn, + PropertyCaller, +) ERC20_Transferable = [ - - Property(name='crytic_zero_always_empty_ERC20Properties()', - description='The address 0x0 should not receive tokens.', - content=''' -\t\treturn this.balanceOf(address(0x0)) == 0;''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), - - Property(name='crytic_approve_overwrites()', - description='Allowance can be changed.', - content=''' + Property( + name="crytic_zero_always_empty_ERC20Properties()", + description="The address 0x0 should not receive tokens.", + content=""" +\t\treturn this.balanceOf(address(0x0)) == 0;""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), + Property( + name="crytic_approve_overwrites()", + description="Allowance can be changed.", + content=""" \t\tbool approve_return; \t\tapprove_return = approve(crytic_user, 10); \t\trequire(approve_return); \t\tapprove_return = approve(crytic_user, 20); \t\trequire(approve_return); -\t\treturn this.allowance(msg.sender, crytic_user) == 20;''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_less_than_total_ERC20Properties()', - description='Balance of one user must be less or equal to the total supply.', - content=''' -\t\treturn this.balanceOf(msg.sender) <= totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_totalSupply_consistant_ERC20Properties()', - description='Balance of the crytic users must be less or equal to the total supply.', - content=''' -\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ANY), - - Property(name='crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()', - description='No one should be able to send tokens to the address 0x0 (transfer).', - content=''' +\t\treturn this.allowance(msg.sender, crytic_user) == 20;""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_less_than_total_ERC20Properties()", + description="Balance of one user must be less or equal to the total supply.", + content=""" +\t\treturn this.balanceOf(msg.sender) <= totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_totalSupply_consistant_ERC20Properties()", + description="Balance of the crytic users must be less or equal to the total supply.", + content=""" +\t\treturn this.balanceOf(crytic_owner) + this.balanceOf(crytic_user) + this.balanceOf(crytic_attacker) <= totalSupply();""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ANY, + ), + Property( + name="crytic_revert_transfer_to_zero_ERC20PropertiesTransferable()", + description="No one should be able to send tokens to the address 0x0 (transfer).", + content=""" \t\tif (this.balanceOf(msg.sender) == 0){ \t\t\trevert(); \t\t} -\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()', - description='No one should be able to send tokens to the address 0x0 (transferFrom).', - content=''' +\t\treturn transfer(address(0x0), this.balanceOf(msg.sender));""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_revert_transferFrom_to_zero_ERC20PropertiesTransferable()", + description="No one should be able to send tokens to the address 0x0 (transferFrom).", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tif (balance == 0){ \t\t\trevert(); \t\t} \t\tapprove(msg.sender, balance); -\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));''', - type=PropertyType.CODE_QUALITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_self_transferFrom_ERC20PropertiesTransferable()', - description='Self transferFrom works.', - content=''' +\t\treturn transferFrom(msg.sender, address(0x0), this.balanceOf(msg.sender));""", + type=PropertyType.CODE_QUALITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_self_transferFrom_ERC20PropertiesTransferable()", + description="Self transferFrom works.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tbool approve_return = approve(msg.sender, balance); \t\tbool transfer_return = transferFrom(msg.sender, msg.sender, balance); -\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()', - description='transferFrom works.', - content=''' +\t\treturn (this.balanceOf(msg.sender) == balance) && approve_return && transfer_return;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_self_transferFrom_to_other_ERC20PropertiesTransferable()", + description="transferFrom works.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tbool approve_return = approve(msg.sender, balance); \t\taddress other = crytic_user; @@ -98,29 +110,30 @@ ERC20_Transferable = [ \t\t\tother = crytic_owner; \t\t} \t\tbool transfer_return = transferFrom(msg.sender, other, balance); -\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - - Property(name='crytic_self_transfer_ERC20PropertiesTransferable()', - description='Self transfer works.', - content=''' +\t\treturn (this.balanceOf(msg.sender) == 0) && approve_return && transfer_return;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_self_transfer_ERC20PropertiesTransferable()", + description="Self transfer works.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tbool transfer_return = transfer(msg.sender, balance); -\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_transfer_to_other_ERC20PropertiesTransferable()', - description='transfer works.', - content=''' +\t\treturn (this.balanceOf(msg.sender) == balance) && transfer_return;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_transfer_to_other_ERC20PropertiesTransferable()", + description="transfer works.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\taddress other = crytic_user; \t\tif (other == msg.sender) { @@ -130,74 +143,76 @@ ERC20_Transferable = [ \t\t\tbool transfer_other = transfer(other, 1); \t\t\treturn (this.balanceOf(msg.sender) == balance-1) && (this.balanceOf(other) >= 1) && transfer_other; \t\t} -\t\treturn true;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_revert_transfer_to_user_ERC20PropertiesTransferable()', - description='Cannot transfer more than the balance.', - content=''' +\t\treturn true;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_revert_transfer_to_user_ERC20PropertiesTransferable()", + description="Cannot transfer more than the balance.", + content=""" \t\tuint balance = this.balanceOf(msg.sender); \t\tif (balance == (2 ** 256 - 1)) \t\t\treturn true; \t\tbool transfer_other = transfer(crytic_user, balance+1); -\t\treturn transfer_other;''', - type=PropertyType.HIGH_SEVERITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - +\t\treturn transfer_other;""", + type=PropertyType.HIGH_SEVERITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), ] ERC20_Pausable = [ - - Property(name='crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()', - description='Cannot transfer.', - content=''' -\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()', - description='Cannot execute transferFrom.', - content=''' + Property( + name="crytic_revert_transfer_ERC20AlwaysTruePropertiesNotTransferable()", + description="Cannot transfer.", + content=""" +\t\treturn transfer(crytic_user, this.balanceOf(msg.sender));""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_revert_transferFrom_ERC20AlwaysTruePropertiesNotTransferable()", + description="Cannot execute transferFrom.", + content=""" \t\tapprove(msg.sender, this.balanceOf(msg.sender)); -\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.FAIL_OR_THROW, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_constantBalance()', - description='Cannot change the balance.', - content=''' -\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - - Property(name='crytic_constantAllowance()', - description='Cannot change the allowance.', - content=''' +\t\ttransferFrom(msg.sender, msg.sender, this.balanceOf(msg.sender));""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.FAIL_OR_THROW, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_constantBalance()", + description="Cannot change the balance.", + content=""" +\t\treturn this.balanceOf(crytic_user) == initialBalance_user && this.balanceOf(crytic_attacker) == initialBalance_attacker;""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), + Property( + name="crytic_constantAllowance()", + description="Cannot change the allowance.", + content=""" \t\treturn (this.allowance(crytic_user, crytic_attacker) == initialAllowance_user_attacker) && -\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);''', - type=PropertyType.MEDIUM_SEVERITY, - return_type=PropertyReturn.SUCCESS, - is_unit_test=True, - is_property_test=True, - caller=PropertyCaller.ALL), - +\t\t\t(this.allowance(crytic_attacker, crytic_attacker) == initialAllowance_attacker_attacker);""", + type=PropertyType.MEDIUM_SEVERITY, + return_type=PropertyReturn.SUCCESS, + is_unit_test=True, + is_property_test=True, + caller=PropertyCaller.ALL, + ), ] - - diff --git a/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py b/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py index 97bdf167c..614447b4d 100644 --- a/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py +++ b/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py @@ -11,26 +11,32 @@ from slither.tools.properties.properties.properties import Property logger = logging.getLogger("Slither") -def generate_truffle_test(contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses) -> str: - test_contract = f'Test{contract.name}{type_property}' - filename_init = f'Initialization{test_contract}.js' - filename = f'{test_contract}.js' +def generate_truffle_test( + contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses +) -> str: + test_contract = f"Test{contract.name}{type_property}" + filename_init = f"Initialization{test_contract}.js" + filename = f"{test_contract}.js" output_dir = Path(contract.slither.crytic_compile.target) generate_migration(test_contract, output_dir, addresses.owner) - generate_unit_test(test_contract, - filename_init, - ERC20_CONFIG, - output_dir, - addresses, - f'Check the constructor of {test_contract}') - - generate_unit_test(test_contract, filename, unit_tests, output_dir, addresses,) - - log_info = '\n' - log_info += 'To run the unit tests:\n' + generate_unit_test( + test_contract, + filename_init, + ERC20_CONFIG, + output_dir, + addresses, + f"Check the constructor of {test_contract}", + ) + + generate_unit_test( + test_contract, filename, unit_tests, output_dir, addresses, + ) + + log_info = "\n" + log_info += "To run the unit tests:\n" log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename_init)}\n" log_info += f"\ttruffle test {Path(output_dir, 'test', 'crytic', filename)}\n" return log_info diff --git a/slither/tools/properties/properties/properties.py b/slither/tools/properties/properties/properties.py index 3280325d8..f90cde7be 100644 --- a/slither/tools/properties/properties/properties.py +++ b/slither/tools/properties/properties/properties.py @@ -36,4 +36,4 @@ class Property(NamedTuple): def property_to_solidity(p: Property): - return f'\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n' + return f"\tfunction {p.name} public returns(bool){{{p.content}\n\t}}\n" diff --git a/slither/tools/properties/solidity/generate_properties.py b/slither/tools/properties/solidity/generate_properties.py index d02b730c7..006ab24c3 100644 --- a/slither/tools/properties/solidity/generate_properties.py +++ b/slither/tools/properties/solidity/generate_properties.py @@ -9,59 +9,64 @@ from slither.tools.properties.utils import write_file logger = logging.getLogger("Slither") -def generate_solidity_properties(contract: Contract, type_property: str, solidity_properties: str, - output_dir: Path) -> Path: +def generate_solidity_properties( + contract: Contract, type_property: str, solidity_properties: str, output_dir: Path +) -> Path: solidity_import = f'import "./interfaces.sol";\n' solidity_import += f'import "../{contract.source_mapping["filename_short"]}";' - test_contract_name = f'Properties{contract.name}{type_property}' + test_contract_name = f"Properties{contract.name}{type_property}" - solidity_content = f'{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}' - solidity_content += f'{{\n\n{solidity_properties}\n}}\n' + solidity_content = ( + f"{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}" + ) + solidity_content += f"{{\n\n{solidity_properties}\n}}\n" - filename = f'{test_contract_name}.sol' + filename = f"{test_contract_name}.sol" write_file(output_dir, filename, solidity_content) return Path(filename) -def generate_test_contract(contract: Contract, - type_property: str, - output_dir: Path, - property_file: Path, - initialization_recommendation: str) -> Tuple[str, str]: - test_contract_name = f'Test{contract.name}{type_property}' - properties_name = f'Properties{contract.name}{type_property}' +def generate_test_contract( + contract: Contract, + type_property: str, + output_dir: Path, + property_file: Path, + initialization_recommendation: str, +) -> Tuple[str, str]: + test_contract_name = f"Test{contract.name}{type_property}" + properties_name = f"Properties{contract.name}{type_property}" - content = '' + content = "" content += f'import "./{property_file}";\n' content += f"contract {test_contract_name} is {properties_name} {{\n" - content += '\tconstructor() public{\n' - content += '\t\t// Existing addresses:\n' - content += '\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n' - content += '\t\t// - crytic_user: Legitimate user\n' - content += '\t\t// - crytic_attacker: Attacker\n' - content += '\t\t// \n' + content += "\tconstructor() public{\n" + content += "\t\t// Existing addresses:\n" + content += "\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n" + content += "\t\t// - crytic_user: Legitimate user\n" + content += "\t\t// - crytic_attacker: Attacker\n" + content += "\t\t// \n" content += initialization_recommendation - content += '\t\t// \n' - content += '\t\t// \n' - content += '\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n' - content += '\t\tinitialTotalSupply = totalSupply();\n' - content += '\t\tinitialBalance_owner = balanceOf(crytic_owner);\n' - content += '\t\tinitialBalance_user = balanceOf(crytic_user);\n' - content += '\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n' + content += "\t\t// \n" + content += "\t\t// \n" + content += "\t\t// Update the following if totalSupply and balanceOf are external functions or state variables:\n\n" + content += "\t\tinitialTotalSupply = totalSupply();\n" + content += "\t\tinitialBalance_owner = balanceOf(crytic_owner);\n" + content += "\t\tinitialBalance_user = balanceOf(crytic_user);\n" + content += "\t\tinitialBalance_attacker = balanceOf(crytic_attacker);\n" - content += '\t}\n}\n' + content += "\t}\n}\n" - filename = f'{test_contract_name}.sol' + filename = f"{test_contract_name}.sol" write_file(output_dir, filename, content, allow_overwrite=False) return filename, test_contract_name def generate_solidity_interface(output_dir: Path, addresses: Addresses): - content = f''' + content = f""" contract CryticInterface{{ address internal crytic_owner = address({addresses.owner}); address internal crytic_user = address({addresses.user}); @@ -70,7 +75,7 @@ contract CryticInterface{{ uint internal initialBalance_owner; uint internal initialBalance_user; uint internal initialBalance_attacker; -}}''' +}}""" # Static file, we discard if it exists as it should never change - write_file(output_dir, 'interfaces.sol', content, discard_if_exist=True) + write_file(output_dir, "interfaces.sol", content, discard_if_exist=True) diff --git a/slither/tools/properties/utils.py b/slither/tools/properties/utils.py index 239885319..541d85712 100644 --- a/slither/tools/properties/utils.py +++ b/slither/tools/properties/utils.py @@ -6,11 +6,13 @@ from slither.utils.colors import green, yellow logger = logging.getLogger("Slither") -def write_file(output_dir: Path, - filename: str, - content: str, - allow_overwrite: bool = True, - discard_if_exist: bool = False): +def write_file( + output_dir: Path, + filename: str, + content: str, + allow_overwrite: bool = True, + discard_if_exist: bool = False, +): """ Write the content into output_dir/filename :param output_dir: @@ -25,10 +27,10 @@ def write_file(output_dir: Path, if discard_if_exist: return if not allow_overwrite: - logger.info(yellow(f'{file_to_write} already exist and will not be overwritten')) + logger.info(yellow(f"{file_to_write} already exist and will not be overwritten")) return - logger.info(yellow(f'Overwrite {file_to_write}')) + logger.info(yellow(f"Overwrite {file_to_write}")) else: - logger.info(green(f'Write {file_to_write}')) - with open(file_to_write, 'w') as f: + logger.info(green(f"Write {file_to_write}")) + with open(file_to_write, "w") as f: f.write(content) diff --git a/slither/tools/similarity/__main__.py b/slither/tools/similarity/__main__.py index 239b68b62..85f837115 100755 --- a/slither/tools/similarity/__main__.py +++ b/slither/tools/similarity/__main__.py @@ -8,62 +8,56 @@ import operator from crytic_compile import cryticparser -from .info import info -from .test import test -from .train import train -from .plot import plot +from .info import info +from .test import test +from .train import train +from .plot import plot logging.basicConfig() logger = logging.getLogger("Slither-simil") modes = ["info", "test", "train", "plot"] + def parse_args(): - parser = argparse.ArgumentParser(description='Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector') - - parser.add_argument('mode', - help="|".join(modes)) - - parser.add_argument('model', - help='model.bin') - - parser.add_argument('--filename', - action='store', - dest='filename', - help='contract.sol') - - parser.add_argument('--fname', - action='store', - dest='fname', - help='Target function') - - parser.add_argument('--ext', - action='store', - dest='ext', - help='Extension to filter contracts') - - parser.add_argument('--nsamples', - action='store', - type=int, - dest='nsamples', - help='Number of contract samples used for training') - - parser.add_argument('--ntop', - action='store', - type=int, - dest='ntop', - default=10, - help='Number of more similar contracts to show for testing') - - parser.add_argument('--input', - action='store', - dest='input', - help='File or directory used as input') - - parser.add_argument('--version', - help='displays the current version', - version="0.0", - action='version') + parser = argparse.ArgumentParser( + description="Code similarity detection tool. For usage, see https://github.com/crytic/slither/wiki/Code-Similarity-detector" + ) + + parser.add_argument("mode", help="|".join(modes)) + + parser.add_argument("model", help="model.bin") + + parser.add_argument("--filename", action="store", dest="filename", help="contract.sol") + + parser.add_argument("--fname", action="store", dest="fname", help="Target function") + + parser.add_argument("--ext", action="store", dest="ext", help="Extension to filter contracts") + + parser.add_argument( + "--nsamples", + action="store", + type=int, + dest="nsamples", + help="Number of contract samples used for training", + ) + + parser.add_argument( + "--ntop", + action="store", + type=int, + dest="ntop", + default=10, + help="Number of more similar contracts to show for testing", + ) + + parser.add_argument( + "--input", action="store", dest="input", help="File or directory used as input" + ) + + parser.add_argument( + "--version", help="displays the current version", version="0.0", action="version" + ) cryticparser.init(parser) @@ -74,6 +68,7 @@ def parse_args(): args = parser.parse_args() return args + # endregion ################################################################################### ################################################################################### @@ -81,27 +76,29 @@ def parse_args(): ################################################################################### ################################################################################### + def main(): args = parse_args() default_log = logging.INFO logger.setLevel(default_log) - + mode = args.mode if mode == "info": info(args) elif mode == "train": - train(args) + train(args) elif mode == "test": test(args) elif mode == "plot": plot(args) else: - logger.error('Invalid mode!. It should be one of these: %s' % ", ".join(modes)) + logger.error("Invalid mode!. It should be one of these: %s" % ", ".join(modes)) sys.exit(-1) -if __name__ == '__main__': + +if __name__ == "__main__": main() # endregion diff --git a/slither/tools/similarity/cache.py b/slither/tools/similarity/cache.py index de8896d02..81df40972 100644 --- a/slither/tools/similarity/cache.py +++ b/slither/tools/similarity/cache.py @@ -7,16 +7,18 @@ except ImportError: print("$ pip3 install numpy --user\n") sys.exit(-1) + def load_cache(infile, nsamples=None): cache = dict() with np.load(infile, allow_pickle=True) as data: - array = data['arr_0'][0] - for i,(x,y) in enumerate(array): + array = data["arr_0"][0] + for i, (x, y) in enumerate(array): cache[x] = y if i == nsamples: break return cache + def save_cache(cache, outfile): - np.savez(outfile,[np.array(cache)]) + np.savez(outfile, [np.array(cache)]) diff --git a/slither/tools/similarity/encode.py b/slither/tools/similarity/encode.py index 06b3691ed..3101a1562 100644 --- a/slither/tools/similarity/encode.py +++ b/slither/tools/similarity/encode.py @@ -2,16 +2,47 @@ import logging import os from slither import Slither -from slither.core.declarations import Structure, Enum, SolidityVariableComposed, SolidityVariable, Function +from slither.core.declarations import ( + Structure, + Enum, + SolidityVariableComposed, + SolidityVariable, + Function, +) from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple from slither.core.variables.state_variable import StateVariable -from slither.slithir.operations import Assignment, Index, Member, Length, Balance, Binary, \ - Unary, Condition, NewArray, NewStructure, NewContract, NewElementaryType, \ - SolidityCall, Push, Delete, EventCall, LibraryCall, InternalDynamicCall, \ - HighLevelCall, LowLevelCall, TypeConversion, Return, Transfer, Send, Unpack, InitArray, InternalCall -from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, ReferenceVariable +from slither.slithir.operations import ( + Assignment, + Index, + AccessMember, + Length, + Balance, + Binary, + Unary, + Condition, + NewArray, + NewStructure, + NewContract, + NewElementaryType, + SolidityCall, + Push, + Delete, + EventCall, + LibraryCall, + InternalDynamicCall, + HighLevelCall, + LowLevelCall, + TypeConversion, + Return, + Transfer, + Send, + Unpack, + InitArray, + InternalCall, +) +from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, IndexVariable from .cache import load_cache simil_logger = logging.getLogger("Slither-simil") @@ -20,11 +51,12 @@ compiler_logger.setLevel(logging.CRITICAL) slither_logger = logging.getLogger("Slither") slither_logger.setLevel(logging.CRITICAL) + def parse_target(target): if target is None: return None, None - parts = target.split('.') + parts = target.split(".") if len(parts) == 1: return None, parts[0] elif len(parts) == 2: @@ -32,25 +64,27 @@ def parse_target(target): else: simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") + def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): r = dict() if infile.endswith(".npz"): r = load_cache(infile, nsamples=nsamples) - else: + else: contracts = load_contracts(infile, ext=ext, nsamples=nsamples) for contract in contracts: - for x,ir in encode_contract(contract, **kwargs).items(): + for x, ir in encode_contract(contract, **kwargs).items(): if ir != []: y = " ".join(ir) r[x] = vmodel.get_sentence_vector(y) return r + def load_contracts(dirname, ext=None, nsamples=None, **kwargs): r = [] walk = list(os.walk(dirname)) for x, y, files in walk: - for f in files: + for f in files: if ext is None or f.endswith(ext): r.append(x + "/".join(y) + "/" + f) @@ -60,6 +94,7 @@ def load_contracts(dirname, ext=None, nsamples=None, **kwargs): # TODO: shuffle return r[:nsamples] + def ntype(_type): if isinstance(_type, ElementaryType): _type = str(_type) @@ -79,8 +114,8 @@ def ntype(_type): else: _type = str(_type) - _type = _type.replace(" memory","") - _type = _type.replace(" storage ref","") + _type = _type.replace(" memory", "") + _type = _type.replace(" storage ref", "") if "struct" in _type: return "struct" @@ -93,92 +128,94 @@ def ntype(_type): elif "mapping" in _type: return "mapping" else: - return _type.replace(" ","_") + return _type.replace(" ", "_") + def encode_ir(ir): # operations if isinstance(ir, Assignment): - return '({}):=({})'.format(encode_ir(ir.lvalue), encode_ir(ir.rvalue)) + return "({}):=({})".format(encode_ir(ir.lvalue), encode_ir(ir.rvalue)) if isinstance(ir, Index): - return 'index({})'.format(ntype(ir._type)) - if isinstance(ir, Member): - return 'member' #.format(ntype(ir._type)) + return "index({})".format(ntype(ir._type)) + if isinstance(ir, AccessMember): + return "member" # .format(ntype(ir._type)) if isinstance(ir, Length): - return 'length' + return "length" if isinstance(ir, Balance): - return 'balance' + return "balance" if isinstance(ir, Binary): - return 'binary({})'.format(ir.type_str) + return "binary({})".format(str(ir.type)) if isinstance(ir, Unary): - return 'unary({})'.format(ir.type_str) + return "unary({})".format(str(ir.type)) if isinstance(ir, Condition): - return 'condition({})'.format(encode_ir(ir.value)) + return "condition({})".format(encode_ir(ir.value)) if isinstance(ir, NewStructure): - return 'new_structure' + return "new_structure" if isinstance(ir, NewContract): - return 'new_contract' + return "new_contract" if isinstance(ir, NewArray): - return 'new_array({})'.format(ntype(ir._array_type)) + return "new_array({})".format(ntype(ir._array_type)) if isinstance(ir, NewElementaryType): - return 'new_elementary({})'.format(ntype(ir._type)) + return "new_elementary({})".format(ntype(ir._type)) if isinstance(ir, Push): - return 'push({},{})'.format(encode_ir(ir.value), encode_ir(ir.lvalue)) + return "push({},{})".format(encode_ir(ir.value), encode_ir(ir.lvalue)) if isinstance(ir, Delete): - return 'delete({},{})'.format(encode_ir(ir.lvalue), encode_ir(ir.variable)) + return "delete({},{})".format(encode_ir(ir.lvalue), encode_ir(ir.variable)) if isinstance(ir, SolidityCall): - return 'solidity_call({})'.format(ir.function.full_name) + return "solidity_call({})".format(ir.function.full_name) if isinstance(ir, InternalCall): - return 'internal_call({})'.format(ntype(ir._type_call)) - if isinstance(ir, EventCall): # is this useful? - return 'event' + return "internal_call({})".format(ntype(ir._type_call)) + if isinstance(ir, EventCall): # is this useful? + return "event" if isinstance(ir, LibraryCall): - return 'library_call' + return "library_call" if isinstance(ir, InternalDynamicCall): - return 'internal_dynamic_call' - if isinstance(ir, HighLevelCall): # TODO: improve - return 'high_level_call' - if isinstance(ir, LowLevelCall): # TODO: improve - return 'low_level_call' + return "internal_dynamic_call" + if isinstance(ir, HighLevelCall): # TODO: improve + return "high_level_call" + if isinstance(ir, LowLevelCall): # TODO: improve + return "low_level_call" if isinstance(ir, TypeConversion): - return 'type_conversion({})'.format(ntype(ir.type)) - if isinstance(ir, Return): # this can be improved using values - return 'return' #.format(ntype(ir.type)) + return "type_conversion({})".format(ntype(ir.type)) + if isinstance(ir, Return): # this can be improved using values + return "return" # .format(ntype(ir.type)) if isinstance(ir, Transfer): - return 'transfer({})'.format(encode_ir(ir.call_value)) + return "transfer({})".format(encode_ir(ir.call_value)) if isinstance(ir, Send): - return 'send({})'.format(encode_ir(ir.call_value)) - if isinstance(ir, Unpack): # TODO: improve - return 'unpack' - if isinstance(ir, InitArray): # TODO: improve - return 'init_array' - if isinstance(ir, Function): # TODO: investigate this - return 'function_solc' + return "send({})".format(encode_ir(ir.call_value)) + if isinstance(ir, Unpack): # TODO: improve + return "unpack" + if isinstance(ir, InitArray): # TODO: improve + return "init_array" + if isinstance(ir, Function): # TODO: investigate this + return "function_solc" # variables if isinstance(ir, Constant): - return 'constant({})'.format(ntype(ir._type)) + return "constant({})".format(ntype(ir._type)) if isinstance(ir, SolidityVariableComposed): - return 'solidity_variable_composed({})'.format(ir.name) + return "solidity_variable_composed({})".format(ir.name) if isinstance(ir, SolidityVariable): - return 'solidity_variable{}'.format(ir.name) + return "solidity_variable{}".format(ir.name) if isinstance(ir, TemporaryVariable): - return 'temporary_variable' - if isinstance(ir, ReferenceVariable): - return 'reference({})'.format(ntype(ir._type)) + return "temporary_variable" + if isinstance(ir, IndexVariable): + return "reference({})".format(ntype(ir._type)) if isinstance(ir, LocalVariable): - return 'local_solc_variable({})'.format(ir._location) + return "local_solc_variable({})".format(ir._location) if isinstance(ir, StateVariable): - return 'state_solc_variable({})'.format(ntype(ir._type)) + return "state_solc_variable({})".format(ntype(ir._type)) if isinstance(ir, LocalVariableInitFromTuple): - return 'local_variable_init_tuple' + return "local_variable_init_tuple" if isinstance(ir, TupleVariable): - return 'tuple_variable' + return "tuple_variable" # default else: - simil_logger.error(type(ir),"is missing encoding!") - return '' - + simil_logger.error(type(ir), "is missing encoding!") + return "" + + def encode_contract(cfilename, **kwargs): r = dict() @@ -186,7 +223,7 @@ def encode_contract(cfilename, **kwargs): try: slither = Slither(cfilename, **kwargs) except: - simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs['solc']) + simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs["solc"]) return r # Iterate over all the contracts @@ -198,7 +235,7 @@ def encode_contract(cfilename, **kwargs): if function.nodes == [] or function.is_constructor_variables: continue - x = (cfilename,contract.name,function.name) + x = (cfilename, contract.name, function.name) r[x] = [] @@ -210,5 +247,3 @@ def encode_contract(cfilename, **kwargs): for ir in node.irs: r[x].append(encode_ir(ir)) return r - - diff --git a/slither/tools/similarity/info.py b/slither/tools/similarity/info.py index e250aa991..b577bfd93 100644 --- a/slither/tools/similarity/info.py +++ b/slither/tools/similarity/info.py @@ -3,18 +3,19 @@ import sys import os.path import traceback -from .model import load_model +from .model import load_model from .encode import parse_target, encode_contract logging.basicConfig() logger = logging.getLogger("Slither-simil") + def info(args): try: model = args.model - if os.path.isfile(model): + if os.path.isfile(model): model = load_model(model) else: model = None @@ -22,22 +23,22 @@ def info(args): filename = args.filename contract, fname = parse_target(args.fname) solc = args.solc - + if filename is None and contract is None and fname is None: - logger.info("%s uses the following words:",args.model) + logger.info("%s uses the following words:", args.model) for word in model.get_words(): logger.info(word) sys.exit(0) if filename is None or contract is None or fname is None: - logger.error('The encode mode requires filename, contract and fname parameters.') + logger.error("The encode mode requires filename, contract and fname parameters.") sys.exit(-1) irs = encode_contract(filename, **vars(args)) if len(irs) == 0: sys.exit(-1) - - x = (filename,contract,fname) + + x = (filename, contract, fname) y = " ".join(irs[x]) logger.info("Function {} in contract {} is encoded as:".format(fname, contract)) @@ -47,8 +48,6 @@ def info(args): logger.info(fvector) except Exception: - logger.error('Error in %s' % args.filename) + logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) - - diff --git a/slither/tools/similarity/plot.py b/slither/tools/similarity/plot.py index 05d8bf921..75ef90c15 100644 --- a/slither/tools/similarity/plot.py +++ b/slither/tools/similarity/plot.py @@ -5,7 +5,7 @@ import operator import numpy as np import random -from .model import load_model +from .model import load_model from .encode import load_and_encode, parse_target try: @@ -17,10 +17,13 @@ except ImportError: logger = logging.getLogger("Slither-simil") + def plot(args): if decomposition is None or plt is None: - logger.error("ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:") + logger.error( + "ERROR: In order to use plot mode in slither-simil, you need to install sklearn and matplotlib:" + ) logger.error("$ pip3 install sklearn matplotlib --user") sys.exit(-1) @@ -29,50 +32,50 @@ def plot(args): model = args.model model = load_model(model) filename = args.filename - #contract = args.contract + # contract = args.contract contract, fname = parse_target(args.fname) - #solc = args.solc + # solc = args.solc infile = args.input - #ext = args.filter - #nsamples = args.nsamples + # ext = args.filter + # nsamples = args.nsamples if fname is None or infile is None: - logger.error('The plot mode requieres fname and input parameters.') + logger.error("The plot mode requieres fname and input parameters.") sys.exit(-1) - logger.info('Loading data..') + logger.info("Loading data..") cache = load_and_encode(infile, **vars(args)) data = list() fs = list() - logger.info('Procesing data..') - for (f,c,n),y in cache.items(): + logger.info("Procesing data..") + for (f, c, n), y in cache.items(): if (c == contract or contract is None) and n == fname: fs.append(f) data.append(y) if len(data) == 0: - logger.error('No contract was found with function %s', fname) + logger.error("No contract was found with function %s", fname) sys.exit(-1) data = np.array(data) pca = decomposition.PCA(n_components=2) tdata = pca.fit_transform(data) - logger.info('Plotting data..') - plt.figure(figsize=(20,10)) - assert(len(tdata) == len(fs)) - for ([x,y],l) in zip(tdata, fs): + logger.info("Plotting data..") + plt.figure(figsize=(20, 10)) + assert len(tdata) == len(fs) + for ([x, y], l) in zip(tdata, fs): x = random.gauss(0, 0.01) + x y = random.gauss(0, 0.01) + y - plt.scatter(x, y, c='blue') - plt.text(x-0.001,y+0.001, l) + plt.scatter(x, y, c="blue") + plt.text(x - 0.001, y + 0.001, l) + + logger.info("Saving figure to plot.png..") + plt.savefig("plot.png", bbox_inches="tight") - logger.info('Saving figure to plot.png..') - plt.savefig('plot.png', bbox_inches='tight') - except Exception: - logger.error('Error in %s' % args.filename) + logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/similarity/similarity.py b/slither/tools/similarity/similarity.py index 4cc3f2b35..3cf30acda 100644 --- a/slither/tools/similarity/similarity.py +++ b/slither/tools/similarity/similarity.py @@ -1,5 +1,6 @@ import numpy as np + def similarity(v1, v2): n1 = np.linalg.norm(v1) n2 = np.linalg.norm(v2) diff --git a/slither/tools/similarity/test.py b/slither/tools/similarity/test.py index 15a39cc13..89043a5a1 100755 --- a/slither/tools/similarity/test.py +++ b/slither/tools/similarity/test.py @@ -5,50 +5,51 @@ import traceback import operator import numpy as np -from .model import load_model -from .encode import encode_contract, load_and_encode, parse_target -from .cache import save_cache +from .model import load_model +from .encode import encode_contract, load_and_encode, parse_target +from .cache import save_cache from .similarity import similarity logger = logging.getLogger("Slither-simil") + def test(args): try: model = args.model model = load_model(model) filename = args.filename - contract, fname = parse_target(args.fname) + contract, fname = parse_target(args.fname) infile = args.input ntop = args.ntop if filename is None or contract is None or fname is None or infile is None: - logger.error('The test mode requires filename, contract, fname and input parameters.') + logger.error("The test mode requires filename, contract, fname and input parameters.") sys.exit(-1) irs = encode_contract(filename, **vars(args)) if len(irs) == 0: sys.exit(-1) - y = " ".join(irs[(filename,contract,fname)]) - + y = " ".join(irs[(filename, contract, fname)]) + fvector = model.get_sentence_vector(y) cache = load_and_encode(infile, model, **vars(args)) - #save_cache("cache.npz", cache) + # save_cache("cache.npz", cache) r = dict() - for x,y in cache.items(): + for x, y in cache.items(): r[x] = similarity(fvector, y) r = sorted(r.items(), key=operator.itemgetter(1), reverse=True) logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop) format_table = "{: <65} {: <20} {: <20} {: <10}" logger.info(format_table.format(*["filename", "contract", "function", "score"])) - for x,score in r[:ntop]: + for x, score in r[:ntop]: score = str(round(score, 3)) - logger.info(format_table.format(*(list(x)+[score]))) + logger.info(format_table.format(*(list(x) + [score]))) except Exception: - logger.error('Error in %s' % args.filename) + logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/similarity/train.py b/slither/tools/similarity/train.py index e810450a6..3052ae6c5 100755 --- a/slither/tools/similarity/train.py +++ b/slither/tools/similarity/train.py @@ -5,12 +5,13 @@ import traceback import operator import os -from .model import train_unsupervised -from .encode import encode_contract, load_contracts -from .cache import save_cache +from .model import train_unsupervised +from .encode import encode_contract, load_contracts +from .cache import save_cache logger = logging.getLogger("Slither-simil") + def train(args): try: @@ -20,35 +21,37 @@ def train(args): nsamples = args.nsamples if dirname is None: - logger.error('The train mode requires the input parameter.') + logger.error("The train mode requires the input parameter.") sys.exit(-1) contracts = load_contracts(dirname, **vars(args)) - logger.info('Saving extracted data into %s', last_data_train_filename) + logger.info("Saving extracted data into %s", last_data_train_filename) cache = [] - with open(last_data_train_filename, 'w') as f: + with open(last_data_train_filename, "w") as f: for filename in contracts: - #cache[filename] = dict() - for (filename, contract, function), ir in encode_contract(filename, **vars(args)).items(): + # cache[filename] = dict() + for (filename, contract, function), ir in encode_contract( + filename, **vars(args) + ).items(): if ir != []: x = " ".join(ir) - f.write(x+"\n") + f.write(x + "\n") cache.append((os.path.split(filename)[-1], contract, function, x)) - logger.info('Starting training') - model = train_unsupervised(input=last_data_train_filename, model='skipgram') - logger.info('Training complete') - logger.info('Saving model') + logger.info("Starting training") + model = train_unsupervised(input=last_data_train_filename, model="skipgram") + logger.info("Training complete") + logger.info("Saving model") model.save_model(model_filename) - for i,(filename, contract, function, irs) in enumerate(cache): + for i, (filename, contract, function, irs) in enumerate(cache): cache[i] = ((filename, contract, function), model.get_sentence_vector(irs)) - logger.info('Saving cache in cache.npz') + logger.info("Saving cache in cache.npz") save_cache(cache, "cache.npz") - logger.info('Done!') - + logger.info("Done!") + except Exception: - logger.error('Error in %s' % args.filename) + logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/slither_format/__main__.py b/slither/tools/slither_format/__main__.py index 66787ee0e..b99485236 100644 --- a/slither/tools/slither_format/__main__.py +++ b/slither/tools/slither_format/__main__.py @@ -10,63 +10,77 @@ logging.basicConfig() logger = logging.getLogger("Slither").setLevel(logging.INFO) # Slither detectors for which slither-format currently works -available_detectors = ["unused-state", - "solc-version", - "pragma", - "naming-convention", - "external-function", - "constable-states", - "constant-function-asm", - "constatnt-function-state"] +available_detectors = [ + "unused-state", + "solc-version", + "pragma", + "naming-convention", + "external-function", + "constable-states", + "constant-function-asm", + "constatnt-function-state", +] detectors_to_run = [] + def parse_args(): """ Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description='slither_format', - usage='slither_format filename') - - parser.add_argument('filename', help='The filename of the contract or truffle directory to analyze.') - parser.add_argument('--verbose-test', '-v', help='verbose mode output for testing',action='store_true',default=False) - parser.add_argument('--verbose-json', '-j', help='verbose json output',action='store_true',default=False) - parser.add_argument('--version', - help='displays the current version', - version='0.1.0', - action='version') - - parser.add_argument('--config-file', - help='Provide a config file (default: slither.config.json)', - action='store', - dest='config_file', - default='slither.config.json') - - - group_detector = parser.add_argument_group('Detectors') - group_detector.add_argument('--detect', - help='Comma-separated list of detectors, defaults to all, ' - 'available detectors: {}'.format( - ', '.join(d for d in available_detectors)), - action='store', - dest='detectors_to_run', - default='all') - - group_detector.add_argument('--exclude', - help='Comma-separated list of detectors to exclude,' - 'available detectors: {}'.format( - ', '.join(d for d in available_detectors)), - action='store', - dest='detectors_to_exclude', - default='all') - - cryticparser.init(parser) - - if len(sys.argv) == 1: - parser.print_help(sys.stderr) + parser = argparse.ArgumentParser(description="slither_format", usage="slither_format filename") + + parser.add_argument( + "filename", help="The filename of the contract or truffle directory to analyze." + ) + parser.add_argument( + "--verbose-test", + "-v", + help="verbose mode output for testing", + action="store_true", + default=False, + ) + parser.add_argument( + "--verbose-json", "-j", help="verbose json output", action="store_true", default=False + ) + parser.add_argument( + "--version", help="displays the current version", version="0.1.0", action="version" + ) + + parser.add_argument( + "--config-file", + help="Provide a config file (default: slither.config.json)", + action="store", + dest="config_file", + default="slither.config.json", + ) + + group_detector = parser.add_argument_group("Detectors") + group_detector.add_argument( + "--detect", + help="Comma-separated list of detectors, defaults to all, " + "available detectors: {}".format(", ".join(d for d in available_detectors)), + action="store", + dest="detectors_to_run", + default="all", + ) + + group_detector.add_argument( + "--exclude", + help="Comma-separated list of detectors to exclude," + "available detectors: {}".format(", ".join(d for d in available_detectors)), + action="store", + dest="detectors_to_exclude", + default="all", + ) + + cryticparser.init(parser) + + if len(sys.argv) == 1: + parser.print_help(sys.stderr) sys.exit(1) - + return parser.parse_args() @@ -80,11 +94,12 @@ def main(): read_config_file(args) - # Perform slither analysis on the given filename slither = Slither(args.filename, **vars(args)) # Format the input files based on slither analysis slither_format(slither, **vars(args)) -if __name__ == '__main__': + + +if __name__ == "__main__": main() diff --git a/slither/tools/slither_format/slither_format.py b/slither/tools/slither_format/slither_format.py index 597c17753..659b69557 100644 --- a/slither/tools/slither_format/slither_format.py +++ b/slither/tools/slither_format/slither_format.py @@ -11,27 +11,29 @@ from slither.detectors.attributes.const_functions_state import ConstantFunctions from slither.utils.colors import yellow logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('Slither.Format') +logger = logging.getLogger("Slither.Format") all_detectors = { - 'unused-state': UnusedStateVars, - 'solc-version': IncorrectSolc, - 'pragma': ConstantPragma, - 'naming-convention': NamingConvention, - 'external-function': ExternalFunction, - 'constable-states' : ConstCandidateStateVars, - 'constant-function-asm': ConstantFunctionsAsm, - 'constant-functions-state': ConstantFunctionsState + "unused-state": UnusedStateVars, + "solc-version": IncorrectSolc, + "pragma": ConstantPragma, + "naming-convention": NamingConvention, + "external-function": ExternalFunction, + "constable-states": ConstCandidateStateVars, + "constant-function-asm": ConstantFunctionsAsm, + "constant-functions-state": ConstantFunctionsState, } + def slither_format(slither, **kwargs): - '''' + """' Keyword Args: detectors_to_run (str): Comma-separated list of detectors, defaults to all - ''' + """ - detectors_to_run = choose_detectors(kwargs.get('detectors_to_run', 'all'), - kwargs.get('detectors_to_exclude', '')) + detectors_to_run = choose_detectors( + kwargs.get("detectors_to_run", "all"), kwargs.get("detectors_to_exclude", "") + ) for detector in detectors_to_run: slither.register_detector(detector) @@ -42,32 +44,32 @@ def slither_format(slither, **kwargs): detector_results = [x for x in detector_results if x] # remove empty results detector_results = [item for sublist in detector_results for item in sublist] # flatten - export = Path('crytic-export', 'patches') + export = Path("crytic-export", "patches") export.mkdir(parents=True, exist_ok=True) counter_result = 0 - logger.info(yellow('slither-format is in beta, carefully review each patch before merging it.')) + logger.info(yellow("slither-format is in beta, carefully review each patch before merging it.")) for result in detector_results: - if not 'patches' in result: + if not "patches" in result: continue one_line_description = result["description"].split("\n")[0] - export_result = Path(export, f'{counter_result}') + export_result = Path(export, f"{counter_result}") export_result.mkdir(parents=True, exist_ok=True) counter_result += 1 counter = 0 - logger.info(f'Issue: {one_line_description}') - logger.info(f'Generated: ({export_result})') + logger.info(f"Issue: {one_line_description}") + logger.info(f"Generated: ({export_result})") - for file, diff, in result['patches_diff'].items(): - filename = f'fix_{counter}.patch' + for file, diff, in result["patches_diff"].items(): + filename = f"fix_{counter}.patch" path = Path(export_result, filename) - logger.info(f'\t- {filename}') - with open(path, 'w') as f: + logger.info(f"\t- {filename}") + with open(path, "w") as f: f.write(diff) counter += 1 @@ -79,26 +81,28 @@ def slither_format(slither, **kwargs): ################################################################################### ################################################################################### + def choose_detectors(detectors_to_run, detectors_to_exclude): # If detectors are specified, run only these ones cls_detectors_to_run = [] - exclude = detectors_to_exclude.split(',') - if detectors_to_run == 'all': + exclude = detectors_to_exclude.split(",") + if detectors_to_run == "all": for d in all_detectors: if d in exclude: continue cls_detectors_to_run.append(all_detectors[d]) else: - exclude = detectors_to_exclude.split(',') - for d in detectors_to_run.split(','): + exclude = detectors_to_exclude.split(",") + for d in detectors_to_run.split(","): if d in all_detectors: if d in exclude: continue cls_detectors_to_run.append(all_detectors[d]) else: - raise Exception('Error: {} is not a detector'.format(d)) + raise Exception("Error: {} is not a detector".format(d)) return cls_detectors_to_run + # endregion ################################################################################### ################################################################################### @@ -106,6 +110,7 @@ def choose_detectors(detectors_to_run, detectors_to_exclude): ################################################################################### ################################################################################### + def print_patches(number_of_slither_results, patches): logger.info("Number of Slither results: " + str(number_of_slither_results)) number_of_patches = 0 @@ -115,39 +120,38 @@ def print_patches(number_of_slither_results, patches): for file in patches: logger.info("Patch file: " + file) for patch in patches[file]: - logger.info("Detector: " + patch['detector']) - logger.info("Old string: " + patch['old_string'].replace("\n","")) - logger.info("New string: " + patch['new_string'].replace("\n","")) - logger.info("Location start: " + str(patch['start'])) - logger.info("Location end: " + str(patch['end'])) + logger.info("Detector: " + patch["detector"]) + logger.info("Old string: " + patch["old_string"].replace("\n", "")) + logger.info("New string: " + patch["new_string"].replace("\n", "")) + logger.info("Location start: " + str(patch["start"])) + logger.info("Location end: " + str(patch["end"])) + def print_patches_json(number_of_slither_results, patches): - print('{',end='') - print("\"Number of Slither results\":" + '"' + str(number_of_slither_results) + '",') - print("\"Number of patchlets\":" + "\"" + str(len(patches)) + "\"", ',') - print("\"Patchlets\":" + '[') + print("{", end="") + print('"Number of Slither results":' + '"' + str(number_of_slither_results) + '",') + print('"Number of patchlets":' + '"' + str(len(patches)) + '"', ",") + print('"Patchlets":' + "[") for index, file in enumerate(patches): if index > 0: - print(',') - print('{',end='') - print("\"Patch file\":" + '"' + file + '",') - print("\"Number of patches\":" + "\"" + str(len(patches[file])) + "\"", ',') - print("\"Patches\":" + '[') + print(",") + print("{", end="") + print('"Patch file":' + '"' + file + '",') + print('"Number of patches":' + '"' + str(len(patches[file])) + '"', ",") + print('"Patches":' + "[") for index, patch in enumerate(patches[file]): if index > 0: - print(',') - print('{',end='') - print("\"Detector\":" + '"' + patch['detector'] + '",') - print("\"Old string\":" + '"' + patch['old_string'].replace("\n","") + '",') - print("\"New string\":" + '"' + patch['new_string'].replace("\n","") + '",') - print("\"Location start\":" + '"' + str(patch['start']) + '",') - print("\"Location end\":" + '"' + str(patch['end']) + '"') - if 'overlaps' in patch: - print("\"Overlaps\":" + "Yes") - print('}',end='') - print(']',end='') - print('}',end='') - print(']',end='') - print('}') - - + print(",") + print("{", end="") + print('"Detector":' + '"' + patch["detector"] + '",') + print('"Old string":' + '"' + patch["old_string"].replace("\n", "") + '",') + print('"New string":' + '"' + patch["new_string"].replace("\n", "") + '",') + print('"Location start":' + '"' + str(patch["start"]) + '",') + print('"Location end":' + '"' + str(patch["end"]) + '"') + if "overlaps" in patch: + print('"Overlaps":' + "Yes") + print("}", end="") + print("]", end="") + print("}", end="") + print("]", end="") + print("}") diff --git a/slither/tools/upgradeability/__main__.py b/slither/tools/upgradeability/__main__.py index 52b3be060..8903c01c5 100644 --- a/slither/tools/upgradeability/__main__.py +++ b/slither/tools/upgradeability/__main__.py @@ -12,7 +12,12 @@ from slither.utils.colors import red from slither.utils.output import output_to_json from .checks import all_checks from .checks.abstract_checks import AbstractCheck -from .utils.command_line import output_detectors_json, output_wiki, output_detectors, output_to_markdown +from .utils.command_line import ( + output_detectors_json, + output_wiki, + output_detectors, + output_to_markdown, +) logging.basicConfig() logger = logging.getLogger("Slither") @@ -21,49 +26,53 @@ logger.setLevel(logging.INFO) def parse_args(): parser = argparse.ArgumentParser( - description='Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.', - usage="slither-check-upgradeability contract.sol ContractName") - - parser.add_argument('contract.sol', help='Codebase to analyze') - parser.add_argument('ContractName', help='Contract name (logic contract)') - - parser.add_argument('--proxy-name', help='Proxy name') - parser.add_argument('--proxy-filename', help='Proxy filename (if different)') - - parser.add_argument('--new-contract-name', help='New contract name (if changed)') - parser.add_argument('--new-contract-filename', help='New implementation filename (if different)') - - parser.add_argument('--json', - help='Export the results as a JSON file ("--json -" to export to stdout)', - action='store', - default=False) - - parser.add_argument('--list-detectors', - help='List available detectors', - action=ListDetectors, - nargs=0, - default=False) - - parser.add_argument('--markdown-root', - help='URL for markdown generation', - action='store', - default="") - - parser.add_argument('--wiki-detectors', - help=argparse.SUPPRESS, - action=OutputWiki, - default=False) - - parser.add_argument('--list-detectors-json', - help=argparse.SUPPRESS, - action=ListDetectorsJson, - nargs=0, - default=False) - - parser.add_argument('--markdown', - help=argparse.SUPPRESS, - action=OutputMarkdown, - default=False) + description="Slither Upgradeability Checks. For usage information see https://github.com/crytic/slither/wiki/Upgradeability-Checks.", + usage="slither-check-upgradeability contract.sol ContractName", + ) + + parser.add_argument("contract.sol", help="Codebase to analyze") + parser.add_argument("ContractName", help="Contract name (logic contract)") + + parser.add_argument("--proxy-name", help="Proxy name") + parser.add_argument("--proxy-filename", help="Proxy filename (if different)") + + parser.add_argument("--new-contract-name", help="New contract name (if changed)") + parser.add_argument( + "--new-contract-filename", help="New implementation filename (if different)" + ) + + parser.add_argument( + "--json", + help='Export the results as a JSON file ("--json -" to export to stdout)', + action="store", + default=False, + ) + + parser.add_argument( + "--list-detectors", + help="List available detectors", + action=ListDetectors, + nargs=0, + default=False, + ) + + parser.add_argument( + "--markdown-root", help="URL for markdown generation", action="store", default="" + ) + + parser.add_argument( + "--wiki-detectors", help=argparse.SUPPRESS, action=OutputWiki, default=False + ) + + parser.add_argument( + "--list-detectors-json", + help=argparse.SUPPRESS, + action=ListDetectorsJson, + nargs=0, + default=False, + ) + + parser.add_argument("--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False) cryticparser.init(parser) @@ -80,6 +89,7 @@ def parse_args(): ################################################################################### ################################################################################### + def _get_checks(): detectors = [getattr(all_checks, name) for name in dir(all_checks)] detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)] @@ -123,13 +133,18 @@ def _run_checks(detectors): def _checks_on_contract(detectors, contract): - detectors = [d(logger, contract) for d in detectors if (not d.REQUIRE_PROXY and - not d.REQUIRE_CONTRACT_V2)] + detectors = [ + d(logger, contract) + for d in detectors + if (not d.REQUIRE_PROXY and not d.REQUIRE_CONTRACT_V2) + ] return _run_checks(detectors), len(detectors) def _checks_on_contract_update(detectors, contract_v1, contract_v2): - detectors = [d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2] + detectors = [ + d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2 + ] return _run_checks(detectors), len(detectors) @@ -147,15 +162,11 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy): def main(): - json_results = { - 'proxy-present': False, - 'contract_v2-present': False, - 'detectors': [] - } + json_results = {"proxy-present": False, "contract_v2-present": False, "detectors": []} args = parse_args() - v1_filename = vars(args)['contract.sol'] + v1_filename = vars(args)["contract.sol"] number_detectors_run = 0 detectors = _get_checks() try: @@ -165,14 +176,14 @@ def main(): v1_name = args.ContractName v1_contract = v1.get_contract_from_name(v1_name) if v1_contract is None: - info = 'Contract {} not found in {}'.format(v1_name, v1.filename) + info = "Contract {} not found in {}".format(v1_name, v1.filename) logger.error(red(info)) if args.json: output_to_json(args.json, str(info), json_results) return detectors_results, number_detectors = _checks_on_contract(detectors, v1_contract) - json_results['detectors'] += detectors_results + json_results["detectors"] += detectors_results number_detectors_run += number_detectors # Analyze Proxy @@ -185,15 +196,17 @@ def main(): proxy_contract = proxy.get_contract_from_name(args.proxy_name) if proxy_contract is None: - info = 'Proxy {} not found in {}'.format(args.proxy_name, proxy.filename) + info = "Proxy {} not found in {}".format(args.proxy_name, proxy.filename) logger.error(red(info)) if args.json: output_to_json(args.json, str(info), json_results) return - json_results['proxy-present'] = True + json_results["proxy-present"] = True - detectors_results, number_detectors = _checks_on_contract_and_proxy(detectors, v1_contract, proxy_contract) - json_results['detectors'] += detectors_results + detectors_results, number_detectors = _checks_on_contract_and_proxy( + detectors, v1_contract, proxy_contract + ) + json_results["detectors"] += detectors_results number_detectors_run += number_detectors # Analyze new version if args.new_contract_name: @@ -204,30 +217,36 @@ def main(): v2_contract = v2.get_contract_from_name(args.new_contract_name) if v2_contract is None: - info = 'New logic contract {} not found in {}'.format(args.new_contract_name, v2.filename) + info = "New logic contract {} not found in {}".format( + args.new_contract_name, v2.filename + ) logger.error(red(info)) if args.json: output_to_json(args.json, str(info), json_results) return - json_results['contract_v2-present'] = True + json_results["contract_v2-present"] = True if proxy_contract: - detectors_results, _ = _checks_on_contract_and_proxy(detectors, - v2_contract, - proxy_contract) + detectors_results, _ = _checks_on_contract_and_proxy( + detectors, v2_contract, proxy_contract + ) - json_results['detectors'] += detectors_results + json_results["detectors"] += detectors_results - detectors_results, number_detectors = _checks_on_contract_update(detectors, v1_contract, v2_contract) - json_results['detectors'] += detectors_results + detectors_results, number_detectors = _checks_on_contract_update( + detectors, v1_contract, v2_contract + ) + json_results["detectors"] += detectors_results number_detectors_run += number_detectors # If there is a V2, we run the contract-only check on the V2 detectors_results, _ = _checks_on_contract(detectors, v2_contract) - json_results['detectors'] += detectors_results + json_results["detectors"] += detectors_results number_detectors_run += number_detectors - logger.info(f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run') + logger.info( + f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run' + ) if args.json: output_to_json(args.json, None, json_results) @@ -237,4 +256,5 @@ def main(): output_to_json(args.json, str(e), json_results) return + # endregion diff --git a/slither/tools/upgradeability/checks/abstract_checks.py b/slither/tools/upgradeability/checks/abstract_checks.py index 94d55e107..05a0d1182 100644 --- a/slither/tools/upgradeability/checks/abstract_checks.py +++ b/slither/tools/upgradeability/checks/abstract_checks.py @@ -19,28 +19,28 @@ classification_colors = { CheckClassification.INFORMATIONAL: green, CheckClassification.LOW: yellow, CheckClassification.MEDIUM: yellow, - CheckClassification.HIGH: red + CheckClassification.HIGH: red, } classification_txt = { - CheckClassification.INFORMATIONAL: 'Informational', - CheckClassification.LOW: 'Low', - CheckClassification.MEDIUM: 'Medium', - CheckClassification.HIGH: 'High', + CheckClassification.INFORMATIONAL: "Informational", + CheckClassification.LOW: "Low", + CheckClassification.MEDIUM: "Medium", + CheckClassification.HIGH: "High", } class AbstractCheck(metaclass=abc.ABCMeta): - ARGUMENT = '' - HELP = '' + ARGUMENT = "" + HELP = "" IMPACT = None - WIKI = '' + WIKI = "" - WIKI_TITLE = '' - WIKI_DESCRIPTION = '' - WIKI_EXPLOIT_SCENARIO = '' - WIKI_RECOMMENDATION = '' + WIKI_TITLE = "" + WIKI_DESCRIPTION = "" + WIKI_EXPLOIT_SCENARIO = "" + WIKI_RECOMMENDATION = "" REQUIRE_CONTRACT = False REQUIRE_PROXY = False @@ -53,43 +53,69 @@ class AbstractCheck(metaclass=abc.ABCMeta): self.contract_v2 = contract_v2 if not self.ARGUMENT: - raise IncorrectCheckInitialization('NAME is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "NAME is not initialized {}".format(self.__class__.__name__) + ) if not self.HELP: - raise IncorrectCheckInitialization('HELP is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "HELP is not initialized {}".format(self.__class__.__name__) + ) if not self.WIKI: - raise IncorrectCheckInitialization('WIKI is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "WIKI is not initialized {}".format(self.__class__.__name__) + ) if not self.WIKI_TITLE: - raise IncorrectCheckInitialization('WIKI_TITLE is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "WIKI_TITLE is not initialized {}".format(self.__class__.__name__) + ) if not self.WIKI_DESCRIPTION: - raise IncorrectCheckInitialization('WIKI_DESCRIPTION is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "WIKI_DESCRIPTION is not initialized {}".format(self.__class__.__name__) + ) - if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [CheckClassification.INFORMATIONAL]: - raise IncorrectCheckInitialization('WIKI_EXPLOIT_SCENARIO is not initialized {}'.format(self.__class__.__name__)) + if not self.WIKI_EXPLOIT_SCENARIO and self.IMPACT not in [ + CheckClassification.INFORMATIONAL + ]: + raise IncorrectCheckInitialization( + "WIKI_EXPLOIT_SCENARIO is not initialized {}".format(self.__class__.__name__) + ) if not self.WIKI_RECOMMENDATION: - raise IncorrectCheckInitialization('WIKI_RECOMMENDATION is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "WIKI_RECOMMENDATION is not initialized {}".format(self.__class__.__name__) + ) if self.REQUIRE_PROXY and self.REQUIRE_CONTRACT_V2: # This is not a fundatemenal issues # But it requires to change __main__ to avoid running two times the detectors - txt = 'REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}'.format(self.__class__.__name__) + txt = "REQUIRE_PROXY and REQUIRE_CONTRACT_V2 needs change in __main___ {}".format( + self.__class__.__name__ + ) raise IncorrectCheckInitialization(txt) - if self.IMPACT not in [CheckClassification.LOW, - CheckClassification.MEDIUM, - CheckClassification.HIGH, - CheckClassification.INFORMATIONAL]: - raise IncorrectCheckInitialization('IMPACT is not initialized {}'.format(self.__class__.__name__)) + if self.IMPACT not in [ + CheckClassification.LOW, + CheckClassification.MEDIUM, + CheckClassification.HIGH, + CheckClassification.INFORMATIONAL, + ]: + raise IncorrectCheckInitialization( + "IMPACT is not initialized {}".format(self.__class__.__name__) + ) if self.REQUIRE_CONTRACT_V2 and contract_v2 is None: - raise IncorrectCheckInitialization('ContractV2 is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "ContractV2 is not initialized {}".format(self.__class__.__name__) + ) if self.REQUIRE_PROXY and proxy is None: - raise IncorrectCheckInitialization('Proxy is not initialized {}'.format(self.__class__.__name__)) + raise IncorrectCheckInitialization( + "Proxy is not initialized {}".format(self.__class__.__name__) + ) @abc.abstractmethod def _check(self): @@ -102,19 +128,17 @@ class AbstractCheck(metaclass=abc.ABCMeta): all_results = [r.data for r in all_results] if all_results: if self.logger: - info = '\n' + info = "\n" for idx, result in enumerate(all_results): - info += result['description'] - info += 'Reference: {}'.format(self.WIKI) + info += result["description"] + info += "Reference: {}".format(self.WIKI) self._log(info) return all_results def generate_result(self, info, additional_fields=None): - output = Output(info, - additional_fields, - markdown_root=self.contract.slither.markdown_root) + output = Output(info, additional_fields, markdown_root=self.contract.slither.markdown_root) - output.data['check'] = self.ARGUMENT + output.data["check"] = self.ARGUMENT return output diff --git a/slither/tools/upgradeability/checks/all_checks.py b/slither/tools/upgradeability/checks/all_checks.py index 1c41316c6..fcedb69f8 100644 --- a/slither/tools/upgradeability/checks/all_checks.py +++ b/slither/tools/upgradeability/checks/all_checks.py @@ -1,11 +1,23 @@ -from .initialization import (InitializablePresent, InitializableInherited, - InitializableInitializer, MissingInitializerModifier, MissingCalls, MultipleCalls, InitializeTarget) +from .initialization import ( + InitializablePresent, + InitializableInherited, + InitializableInitializer, + MissingInitializerModifier, + MissingCalls, + MultipleCalls, + InitializeTarget, +) from .functions_ids import IDCollision, FunctionShadowing from .variable_initialization import VariableWithInit -from .variables_order import (MissingVariable, DifferentVariableContractProxy, - DifferentVariableContractNewContract, ExtraVariablesProxy, ExtraVariablesNewContract) +from .variables_order import ( + MissingVariable, + DifferentVariableContractProxy, + DifferentVariableContractNewContract, + ExtraVariablesProxy, + ExtraVariablesNewContract, +) -from .constant import WereConstant, BecameConstant \ No newline at end of file +from .constant import WereConstant, BecameConstant diff --git a/slither/tools/upgradeability/checks/constant.py b/slither/tools/upgradeability/checks/constant.py index 60f37f8d3..e1d547e28 100644 --- a/slither/tools/upgradeability/checks/constant.py +++ b/slither/tools/upgradeability/checks/constant.py @@ -2,17 +2,17 @@ from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck, C class WereConstant(AbstractCheck): - ARGUMENT = 'were-constant' + ARGUMENT = "were-constant" IMPACT = CheckClassification.HIGH - HELP = 'Variables that should be constant' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant' - WIKI_TITLE = 'Variables that should be constant' - WIKI_DESCRIPTION = ''' + HELP = "Variables that should be constant" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-be-constant" + WIKI_TITLE = "Variables that should be constant" + WIKI_DESCRIPTION = """ Detect state variables that should be `constant̀`. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -28,11 +28,11 @@ contract ContractV2{ ``` Because `variable2` is not anymore a `constant`, the storage location of `variable3` will be different. As a result, `ContractV2` will have a corrupted storage layout. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Do not remove `constant` from a state variables during an update. -''' +""" REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True @@ -66,12 +66,13 @@ Do not remove `constant` from a state variables during an update. if state_v1.is_constant: if not state_v2.is_constant: # If v2 has additional non constant variables, we need to skip them - if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type) - and v2_additional_variables > 0): + if ( + state_v1.name != state_v2.name or state_v1.type != state_v2.type + ) and v2_additional_variables > 0: v2_additional_variables -= 1 idx_v2 += 1 continue - info = [state_v1, ' was constant, but ', state_v2, 'is not.\n'] + info = [state_v1, " was constant, but ", state_v2, "is not.\n"] json = self.generate_result(info) results.append(json) @@ -80,19 +81,20 @@ Do not remove `constant` from a state variables during an update. return results + class BecameConstant(AbstractCheck): - ARGUMENT = 'became-constant' + ARGUMENT = "became-constant" IMPACT = CheckClassification.HIGH - HELP = 'Variables that should not be constant' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant' - WIKI_TITLE = 'Variables that should not be constant' + HELP = "Variables that should not be constant" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#variables-that-should-not-be-constant" + WIKI_TITLE = "Variables that should not be constant" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect state variables that should not be `constant̀`. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -108,11 +110,11 @@ contract ContractV2{ ``` Because `variable2` is now a `constant`, the storage location of `variable3` will be different. As a result, `ContractV2` will have a corrupted storage layout. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Do not make an existing state variable `constant`. -''' +""" REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True @@ -146,13 +148,14 @@ Do not make an existing state variable `constant`. if state_v1.is_constant: if not state_v2.is_constant: # If v2 has additional non constant variables, we need to skip them - if ((state_v1.name != state_v2.name or state_v1.type != state_v2.type) - and v2_additional_variables > 0): + if ( + state_v1.name != state_v2.name or state_v1.type != state_v2.type + ) and v2_additional_variables > 0: v2_additional_variables -= 1 idx_v2 += 1 continue elif state_v2.is_constant: - info = [state_v1, ' was not constant but ', state_v2, ' is.\n'] + info = [state_v1, " was not constant but ", state_v2, " is.\n"] json = self.generate_result(info) results.append(json) diff --git a/slither/tools/upgradeability/checks/functions_ids.py b/slither/tools/upgradeability/checks/functions_ids.py index ecacbb798..cbc822d20 100644 --- a/slither/tools/upgradeability/checks/functions_ids.py +++ b/slither/tools/upgradeability/checks/functions_ids.py @@ -5,11 +5,16 @@ from slither.utils.function import get_function_id def get_signatures(c): functions = c.functions - functions = [f.full_name for f in functions if f.visibility in ['public', 'external'] and - not f.is_constructor and not f.is_fallback] + functions = [ + f.full_name + for f in functions + if f.visibility in ["public", "external"] and not f.is_constructor and not f.is_fallback + ] variables = c.state_variables - variables = [variable.name + '()' for variable in variables if variable.visibility in ['public']] + variables = [ + variable.name + "()" for variable in variables if variable.visibility in ["public"] + ] return list(set(functions + variables)) @@ -21,26 +26,26 @@ def _get_function_or_variable(contract, signature): for variable in contract.state_variables: # Todo: can lead to incorrect variable in case of shadowing - if variable.visibility in ['public']: - if variable.name + '()' == signature: + if variable.visibility in ["public"]: + if variable.name + "()" == signature: return variable - raise SlitherError(f'Function id checks: {signature} not found in {contract.name}') + raise SlitherError(f"Function id checks: {signature} not found in {contract.name}") class IDCollision(AbstractCheck): - ARGUMENT = 'function-id-collision' + ARGUMENT = "function-id-collision" IMPACT = CheckClassification.HIGH - HELP = 'Functions ids collision' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions' - WIKI_TITLE = 'Functions ids collisions' + HELP = "Functions ids collision" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-ids-collisions" + WIKI_TITLE = "Functions ids collisions" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect function id collision between the contract and the proxy. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ function gsf() public { @@ -56,11 +61,11 @@ contract Proxy{ ``` `Proxy.tgeo()` and `Contract.gsf()` have the same function id (0x67e43e43). As a result, `Proxy.tgeo()` will shadow Contract.gsf()`. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Rename the function. Avoid public functions in the proxy. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = True @@ -77,11 +82,18 @@ Rename the function. Avoid public functions in the proxy. for (k, _) in signatures_ids_implem.items(): if k in signatures_ids_proxy: if signatures_ids_implem[k] != signatures_ids_proxy[k]: - implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k]) + implem_function = _get_function_or_variable( + self.contract, signatures_ids_implem[k] + ) proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k]) - info = ['Function id collision found: ', implem_function, - ' ', proxy_function, '\n'] + info = [ + "Function id collision found: ", + implem_function, + " ", + proxy_function, + "\n", + ] json = self.generate_result(info) results.append(json) @@ -89,18 +101,18 @@ Rename the function. Avoid public functions in the proxy. class FunctionShadowing(AbstractCheck): - ARGUMENT = 'function-shadowing' + ARGUMENT = "function-shadowing" IMPACT = CheckClassification.HIGH - HELP = 'Functions shadowing' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing' - WIKI_TITLE = 'Functions shadowing' + HELP = "Functions shadowing" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#functions-shadowing" + WIKI_TITLE = "Functions shadowing" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect function shadowing between the contract and the proxy. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ function get() public { @@ -115,11 +127,11 @@ contract Proxy{ } ``` `Proxy.get` will shadow any call to `get()`. As a result `get()` is never executed in the logic contract and cannot be updated. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Rename the function. Avoid public functions in the proxy. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = True @@ -136,11 +148,18 @@ Rename the function. Avoid public functions in the proxy. for (k, _) in signatures_ids_implem.items(): if k in signatures_ids_proxy: if signatures_ids_implem[k] == signatures_ids_proxy[k]: - implem_function = _get_function_or_variable(self.contract, signatures_ids_implem[k]) + implem_function = _get_function_or_variable( + self.contract, signatures_ids_implem[k] + ) proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k]) - info = ['Function shadowing found: ', implem_function, - ' ', proxy_function, '\n'] + info = [ + "Function shadowing found: ", + implem_function, + " ", + proxy_function, + "\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/tools/upgradeability/checks/initialization.py b/slither/tools/upgradeability/checks/initialization.py index 3809f294b..2e37457dc 100644 --- a/slither/tools/upgradeability/checks/initialization.py +++ b/slither/tools/upgradeability/checks/initialization.py @@ -13,16 +13,20 @@ class MultipleInitTarget(Exception): def _get_initialize_functions(contract): - return [f for f in contract.functions if f.name == 'initialize' and f.is_implemented] + return [f for f in contract.functions if f.name == "initialize" and f.is_implemented] def _get_all_internal_calls(function): all_ir = function.all_slithir_operations() - return [i.function for i in all_ir if isinstance(i, InternalCall) and i.function_name == "initialize"] + return [ + i.function + for i in all_ir + if isinstance(i, InternalCall) and i.function_name == "initialize" + ] def _get_most_derived_init(contract): - init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == 'initialize'] + init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == "initialize"] if len(init_functions) > 1: if len([f for f in init_functions if f.contract_declarer == contract]) == 1: return next((f for f in init_functions if f.contract_declarer == contract)) @@ -33,80 +37,82 @@ def _get_most_derived_init(contract): class InitializablePresent(AbstractCheck): - ARGUMENT = 'init-missing' + ARGUMENT = "init-missing" IMPACT = CheckClassification.INFORMATIONAL - HELP = 'Initializable is missing' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing' - WIKI_TITLE = 'Initializable is missing' - WIKI_DESCRIPTION = ''' + HELP = "Initializable is missing" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-missing" + WIKI_TITLE = "Initializable is missing" + WIKI_DESCRIPTION = """ Detect if a contract `Initializable` is present. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Review manually the contract's initialization.. Consider using a `Initializable` contract to follow [standard practice](https://docs.openzeppelin.com/upgrades/2.7/writing-upgradeable). -''' +""" def _check(self): - initializable = self.contract.slither.get_contract_from_name('Initializable') + initializable = self.contract.slither.get_contract_from_name("Initializable") if initializable is None: - info = ["Initializable contract not found, the contract does not follow a standard initalization schema.\n"] + info = [ + "Initializable contract not found, the contract does not follow a standard initalization schema.\n" + ] json = self.generate_result(info) return [json] return [] class InitializableInherited(AbstractCheck): - ARGUMENT = 'init-inherited' + ARGUMENT = "init-inherited" IMPACT = CheckClassification.INFORMATIONAL - HELP = 'Initializable is not inherited' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited' - WIKI_TITLE = 'Initializable is not inherited' + HELP = "Initializable is not inherited" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializable-is-not-inherited" + WIKI_TITLE = "Initializable is not inherited" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect if `Initializable` is inherited. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Review manually the contract's initialization. Consider inheriting `Initializable`. -''' +""" REQUIRE_CONTRACT = True def _check(self): - initializable = self.contract.slither.get_contract_from_name('Initializable') + initializable = self.contract.slither.get_contract_from_name("Initializable") # See InitializablePresent if initializable is None: return [] if initializable not in self.contract.inheritance: - info = [self.contract, ' does not inherit from ', initializable, '.\n'] + info = [self.contract, " does not inherit from ", initializable, ".\n"] json = self.generate_result(info) return [json] return [] class InitializableInitializer(AbstractCheck): - ARGUMENT = 'initializer-missing' + ARGUMENT = "initializer-missing" IMPACT = CheckClassification.INFORMATIONAL - HELP = 'initializer() is missing' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing' - WIKI_TITLE = 'initializer() is missing' + HELP = "initializer() is missing" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-missing" + WIKI_TITLE = "initializer() is missing" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect the lack of `Initializable.initializer()` modifier. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Review manually the contract's initialization. Consider inheriting a `Initializable.initializer()` modifier. -''' +""" REQUIRE_CONTRACT = True def _check(self): - initializable = self.contract.slither.get_contract_from_name('Initializable') + initializable = self.contract.slither.get_contract_from_name("Initializable") # See InitializablePresent if initializable is None: return [] @@ -114,26 +120,26 @@ Review manually the contract's initialization. Consider inheriting a `Initializa if initializable not in self.contract.inheritance: return [] - initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()') + initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()") if initializer is None: - info = ['Initializable.initializer() does not exist.\n'] + info = ["Initializable.initializer() does not exist.\n"] json = self.generate_result(info) return [json] return [] class MissingInitializerModifier(AbstractCheck): - ARGUMENT = 'missing-init-modifier' + ARGUMENT = "missing-init-modifier" IMPACT = CheckClassification.HIGH - HELP = 'initializer() is not called' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called' - WIKI_TITLE = 'initializer() is not called' - WIKI_DESCRIPTION = ''' + HELP = "initializer() is not called" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initializer-is-not-called" + WIKI_TITLE = "initializer() is not called" + WIKI_DESCRIPTION = """ Detect if `Initializable.initializer()` is called. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ function initialize() public{ @@ -143,23 +149,23 @@ contract Contract{ ``` `initialize` should have the `initializer` modifier to prevent someone from initializing the contract multiple times. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Use `Initializable.initializer()`. -''' +""" REQUIRE_CONTRACT = True def _check(self): - initializable = self.contract.slither.get_contract_from_name('Initializable') + initializable = self.contract.slither.get_contract_from_name("Initializable") # See InitializablePresent if initializable is None: return [] # See InitializableInherited if initializable not in self.contract.inheritance: return [] - initializer = self.contract.get_modifier_from_canonical_name('Initializable.initializer()') + initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()") # InitializableInitializer if initializer is None: return [] @@ -168,24 +174,24 @@ Use `Initializable.initializer()`. all_init_functions = _get_initialize_functions(self.contract) for f in all_init_functions: if initializer not in f.modifiers: - info = [f, ' does not call the initializer modifier.\n'] + info = [f, " does not call the initializer modifier.\n"] json = self.generate_result(info) results.append(json) return results class MissingCalls(AbstractCheck): - ARGUMENT = 'missing-calls' + ARGUMENT = "missing-calls" IMPACT = CheckClassification.HIGH - HELP = 'Missing calls to init functions' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called' - WIKI_TITLE = 'Initialize functions are not called' - WIKI_DESCRIPTION = ''' + HELP = "Missing calls to init functions" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-not-called" + WIKI_TITLE = "Initialize functions are not called" + WIKI_DESCRIPTION = """ Detect missing calls to initialize functions. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Base{ function initialize() public{ @@ -200,11 +206,11 @@ contract Derived is Base{ ``` `Derived.initialize` does not call `Base.initialize` leading the contract to not be correctly initialized. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Ensure all the initialize functions are reached by the most derived initialize function. -''' +""" REQUIRE_CONTRACT = True @@ -215,7 +221,7 @@ Ensure all the initialize functions are reached by the most derived initialize f try: most_derived_init = _get_most_derived_init(self.contract) except MultipleInitTarget: - logger.error(red(f'Too many init targets in {self.contract}')) + logger.error(red(f"Too many init targets in {self.contract}")) return [] if most_derived_init is None: @@ -225,24 +231,24 @@ Ensure all the initialize functions are reached by the most derived initialize f all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init] missing_calls = [f for f in all_init_functions if not f in all_init_functions_called] for f in missing_calls: - info = ['Missing call to ', f, ' in ', most_derived_init, '.\n'] + info = ["Missing call to ", f, " in ", most_derived_init, ".\n"] json = self.generate_result(info) results.append(json) return results class MultipleCalls(AbstractCheck): - ARGUMENT = 'multiple-calls' + ARGUMENT = "multiple-calls" IMPACT = CheckClassification.HIGH - HELP = 'Init functions called multiple times' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times' - WIKI_TITLE = 'Initialize functions are called multiple times' - WIKI_DESCRIPTION = ''' + HELP = "Init functions called multiple times" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-functions-are-called-multiple-times" + WIKI_TITLE = "Initialize functions are called multiple times" + WIKI_DESCRIPTION = """ Detect multiple calls to a initialize function. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Base{ function initialize(uint) public{ @@ -264,11 +270,11 @@ contract DerivedDerived is Derived{ ``` `Base.initialize(uint)` is called two times in `DerivedDerived.initiliaze` execution, leading to a potential corruption. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Call only one time every initialize function. -''' +""" REQUIRE_CONTRACT = True @@ -280,38 +286,41 @@ Call only one time every initialize function. most_derived_init = _get_most_derived_init(self.contract) except MultipleInitTarget: # Should be already reported by MissingCalls - #logger.error(red(f'Too many init targets in {self.contract}')) + # logger.error(red(f'Too many init targets in {self.contract}')) return [] if most_derived_init is None: return [] all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init] - double_calls = list(set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1])) + double_calls = list( + set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1]) + ) for f in double_calls: - info = [f, ' is called multiple times in ', most_derived_init, '.\n'] + info = [f, " is called multiple times in ", most_derived_init, ".\n"] json = self.generate_result(info) results.append(json) return results + class InitializeTarget(AbstractCheck): - ARGUMENT = 'initialize-target' + ARGUMENT = "initialize-target" IMPACT = CheckClassification.INFORMATIONAL - HELP = 'Initialize function that must be called' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function' - WIKI_TITLE = 'Initialize function' + HELP = "Initialize function that must be called" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#initialize-function" + WIKI_TITLE = "Initialize function" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Show the function that must be called at deployment. This finding does not have an immediate security impact and is informative. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Ensure that the function is called at deployment. -''' +""" REQUIRE_CONTRACT = True @@ -322,12 +331,12 @@ Ensure that the function is called at deployment. most_derived_init = _get_most_derived_init(self.contract) except MultipleInitTarget: # Should be already reported by MissingCalls - #logger.error(red(f'Too many init targets in {self.contract}')) + # logger.error(red(f'Too many init targets in {self.contract}')) return [] if most_derived_init is None: return [] - info = [self.contract, f' needs to be initialized by ', most_derived_init, '.\n'] + info = [self.contract, f" needs to be initialized by ", most_derived_init, ".\n"] json = self.generate_result(info) return [json] diff --git a/slither/tools/upgradeability/checks/variable_initialization.py b/slither/tools/upgradeability/checks/variable_initialization.py index 9c03e0789..7b9316ef7 100644 --- a/slither/tools/upgradeability/checks/variable_initialization.py +++ b/slither/tools/upgradeability/checks/variable_initialization.py @@ -2,29 +2,29 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat class VariableWithInit(AbstractCheck): - ARGUMENT = 'variables-initialized' + ARGUMENT = "variables-initialized" IMPACT = CheckClassification.HIGH - HELP = 'State variables with an initial value' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized' - WIKI_TITLE = 'State variable initialized' + HELP = "State variables with an initial value" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#state-variable-initialized" + WIKI_TITLE = "State variable initialized" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect state variables that are initialized. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable = 10; } ``` Using `Contract` will the delegatecall proxy pattern will lead `variable` to be 0 when called through the proxy. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Using initialize functions to write initial values in state variables. -''' +""" REQUIRE_CONTRACT = True @@ -32,7 +32,7 @@ Using initialize functions to write initial values in state variables. results = [] for s in self.contract.state_variables: if s.initialized and not s.is_constant: - info = [s, ' is a state variable with an initial value.\n'] + info = [s, " is a state variable with an initial value.\n"] json = self.generate_result(info) results.append(json) return results diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index f300b9dca..735a8a1e1 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -2,16 +2,16 @@ from slither.tools.upgradeability.checks.abstract_checks import CheckClassificat class MissingVariable(AbstractCheck): - ARGUMENT = 'missing-variables' + ARGUMENT = "missing-variables" IMPACT = CheckClassification.MEDIUM - HELP = 'Variable missing in the v2' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables' - WIKI_TITLE = 'Missing variables' - WIKI_DESCRIPTION = ''' + HELP = "Variable missing in the v2" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables" + WIKI_TITLE = "Missing variables" + WIKI_DESCRIPTION = """ Detect variables that were present in the original contracts but are not in the updated one. -''' - WIKI_EXPLOIT_SCENARIO = ''' +""" + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract V1{ uint variable1; @@ -25,11 +25,11 @@ contract V2{ The new version, `V2` does not contain `variable1`. If a new variable is added in an update of `V2`, this variable will hold the latest value of `variable2` and will be corrupted. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Do not change the order of the state variables in the updated contract. -''' +""" REQUIRE_CONTRACT = True REQUIRE_CONTRACT_V2 = True @@ -44,7 +44,7 @@ Do not change the order of the state variables in the updated contract. for idx in range(0, len(order1)): variable1 = order1[idx] if len(order2) <= idx: - info = ['Variable missing in ', contract2, ': ', variable1, '\n'] + info = ["Variable missing in ", contract2, ": ", variable1, "\n"] json = self.generate_result(info) results.append(json) @@ -52,18 +52,18 @@ Do not change the order of the state variables in the updated contract. class DifferentVariableContractProxy(AbstractCheck): - ARGUMENT = 'order-vars-proxy' + ARGUMENT = "order-vars-proxy" IMPACT = CheckClassification.HIGH - HELP = 'Incorrect vars order with the proxy' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy' - WIKI_TITLE = 'Incorrect variables with the proxy' + HELP = "Incorrect vars order with the proxy" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-proxy" + WIKI_TITLE = "Incorrect variables with the proxy" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect variables that are different between the contract and the proxy. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -74,11 +74,11 @@ contract Proxy{ } ``` `Contract` and `Proxy` do not have the same storage layout. As a result the storage of both contracts can be corrupted. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = True @@ -104,9 +104,9 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s variable1 = order1[idx] variable2 = order2[idx] if (variable1.name != variable2.name) or (variable1.type != variable2.type): - info = ['Different variables between ', contract1, ' and ', contract2, '\n'] - info += [f'\t ', variable1, '\n'] - info += [f'\t ', variable2, '\n'] + info = ["Different variables between ", contract1, " and ", contract2, "\n"] + info += [f"\t ", variable1, "\n"] + info += [f"\t ", variable2, "\n"] json = self.generate_result(info) results.append(json) @@ -114,17 +114,17 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s class DifferentVariableContractNewContract(DifferentVariableContractProxy): - ARGUMENT = 'order-vars-contracts' + ARGUMENT = "order-vars-contracts" - HELP = 'Incorrect vars order with the v2' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2' - WIKI_TITLE = 'Incorrect variables with the v2' + HELP = "Incorrect vars order with the v2" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#incorrect-variables-with-the-v2" + WIKI_TITLE = "Incorrect variables with the v2" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect variables that are different between the original contract and the updated one. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -135,11 +135,11 @@ contract ContractV2{ } ``` `Contract` and `ContractV2` do not have the same storage layout. As a result the storage of both contracts can be corrupted. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Respect the variable order of the original contract in the updated contract. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = False @@ -150,18 +150,20 @@ Respect the variable order of the original contract in the updated contract. class ExtraVariablesProxy(AbstractCheck): - ARGUMENT = 'extra-vars-proxy' + ARGUMENT = "extra-vars-proxy" IMPACT = CheckClassification.MEDIUM - HELP = 'Extra vars in the proxy' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy' - WIKI_TITLE = 'Extra variables in the proxy' + HELP = "Extra vars in the proxy" + WIKI = ( + "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy" + ) + WIKI_TITLE = "Extra variables in the proxy" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Detect variables that are in the proxy and not in the contract. -''' +""" - WIKI_EXPLOIT_SCENARIO = ''' + WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Contract{ uint variable1; @@ -173,11 +175,11 @@ contract Proxy{ } ``` `Proxy` contains additional variables. A future update of `Contract` is likely to corrupt the proxy. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the same layout than in the contract. -''' +""" REQUIRE_CONTRACT = True REQUIRE_PROXY = True @@ -203,7 +205,7 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s while idx < len(order2): variable2 = order2[idx] - info = ['Extra variables in ', contract2, ': ', variable2, '\n'] + info = ["Extra variables in ", contract2, ": ", variable2, "\n"] json = self.generate_result(info) results.append(json) idx = idx + 1 @@ -212,21 +214,21 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s class ExtraVariablesNewContract(ExtraVariablesProxy): - ARGUMENT = 'extra-vars-v2' + ARGUMENT = "extra-vars-v2" - HELP = 'Extra vars in the v2' - WIKI = 'https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2' - WIKI_TITLE = 'Extra variables in the v2' + HELP = "Extra vars in the v2" + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-v2" + WIKI_TITLE = "Extra variables in the v2" - WIKI_DESCRIPTION = ''' + WIKI_DESCRIPTION = """ Show new variables in the updated contract. This finding does not have an immediate security impact and is informative. -''' +""" - WIKI_RECOMMENDATION = ''' + WIKI_RECOMMENDATION = """ Ensure that all the new variables are expected. -''' +""" IMPACT = CheckClassification.INFORMATIONAL diff --git a/slither/tools/upgradeability/utils/command_line.py b/slither/tools/upgradeability/utils/command_line.py index 2af6daa54..57a6ef88d 100644 --- a/slither/tools/upgradeability/utils/command_line.py +++ b/slither/tools/upgradeability/utils/command_line.py @@ -4,7 +4,9 @@ from slither.utils.myprettytable import MyPrettyTable def output_wiki(detector_classes, filter_wiki): # Sort by impact, confidence, and name - detectors_list = sorted(detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT)) + detectors_list = sorted( + detector_classes, key=lambda element: (element.IMPACT, element.ARGUMENT) + ) for detector in detectors_list: if filter_wiki not in detector.WIKI: @@ -16,16 +18,16 @@ def output_wiki(detector_classes, filter_wiki): exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO recommendation = detector.WIKI_RECOMMENDATION - print('\n## {}'.format(title)) - print('### Configuration') - print('* Check: `{}`'.format(argument)) - print('* Severity: `{}`'.format(impact)) - print('\n### Description') + print("\n## {}".format(title)) + print("### Configuration") + print("* Check: `{}`".format(argument)) + print("* Severity: `{}`".format(impact)) + print("\n### Description") print(description) if exploit_scenario: - print('\n### Exploit Scenario:') + print("\n### Exploit Scenario:") print(exploit_scenario) - print('\n### Recommendation') + print("\n### Recommendation") print(recommendation) @@ -38,27 +40,31 @@ def output_detectors(detector_classes): require_proxy = detector.REQUIRE_PROXY require_v2 = detector.REQUIRE_CONTRACT_V2 detectors_list.append((argument, help_info, impact, require_proxy, require_v2)) - table = MyPrettyTable(["Num", - "Check", - "What it Detects", - "Impact", - "Proxy", - "Contract V2"]) + table = MyPrettyTable(["Num", "Check", "What it Detects", "Impact", "Proxy", "Contract V2"]) # Sort by impact, confidence, and name detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) idx = 1 for (argument, help_info, impact, proxy, v2) in detectors_list: - table.add_row([idx, argument, help_info, classification_txt[impact], 'X' if proxy else '', 'X' if v2 else '']) + table.add_row( + [ + idx, + argument, + help_info, + classification_txt[impact], + "X" if proxy else "", + "X" if v2 else "", + ] + ) idx = idx + 1 print(table) def output_to_markdown(detector_classes, filter_wiki): def extract_help(cls): - if cls.WIKI == '': + if cls.WIKI == "": return cls.HELP - return '[{}]({})'.format(cls.HELP, cls.WIKI) + return "[{}]({})".format(cls.HELP, cls.WIKI) detectors_list = [] for detector in detector_classes: @@ -73,12 +79,16 @@ def output_to_markdown(detector_classes, filter_wiki): detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) idx = 1 for (argument, help_info, impact, proxy, v2) in detectors_list: - print('{} | `{}` | {} | {} | {} | {}'.format(idx, - argument, - help_info, - classification_txt[impact], - 'X' if proxy else '', - 'X' if v2 else '')) + print( + "{} | `{}` | {} | {} | {} | {}".format( + idx, + argument, + help_info, + classification_txt[impact], + "X" if proxy else "", + "X" if v2 else "", + ) + ) idx = idx + 1 @@ -92,26 +102,42 @@ def output_detectors_json(detector_classes): wiki_description = detector.WIKI_DESCRIPTION wiki_exploit_scenario = detector.WIKI_EXPLOIT_SCENARIO wiki_recommendation = detector.WIKI_RECOMMENDATION - detectors_list.append((argument, - help_info, - impact, - wiki_url, - wiki_description, - wiki_exploit_scenario, - wiki_recommendation)) + detectors_list.append( + ( + argument, + help_info, + impact, + wiki_url, + wiki_description, + wiki_exploit_scenario, + wiki_recommendation, + ) + ) # Sort by impact, confidence, and name detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) idx = 1 table = [] - for (argument, help_info, impact, wiki_url, description, exploit, recommendation) in detectors_list: - table.append({'index': idx, - 'check': argument, - 'title': help_info, - 'impact': classification_txt[impact], - 'wiki_url': wiki_url, - 'description': description, - 'exploit_scenario': exploit, - 'recommendation': recommendation}) + for ( + argument, + help_info, + impact, + wiki_url, + description, + exploit, + recommendation, + ) in detectors_list: + table.append( + { + "index": idx, + "check": argument, + "title": help_info, + "impact": classification_txt[impact], + "wiki_url": wiki_url, + "description": description, + "exploit_scenario": exploit, + "recommendation": recommendation, + } + ) idx = idx + 1 return table diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 372a383b5..4a82ccfae 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -3,11 +3,13 @@ import logging from .expression import ExpressionVisitor from slither.core.expressions import BinaryOperationType, Literal + class NotConstant(Exception): pass -KEY = 'ConstantFolding' +KEY = "ConstantFolding" + def get_val(expression): val = expression.context[KEY] @@ -15,11 +17,12 @@ def get_val(expression): del expression.context[KEY] return val + def set_val(expression, val): expression.context[KEY] = val -class ConstantFolding(ExpressionVisitor): +class ConstantFolding(ExpressionVisitor): def __init__(self, expression, type): self._type = type super(ConstantFolding, self).__init__(expression) @@ -40,24 +43,24 @@ class ConstantFolding(ExpressionVisitor): def _post_binary_operation(self, expression): left = get_val(expression.expression_left) right = get_val(expression.expression_right) - if expression.type == BinaryOperationType.POWER: + if expression.type == BinaryOperationType.POWER: set_val(expression, left ** right) - elif expression.type == BinaryOperationType.MULTIPLICATION: + elif expression.type == BinaryOperationType.MULTIPLICATION: set_val(expression, left * right) - elif expression.type == BinaryOperationType.DIVISION: + elif expression.type == BinaryOperationType.DIVISION: set_val(expression, left / right) - elif expression.type == BinaryOperationType.MODULO: + elif expression.type == BinaryOperationType.MODULO: set_val(expression, left % right) - elif expression.type == BinaryOperationType.ADDITION: + elif expression.type == BinaryOperationType.ADDITION: set_val(expression, left + right) - elif expression.type == BinaryOperationType.SUBTRACTION: - if(left-right) <0: + elif expression.type == BinaryOperationType.SUBTRACTION: + if (left - right) < 0: # Could trigger underflow raise NotConstant set_val(expression, left - right) - elif expression.type == BinaryOperationType.LEFT_SHIFT: + elif expression.type == BinaryOperationType.LEFT_SHIFT: set_val(expression, left << right) - elif expression.type == BinaryOperationType.RIGHT_SHIFT: + elif expression.type == BinaryOperationType.RIGHT_SHIFT: set_val(expression, left >> right) else: raise NotConstant @@ -110,6 +113,3 @@ class ConstantFolding(ExpressionVisitor): def _post_type_conversion(self, expression): raise NotConstant - - - diff --git a/slither/visitors/expression/export_values.py b/slither/visitors/expression/export_values.py index 15afb4dfa..f478a83e1 100644 --- a/slither/visitors/expression/export_values.py +++ b/slither/visitors/expression/export_values.py @@ -1,11 +1,11 @@ - from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import AssignmentOperationType from slither.core.variables.variable import Variable -key = 'ExportValues' +key = "ExportValues" + def get(expression): val = expression.context[key] @@ -13,11 +13,12 @@ def get(expression): del expression.context[key] return val + def set_val(expression, val): expression.context[key] = val -class ExportValues(ExpressionVisitor): +class ExportValues(ExpressionVisitor): def result(self): if self._result is None: self._result = list(set(get(self.expression))) diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index c0f1b35d4..01cdf81d9 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -1,10 +1,12 @@ import logging +from typing import Any from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.conditional_expression import ConditionalExpression from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.expression import Expression from slither.core.expressions.identifier import Identifier from slither.core.expressions.index_access import IndexAccess from slither.core.expressions.literal import Literal @@ -19,24 +21,24 @@ from slither.exceptions import SlitherError logger = logging.getLogger("ExpressionVisitor") -class ExpressionVisitor: - def __init__(self, expression): +class ExpressionVisitor: + def __init__(self, expression: Expression): # Inherited class must declared their variables prior calling super().__init__ self._expression = expression - self._result = None + self._result: Any = None self._visit_expression(self.expression) def result(self): return self._result @property - def expression(self): + def expression(self) -> Expression: return self._expression # visit an expression # call pre_visit, visit_expression_name, post_visit - def _visit_expression(self, expression): + def _visit_expression(self, expression: Expression): self._pre_visit(expression) if isinstance(expression, AssignmentOperation): @@ -88,7 +90,7 @@ class ExpressionVisitor: pass else: - raise SlitherError('Expression not handled: {}'.format(expression)) + raise SlitherError("Expression not handled: {}".format(expression)) self._post_visit(expression) @@ -207,7 +209,7 @@ class ExpressionVisitor: pass else: - raise SlitherError('Expression not handled: {}'.format(expression)) + raise SlitherError("Expression not handled: {}".format(expression)) # pre_expression_name @@ -308,7 +310,7 @@ class ExpressionVisitor: pass else: - raise SlitherError('Expression not handled: {}'.format(expression)) + raise SlitherError("Expression not handled: {}".format(expression)) # post_expression_name @@ -356,5 +358,3 @@ class ExpressionVisitor: def _post_unary_operation(self, expression): pass - - diff --git a/slither/visitors/expression/expression_printer.py b/slither/visitors/expression/expression_printer.py index dce527bbf..0487cd0b2 100644 --- a/slither/visitors/expression/expression_printer.py +++ b/slither/visitors/expression/expression_printer.py @@ -1,17 +1,18 @@ - from slither.visitors.expression.expression import ExpressionVisitor + def get(expression): - val = expression.context['ExpressionPrinter'] + val = expression.context["ExpressionPrinter"] # we delete the item to reduce memory use - del expression.context['ExpressionPrinter'] + del expression.context["ExpressionPrinter"] return val + def set_val(expression, val): - expression.context['ExpressionPrinter'] = val + expression.context["ExpressionPrinter"] = val -class ExpressionPrinter(ExpressionVisitor): +class ExpressionPrinter(ExpressionVisitor): def result(self): if not self._result: self._result = get(self.expression) @@ -20,19 +21,19 @@ class ExpressionPrinter(ExpressionVisitor): def _post_assignement_operation(self, expression): left = get(expression.expression_left) right = get(expression.expression_right) - val = "{} {} {}".format(left, expression.type_str, right) + val = "{} {} {}".format(left, expression.type, right) set_val(expression, val) def _post_binary_operation(self, expression): left = get(expression.expression_left) right = get(expression.expression_right) - val = "{} {} {}".format(left, expression.type_str, right) + val = "{} {} {}".format(left, expression.type, right) set_val(expression, val) def _post_call_expression(self, expression): called = get(expression.called) arguments = [get(x) for x in expression.arguments if x] - val = "{}({})".format(called, ','.join(arguments)) + val = "{}({})".format(called, ",".join(arguments)) set_val(expression, val) def _post_conditional_expression(self, expression): @@ -66,7 +67,7 @@ class ExpressionPrinter(ExpressionVisitor): def _post_new_array(self, expression): array = str(expression.array_type) depth = expression.depth - val = "new {}{}".format(array, '[]'*depth) + val = "new {}{}".format(array, "[]" * depth) set_val(expression, val) def _post_new_contract(self, expression): @@ -81,7 +82,7 @@ class ExpressionPrinter(ExpressionVisitor): def _post_tuple_expression(self, expression): expressions = [get(e) for e in expression.expressions if e] - val = "({})".format(','.join(expressions)) + val = "({})".format(",".join(expressions)) set_val(expression, val) def _post_type_conversion(self, expression): diff --git a/slither/visitors/expression/find_calls.py b/slither/visitors/expression/find_calls.py index 6af072329..9b9141c76 100644 --- a/slither/visitors/expression/find_calls.py +++ b/slither/visitors/expression/find_calls.py @@ -1,11 +1,10 @@ +from typing import List +from slither.core.expressions.expression import Expression from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignment_operation import AssignmentOperationType +key = "FindCall" -from slither.core.variables.variable import Variable - -key = 'FindCall' def get(expression): val = expression.context[key] @@ -13,12 +12,13 @@ def get(expression): del expression.context[key] return val + def set_val(expression, val): expression.context[key] = val -class FindCalls(ExpressionVisitor): - def result(self): +class FindCalls(ExpressionVisitor): + def result(self) -> List[Expression]: if self._result is None: self._result = list(set(get(self.expression))) return self._result diff --git a/slither/visitors/expression/find_push.py b/slither/visitors/expression/find_push.py index 76e0d7a56..b3e79b0c4 100644 --- a/slither/visitors/expression/find_push.py +++ b/slither/visitors/expression/find_push.py @@ -4,7 +4,8 @@ from slither.core.expressions.index_access import IndexAccess from slither.visitors.expression.right_value import RightValue -key = 'FindPush' +key = "FindPush" + def get(expression): val = expression.context[key] @@ -12,11 +13,12 @@ def get(expression): del expression.context[key] return val + def set_val(expression, val): expression.context[key] = val -class FindPush(ExpressionVisitor): +class FindPush(ExpressionVisitor): def result(self): if self._result is None: self._result = list(set(get(self.expression))) @@ -66,7 +68,7 @@ class FindPush(ExpressionVisitor): def _post_member_access(self, expression): val = [] - if expression.member_name == 'push': + if expression.member_name == "push": right = RightValue(expression.expression) val = right.result() set_val(expression, val) diff --git a/slither/visitors/expression/has_conditional.py b/slither/visitors/expression/has_conditional.py index 5378e4d98..906f522ff 100644 --- a/slither/visitors/expression/has_conditional.py +++ b/slither/visitors/expression/has_conditional.py @@ -1,13 +1,12 @@ - from slither.visitors.expression.expression import ExpressionVisitor -class HasConditional(ExpressionVisitor): +class HasConditional(ExpressionVisitor): def result(self): # == True, to convert None to false return self._result is True def _post_conditional_expression(self, expression): -# if self._result is True: -# raise('Slither does not support nested ternary operator') + # if self._result is True: + # raise('Slither does not support nested ternary operator') self._result = True diff --git a/slither/visitors/expression/left_value.py b/slither/visitors/expression/left_value.py index c23c3c06c..3b16c8c26 100644 --- a/slither/visitors/expression/left_value.py +++ b/slither/visitors/expression/left_value.py @@ -6,7 +6,8 @@ from slither.core.expressions.assignment_operation import AssignmentOperationTyp from slither.core.variables.variable import Variable -key = 'LeftValue' +key = "LeftValue" + def get(expression): val = expression.context[key] @@ -14,11 +15,12 @@ def get(expression): del expression.context[key] return val + def set_val(expression, val): expression.context[key] = val -class LeftValue(ExpressionVisitor): +class LeftValue(ExpressionVisitor): def result(self): if self._result is None: self._result = list(set(get(self.expression))) @@ -64,8 +66,8 @@ class LeftValue(ExpressionVisitor): def _post_identifier(self, expression): if isinstance(expression.value, Variable): set_val(expression, [expression.value]) -# elif isinstance(expression.value, SolidityInbuilt): -# set_val(expression, [expression]) + # elif isinstance(expression.value, SolidityInbuilt): + # set_val(expression, [expression]) else: set_val(expression, []) diff --git a/slither/visitors/expression/read_var.py b/slither/visitors/expression/read_var.py index ae4882d84..8fe063a2f 100644 --- a/slither/visitors/expression/read_var.py +++ b/slither/visitors/expression/read_var.py @@ -1,4 +1,3 @@ - from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import AssignmentOperationType @@ -6,7 +5,8 @@ from slither.core.expressions.assignment_operation import AssignmentOperationTyp from slither.core.variables.variable import Variable from slither.core.declarations.solidity_variables import SolidityVariable -key = 'ReadVar' +key = "ReadVar" + def get(expression): val = expression.context[key] @@ -14,17 +14,17 @@ def get(expression): del expression.context[key] return val + def set_val(expression, val): expression.context[key] = val -class ReadVar(ExpressionVisitor): +class ReadVar(ExpressionVisitor): def result(self): if self._result is None: self._result = list(set(get(self.expression))) return self._result - # overide assignement # dont explore if its direct assignement (we explore if its +=, -=, ...) def _visit_assignement_operation(self, expression): diff --git a/slither/visitors/expression/right_value.py b/slither/visitors/expression/right_value.py index 718ed392a..5a97846bc 100644 --- a/slither/visitors/expression/right_value.py +++ b/slither/visitors/expression/right_value.py @@ -10,7 +10,8 @@ from slither.core.expressions.expression import Expression from slither.core.variables.variable import Variable -key = 'RightValue' +key = "RightValue" + def get(expression): val = expression.context[key] @@ -18,11 +19,12 @@ def get(expression): del expression.context[key] return val + def set_val(expression, val): expression.context[key] = val -class RightValue(ExpressionVisitor): +class RightValue(ExpressionVisitor): def result(self): if self._result is None: self._result = list(set(get(self.expression))) @@ -68,8 +70,8 @@ class RightValue(ExpressionVisitor): def _post_identifier(self, expression): if isinstance(expression.value, Variable): set_val(expression, [expression.value]) -# elif isinstance(expression.value, SolidityInbuilt): -# set_val(expression, [expression]) + # elif isinstance(expression.value, SolidityInbuilt): + # set_val(expression, [expression]) else: set_val(expression, []) diff --git a/slither/visitors/expression/write_var.py b/slither/visitors/expression/write_var.py index 0368b4ad1..7267415a6 100644 --- a/slither/visitors/expression/write_var.py +++ b/slither/visitors/expression/write_var.py @@ -1,4 +1,3 @@ - from slither.visitors.expression.expression import ExpressionVisitor from slither.core.expressions.assignment_operation import AssignmentOperation @@ -10,7 +9,8 @@ from slither.core.expressions.member_access import MemberAccess from slither.core.expressions.index_access import IndexAccess -key = 'WriteVar' +key = "WriteVar" + def get(expression): val = expression.context[key] @@ -18,11 +18,12 @@ def get(expression): del expression.context[key] return val + def set_val(expression, val): expression.context[key] = val -class WriteVar(ExpressionVisitor): +class WriteVar(ExpressionVisitor): def result(self): if self._result is None: self._result = list(set(get(self.expression))) @@ -71,27 +72,28 @@ class WriteVar(ExpressionVisitor): set_val(expression, [expression]) else: set_val(expression, []) -# if isinstance(expression.value, Variable): -# set_val(expression, [expression.value]) -# else: -# set_val(expression, []) + + # if isinstance(expression.value, Variable): + # set_val(expression, [expression.value]) + # else: + # set_val(expression, []) def _post_index_access(self, expression): left = get(expression.expression_left) right = get(expression.expression_right) val = left + right if expression.is_lvalue: - # val += [expression] + # val += [expression] val += [expression.expression_left] - # n = expression.expression_left - # parse all the a.b[..].c[..] - # while isinstance(n, (IndexAccess, MemberAccess)): - # if isinstance(n, IndexAccess): - # val += [n.expression_left] - # n = n.expression_left - # else: - # val += [n.expression] - # n = n.expression + # n = expression.expression_left + # parse all the a.b[..].c[..] + # while isinstance(n, (IndexAccess, MemberAccess)): + # if isinstance(n, IndexAccess): + # val += [n.expression_left] + # n = n.expression_left + # else: + # val += [n.expression] + # n = n.expression set_val(expression, val) def _post_literal(self, expression):