diff --git a/examples/scripts/call_graph.py b/examples/scripts/call_graph.py deleted file mode 100644 index b19745cde..000000000 --- a/examples/scripts/call_graph.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -import logging -import argparse -from slither import Slither -from slither.printers.all_printers import PrinterCallGraph -from slither.core.declarations.function import Function - -logging.basicConfig() -logging.getLogger("Slither").setLevel(logging.INFO) -logging.getLogger("Printers").setLevel(logging.INFO) - - -class PrinterCallGraphStateChange(PrinterCallGraph): - def _process_function( - self, - contract, - function, - contract_functions, - contract_calls, - solidity_functions, - solidity_calls, - external_calls, - all_contracts, - ): - if function.view or function.pure: - return - super()._process_function( - contract, - function, - contract_functions, - contract_calls, - solidity_functions, - solidity_calls, - external_calls, - all_contracts, - ) - - def _process_internal_call( - self, contract, function, internal_call, contract_calls, solidity_functions, solidity_calls - ): - if isinstance(internal_call, Function): - if internal_call.view or internal_call.pure: - return - super()._process_internal_call( - contract, function, internal_call, contract_calls, solidity_functions, solidity_calls - ) - - def _process_external_call( - self, contract, function, external_call, contract_functions, external_calls, all_contracts - ): - if isinstance(external_call[1], Function): - if external_call[1].view or external_call[1].pure: - return - super()._process_external_call( - contract, function, external_call, contract_functions, external_calls, all_contracts - ) - - -def parse_args(): - """ - """ - parser = argparse.ArgumentParser( - description="Call graph printer. Similar to --print call-graph, but without printing the view/pure functions", - usage="call_graph.py filename", - ) - - parser.add_argument( - "filename", help="The filename of the contract or truffle directory to analyze." - ) - - parser.add_argument("--solc", help="solc path", default="solc") - - return parser.parse_args() - - -def main(): - - args = parse_args() - slither = Slither(args.filename, is_truffle=os.path.isdir(args.filename), solc=args.solc) - - slither.register_printer(PrinterCallGraphStateChange) - - slither.run_printers() - - -if __name__ == "__main__": - main() diff --git a/examples/scripts/convert_to_evm_ins.py b/examples/scripts/convert_to_evm_ins.py deleted file mode 100644 index 4c98af58f..000000000 --- a/examples/scripts/convert_to_evm_ins.py +++ /dev/null @@ -1,39 +0,0 @@ -import sys -from slither.slither import Slither -from slither.evm.convert import SourceToEVM - -if len(sys.argv) != 2: - print("python3 function_called.py functions_called.sol") - exit(-1) - -# Init slither -slither = Slither(sys.argv[1]) - -# Get the contract evm instructions -contract = slither.get_contract_from_name("Test") -contract_ins = SourceToEVM.get_evm_instructions(contract) -print("## Contract evm instructions: {} ##".format(contract.name)) -for ins in contract_ins: - print(str(ins)) - -# Get the constructor evm instructions -constructor = contract.constructor -print("## Function evm instructions: {} ##".format(constructor.name)) -constructor_ins = SourceToEVM.get_evm_instructions(constructor) -for ins in constructor_ins: - print(str(ins)) - -# Get the function evm instructions -function = contract.get_function_from_signature("foo()") -print("## Function evm instructions: {} ##".format(function.name)) -function_ins = SourceToEVM.get_evm_instructions(function) -for ins in function_ins: - print(str(ins)) - -# Get the node evm instructions -nodes = function.nodes -for node in nodes: - node_ins = SourceToEVM.get_evm_instructions(node) - print("Node evm instructions: {}".format(str(node))) - for ins in node_ins: - print(str(ins)) diff --git a/examples/scripts/convert_to_ir.py b/examples/scripts/convert_to_ir.py index 9ab9d51f2..583170bd6 100644 --- a/examples/scripts/convert_to_ir.py +++ b/examples/scripts/convert_to_ir.py @@ -5,7 +5,7 @@ from slither.slithir.convert import convert_expression if len(sys.argv) != 2: print("python function_called.py functions_called.sol") - exit(-1) + sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) @@ -21,7 +21,7 @@ nodes = test.nodes for node in nodes: if node.expression: print("Expression:\n\t{}".format(node.expression)) - irs = convert_expression(node.expression) + irs = convert_expression(node.expression, node) print("IR expressions:") for ir in irs: print("\t{}".format(ir)) diff --git a/examples/scripts/data_dependency.py b/examples/scripts/data_dependency.py index e28d8c7a2..7dbb7938f 100644 --- a/examples/scripts/data_dependency.py +++ b/examples/scripts/data_dependency.py @@ -1,15 +1,15 @@ import sys + from slither import Slither from slither.analyses.data_dependency.data_dependency import ( is_dependent, is_tainted, - pprint_dependency, ) from slither.core.declarations.solidity_variables import SolidityVariableComposed if len(sys.argv) != 2: print("Usage: python data_dependency.py file.sol") - exit(-1) + sys.exit(-1) slither = Slither(sys.argv[1]) diff --git a/examples/scripts/export_dominator_tree_to_dot.py b/examples/scripts/export_dominator_tree_to_dot.py index f53592ca4..695450d1f 100644 --- a/examples/scripts/export_dominator_tree_to_dot.py +++ b/examples/scripts/export_dominator_tree_to_dot.py @@ -4,7 +4,7 @@ from slither.slither import Slither if len(sys.argv) != 2: print("python export_dominator_tree_to_dot.py contract.sol") - exit(-1) + sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) diff --git a/examples/scripts/export_to_dot.py b/examples/scripts/export_to_dot.py index a2ee35333..9734a6d98 100644 --- a/examples/scripts/export_to_dot.py +++ b/examples/scripts/export_to_dot.py @@ -4,7 +4,7 @@ from slither.slither import Slither if len(sys.argv) != 2: print("python function_called.py contract.sol") - exit(-1) + sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) diff --git a/examples/scripts/functions_called.py b/examples/scripts/functions_called.py index 6bec75105..9aaa3b653 100644 --- a/examples/scripts/functions_called.py +++ b/examples/scripts/functions_called.py @@ -3,7 +3,7 @@ from slither.slither import Slither if len(sys.argv) != 2: print("python functions_called.py functions_called.sol") - exit(-1) + sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) diff --git a/examples/scripts/functions_writing.py b/examples/scripts/functions_writing.py index 5a9253811..1fd83e0e4 100644 --- a/examples/scripts/functions_writing.py +++ b/examples/scripts/functions_writing.py @@ -3,7 +3,7 @@ from slither.slither import Slither if len(sys.argv) != 2: print("python function_writing.py functions_writing.sol") - exit(-1) + sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) diff --git a/examples/scripts/possible_paths.py b/examples/scripts/possible_paths.py index c20ec0b1e..b84d3651f 100644 --- a/examples/scripts/possible_paths.py +++ b/examples/scripts/possible_paths.py @@ -84,8 +84,8 @@ def all_function_definitions(function): ] -def __find_target_paths(target_function, current_path=[]): - +def __find_target_paths(target_function, current_path=None): + current_path = current_path if current_path else [] # Create our results list results = set() @@ -184,17 +184,17 @@ slither = Slither(args.filename, is_truffle=args.is_truffle) targets = resolve_functions(args.targets) # Print out all target functions. -print(f"Target functions:") +print("Target functions:") for target in targets: print(f"-{target.contract.name}.{target.full_name}") print("\n") # Obtain all paths which reach the target functions. reaching_paths = find_target_paths(targets) -reaching_functions = set([y for x in reaching_paths for y in x if y not in targets]) +reaching_functions = {y for x in reaching_paths for y in x if y not in targets} # Print out all function names which can reach the targets. -print(f"The following functions reach the specified targets:") +print("The following functions reach the specified targets:") for function_desc in sorted([f"{f.canonical_name}" for f in reaching_functions]): print(f"-{function_desc}") print("\n") @@ -205,6 +205,6 @@ reaching_paths_str = [ ] # Print a sorted list of all function paths which can reach the targets. -print(f"The following paths reach the specified targets:") +print("The following paths reach the specified targets:") for reaching_path in sorted(reaching_paths_str): print(f"{reaching_path}\n") diff --git a/examples/scripts/slithIR.py b/examples/scripts/slithIR.py index 6d74833c5..5cf14fbb6 100644 --- a/examples/scripts/slithIR.py +++ b/examples/scripts/slithIR.py @@ -3,7 +3,7 @@ from slither import Slither if len(sys.argv) != 2: print("python slithIR.py contract.sol") - exit(-1) + sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) diff --git a/examples/scripts/taint_mapping.py b/examples/scripts/taint_mapping.py index 922d739c8..c1754e635 100644 --- a/examples/scripts/taint_mapping.py +++ b/examples/scripts/taint_mapping.py @@ -58,7 +58,7 @@ def check_call(func, taints): if __name__ == "__main__": if len(sys.argv) != 2: print("python taint_mapping.py taint.sol") - exit(-1) + sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) @@ -79,11 +79,11 @@ if __name__ == "__main__": visit_node(function.entry_point, []) print("All variables tainted : {}".format([str(v) for v in slither.context[KEY]])) + for function in contract.functions: + check_call(function, slither.context[KEY]) + print( "All state variables tainted : {}".format( [str(v) for v in prev_taints if isinstance(v, StateVariable)] ) ) - - for function in contract.functions: - check_call(function, slither.context[KEY]) diff --git a/examples/scripts/variable_in_condition.py b/examples/scripts/variable_in_condition.py index 5bf022ed8..91efded15 100644 --- a/examples/scripts/variable_in_condition.py +++ b/examples/scripts/variable_in_condition.py @@ -3,7 +3,7 @@ from slither.slither import Slither if len(sys.argv) != 2: print("python variable_in_condition.py variable_in_condition.sol") - exit(-1) + sys.exit(-1) # Init slither slither = Slither(sys.argv[1]) diff --git a/pyproject.toml b/pyproject.toml index f5d8f35e5..3bf575997 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,5 +9,12 @@ missing-function-docstring, unnecessary-lambda, bad-continuation, cyclic-import, -line-too-long +line-too-long, +invalid-name, +fixme, +too-many-return-statements, +too-many-ancestors, +logging-fstring-interpolation, +logging-not-lazy, +duplicate-code """ diff --git a/scripts/json_diff.py b/scripts/json_diff.py index 9b1844906..9422f6b8d 100644 --- a/scripts/json_diff.py +++ b/scripts/json_diff.py @@ -1,11 +1,12 @@ import sys import json -from deepdiff import DeepDiff # pip install deepdiff from pprint import pprint +from deepdiff import DeepDiff # pip install deepdiff + if len(sys.argv) != 3: print("Usage: python json_diff.py 1.json 2.json") - exit(-1) + sys.exit(-1) with open(sys.argv[1], encoding="utf8") as f: d1 = json.load(f) diff --git a/slither/__main__.py b/slither/__main__.py index 198a1f304..3fbd2b652 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -10,8 +10,10 @@ import sys import traceback from pkg_resources import iter_entry_points, require + from crytic_compile import cryticparser from crytic_compile.platform.standard import generate_standard_export +from crytic_compile import compile_all, is_supported from slither.detectors import all_detectors from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification @@ -34,7 +36,6 @@ from slither.utils.command_line import ( JSON_OUTPUT_TYPES, DEFAULT_JSON_OUTPUT_TYPES, ) -from crytic_compile import compile_all, is_supported from slither.exceptions import SlitherException logging.basicConfig() @@ -80,7 +81,12 @@ def process_all(target, args, detector_classes, printer_classes): results_printers.extend(current_results_printers) slither_instances.append(slither) analyzed_contracts_count += current_analyzed_count - return slither_instances, results_detectors, results_printers, analyzed_contracts_count + return ( + slither_instances, + results_detectors, + results_printers, + analyzed_contracts_count, + ) def _process(slither, detector_classes, printer_classes): @@ -98,7 +104,9 @@ def _process(slither, detector_classes, printer_classes): if not printer_classes: detector_results = slither.run_detectors() detector_results = [x for x in detector_results if x] # remove empty results - detector_results = [item for sublist in detector_results for item in sublist] # flatten + detector_results = [ + item for sublist in detector_results for item in sublist + ] # flatten results_detectors.extend(detector_results) else: @@ -113,8 +121,8 @@ def process_from_asts(filenames, args, detector_classes, printer_classes): all_contracts = [] for filename in filenames: - with open(filename, encoding="utf8") as f: - contract_loaded = json.load(f) + with open(filename, encoding="utf8") as file_open: + contract_loaded = json.load(file_open) all_contracts.append(contract_loaded["ast"]) return process_single(all_contracts, args, detector_classes, printer_classes) @@ -128,7 +136,7 @@ def process_from_asts(filenames, args, detector_classes, printer_classes): ################################################################################### -def exit(results): +def my_exit(results): if not results: sys.exit(0) sys.exit(len(results)) @@ -148,10 +156,14 @@ def get_detectors_and_printers(): """ detectors = [getattr(all_detectors, name) for name in dir(all_detectors)] - detectors = [d for d in detectors if inspect.isclass(d) and issubclass(d, AbstractDetector)] + detectors = [ + d for d in detectors if inspect.isclass(d) and issubclass(d, AbstractDetector) + ] printers = [getattr(all_printers, name) for name in dir(all_printers)] - printers = [p for p in printers if inspect.isclass(p) and issubclass(p, AbstractPrinter)] + printers = [ + p for p in printers if inspect.isclass(p) and issubclass(p, AbstractPrinter) + ] # Handle plugins! for entry_point in iter_entry_points(group="slither_analyzer.plugin", name=None): @@ -159,11 +171,16 @@ def get_detectors_and_printers(): plugin_detectors, plugin_printers = make_plugin() - if not all(issubclass(d, AbstractDetector) for d in plugin_detectors): - raise Exception("Error when loading plugin %s, %r is not a detector" % (entry_point, d)) - - if not all(issubclass(p, AbstractPrinter) for p in plugin_printers): - raise Exception("Error when loading plugin %s, %r is not a printer" % (entry_point, p)) + detector = None + if not all(issubclass(detector, AbstractDetector) for detector in plugin_detectors): + raise Exception( + "Error when loading plugin %s, %r is not a detector" % (entry_point, detector) + ) + printer = None + if not all(issubclass(printer, AbstractPrinter) for printer in plugin_printers): + raise Exception( + "Error when loading plugin %s, %r is not a printer" % (entry_point, printer) + ) # We convert those to lists in case someone returns a tuple detectors += list(plugin_detectors) @@ -171,7 +188,7 @@ def get_detectors_and_printers(): return detectors, printers - +# pylint: disable=too-many-branches def choose_detectors(args, all_detector_classes): # If detectors are specified, run only these ones @@ -182,35 +199,43 @@ def choose_detectors(args, all_detector_classes): detectors_to_run = all_detector_classes if args.detectors_to_exclude: detectors_excluded = args.detectors_to_exclude.split(",") - for d in detectors: - if d in detectors_excluded: - detectors_to_run.remove(detectors[d]) + for detector in detectors: + if detector in detectors_excluded: + detectors_to_run.remove(detectors[detector]) else: - for d in args.detectors_to_run.split(","): - if d in detectors: - detectors_to_run.append(detectors[d]) + for detector in args.detectors_to_run.split(","): + if detector in detectors: + detectors_to_run.append(detectors[detector]) else: - raise Exception("Error: {} is not a detector".format(d)) + raise Exception("Error: {} is not a detector".format(detector)) detectors_to_run = sorted(detectors_to_run, key=lambda x: x.IMPACT) return detectors_to_run if args.exclude_optimization: detectors_to_run = [ - d for d in detectors_to_run if d.IMPACT != DetectorClassification.OPTIMIZATION + d + for d in detectors_to_run + if d.IMPACT != DetectorClassification.OPTIMIZATION ] if args.exclude_informational: detectors_to_run = [ - d for d in detectors_to_run if d.IMPACT != DetectorClassification.INFORMATIONAL + d + for d in detectors_to_run + if d.IMPACT != DetectorClassification.INFORMATIONAL ] if args.exclude_low: - detectors_to_run = [d for d in detectors_to_run if d.IMPACT != DetectorClassification.LOW] + detectors_to_run = [ + d for d in detectors_to_run if d.IMPACT != DetectorClassification.LOW + ] if args.exclude_medium: detectors_to_run = [ d for d in detectors_to_run if d.IMPACT != DetectorClassification.MEDIUM ] if args.exclude_high: - detectors_to_run = [d for d in detectors_to_run if d.IMPACT != DetectorClassification.HIGH] + detectors_to_run = [ + d for d in detectors_to_run if d.IMPACT != DetectorClassification.HIGH + ] if args.detectors_to_exclude: detectors_to_run = [ d for d in detectors_to_run if d.ARGUMENT not in args.detectors_to_exclude @@ -232,11 +257,11 @@ def choose_printers(args, all_printer_classes): return all_printer_classes printers = {p.ARGUMENT: p for p in all_printer_classes} - for p in args.printers_to_run.split(","): - if p in printers: - printers_to_run.append(printers[p]) + for printer in args.printers_to_run.split(","): + if printer in printers: + printers_to_run.append(printers[printer]) else: - raise Exception("Error: {} is not a printer".format(p)) + raise Exception("Error: {} is not a printer".format(printer)) return printers_to_run @@ -278,7 +303,9 @@ def parse_args(detector_classes, printer_classes): group_detector.add_argument( "--detect", help="Comma-separated list of detectors, defaults to all, " - "available detectors: {}".format(", ".join(d.ARGUMENT for d in detector_classes)), + "available detectors: {}".format( + ", ".join(d.ARGUMENT for d in detector_classes) + ), action="store", dest="detectors_to_run", default=defaults_flag_in_config["detectors_to_run"], @@ -287,7 +314,7 @@ def parse_args(detector_classes, printer_classes): group_printer.add_argument( "--print", help="Comma-separated list fo contract information printers, " - "available printers: {}".format(", ".join(d.ARGUMENT for d in printer_classes)), + "available printers: {}".format(", ".join(d.ARGUMENT for d in printer_classes)), action="store", dest="printers_to_run", default=defaults_flag_in_config["printers_to_run"], @@ -368,9 +395,9 @@ def parse_args(detector_classes, printer_classes): group_misc.add_argument( "--json-types", - help=f"Comma-separated list of result types to output to JSON, defaults to " - + f'{",".join(output_type for output_type in DEFAULT_JSON_OUTPUT_TYPES)}. ' - + f'Available types: {",".join(output_type for output_type in JSON_OUTPUT_TYPES)}', + help="Comma-separated list of result types to output to JSON, defaults to " + + f'{",".join(output_type for output_type in DEFAULT_JSON_OUTPUT_TYPES)}. ' + + f'Available types: {",".join(output_type for output_type in JSON_OUTPUT_TYPES)}', action="store", default=defaults_flag_in_config["json-types"], ) @@ -390,7 +417,10 @@ def parse_args(detector_classes, printer_classes): ) group_misc.add_argument( - "--markdown-root", help="URL for markdown generation", action="store", default="" + "--markdown-root", + help="URL for markdown generation", + action="store", + default="", ) group_misc.add_argument( @@ -425,7 +455,10 @@ def parse_args(detector_classes, printer_classes): ) group_misc.add_argument( - "--solc-ast", help="Provide the contract as a json AST", action="store_true", default=False + "--solc-ast", + help="Provide the contract as a json AST", + action="store_true", + default=False, ) group_misc.add_argument( @@ -436,9 +469,13 @@ def parse_args(detector_classes, printer_classes): ) # debugger command - parser.add_argument("--debug", help=argparse.SUPPRESS, action="store_true", default=False) + parser.add_argument( + "--debug", help=argparse.SUPPRESS, action="store_true", default=False + ) - parser.add_argument("--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False) + parser.add_argument( + "--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False + ) group_misc.add_argument( "--checklist", help=argparse.SUPPRESS, action="store_true", default=False @@ -471,7 +508,9 @@ def parse_args(detector_classes, printer_classes): ) # if the json is splitted in different files - parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False) + parser.add_argument( + "--splitted", help=argparse.SUPPRESS, action="store_true", default=False + ) # Disable the throw/catch on partial analyses parser.add_argument( @@ -491,41 +530,43 @@ def parse_args(detector_classes, printer_classes): args.json_types = set(args.json_types.split(",")) for json_type in args.json_types: if json_type not in JSON_OUTPUT_TYPES: - raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.') + raise Exception( + f'Error: "{json_type}" is not a valid JSON result output type.' + ) return args -class ListDetectors(argparse.Action): - def __call__(self, parser, *args, **kwargs): +class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs detectors, _ = get_detectors_and_printers() output_detectors(detectors) parser.exit() -class ListDetectorsJson(argparse.Action): - def __call__(self, parser, *args, **kwargs): +class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs detectors, _ = get_detectors_and_printers() detector_types_json = output_detectors_json(detectors) print(json.dumps(detector_types_json)) parser.exit() -class ListPrinters(argparse.Action): - def __call__(self, parser, *args, **kwargs): +class ListPrinters(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs _, printers = get_detectors_and_printers() output_printers(printers) parser.exit() -class OutputMarkdown(argparse.Action): +class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods def __call__(self, parser, args, values, option_string=None): detectors, printers = get_detectors_and_printers() output_to_markdown(detectors, printers, values) parser.exit() -class OutputWiki(argparse.Action): +class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods def __call__(self, parser, args, values, option_string=None): detectors, _ = get_detectors_and_printers() output_wiki(detectors, values) @@ -569,6 +610,7 @@ def main(): main_impl(all_detector_classes=detectors, all_printer_classes=printers) +# pylint: disable=too-many-statements,too-many-branches,too-many-locals def main_impl(all_detector_classes, all_printer_classes): """ :param all_detector_classes: A list of all detectors that can be included/excluded. @@ -588,9 +630,8 @@ def main_impl(all_detector_classes, all_printer_classes): outputting_json_stdout = args.json == "-" outputting_zip = args.zip is not None if args.zip_type not in ZIP_TYPES_ACCEPTED.keys(): - logger.error( - f'Zip type not accepted, it must be one of {",".join(ZIP_TYPES_ACCEPTED.keys())}' - ) + to_log = f'Zip type not accepted, it must be one of {",".join(ZIP_TYPES_ACCEPTED.keys())}' + logger.error(to_log) # If we are outputting JSON, capture all standard output. If we are outputting to stdout, we block typical stdout # output. @@ -616,8 +657,8 @@ def main_impl(all_detector_classes, all_printer_classes): ("Printers", default_log), # ('CryticCompile', default_log) ]: - l = logging.getLogger(l_name) - l.setLevel(l_level) + logger_level = logging.getLogger(l_name) + logger_level.setLevel(l_level) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) @@ -649,7 +690,9 @@ def main_impl(all_detector_classes, all_printer_classes): results_detectors, results_printers, number_contracts, - ) = process_from_asts(filenames, args, detector_classes, printer_classes) + ) = process_from_asts( + filenames, args, detector_classes, printer_classes + ) slither_instances.append(slither_instance) else: for filename in filenames: @@ -658,7 +701,9 @@ def main_impl(all_detector_classes, all_printer_classes): results_detectors_tmp, results_printers_tmp, number_contracts_tmp, - ) = process_single(filename, args, detector_classes, printer_classes) + ) = process_single( + filename, args, detector_classes, printer_classes + ) number_contracts += number_contracts_tmp results_detectors += results_detectors_tmp results_printers += results_printers_tmp @@ -728,17 +773,19 @@ def main_impl(all_detector_classes, all_printer_classes): if args.ignore_return_value: return - except SlitherException as se: - output_error = str(se) + except SlitherException as slither_exception: + output_error = str(slither_exception) traceback.print_exc() logging.error(red("Error:")) logging.error(red(output_error)) - logging.error("Please report an issue to https://github.com/crytic/slither/issues") + logging.error( + "Please report an issue to https://github.com/crytic/slither/issues" + ) - except Exception: + except Exception: # pylint: disable=broad-except output_error = traceback.format_exc() logging.error(traceback.print_exc()) - logging.error("Error in %s" % args.filename) + logging.error(f"Error in {args.filename}") # pylint: disable=logging-fstring-interpolation logging.error(output_error) # If we are outputting JSON, capture the redirected output and disable the redirect to output the final JSON. @@ -749,7 +796,9 @@ def main_impl(all_detector_classes, all_printer_classes): "stderr": StandardOutputCapture.get_stderr_output(), } StandardOutputCapture.disable() - output_to_json(None if outputting_json_stdout else args.json, output_error, json_results) + output_to_json( + None if outputting_json_stdout else args.json, output_error, json_results + ) if outputting_zip: output_to_zip(args.zip, output_error, json_results, args.zip_type) @@ -758,7 +807,7 @@ def main_impl(all_detector_classes, all_printer_classes): if output_error: sys.exit(-1) else: - exit(results_detectors) + my_exit(results_detectors) if __name__ == "__main__": diff --git a/slither/all_exceptions.py b/slither/all_exceptions.py index 560078f26..1261a3bf3 100644 --- a/slither/all_exceptions.py +++ b/slither/all_exceptions.py @@ -1,6 +1,7 @@ """ This module import all slither exceptions """ +# pylint: disable=unused-import from slither.slithir.exceptions import SlithIRError from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.core.exceptions import SlitherCoreError diff --git a/slither/analyses/data_dependency/data_dependency.py b/slither/analyses/data_dependency/data_dependency.py index 57b6565fc..be943bb5d 100644 --- a/slither/analyses/data_dependency/data_dependency.py +++ b/slither/analyses/data_dependency/data_dependency.py @@ -112,7 +112,9 @@ def is_tainted(variable, context, only_unprotected=False, ignore_generic_taint=F ) -def is_tainted_ssa(variable, context, only_unprotected=False, ignore_generic_taint=False): +def is_tainted_ssa( + variable, context, only_unprotected=False, ignore_generic_taint=False +): """ Args: variable @@ -135,7 +137,9 @@ def is_tainted_ssa(variable, context, only_unprotected=False, ignore_generic_tai def get_dependencies( - variable: Variable, context: Union[Contract, Function], only_unprotected: bool = False + variable: Variable, + context: Union[Contract, Function], + only_unprotected: bool = False, ) -> Set[Variable]: """ Return the variables for which `variable` depends on. @@ -170,7 +174,9 @@ def get_all_dependencies( def get_dependencies_ssa( - variable: Variable, context: Union[Contract, Function], only_unprotected: bool = False + variable: Variable, + context: Union[Contract, Function], + only_unprotected: bool = False, ) -> Set[Variable]: """ Return the variables for which `variable` depends on (SSA version). @@ -272,8 +278,11 @@ def compute_dependency_contract(contract, slither): compute_dependency_function(function) propagate_function(contract, function, KEY_SSA, KEY_NON_SSA) - propagate_function(contract, function, KEY_SSA_UNPROTECTED, KEY_NON_SSA_UNPROTECTED) + propagate_function( + contract, function, KEY_SSA_UNPROTECTED, KEY_NON_SSA_UNPROTECTED + ) + # pylint: disable=expression-not-assigned if function.visibility in ["public", "external"]: [slither.context[KEY_INPUT].add(p) for p in function.parameters] [slither.context[KEY_INPUT_SSA].add(p) for p in function.parameters_ssa] @@ -296,11 +305,12 @@ def propagate_function(contract, function, context_key, context_key_non_ssa): def transitive_close_dependencies(context, context_key, context_key_non_ssa): # transitive closure changed = True - while changed: + while changed: # pylint: disable=too-many-nested-blocks changed = False # Need to create new set() as its changed during iteration data_depencencies = { - k: set([v for v in values]) for k, values in context.context[context_key].items() + k: set(values) + for k, values in context.context[context_key].items() } for key, items in data_depencencies.items(): for item in items: @@ -310,7 +320,9 @@ def transitive_close_dependencies(context, context_key, context_key_non_ssa): if not additional_item in items and additional_item != key: changed = True context.context[context_key][key].add(additional_item) - context.context[context_key_non_ssa] = convert_to_non_ssa(context.context[context_key]) + context.context[context_key_non_ssa] = convert_to_non_ssa( + context.context[context_key] + ) def propagate_contract(contract, context_key, context_key_non_ssa): @@ -328,7 +340,12 @@ def add_dependency(lvalue, function, ir, is_protected): read = ir.function.return_values_ssa else: read = ir.read - [function.context[KEY_SSA][lvalue].add(v) for v in read if not isinstance(v, Constant)] + # pylint: disable=expression-not-assigned + [ + function.context[KEY_SSA][lvalue].add(v) + for v in read + if not isinstance(v, Constant) + ] if not is_protected: [ function.context[KEY_SSA_UNPROTECTED][lvalue].add(v) @@ -375,7 +392,17 @@ def convert_variable_to_non_ssa(v): ): return v.non_ssa_version assert isinstance( - v, (Constant, SolidityVariable, Contract, Enum, SolidityFunction, Structure, Function, Type) + v, + ( + Constant, + SolidityVariable, + Contract, + Enum, + SolidityFunction, + Structure, + Function, + Type, + ), ) return v @@ -387,6 +414,6 @@ def convert_to_non_ssa(data_depencies): var = convert_variable_to_non_ssa(k) if not var in ret: ret[var] = set() - ret[var] = ret[var].union(set([convert_variable_to_non_ssa(v) for v in values])) + ret[var] = ret[var].union({convert_variable_to_non_ssa(v) for v in values}) return ret diff --git a/slither/analyses/evm/convert.py b/slither/analyses/evm/convert.py index 827f4da29..b359cb491 100644 --- a/slither/analyses/evm/convert.py +++ b/slither/analyses/evm/convert.py @@ -3,7 +3,7 @@ from slither.core.declarations import Contract, Function from slither.core.cfg.node import Node from slither.utils.function import get_function_id from slither.exceptions import SlitherError -from .evm_cfg_builder import load_evm_cfg_builder +from slither.analyses.evm.evm_cfg_builder import load_evm_cfg_builder logger = logging.getLogger("ConvertToEVM") @@ -101,16 +101,17 @@ def _get_evm_instructions_function(function_info): # Todo: Could rename it appropriately in evm-cfg-builder # by detecting that init bytecode is being parsed. name = "_dispatcher" - hash = "" + func_hash = "" else: cfg = function_info["contract_info"]["cfg"] name = function.name # Get first four bytes of function singature's keccak-256 hash used as function selector - hash = str(hex(get_function_id(function.full_name))) + func_hash = str(hex(get_function_id(function.full_name))) - function_evm = _get_function_evm(cfg, name, hash) + function_evm = _get_function_evm(cfg, name, func_hash) if function_evm is None: - logger.error("Function " + function.name + " not found in the EVM code") + to_log = "Function " + function.name + " not found in the EVM code" + logger.error(to_log) raise SlitherError("Function " + function.name + " not found in the EVM code") function_ins = [] @@ -137,7 +138,10 @@ def _get_evm_instructions_node(node_info): # Get evm instructions corresponding to node's source line number node_source_line = ( - contract_file[0 : node_info["node"].source_mapping["start"]].count("\n".encode("utf-8")) + 1 + contract_file[0 : node_info["node"].source_mapping["start"]].count( + "\n".encode("utf-8") + ) + + 1 ) node_pcs = contract_pcs.get(node_source_line, []) node_ins = [] @@ -153,14 +157,19 @@ def _get_function_evm(cfg, function_name, function_hash): if function_evm.name[:2] == "0x" and function_evm.name == function_hash: return function_evm # Match function name - elif function_evm.name[:2] != "0x" and function_evm.name.split("(")[0] == function_name: + if ( + function_evm.name[:2] != "0x" + and function_evm.name.split("(")[0] == function_name + ): return function_evm return None - -def generate_source_to_evm_ins_mapping(evm_instructions, srcmap_runtime, slither, filename): +# pylint: disable=too-many-locals +def generate_source_to_evm_ins_mapping( + evm_instructions, srcmap_runtime, slither, filename +): """ - Generate Solidity source to EVM instruction mapping using evm_cfg_builder:cfg.instructions + Generate Solidity source to EVM instruction mapping using evm_cfg_builder:cfg.instructions and solc:srcmap_runtime Returns: Solidity source to EVM instruction mapping @@ -180,11 +189,11 @@ def generate_source_to_evm_ins_mapping(evm_instructions, srcmap_runtime, slither mapping_item = mapping.split(":") mapping_item += prev_mapping[len(mapping_item) :] - for i in range(len(mapping_item)): + for i, _ in enumerate(mapping_item): if mapping_item[i] == "": mapping_item[i] = int(prev_mapping[i]) - offset, length, file_id, _ = mapping_item + offset, _length, file_id, _ = mapping_item prev_mapping = mapping_item if file_id == "-1": @@ -198,6 +207,8 @@ def generate_source_to_evm_ins_mapping(evm_instructions, srcmap_runtime, slither # Append evm instructions to the corresponding source line number # Note: Some evm instructions in mapping are not necessarily in program execution order # Note: The order depends on how solc creates the srcmap_runtime - source_to_evm_mapping.setdefault(line_number, []).append(evm_instructions[idx].pc) + source_to_evm_mapping.setdefault(line_number, []).append( + evm_instructions[idx].pc + ) return source_to_evm_mapping diff --git a/slither/analyses/evm/evm_cfg_builder.py b/slither/analyses/evm/evm_cfg_builder.py index d299a055f..9ceb298a7 100644 --- a/slither/analyses/evm/evm_cfg_builder.py +++ b/slither/analyses/evm/evm_cfg_builder.py @@ -7,6 +7,7 @@ logger = logging.getLogger("ConvertToEVM") def load_evm_cfg_builder(): try: # Avoiding the addition of evm_cfg_builder as permanent dependency + # pylint: disable=import-outside-toplevel from evm_cfg_builder.cfg import CFG return CFG diff --git a/slither/analyses/write/are_variables_written.py b/slither/analyses/write/are_variables_written.py index a0c591711..7d9433511 100644 --- a/slither/analyses/write/are_variables_written.py +++ b/slither/analyses/write/are_variables_written.py @@ -2,7 +2,7 @@ Detect if all the given variables are written in all the paths of the function """ from collections import defaultdict -from typing import Dict, Tuple, Set, List, Optional +from typing import Dict, Set, List from slither.core.cfg.node import NodeType, Node from slither.core.declarations import SolidityFunction @@ -18,7 +18,7 @@ from slither.slithir.operations import ( from slither.slithir.variables import ReferenceVariable, TemporaryVariable -class State: +class State: # pylint: disable=too-few-public-methods def __init__(self): # Map node -> list of variables set # Were each variables set represents a configuration of a path @@ -33,8 +33,12 @@ class State: self.nodes: Dict[Node, List[Set[Variable]]] = defaultdict(list) +# pylint: disable=too-many-branches def _visit( - node: Node, state: State, variables_written: Set[Variable], variables_to_write: List[Variable] + node: Node, + state: State, + variables_written: Set[Variable], + variables_to_write: List[Variable], ): """ Explore all the nodes to look for values not written when the node's function return @@ -51,7 +55,10 @@ def _visit( for ir in node.irs: if isinstance(ir, SolidityCall): # TODO convert the revert to a THROW node - if ir.function in [SolidityFunction("revert(string)"), SolidityFunction("revert()")]: + if ir.function in [ + SolidityFunction("revert(string)"), + SolidityFunction("revert()"), + ]: return [] if not isinstance(ir, OperationWithLValue): @@ -61,7 +68,9 @@ def _visit( if isinstance(ir, (Length, Balance)): refs[ir.lvalue] = ir.value - if ir.lvalue and not isinstance(ir.lvalue, (TemporaryVariable, ReferenceVariable)): + if ir.lvalue and not isinstance( + ir.lvalue, (TemporaryVariable, ReferenceVariable) + ): variables_written.add(ir.lvalue) lvalue = ir.lvalue diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index a72a1b0dd..54817b2dc 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -56,6 +56,8 @@ if TYPE_CHECKING: ) +# pylint: disable=too-many-lines,too-many-branches,too-many-instance-attributes + ################################################################################### ################################################################################### # region NodeType @@ -140,8 +142,9 @@ class NodeType(Enum): # endregion - -class Node(SourceMapping, ChildFunction): +# I am not sure why, but pylint reports a lot of "no-member" issue that are not real (Josselin) +# pylint: disable=no-member +class Node(SourceMapping, ChildFunction): # pylint: disable=too-many-public-methods """ Node class @@ -166,8 +169,12 @@ class Node(SourceMapping, ChildFunction): self._dominance_frontier: Set["Node"] = set() # Phi origin # key are variable name - self._phi_origins_state_variables: Dict[str, Tuple[StateVariable, Set["Node"]]] = {} - self._phi_origins_local_variables: Dict[str, Tuple[LocalVariable, Set["Node"]]] = {} + self._phi_origins_state_variables: Dict[ + str, Tuple[StateVariable, Set["Node"]] + ] = {} + self._phi_origins_local_variables: Dict[ + str, Tuple[LocalVariable, Set["Node"]] + ] = {} # self._phi_origins_member_variables: Dict[str, Tuple[MemberVariable, Set["Node"]]] = {} self._expression: Optional[Expression] = None @@ -180,7 +187,7 @@ class Node(SourceMapping, ChildFunction): self._ssa_vars_written: List["SlithIRVariable"] = [] self._ssa_vars_read: List["SlithIRVariable"] = [] - self._internal_calls: List[Function] = [] + self._internal_calls: List["Function"] = [] self._solidity_calls: List[SolidityFunction] = [] self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls self._library_calls: List["LibraryCallType"] = [] @@ -457,6 +464,7 @@ class Node(SourceMapping, ChildFunction): :param callstack: used internally to check for recursion :return bool: """ + # pylint: disable=import-outside-toplevel from slither.slithir.operations import Call if self._can_reenter is None: @@ -472,6 +480,7 @@ class Node(SourceMapping, ChildFunction): Check if the node can send eth :return bool: """ + # pylint: disable=import-outside-toplevel from slither.slithir.operations import Call if self._can_send_eth is None: @@ -712,7 +721,9 @@ class Node(SourceMapping, ChildFunction): @staticmethod def _is_non_slithir_var(var: Variable): - return not isinstance(var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)) + return not isinstance( + var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable) + ) @staticmethod def _is_valid_slithir_var(var: Variable): @@ -793,11 +804,15 @@ class Node(SourceMapping, ChildFunction): ################################################################################### @property - def phi_origins_local_variables(self) -> Dict[str, Tuple[LocalVariable, Set["Node"]]]: + def phi_origins_local_variables( + self, + ) -> Dict[str, Tuple[LocalVariable, Set["Node"]]]: return self._phi_origins_local_variables @property - def phi_origins_state_variables(self) -> Dict[str, Tuple[StateVariable, Set["Node"]]]: + def phi_origins_state_variables( + self, + ) -> Dict[str, Tuple[StateVariable, Set["Node"]]]: return self._phi_origins_state_variables # @property @@ -835,11 +850,12 @@ class Node(SourceMapping, ChildFunction): ################################################################################### ################################################################################### - def _find_read_write_call(self): + def _find_read_write_call(self): # pylint: disable=too-many-statements for ir in self.irs: - self._slithir_vars |= set([v for v in ir.read if self._is_valid_slithir_var(v)]) + self._slithir_vars |= {v for v in ir.read if self._is_valid_slithir_var(v)} + if isinstance(ir, OperationWithLValue): var = ir.lvalue if var and self._is_valid_slithir_var(var): @@ -884,7 +900,9 @@ class Node(SourceMapping, ChildFunction): self._high_level_calls.append((self.function.contract, ir.function)) else: try: - self._high_level_calls.append((ir.destination.type.type, ir.function)) + self._high_level_calls.append( + (ir.destination.type.type, ir.function) + ) except AttributeError: raise SlitherException( f"Function not found on {ir}. Please try compiling with a recent Solidity version." @@ -895,12 +913,22 @@ class Node(SourceMapping, ChildFunction): self._library_calls.append((ir.destination, ir.function)) self._vars_read = list(set(self._vars_read)) - self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] - self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] - self._solidity_vars_read = [v for v in self._vars_read if isinstance(v, SolidityVariable)] + self._state_vars_read = [ + v for v in self._vars_read if isinstance(v, StateVariable) + ] + self._local_vars_read = [ + v for v in self._vars_read if isinstance(v, LocalVariable) + ] + self._solidity_vars_read = [ + v for v in self._vars_read if isinstance(v, SolidityVariable) + ] self._vars_written = list(set(self._vars_written)) - self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] - self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] + self._state_vars_written = [ + v for v in self._vars_written if isinstance(v, StateVariable) + ] + self._local_vars_written = [ + v for v in self._vars_written if isinstance(v, LocalVariable) + ] self._internal_calls = list(set(self._internal_calls)) self._solidity_calls = list(set(self._solidity_calls)) self._high_level_calls = list(set(self._high_level_calls)) @@ -926,7 +954,9 @@ class Node(SourceMapping, ChildFunction): continue if not isinstance(ir, (Phi, Index, Member)): self._ssa_vars_read += [ - v for v in ir.read if isinstance(v, (StateIRVariable, LocalIRVariable)) + v + for v in ir.read + if isinstance(v, (StateIRVariable, LocalIRVariable)) ] for var in ir.read: if isinstance(var, ReferenceVariable): @@ -954,8 +984,12 @@ class Node(SourceMapping, ChildFunction): continue self._ssa_vars_written.append(var) self._ssa_vars_read = list(set(self._ssa_vars_read)) - self._ssa_state_vars_read = [v for v in self._ssa_vars_read if isinstance(v, StateVariable)] - self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] + self._ssa_state_vars_read = [ + v for v in self._ssa_vars_read if isinstance(v, StateVariable) + ] + self._ssa_local_vars_read = [ + v for v in self._ssa_vars_read if isinstance(v, LocalVariable) + ] self._ssa_vars_written = list(set(self._ssa_vars_written)) self._ssa_state_vars_written = [ v for v in self._ssa_vars_written if isinstance(v, StateVariable) @@ -968,12 +1002,20 @@ class Node(SourceMapping, ChildFunction): vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] self._vars_read += [v for v in vars_read if v not in self._vars_read] - self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] - self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] + self._state_vars_read = [ + v for v in self._vars_read if isinstance(v, StateVariable) + ] + self._local_vars_read = [ + v for v in self._vars_read if isinstance(v, LocalVariable) + ] self._vars_written += [v for v in vars_written if v not in self._vars_written] - self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] - self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] + self._state_vars_written = [ + v for v in self._vars_written if isinstance(v, StateVariable) + ] + self._local_vars_written = [ + v for v in self._vars_written if isinstance(v, LocalVariable) + ] # endregion ################################################################################### @@ -1024,11 +1066,11 @@ def recheable(node: Node) -> Set[Node]: nodes = node.sons visited = set() while nodes: - next = nodes[0] + next_node = nodes[0] nodes = nodes[1:] - if next not in visited: - visited.add(next) - for son in next.sons: + if next_node not in visited: + visited.add(next_node) + for son in next_node.sons: if son not in visited: nodes.append(son) return visited diff --git a/slither/core/children/child_node.py b/slither/core/children/child_node.py index a90acfe48..5cf0c808d 100644 --- a/slither/core/children/child_node.py +++ b/slither/core/children/child_node.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from slither.core.declarations import Function, Contract -class ChildNode(object): +class ChildNode: def __init__(self): super(ChildNode, self).__init__() self._node = None diff --git a/slither/core/context/context.py b/slither/core/context/context.py index d16178a58..3cf84363f 100644 --- a/slither/core/context/context.py +++ b/slither/core/context/context.py @@ -2,7 +2,7 @@ from collections import defaultdict from typing import Dict -class Context: +class Context: # pylint: disable=too-few-public-methods def __init__(self): super(Context, self).__init__() self._context = {"MEMBERS": defaultdict(None)} diff --git a/slither/core/declarations/__init__.py b/slither/core/declarations/__init__.py index b6cddc787..89de67e6b 100644 --- a/slither/core/declarations/__init__.py +++ b/slither/core/declarations/__init__.py @@ -5,5 +5,9 @@ from .function import Function from .import_directive import Import from .modifier import Modifier from .pragma_directive import Pragma -from .solidity_variables import SolidityVariable, SolidityVariableComposed, SolidityFunction +from .solidity_variables import ( + SolidityVariable, + SolidityVariableComposed, + SolidityFunction, +) from .structure import Structure diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 18ed5ac10..69006bf33 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -3,9 +3,9 @@ """ import logging from pathlib import Path +from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union from crytic_compile.platform import Type as PlatformType -from typing import Optional, List, Dict, Callable, Tuple, TYPE_CHECKING, Union from slither.core.children.child_slither import ChildSlither from slither.core.solidity_types.type import Type @@ -22,6 +22,7 @@ from slither.utils.erc import ( ) from slither.utils.tests_pattern import is_test_contract +# pylint: disable=too-many-lines,too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks if TYPE_CHECKING: from slither.utils.type_helpers import LibraryCallType, HighLevelCallType from slither.core.declarations import Enum, Event, Modifier @@ -33,7 +34,7 @@ if TYPE_CHECKING: LOGGER = logging.getLogger("Contract") -class Contract(ChildSlither, SourceMapping): +class Contract(ChildSlither, SourceMapping): # pylint: disable=too-many-public-methods """ Contract class """ @@ -43,7 +44,9 @@ class Contract(ChildSlither, SourceMapping): self._name: Optional[str] = None self._id: Optional[int] = None - self._inheritance: List["Contract"] = [] # all contract inherited, c3 linearization + self._inheritance: List[ + "Contract" + ] = [] # all contract inherited, c3 linearization self._immediate_inheritance: List["Contract"] = [] # immediate inheritance # Constructors called on contract's definition @@ -338,7 +341,11 @@ class Contract(ChildSlither, SourceMapping): On "contract B is A(){..}" it returns the constructor of A """ - return [c.constructor for c in self._explicit_base_constructor_calls if c.constructor] + return [ + c.constructor + for c in self._explicit_base_constructor_calls + if c.constructor + ] # endregion ################################################################################### @@ -355,12 +362,16 @@ class Contract(ChildSlither, SourceMapping): """ if self._signatures is None: sigs = [ - v.full_name for v in self.state_variables if v.visibility in ["public", "external"] + v.full_name + for v in self.state_variables + if v.visibility in ["public", "external"] ] - sigs += set( - [f.full_name for f in self.functions if f.visibility in ["public", "external"]] - ) + sigs += { + f.full_name + for f in self.functions + if f.visibility in ["public", "external"] + } self._signatures = list(set(sigs)) return self._signatures @@ -377,13 +388,11 @@ class Contract(ChildSlither, SourceMapping): if v.visibility in ["public", "external"] ] - sigs += set( - [ - f.full_name - for f in self.functions_declared - if f.visibility in ["public", "external"] - ] - ) + sigs += { + f.full_name + for f in self.functions_declared + if f.visibility in ["public", "external"] + } self._signatures_declared = list(set(sigs)) return self._signatures_declared @@ -397,6 +406,9 @@ class Contract(ChildSlither, SourceMapping): def available_functions_as_dict(self) -> Dict[str, "Function"]: return {f.full_name: f for f in self._functions.values() if not f.is_shadowed} + def add_function(self, func: "Function"): + self._functions[func.canonical_name] = func + def set_functions(self, functions: Dict[str, "Function"]): """ Set the functions @@ -569,19 +581,25 @@ class Contract(ChildSlither, SourceMapping): ################################################################################### ################################################################################### - def get_functions_reading_from_variable(self, variable: "Variable") -> List["Function"]: + def get_functions_reading_from_variable( + self, variable: "Variable" + ) -> List["Function"]: """ Return the functions reading the variable """ return [f for f in self.functions if f.is_reading(variable)] - def get_functions_writing_to_variable(self, variable: "Variable") -> List["Function"]: + def get_functions_writing_to_variable( + self, variable: "Variable" + ) -> List["Function"]: """ Return the functions writting the variable """ return [f for f in self.functions if f.is_writing(variable)] - def get_function_from_signature(self, function_signature: str) -> Optional["Function"]: + def get_function_from_signature( + self, function_signature: str + ) -> Optional["Function"]: """ Return a function from a signature Args: @@ -590,22 +608,34 @@ class Contract(ChildSlither, SourceMapping): Function """ return next( - (f for f in self.functions if f.full_name == function_signature and not f.is_shadowed), + ( + f + for f in self.functions + if f.full_name == function_signature and not f.is_shadowed + ), None, ) - def get_modifier_from_signature(self, modifier_signature: str) -> Optional["Modifier"]: + def get_modifier_from_signature( + self, modifier_signature: str + ) -> Optional["Modifier"]: """ Return a modifier from a signature :param modifier_signature: """ return next( - (m for m in self.modifiers if m.full_name == modifier_signature and not m.is_shadowed), + ( + m + for m in self.modifiers + if m.full_name == modifier_signature and not m.is_shadowed + ), None, ) - def get_function_from_canonical_name(self, canonical_name: str) -> Optional["Function"]: + def get_function_from_canonical_name( + self, canonical_name: str + ) -> Optional["Function"]: """ Return a function from a a canonical name (contract.signature()) Args: @@ -613,9 +643,13 @@ class Contract(ChildSlither, SourceMapping): Returns: Function """ - return next((f for f in self.functions if f.canonical_name == canonical_name), None) + return next( + (f for f in self.functions if f.canonical_name == canonical_name), None + ) - def get_modifier_from_canonical_name(self, canonical_name: str) -> Optional["Modifier"]: + def get_modifier_from_canonical_name( + self, canonical_name: str + ) -> Optional["Modifier"]: """ Return a modifier from a canonical name (contract.signature()) Args: @@ -623,9 +657,13 @@ class Contract(ChildSlither, SourceMapping): Returns: Modifier """ - return next((m for m in self.modifiers if m.canonical_name == canonical_name), None) + return next( + (m for m in self.modifiers if m.canonical_name == canonical_name), None + ) - def get_state_variable_from_name(self, variable_name: str) -> Optional["StateVariable"]: + def get_state_variable_from_name( + self, variable_name: str + ) -> Optional["StateVariable"]: """ Return a state variable from a name @@ -655,7 +693,9 @@ class Contract(ChildSlither, SourceMapping): """ return next((st for st in self.structures if st.name == structure_name), None) - def get_structure_from_canonical_name(self, structure_name: str) -> Optional["Structure"]: + def get_structure_from_canonical_name( + self, structure_name: str + ) -> Optional["Structure"]: """ Return a structure from a canonical name Args: @@ -663,7 +703,9 @@ class Contract(ChildSlither, SourceMapping): Returns: Structure """ - return next((st for st in self.structures if st.canonical_name == structure_name), None) + return next( + (st for st in self.structures if st.canonical_name == structure_name), None + ) def get_event_from_signature(self, event_signature: str) -> Optional["Event"]: """ @@ -675,7 +717,9 @@ class Contract(ChildSlither, SourceMapping): """ return next((e for e in self.events if e.full_name == event_signature), None) - def get_event_from_canonical_name(self, event_canonical_name: str) -> Optional["Event"]: + def get_event_from_canonical_name( + self, event_canonical_name: str + ) -> Optional["Event"]: """ Return an event from a canonical name Args: @@ -683,7 +727,9 @@ class Contract(ChildSlither, SourceMapping): Returns: Event """ - return next((e for e in self.events if e.canonical_name == event_canonical_name), None) + return next( + (e for e in self.events if e.canonical_name == event_canonical_name), None + ) def get_enum_from_name(self, enum_name: str) -> Optional["Enum"]: """ @@ -775,7 +821,9 @@ class Contract(ChildSlither, SourceMapping): list((Contract, Function): List all of the libraries func called """ all_high_level_calls = [f.all_library_calls() for f in self.functions + self.modifiers] # type: ignore - all_high_level_calls = [item for sublist in all_high_level_calls for item in sublist] + all_high_level_calls = [ + item for sublist in all_high_level_calls for item in sublist + ] return list(set(all_high_level_calls)) @property @@ -784,7 +832,9 @@ class Contract(ChildSlither, SourceMapping): list((Contract, Function|Variable)): List all of the external high level calls """ all_high_level_calls = [f.all_high_level_calls() for f in self.functions + self.modifiers] # type: ignore - all_high_level_calls = [item for sublist in all_high_level_calls for item in sublist] + all_high_level_calls = [ + item for sublist in all_high_level_calls for item in sublist + ] return list(set(all_high_level_calls)) # endregion @@ -804,10 +854,14 @@ class Contract(ChildSlither, SourceMapping): (str, list, list, list, list): (name, inheritance, variables, fuction summaries, modifier summaries) """ func_summaries = [ - f.get_summary() for f in self.functions if (not f.is_shadowed or include_shadowed) + f.get_summary() + for f in self.functions + if (not f.is_shadowed or include_shadowed) ] modif_summaries = [ - f.get_summary() for f in self.modifiers if (not f.is_shadowed or include_shadowed) + f.get_summary() + for f in self.modifiers + if (not f.is_shadowed or include_shadowed) ] return ( self.name, @@ -971,7 +1025,9 @@ class Contract(ChildSlither, SourceMapping): def is_from_dependency(self) -> bool: if self.slither.crytic_compile is None: return False - return self.slither.crytic_compile.is_dependency(self.source_mapping["filename_absolute"]) + return self.slither.crytic_compile.is_dependency( + self.source_mapping["filename_absolute"] + ) # endregion ################################################################################### @@ -991,7 +1047,9 @@ class Contract(ChildSlither, SourceMapping): if self.name == "Migrations": paths = Path(self.source_mapping["filename_absolute"]).parts if len(paths) >= 2: - return paths[-2] == "contracts" and paths[-1] == "migrations.sol" + return ( + paths[-2] == "contracts" and paths[-1] == "migrations.sol" + ) return False @property @@ -1027,7 +1085,10 @@ class Contract(ChildSlither, SourceMapping): else: for c in self.inheritance + [self]: # This might lead to false positive - if "upgradeable" in c.name.lower() or "upgradable" in c.name.lower(): + if ( + "upgradeable" in c.name.lower() + or "upgradable" in c.name.lower() + ): self._is_upgradeable = True break return self._is_upgradeable @@ -1043,7 +1104,10 @@ class Contract(ChildSlither, SourceMapping): if f.is_fallback: for node in f.all_nodes(): for ir in node.irs: - if isinstance(ir, LowLevelCall) and ir.function_name == "delegatecall": + if ( + isinstance(ir, LowLevelCall) + and ir.function_name == "delegatecall" + ): self._is_upgradeable_proxy = True return self._is_upgradeable_proxy if node.type == NodeType.ASSEMBLY: @@ -1079,21 +1143,29 @@ class Contract(ChildSlither, SourceMapping): if variable_candidate.expression and not variable_candidate.is_constant: constructor_variable = Function() - constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) + constructor_variable.set_function_type( + FunctionType.CONSTRUCTOR_VARIABLES + ) constructor_variable.set_contract(self) constructor_variable.set_contract_declarer(self) constructor_variable.set_visibility("internal") # For now, source mapping of the constructor variable is the whole contract # Could be improved with a targeted source mapping constructor_variable.set_offset(self.source_mapping, self.slither) - self._functions[constructor_variable.canonical_name] = constructor_variable + self._functions[ + constructor_variable.canonical_name + ] = constructor_variable - prev_node = self._create_node(constructor_variable, 0, variable_candidate) + prev_node = self._create_node( + constructor_variable, 0, variable_candidate + ) variable_candidate.node_initialization = prev_node counter = 1 for v in self.state_variables[idx + 1 :]: if v.expression and not v.is_constant: - next_node = self._create_node(constructor_variable, counter, v) + next_node = self._create_node( + constructor_variable, counter, v + ) v.node_initialization = next_node prev_node.add_son(next_node) next_node.add_father(prev_node) @@ -1113,14 +1185,20 @@ class Contract(ChildSlither, SourceMapping): # For now, source mapping of the constructor variable is the whole contract # Could be improved with a targeted source mapping constructor_variable.set_offset(self.source_mapping, self.slither) - self._functions[constructor_variable.canonical_name] = constructor_variable + self._functions[ + constructor_variable.canonical_name + ] = constructor_variable - prev_node = self._create_node(constructor_variable, 0, variable_candidate) + prev_node = self._create_node( + constructor_variable, 0, variable_candidate + ) variable_candidate.node_initialization = prev_node counter = 1 for v in self.state_variables[idx + 1 :]: if v.expression and v.is_constant: - next_node = self._create_node(constructor_variable, counter, v) + next_node = self._create_node( + constructor_variable, counter, v + ) v.node_initialization = next_node prev_node.add_son(next_node) next_node.add_father(prev_node) @@ -1142,7 +1220,10 @@ class Contract(ChildSlither, SourceMapping): node.set_function(func) func.add_node(node) expression = AssignmentOperation( - Identifier(variable), variable.expression, AssignmentOperationType.ASSIGN, variable.type + Identifier(variable), + variable.expression, + AssignmentOperationType.ASSIGN, + variable.type, ) expression.set_offset(variable.source_mapping, self.slither) @@ -1194,7 +1275,9 @@ class Contract(ChildSlither, SourceMapping): last_state_variables_instances[variable_name] += instances for func in self.functions + self.modifiers: - func.fix_phi(last_state_variables_instances, initial_state_variables_instances) + func.fix_phi( + last_state_variables_instances, initial_state_variables_instances + ) @property def is_top_level(self) -> bool: diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 7fe309bba..d2b2e5d4b 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -14,7 +14,12 @@ from slither.core.declarations.solidity_variables import ( SolidityVariable, SolidityVariableComposed, ) -from slither.core.expressions import Identifier, IndexAccess, MemberAccess, UnaryOperation +from slither.core.expressions import ( + Identifier, + IndexAccess, + MemberAccess, + UnaryOperation, +) from slither.core.solidity_types import UserDefinedType from slither.core.solidity_types.type import Type from slither.core.source_mapping.source_mapping import SourceMapping @@ -23,6 +28,8 @@ from slither.core.variables.local_variable import LocalVariable from slither.core.variables.state_variable import StateVariable from slither.utils.utils import unroll +# pylint: disable=import-outside-toplevel,too-many-instance-attributes,too-many-statements,too-many-lines + if TYPE_CHECKING: from slither.utils.type_helpers import ( InternalCallType, @@ -46,7 +53,10 @@ ReacheableNode = namedtuple("ReacheableNode", ["node", "ir"]) class ModifierStatements: def __init__( - self, modifier: Union["Contract", "Function"], entry_point: "Node", nodes: List["Node"] + self, + modifier: Union["Contract", "Function"], + entry_point: "Node", + nodes: List["Node"], ): self._modifier = modifier self._entry_point = entry_point @@ -79,10 +89,26 @@ class FunctionType(Enum): FALLBACK = 2 RECEIVE = 3 CONSTRUCTOR_VARIABLES = 10 # Fake function to hold variable declaration statements - CONSTRUCTOR_CONSTANT_VARIABLES = 11 # Fake function to hold variable declaration statements + CONSTRUCTOR_CONSTANT_VARIABLES = ( + 11 # Fake function to hold variable declaration statements + ) -class Function(ChildContract, ChildInheritance, SourceMapping): +def _filter_state_variables_written(expressions: List["Expression"]): + ret = [] + for expression in expressions: + if isinstance(expression, Identifier): + ret.append(expression) + if isinstance(expression, UnaryOperation): + ret.append(expression.expression) + if isinstance(expression, MemberAccess): + ret.append(expression.expression) + if isinstance(expression, IndexAccess): + ret.append(expression.expression_left) + return ret + + +class Function(ChildContract, ChildInheritance, SourceMapping): # pylint: disable=too-many-public-methods """ Function class """ @@ -147,13 +173,21 @@ class Function(ChildContract, ChildInheritance, SourceMapping): self._all_state_variables_written: Optional[List["StateVariable"]] = None self._all_slithir_variables: Optional[List["SlithIRVariable"]] = None self._all_nodes: Optional[List["Node"]] = None - self._all_conditional_state_variables_read: Optional[List["StateVariable"]] = None - self._all_conditional_state_variables_read_with_loop: Optional[List["StateVariable"]] = None - self._all_conditional_solidity_variables_read: Optional[List["SolidityVariable"]] = None + self._all_conditional_state_variables_read: Optional[ + List["StateVariable"] + ] = None + self._all_conditional_state_variables_read_with_loop: Optional[ + List["StateVariable"] + ] = None + self._all_conditional_solidity_variables_read: Optional[ + List["SolidityVariable"] + ] = None self._all_conditional_solidity_variables_read_with_loop: Optional[ List["SolidityVariable"] ] = None - self._all_solidity_variables_used_as_args: Optional[List["SolidityVariable"]] = None + self._all_solidity_variables_used_as_args: Optional[ + List["SolidityVariable"] + ] = None self._is_shadowed: bool = False self._shadows: bool = False @@ -187,13 +221,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ if self._name == "" and self._function_type == FunctionType.CONSTRUCTOR: return "constructor" - elif self._function_type == FunctionType.FALLBACK: + if self._function_type == FunctionType.FALLBACK: return "fallback" - elif self._function_type == FunctionType.RECEIVE: + if self._function_type == FunctionType.RECEIVE: return "receive" - elif self._function_type == FunctionType.CONSTRUCTOR_VARIABLES: + if self._function_type == FunctionType.CONSTRUCTOR_VARIABLES: return "slitherConstructorVariables" - elif self._function_type == FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES: + if self._function_type == FunctionType.CONSTRUCTOR_CONSTANT_VARIABLES: return "slitherConstructorConstantVariables" return self._name @@ -815,15 +849,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if self._return_values is None: return_values = list() returns = [n for n in self.nodes if n.type == NodeType.RETURN] - [ + [ # pylint: disable=expression-not-assigned return_values.extend(ir.values) for node in returns for ir in node.irs if isinstance(ir, Return) ] - self._return_values = list( - set([x for x in return_values if not isinstance(x, Constant)]) - ) + self._return_values = list({x for x in return_values if not isinstance(x, Constant)}) return self._return_values @property @@ -838,15 +870,13 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if self._return_values_ssa is None: return_values_ssa = list() returns = [n for n in self.nodes if n.type == NodeType.RETURN] - [ + [ # pylint: disable=expression-not-assigned return_values_ssa.extend(ir.values) for node in returns for ir in node.irs_ssa if isinstance(ir, Return) ] - self._return_values_ssa = list( - set([x for x in return_values_ssa if not isinstance(x, Constant)]) - ) + self._return_values_ssa = list({x for x in return_values_ssa if not isinstance(x, Constant)}) return self._return_values_ssa # endregion @@ -900,7 +930,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): Contract and converted into address :return: the solidity signature """ - parameters = [self._convert_type_for_solidity_signature(x.type) for x in self.parameters] + parameters = [ + self._convert_type_for_solidity_signature(x.type) for x in self.parameters + ] return self.name + "(" + ",".join(parameters) + ")" @property @@ -922,7 +954,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping): Return the function signature as a str (contains the return values) """ name, parameters, returnVars = self.signature - return name + "(" + ",".join(parameters) + ") returns(" + ",".join(returnVars) + ")" + return ( + name + + "(" + + ",".join(parameters) + + ") returns(" + + ",".join(returnVars) + + ")" + ) # endregion ################################################################################### @@ -977,10 +1016,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping): values = f_new_values(self) explored = [self] to_explore = [ - c for c in self.internal_calls if isinstance(c, Function) and c not in explored + c + for c in self.internal_calls + if isinstance(c, Function) and c not in explored ] to_explore += [ - c for (_, c) in self.library_calls if isinstance(c, Function) and c not in explored + c + for (_, c) in self.library_calls + if isinstance(c, Function) and c not in explored ] to_explore += [m for m in self.modifiers if m not in explored] @@ -1003,7 +1046,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): for (_, c) in f.library_calls if isinstance(c, Function) and c not in explored and c not in to_explore ] - to_explore += [m for m in f.modifiers if m not in explored and m not in to_explore] + to_explore += [ + m for m in f.modifiers if m not in explored and m not in to_explore + ] return list(set(values)) @@ -1029,7 +1074,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ recursive version of slithir_variables """ if self._all_slithir_variables is None: - self._all_slithir_variables = self._explore_functions(lambda x: x.slithir_variables) + self._all_slithir_variables = self._explore_functions( + lambda x: x.slithir_variables + ) return self._all_slithir_variables def all_nodes(self) -> List["Node"]: @@ -1047,10 +1094,10 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return self._all_expressions def all_slithir_operations(self) -> List["Operation"]: - """ - """ if self._all_slithir_operations is None: - self._all_slithir_operations = self._explore_functions(lambda x: x.slithir_operations) + self._all_slithir_operations = self._explore_functions( + lambda x: x.slithir_operations + ) return self._all_slithir_operations def all_state_variables_written(self) -> List[StateVariable]: @@ -1066,21 +1113,27 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ recursive version of internal_calls """ if self._all_internals_calls is None: - self._all_internals_calls = self._explore_functions(lambda x: x.internal_calls) + self._all_internals_calls = self._explore_functions( + lambda x: x.internal_calls + ) return self._all_internals_calls def all_low_level_calls(self) -> List["LowLevelCallType"]: """ recursive version of low_level calls """ if self._all_low_level_calls is None: - self._all_low_level_calls = self._explore_functions(lambda x: x.low_level_calls) + self._all_low_level_calls = self._explore_functions( + lambda x: x.low_level_calls + ) return self._all_low_level_calls def all_high_level_calls(self) -> List["HighLevelCallType"]: """ recursive version of high_level calls """ if self._all_high_level_calls is None: - self._all_high_level_calls = self._explore_functions(lambda x: x.high_level_calls) + self._all_high_level_calls = self._explore_functions( + lambda x: x.high_level_calls + ) return self._all_high_level_calls def all_library_calls(self) -> List["LibraryCallType"]: @@ -1094,15 +1147,23 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ recursive version of solidity calls """ if self._all_solidity_calls is None: - self._all_solidity_calls = self._explore_functions(lambda x: x.solidity_calls) + self._all_solidity_calls = self._explore_functions( + lambda x: x.solidity_calls + ) return self._all_solidity_calls @staticmethod - def _explore_func_cond_read(func: "Function", include_loop: bool) -> List["StateVariable"]: - ret = [n.state_variables_read for n in func.nodes if n.is_conditional(include_loop)] + def _explore_func_cond_read( + func: "Function", include_loop: bool + ) -> List["StateVariable"]: + ret = [ + n.state_variables_read for n in func.nodes if n.is_conditional(include_loop) + ] return [item for sublist in ret for item in sublist] - def all_conditional_state_variables_read(self, include_loop=True) -> List["StateVariable"]: + def all_conditional_state_variables_read( + self, include_loop=True + ) -> List["StateVariable"]: """ Return the state variable used in a condition @@ -1133,12 +1194,16 @@ class Function(ChildContract, ChildInheritance, SourceMapping): @staticmethod def _explore_func_conditional( - func: "Function", f: Callable[["Node"], List[SolidityVariable]], include_loop: bool + func: "Function", + f: Callable[["Node"], List[SolidityVariable]], + include_loop: bool, ): ret = [f(n) for n in func.nodes if n.is_conditional(include_loop)] return [item for sublist in ret for item in sublist] - def all_conditional_solidity_variables_read(self, include_loop=True) -> List[SolidityVariable]: + def all_conditional_solidity_variables_read( + self, include_loop=True + ) -> List[SolidityVariable]: """ Return the Soldiity variables directly used in a condtion @@ -1174,7 +1239,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return [var for var in ret if isinstance(var, SolidityVariable)] @staticmethod - def _explore_func_nodes(func: "Function", f: Callable[["Node"], List[SolidityVariable]]): + def _explore_func_nodes( + func: "Function", f: Callable[["Node"], List[SolidityVariable]] + ): ret = [f(n) for n in func.nodes] return [item for sublist in ret for item in sublist] @@ -1187,7 +1254,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ if self._all_solidity_variables_used_as_args is None: self._all_solidity_variables_used_as_args = self._explore_functions( - lambda x: self._explore_func_nodes(x, self._solidity_variable_in_internal_calls) + lambda x: self._explore_func_nodes( + x, self._solidity_variable_in_internal_calls + ) ) return self._all_solidity_variables_used_as_args @@ -1217,7 +1286,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def get_local_variable_from_name(self, variable_name: str) -> Optional[LocalVariable]: + def get_local_variable_from_name( + self, variable_name: str + ) -> Optional[LocalVariable]: """ Return a local variable from a name @@ -1271,7 +1342,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping): for node in self.nodes: f.write('{}[label="{}"];\n'.format(node.node_id, description(node))) if node.immediate_dominator: - f.write("{}->{};\n".format(node.immediate_dominator.node_id, node.node_id)) + f.write( + "{}->{};\n".format( + node.immediate_dominator.node_id, node.node_id + ) + ) f.write("}\n") @@ -1305,10 +1380,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if node.type in [NodeType.IF, NodeType.IFLOOP]: true_node = node.son_true if true_node: - content += '{}->{}[label="True"];\n'.format(node.node_id, true_node.node_id) + content += '{}->{}[label="True"];\n'.format( + node.node_id, true_node.node_id + ) false_node = node.son_false if false_node: - content += '{}->{}[label="False"];\n'.format(node.node_id, false_node.node_id) + content += '{}->{}[label="False"];\n'.format( + node.node_id, false_node.node_id + ) else: for son in node.sons: content += "{}->{};\n".format(node.node_id, son.node_id) @@ -1353,7 +1432,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): Returns: bool: True if the variable is read """ - variables_reads = [n.variables_read for n in self.nodes if n.contains_require_or_assert()] + variables_reads = [ + n.variables_read for n in self.nodes if n.contains_require_or_assert() + ] variables_read = [item for sublist in variables_reads for item in sublist] return variable in variables_read @@ -1401,7 +1482,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): if self.is_constructor: return True - conditional_vars = self.all_conditional_solidity_variables_read(include_loop=False) + conditional_vars = self.all_conditional_solidity_variables_read( + include_loop=False + ) args_vars = self.all_solidity_variables_used_as_args() return SolidityVariableComposed("msg.sender") in conditional_vars + args_vars @@ -1412,19 +1495,6 @@ class Function(ChildContract, ChildInheritance, SourceMapping): ################################################################################### ################################################################################### - def _filter_state_variables_written(self, expressions: List["Expression"]): - ret = [] - for expression in expressions: - if isinstance(expression, Identifier): - ret.append(expression) - if isinstance(expression, UnaryOperation): - ret.append(expression.expression) - if isinstance(expression, MemberAccess): - ret.append(expression.expression) - if isinstance(expression, IndexAccess): - ret.append(expression.expression_left) - return ret - def _analyze_read_write(self): """ Compute variables read/written/... @@ -1436,7 +1506,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): # Remove dupplicate if they share the same string representation write_var = [ next(obj) - for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x)) + for i, obj in groupby( + sorted(write_var, key=lambda x: str(x)), lambda x: str(x) + ) ] self._expression_vars_written = write_var @@ -1447,7 +1519,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): # Remove dupplicate if they share the same string representation write_var = [ next(obj) - for i, obj in groupby(sorted(write_var, key=lambda x: str(x)), lambda x: str(x)) + for i, obj in groupby( + sorted(write_var, key=lambda x: str(x)), lambda x: str(x) + ) ] self._vars_written = write_var @@ -1457,7 +1531,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): # Remove dupplicate if they share the same string representation read_var = [ next(obj) - for i, obj in groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x)) + for i, obj in groupby( + sorted(read_var, key=lambda x: str(x)), lambda x: str(x) + ) ] self._expression_vars_read = read_var @@ -1467,14 +1543,18 @@ class Function(ChildContract, ChildInheritance, SourceMapping): # Remove dupplicate if they share the same string representation read_var = [ next(obj) - for i, obj in groupby(sorted(read_var, key=lambda x: str(x)), lambda x: str(x)) + for i, obj in groupby( + sorted(read_var, key=lambda x: str(x)), lambda x: str(x) + ) ] self._vars_read = read_var self._state_vars_written = [ x for x in self.variables_written if isinstance(x, StateVariable) ] - self._state_vars_read = [x for x in self.variables_read if isinstance(x, StateVariable)] + self._state_vars_read = [ + x for x in self.variables_read if isinstance(x, StateVariable) + ] self._solidity_vars_read = [ x for x in self.variables_read if isinstance(x, SolidityVariable) ] @@ -1483,7 +1563,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): slithir_variables = [x.slithir_variables for x in self.nodes] slithir_variables = [x for x in slithir_variables if x] - self._slithir_variables = [item for sublist in slithir_variables for item in sublist] + self._slithir_variables = [ + item for sublist in slithir_variables for item in sublist + ] def _analyze_calls(self): calls = [x.calls_as_expression for x in self.nodes] @@ -1496,7 +1578,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): internal_calls = [item for sublist in internal_calls for item in sublist] self._internal_calls = list(set(internal_calls)) - self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)] + self._solidity_calls = [ + c for c in internal_calls if isinstance(c, SolidityFunction) + ] low_level_calls = [x.low_level_calls for x in self.nodes] low_level_calls = [x for x in low_level_calls if x] @@ -1513,7 +1597,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): library_calls = [item for sublist in library_calls for item in sublist] self._library_calls = list(set(library_calls)) - external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes] + external_calls_as_expressions = [ + x.external_calls_as_expressions for x in self.nodes + ] external_calls_as_expressions = [x for x in external_calls_as_expressions if x] external_calls_as_expressions = [ item for sublist in external_calls_as_expressions for item in sublist @@ -1548,6 +1634,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): def _get_last_ssa_variable_instances( self, target_state: bool, target_local: bool ) -> Dict[str, Set["SlithIRVariable"]]: + # pylint: disable=too-many-locals,too-many-branches from slither.slithir.variables import ReferenceVariable from slither.slithir.operations import OperationWithLValue from slither.core.cfg.node import NodeType @@ -1603,11 +1690,19 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return ret - def get_last_ssa_state_variables_instances(self) -> Dict[str, Set["SlithIRVariable"]]: - return self._get_last_ssa_variable_instances(target_state=True, target_local=False) + def get_last_ssa_state_variables_instances( + self, + ) -> Dict[str, Set["SlithIRVariable"]]: + return self._get_last_ssa_variable_instances( + target_state=True, target_local=False + ) - def get_last_ssa_local_variables_instances(self) -> Dict[str, Set["SlithIRVariable"]]: - return self._get_last_ssa_variable_instances(target_state=False, target_local=True) + def get_last_ssa_local_variables_instances( + self, + ) -> Dict[str, Set["SlithIRVariable"]]: + return self._get_last_ssa_variable_instances( + target_state=False, target_local=True + ) @staticmethod def _unchange_phi(ir: "Operation"): @@ -1619,7 +1714,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): return True return ir.rvalues[0] == ir.lvalue - def fix_phi(self, last_state_variables_instances, initial_state_variables_instances): + def fix_phi( + self, last_state_variables_instances, initial_state_variables_instances + ): from slither.slithir.operations import InternalCall, PhiCallback from slither.slithir.variables import Constant, StateIRVariable @@ -1627,28 +1724,40 @@ class Function(ChildContract, ChildInheritance, SourceMapping): for ir in node.irs_ssa: if node == self.entry_point: if isinstance(ir.lvalue, StateIRVariable): - additional = [initial_state_variables_instances[ir.lvalue.canonical_name]] - additional += last_state_variables_instances[ir.lvalue.canonical_name] + additional = [ + initial_state_variables_instances[ir.lvalue.canonical_name] + ] + additional += last_state_variables_instances[ + ir.lvalue.canonical_name + ] ir.rvalues = list(set(additional + ir.rvalues)) # function parameter else: # find index of the parameter idx = self.parameters.index(ir.lvalue.non_ssa_version) # find non ssa version of that index - additional = [n.ir.arguments[idx] for n in self.reachable_from_nodes] + additional = [ + n.ir.arguments[idx] for n in self.reachable_from_nodes + ] additional = unroll(additional) - additional = [a for a in additional if not isinstance(a, Constant)] + additional = [ + a for a in additional if not isinstance(a, Constant) + ] ir.rvalues = list(set(additional + ir.rvalues)) if isinstance(ir, PhiCallback): callee_ir = ir.callee_ir if isinstance(callee_ir, InternalCall): - last_ssa = callee_ir.function.get_last_ssa_state_variables_instances() + last_ssa = ( + callee_ir.function.get_last_ssa_state_variables_instances() + ) if ir.lvalue.canonical_name in last_ssa: ir.rvalues = list(last_ssa[ir.lvalue.canonical_name]) else: ir.rvalues = [ir.lvalue] else: - additional = last_state_variables_instances[ir.lvalue.canonical_name] + additional = last_state_variables_instances[ + ir.lvalue.canonical_name + ] ir.rvalues = list(set(additional + ir.rvalues)) node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)] @@ -1662,7 +1771,10 @@ class Function(ChildContract, ChildInheritance, SourceMapping): def generate_slithir_ssa(self, all_ssa_state_variables_instances): from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa - from slither.core.dominators.utils import compute_dominance_frontier, compute_dominators + from slither.core.dominators.utils import ( + compute_dominance_frontier, + compute_dominators, + ) compute_dominators(self.nodes) compute_dominance_frontier(self.nodes) diff --git a/slither/core/declarations/pragma_directive.py b/slither/core/declarations/pragma_directive.py index d43a62200..39b818599 100644 --- a/slither/core/declarations/pragma_directive.py +++ b/slither/core/declarations/pragma_directive.py @@ -32,7 +32,10 @@ class Pragma(SourceMapping): @property def is_abi_encoder_v2(self) -> bool: if len(self._directive) == 2: - return self._directive[0] == "experimental" and self._directive[1] == "ABIEncoderV2" + return ( + self._directive[0] == "experimental" + and self._directive[1] == "ABIEncoderV2" + ) return False def __str__(self): diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index db64ae801..f978021fc 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -3,6 +3,7 @@ from typing import List, Dict, Union from slither.core.context.context import Context from slither.core.solidity_types import ElementaryType, TypeInformation +from slither.exceptions import SlitherException SOLIDITY_VARIABLES = { "now": "uint256", @@ -90,8 +91,12 @@ class SolidityVariable(Context): self._name = name # dev function, will be removed once the code is stable - def _check_name(self, name: str): - assert name in SOLIDITY_VARIABLES or name.endswith("_slot") or name.endswith("_offset") + def _check_name(self, name: str): # pylint: disable=no-self-use + assert ( + name in SOLIDITY_VARIABLES + or name.endswith("_slot") + or name.endswith("_offset") + ) @property def state_variable(self): @@ -99,6 +104,8 @@ class SolidityVariable(Context): return self._name[:-5] if self._name.endswith("_offset"): return self._name[:-7] + to_log = f"Incorrect YUL parsing. {self} is not a solidity variable that can be seen as a state variable" + raise SlitherException(to_log) @property def name(self) -> str: @@ -119,8 +126,6 @@ class SolidityVariable(Context): class SolidityVariableComposed(SolidityVariable): - def __init__(self, name: str): - super(SolidityVariableComposed, self).__init__(name) def _check_name(self, name: str): assert name in SOLIDITY_VARIABLES_COMPOSED diff --git a/slither/core/dominators/node_dominator_tree.py b/slither/core/dominators/node_dominator_tree.py index b97279065..c561cef75 100644 --- a/slither/core/dominators/node_dominator_tree.py +++ b/slither/core/dominators/node_dominator_tree.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from slither.core.cfg.node import Node -class DominatorNode(object): +class DominatorNode: def __init__(self): self._succ: Set["Node"] = set() self._nodes: List["Node"] = [] diff --git a/slither/core/dominators/utils.py b/slither/core/dominators/utils.py index b41409a58..0470bb73a 100644 --- a/slither/core/dominators/utils.py +++ b/slither/core/dominators/utils.py @@ -43,6 +43,7 @@ def compute_dominators(nodes: List["Node"]): for dominator in node.dominators: if dominator != node: + # pylint: disable=expression-not-assigned [ idom_candidates.remove(d) for d in dominator.dominators diff --git a/slither/core/expressions/assignment_operation.py b/slither/core/expressions/assignment_operation.py index c1677d9b5..7f6d18d58 100644 --- a/slither/core/expressions/assignment_operation.py +++ b/slither/core/expressions/assignment_operation.py @@ -50,7 +50,9 @@ class AssignmentOperationType(Enum): if operation_type == "%=": return AssignmentOperationType.ASSIGN_MODULO - raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) + raise SlitherCoreError( + "get_type: Unknown operation type {})".format(operation_type) + ) def __str__(self): if self == AssignmentOperationType.ASSIGN: @@ -115,4 +117,10 @@ class AssignmentOperation(ExpressionTyped): return self._type def __str__(self): - return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right) + return ( + str(self.expression_left) + + " " + + str(self.type) + + " " + + str(self.expression_right) + ) diff --git a/slither/core/expressions/binary_operation.py b/slither/core/expressions/binary_operation.py index 591319d29..e4dd04e8e 100644 --- a/slither/core/expressions/binary_operation.py +++ b/slither/core/expressions/binary_operation.py @@ -41,7 +41,7 @@ class BinaryOperationType(Enum): RIGHT_SHIFT_ARITHMETIC = 23 @staticmethod - def get_type(operation_type: "BinaryOperation"): + def get_type(operation_type: "BinaryOperation"): # pylint: disable=too-many-branches if operation_type == "**": return BinaryOperationType.POWER if operation_type == "*": @@ -91,9 +91,11 @@ class BinaryOperationType(Enum): if operation_type == ">>'": return BinaryOperationType.RIGHT_SHIFT_ARITHMETIC - raise SlitherCoreError("get_type: Unknown operation type {})".format(operation_type)) + raise SlitherCoreError( + "get_type: Unknown operation type {})".format(operation_type) + ) - def __str__(self): + def __str__(self): # pylint: disable=too-many-branches if self == BinaryOperationType.POWER: return "**" if self == BinaryOperationType.MULTIPLICATION: @@ -170,4 +172,10 @@ class BinaryOperation(ExpressionTyped): return self._type def __str__(self): - return str(self.expression_left) + " " + str(self.type) + " " + str(self.expression_right) + return ( + str(self.expression_left) + + " " + + str(self.type) + + " " + + str(self.expression_right) + ) diff --git a/slither/core/expressions/call_expression.py b/slither/core/expressions/call_expression.py index 848eb221a..bb5c8aa4d 100644 --- a/slither/core/expressions/call_expression.py +++ b/slither/core/expressions/call_expression.py @@ -3,7 +3,7 @@ from typing import Optional, List from slither.core.expressions.expression import Expression -class CallExpression(Expression): +class CallExpression(Expression): # pylint: disable=too-many-instance-attributes def __init__(self, called, arguments, type_call): assert isinstance(called, Expression) super(CallExpression, self).__init__() diff --git a/slither/core/expressions/expression_typed.py b/slither/core/expressions/expression_typed.py index 8aaf2b756..f077aff02 100644 --- a/slither/core/expressions/expression_typed.py +++ b/slither/core/expressions/expression_typed.py @@ -14,3 +14,7 @@ class ExpressionTyped(Expression): @property def type(self): return self._type + + @type.setter + def type(self, new_type: "Type"): + self._type = new_type diff --git a/slither/core/expressions/literal.py b/slither/core/expressions/literal.py index 620dc83f2..072dd83aa 100644 --- a/slither/core/expressions/literal.py +++ b/slither/core/expressions/literal.py @@ -8,10 +8,10 @@ if TYPE_CHECKING: class Literal(Expression): - def __init__(self, value, type, subdenomination=None): + def __init__(self, value, custom_type, subdenomination=None): super(Literal, self).__init__() self._value: Union[int, str] = value - self._type = type + self._type = custom_type self._subdenomination: Optional[str] = subdenomination @property diff --git a/slither/core/expressions/super_call_expression.py b/slither/core/expressions/super_call_expression.py index c8b0dd9f8..17179ddc8 100644 --- a/slither/core/expressions/super_call_expression.py +++ b/slither/core/expressions/super_call_expression.py @@ -1,4 +1,3 @@ -from slither.core.expressions.expression import Expression from slither.core.expressions.call_expression import CallExpression diff --git a/slither/core/expressions/unary_operation.py b/slither/core/expressions/unary_operation.py index 72d2f8410..880863af7 100644 --- a/slither/core/expressions/unary_operation.py +++ b/slither/core/expressions/unary_operation.py @@ -41,7 +41,9 @@ class UnaryOperationType(Enum): return UnaryOperationType.PLUSPLUS_POST if operation_type == "--": return UnaryOperationType.MINUSMINUS_POST - raise SlitherCoreError("get_type: Unknown operation type {}".format(operation_type)) + raise SlitherCoreError( + "get_type: Unknown operation type {}".format(operation_type) + ) def __str__(self): if self == UnaryOperationType.BANG: @@ -76,13 +78,15 @@ class UnaryOperationType(Enum): UnaryOperationType.MINUS_PRE, ]: return True - elif operation_type in [ + if operation_type in [ UnaryOperationType.PLUSPLUS_POST, UnaryOperationType.MINUSMINUS_POST, ]: return False - raise SlitherCoreError("is_prefix: Unknown operation type {}".format(operation_type)) + raise SlitherCoreError( + "is_prefix: Unknown operation type {}".format(operation_type) + ) class UnaryOperation(ExpressionTyped): @@ -117,5 +121,4 @@ class UnaryOperation(ExpressionTyped): def __str__(self): if self.is_prefix: return str(self.type) + " " + str(self._expression) - else: - return str(self._expression) + " " + str(self.type) + return str(self._expression) + " " + str(self.type) diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index edf2ae9ab..ea1639629 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -12,7 +12,15 @@ from typing import Optional, Dict, List, Set, Union, Tuple from crytic_compile import CryticCompile from slither.core.context.context import Context -from slither.core.declarations import Contract, Pragma, Import, Function, Modifier, Structure, Enum +from slither.core.declarations import ( + Contract, + Pragma, + Import, + Function, + Modifier, + Structure, + Enum, +) from slither.core.variables.state_variable import StateVariable from slither.slithir.operations import InternalCall from slither.slithir.variables import Constant @@ -22,7 +30,14 @@ logger = logging.getLogger("Slither") logging.basicConfig() -class SlitherCore(Context): +def _relative_path_format(path: str) -> str: + """ + Strip relative paths of "." and ".." + """ + return path.split("..")[-1].strip(".").strip("/") + + +class SlitherCore(Context): # pylint: disable=too-many-instance-attributes,too-many-public-methods """ Slither static analyzer """ @@ -115,6 +130,10 @@ class SlitherCore(Context): return self.crytic_compile.compiler_version.version return self._solc_version + @solc_version.setter + def solc_version(self, version: str): + self._solc_version = version + @property def pragma_directives(self) -> List[Pragma]: """ list(core.declarations.Pragma): Pragma directives.""" @@ -142,14 +161,20 @@ class SlitherCore(Context): """list(Contract): List of contracts that are derived and not inherited.""" inheritance = (x.inheritance for x in self.contracts) inheritance = [item for sublist in inheritance for item in sublist] - return [c for c in self._contracts.values() if c not in inheritance and not c.is_top_level] + return [ + c + for c in self._contracts.values() + if c not in inheritance and not c.is_top_level + ] @property def contracts_as_dict(self) -> Dict[str, Contract]: """list(dict(str: Contract): List of contracts as dict: name -> Contract.""" return self._contracts - def get_contract_from_name(self, contract_name: Union[str, Constant]) -> Optional[Contract]: + def get_contract_from_name( + self, contract_name: Union[str, Constant] + ) -> Optional[Contract]: """ Return a contract from a name Args: @@ -245,12 +270,6 @@ class SlitherCore(Context): ################################################################################### ################################################################################### - def relative_path_format(self, path: str) -> str: - """ - Strip relative paths of "." and ".." - """ - return path.split("..")[-1].strip(".").strip("/") - def valid_result(self, r: Dict) -> bool: """ Check if the result is valid @@ -272,7 +291,7 @@ class SlitherCore(Context): for path in self._paths_to_filter: try: if any( - bool(re.search(self.relative_path_format(path), src_mapping)) + bool(re.search(_relative_path_format(path), src_mapping)) for src_mapping in source_mapping_elements ): matching = True @@ -287,11 +306,15 @@ class SlitherCore(Context): if r["elements"] and matching: return False if r["elements"] and self._exclude_dependencies: - return not all(element["source_mapping"]["is_dependency"] for element in r["elements"]) + return not all( + element["source_mapping"]["is_dependency"] for element in r["elements"] + ) if r["id"] in self._previous_results_ids: return False # Conserve previous result filtering. This is conserved for compatibility, but is meant to be removed - return not r["description"] in [pr["description"] for pr in self._previous_results] + return not r["description"] in [ + pr["description"] for pr in self._previous_results + ] def load_previous_results(self): filename = self._previous_results_filename @@ -305,7 +328,11 @@ class SlitherCore(Context): self._previous_results_ids.add(r["id"]) except json.decoder.JSONDecodeError: logger.error( - red("Impossible to decode {}. Consider removing the file".format(filename)) + red( + "Impossible to decode {}. Consider removing the file".format( + filename + ) + ) ) def write_results_to_hide(self): @@ -404,7 +431,10 @@ class SlitherCore(Context): slot += 1 offset = 0 - self._storage_layouts[contract.name][var.canonical_name] = (slot, offset) + self._storage_layouts[contract.name][var.canonical_name] = ( + slot, + offset, + ) if new_slot: slot += math.ceil(size / 32) else: diff --git a/slither/core/solidity_types/elementary_type.py b/slither/core/solidity_types/elementary_type.py index 3fc53dc29..d5d006597 100644 --- a/slither/core/solidity_types/elementary_type.py +++ b/slither/core/solidity_types/elementary_type.py @@ -123,7 +123,9 @@ MN = list(itertools.product(M, N)) Fixed = ["fixed{}x{}".format(m, n) for (m, n) in MN] + ["fixed"] Ufixed = ["ufixed{}x{}".format(m, n) for (m, n) in MN] + ["ufixed"] -ElementaryTypeName = ["address", "bool", "string", "var"] + Int + Uint + Byte + Fixed + Ufixed +ElementaryTypeName = ( + ["address", "bool", "string", "var"] + Int + Uint + Byte + Fixed + Ufixed +) class NonElementaryType(Exception): diff --git a/slither/core/solidity_types/function_type.py b/slither/core/solidity_types/function_type.py index 6a96bf368..9f5be8870 100644 --- a/slither/core/solidity_types/function_type.py +++ b/slither/core/solidity_types/function_type.py @@ -6,7 +6,9 @@ from slither.core.variables.function_type_variable import FunctionTypeVariable class FunctionType(Type): def __init__( - self, params: List[FunctionTypeVariable], return_values: List[FunctionTypeVariable] + self, + params: List[FunctionTypeVariable], + return_values: List[FunctionTypeVariable], ): assert all(isinstance(x, FunctionTypeVariable) for x in params) assert all(isinstance(x, FunctionTypeVariable) for x in return_values) diff --git a/slither/core/solidity_types/type_information.py b/slither/core/solidity_types/type_information.py index 71c7ca3c5..62955fb49 100644 --- a/slither/core/solidity_types/type_information.py +++ b/slither/core/solidity_types/type_information.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: # https://solidity.readthedocs.io/en/latest/units-and-global-variables.html#type-information class TypeInformation(Type): def __init__(self, c): + # pylint: disable=import-outside-toplevel from slither.core.declarations.contract import Contract assert isinstance(c, Contract) diff --git a/slither/core/solidity_types/user_defined_type.py b/slither/core/solidity_types/user_defined_type.py index 4aed394e8..d9a06e6bd 100644 --- a/slither/core/solidity_types/user_defined_type.py +++ b/slither/core/solidity_types/user_defined_type.py @@ -2,13 +2,14 @@ from typing import Union, TYPE_CHECKING, Tuple import math from slither.core.solidity_types.type import Type +from slither.exceptions import SlitherException if TYPE_CHECKING: from slither.core.declarations.structure import Structure from slither.core.declarations.enum import Enum from slither.core.declarations.contract import Contract - +# pylint: disable=import-outside-toplevel class UserDefinedType(Type): def __init__(self, t): from slither.core.declarations.structure import Structure @@ -31,9 +32,9 @@ class UserDefinedType(Type): if isinstance(self._type, Contract): return 20, False - elif isinstance(self._type, Enum): + if isinstance(self._type, Enum): return int(math.ceil(math.log2(len(self._type.values)) / 8)), False - elif isinstance(self._type, Structure): + if isinstance(self._type, Structure): # todo there's some duplicate logic here and slither_core, can we refactor this? slot = 0 offset = 0 @@ -54,6 +55,8 @@ class UserDefinedType(Type): if offset > 0: slot += 1 return slot * 32, True + to_log = f"{self} does not have storage size" + raise SlitherException(to_log) def __str__(self): from slither.core.declarations.structure import Structure diff --git a/slither/core/source_mapping/source_mapping.py b/slither/core/source_mapping/source_mapping.py index 31e1ff34e..ccb51ab02 100644 --- a/slither/core/source_mapping/source_mapping.py +++ b/slither/core/source_mapping/source_mapping.py @@ -58,7 +58,7 @@ class SourceMapping(Context): return lines, starting_column, ending_column @staticmethod - def _convert_source_mapping(offset: str, slither): + def _convert_source_mapping(offset: str, slither): # pylint: disable=too-many-locals """ Convert a text offset to a real offset see https://solidity.readthedocs.io/en/develop/miscellaneous.html#source-mappings @@ -112,10 +112,14 @@ class SourceMapping(Context): if slither.crytic_compile and filename in slither.crytic_compile.src_content: source_code = slither.crytic_compile.src_content[filename] - (lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, s, l) + (lines, starting_column, ending_column) = SourceMapping._compute_line( + source_code, s, l + ) elif filename in slither.source_code: source_code = slither.source_code[filename] - (lines, starting_column, ending_column) = SourceMapping._compute_line(source_code, s, l) + (lines, starting_column, ending_column) = SourceMapping._compute_line( + source_code, s, l + ) else: (lines, starting_column, ending_column) = ([], None, None) @@ -145,7 +149,7 @@ class SourceMapping(Context): elif len(lines) == 1: lines = "#{}{}".format(line_descr, lines[0]) else: - lines = "#{}{}-{}{}".format(line_descr, lines[0], line_descr, lines[-1]) + lines = f"#{line_descr}{lines[0]}-{line_descr}{lines[-1]}" return lines def source_mapping_to_markdown(self, markdown_root: str) -> str: diff --git a/slither/core/variables/event_variable.py b/slither/core/variables/event_variable.py index 3cf000273..f5e7ff6f1 100644 --- a/slither/core/variables/event_variable.py +++ b/slither/core/variables/event_variable.py @@ -1,4 +1,4 @@ -from .variable import Variable +from slither.core.variables.variable import Variable from slither.core.children.child_event import ChildEvent diff --git a/slither/core/variables/local_variable.py b/slither/core/variables/local_variable.py index 2a0143f64..660f7bb5b 100644 --- a/slither/core/variables/local_variable.py +++ b/slither/core/variables/local_variable.py @@ -1,6 +1,6 @@ from typing import Optional -from .variable import Variable +from slither.core.variables.variable import Variable from slither.core.children.child_function import ChildFunction from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.array_type import ArrayType diff --git a/slither/core/variables/state_variable.py b/slither/core/variables/state_variable.py index 5bb3f1a1c..47686b487 100644 --- a/slither/core/variables/state_variable.py +++ b/slither/core/variables/state_variable.py @@ -1,6 +1,6 @@ from typing import Optional, TYPE_CHECKING, Tuple, List -from .variable import Variable +from slither.core.variables.variable import Variable from slither.core.children.child_contract import ChildContract from slither.utils.type import export_nested_types_from_variable @@ -34,7 +34,11 @@ class StateVariable(ChildContract, Variable): Return the signature of the state variable as a function signature :return: (str, list(str), list(str)), as (name, list parameters type, list return values type) """ - return self.name, [str(x) for x in export_nested_types_from_variable(self)], str(self.type) + return ( + self.name, + [str(x) for x in export_nested_types_from_variable(self)], + str(self.type), + ) @property def signature_str(self) -> str: @@ -43,7 +47,14 @@ class StateVariable(ChildContract, Variable): :return: str: func_name(type1,type2) returns(type3) """ name, parameters, returnVars = self.signature - return name + "(" + ",".join(parameters) + ") returns(" + ",".join(returnVars) + ")" + return ( + name + + "(" + + ",".join(parameters) + + ") returns(" + + ",".join(returnVars) + + ")" + ) # endregion ################################################################################### diff --git a/slither/core/variables/structure_variable.py b/slither/core/variables/structure_variable.py index c1b34d4f1..c6034da63 100644 --- a/slither/core/variables/structure_variable.py +++ b/slither/core/variables/structure_variable.py @@ -1,4 +1,4 @@ -from .variable import Variable +from slither.core.variables.variable import Variable from slither.core.children.child_structure import ChildStructure diff --git a/slither/core/variables/variable.py b/slither/core/variables/variable.py index bd3413dc6..be0ae9558 100644 --- a/slither/core/variables/variable.py +++ b/slither/core/variables/variable.py @@ -112,6 +112,7 @@ class Variable(SourceMapping): Return the name of the variable as a function signature :return: """ + # pylint: disable=import-outside-toplevel from slither.core.solidity_types import ArrayType, MappingType from slither.utils.type import export_nested_types_from_variable @@ -120,7 +121,9 @@ class Variable(SourceMapping): assert return_type if isinstance(return_type, (ArrayType, MappingType)): - variable_getter_args = ",".join(map(str, export_nested_types_from_variable(self))) + variable_getter_args = ",".join( + map(str, export_nested_types_from_variable(self)) + ) return f"{self.name}({variable_getter_args})" diff --git a/slither/detectors/abstract_detector.py b/slither/detectors/abstract_detector.py index b3008d43c..d34e14e09 100644 --- a/slither/detectors/abstract_detector.py +++ b/slither/detectors/abstract_detector.py @@ -11,7 +11,7 @@ class IncorrectDetectorInitialization(Exception): pass -class DetectorClassification: +class DetectorClassification: # pylint: disable=too-few-public-methods HIGH = 0 MEDIUM = 1 LOW = 2 @@ -87,12 +87,16 @@ class AbstractDetector(metaclass=abc.ABCMeta): DetectorClassification.OPTIMIZATION, ]: raise IncorrectDetectorInitialization( - "WIKI_EXPLOIT_SCENARIO is not initialized {}".format(self.__class__.__name__) + "WIKI_EXPLOIT_SCENARIO is not initialized {}".format( + self.__class__.__name__ + ) ) if not self.WIKI_RECOMMENDATION: raise IncorrectDetectorInitialization( - "WIKI_RECOMMENDATION is not initialized {}".format(self.__class__.__name__) + "WIKI_RECOMMENDATION is not initialized {}".format( + self.__class__.__name__ + ) ) if re.match("^[a-zA-Z0-9_-]*$", self.ARGUMENT) is None: @@ -131,12 +135,14 @@ class AbstractDetector(metaclass=abc.ABCMeta): """TODO Documentation""" return [] + # pylint: disable=too-many-branches def detect(self): all_results = self._detect() # Keep only dictionaries all_results = [r.data for r in all_results] results = [] # only keep valid result, and remove dupplicate + # pylint: disable=expression-not-assigned [ results.append(r) for r in all_results @@ -173,15 +179,19 @@ class AbstractDetector(metaclass=abc.ABCMeta): ) continue for patch in patches: - patched_txt, offset = apply_patch(patched_txt, patch, offset) - diff = create_diff(self.slither, original_txt, patched_txt, file) + patched_txt, offset = apply_patch( + patched_txt, patch, offset + ) + diff = create_diff( + self.slither, original_txt, patched_txt, file + ) if not diff: self._log(f"Impossible to generate patch; empty {result}") else: result["patches_diff"][file] = diff - except FormatImpossible as e: - self._log(f'\nImpossible to patch:\n\t{result["description"]}\t{e}') + except FormatImpossible as exception: + self._log(f'\nImpossible to patch:\n\t{result["description"]}\t{exception}') if results and self.slither.triage_mode: while True: @@ -206,7 +216,9 @@ class AbstractDetector(metaclass=abc.ABCMeta): ) return [r for (idx, r) in enumerate(results) if idx not in indexes] except ValueError: - self.logger.error(yellow("Malformed input. Example of valid input: 0,1,2,3")) + self.logger.error( + yellow("Malformed input. Example of valid input: 0,1,2,3") + ) return results @property @@ -228,6 +240,6 @@ class AbstractDetector(metaclass=abc.ABCMeta): return output @staticmethod - def _format(slither, result): + def _format(_slither, _result): """Implement format""" return diff --git a/slither/detectors/all_detectors.py b/slither/detectors/all_detectors.py index e5088bd62..789c85c31 100644 --- a/slither/detectors/all_detectors.py +++ b/slither/detectors/all_detectors.py @@ -1,3 +1,4 @@ +# pylint: disable=unused-import,relative-beyond-top-level from .examples.backdoor import Backdoor from .variables.uninitialized_state_variables import UninitializedStateVarsDetection from .variables.uninitialized_storage_variables import UninitializedStorageVars diff --git a/slither/detectors/attributes/const_functions_asm.py b/slither/detectors/attributes/const_functions_asm.py index a72d3cba8..ab983e9b0 100644 --- a/slither/detectors/attributes/const_functions_asm.py +++ b/slither/detectors/attributes/const_functions_asm.py @@ -3,7 +3,7 @@ Module detecting constant functions Recursively check the called functions """ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.formatters.attributes.const_functions import format +from slither.formatters.attributes.const_functions import custom_format class ConstantFunctionsAsm(AbstractDetector): @@ -40,9 +40,7 @@ contract Constant{ `Constant` was deployed with Solidity 0.4.25. Bob writes a smart contract that interacts with `Constant` in Solidity 0.5.0. All the calls to `get` revert, breaking Bob's smart contract execution.""" - WIKI_RECOMMENDATION = ( - "Ensure the attributes of contracts compiled prior to Solidity 0.5.0 are correct." - ) + WIKI_RECOMMENDATION = "Ensure the attributes of contracts compiled prior to Solidity 0.5.0 are correct." def _detect(self): """ Detect the constant function using assembly code @@ -71,4 +69,4 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" @staticmethod def _format(slither, result): - format(slither, result) + custom_format(slither, result) diff --git a/slither/detectors/attributes/const_functions_state.py b/slither/detectors/attributes/const_functions_state.py index 19b86f465..9ff706c74 100644 --- a/slither/detectors/attributes/const_functions_state.py +++ b/slither/detectors/attributes/const_functions_state.py @@ -3,7 +3,7 @@ Module detecting constant functions Recursively check the called functions """ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.formatters.attributes.const_functions import format +from slither.formatters.attributes.const_functions import custom_format class ConstantFunctionsState(AbstractDetector): @@ -40,9 +40,7 @@ contract Constant{ `Constant` was deployed with Solidity 0.4.25. Bob writes a smart contract that interacts with `Constant` in Solidity 0.5.0. All the calls to `get` revert, breaking Bob's smart contract execution.""" - WIKI_RECOMMENDATION = ( - "Ensure that attributes of contracts compiled prior to Solidity 0.5.0 are correct." - ) + WIKI_RECOMMENDATION = "Ensure that attributes of contracts compiled prior to Solidity 0.5.0 are correct." def _detect(self): """ Detect the constant function changing the state @@ -63,7 +61,10 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" if variables_written: attr = "view" if f.view else "pure" - info = [f, f" is declared {attr} but changes state variables:\n"] + info = [ + f, + f" is declared {attr} but changes state variables:\n", + ] for variable_written in variables_written: info += ["\t- ", variable_written, "\n"] @@ -76,4 +77,4 @@ All the calls to `get` revert, breaking Bob's smart contract execution.""" @staticmethod def _format(slither, result): - format(slither, result) + custom_format(slither, result) diff --git a/slither/detectors/attributes/constant_pragma.py b/slither/detectors/attributes/constant_pragma.py index 45b40ae04..d542d70a0 100644 --- a/slither/detectors/attributes/constant_pragma.py +++ b/slither/detectors/attributes/constant_pragma.py @@ -3,7 +3,7 @@ """ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.formatters.attributes.constant_pragma import format +from slither.formatters.attributes.constant_pragma import custom_format class ConstantPragma(AbstractDetector): @@ -43,4 +43,4 @@ class ConstantPragma(AbstractDetector): @staticmethod def _format(slither, result): - format(slither, result) + custom_format(slither, result) diff --git a/slither/detectors/attributes/incorrect_solc.py b/slither/detectors/attributes/incorrect_solc.py index 62baeb5a9..8b9770540 100644 --- a/slither/detectors/attributes/incorrect_solc.py +++ b/slither/detectors/attributes/incorrect_solc.py @@ -4,7 +4,7 @@ import re from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.formatters.attributes.incorrect_solc import format +from slither.formatters.attributes.incorrect_solc import custom_format # group: # 0: ^ > >= < <= (optional) @@ -13,6 +13,7 @@ from slither.formatters.attributes.incorrect_solc import format # 3: version number # 4: version number +# pylint: disable=anomalous-backslash-in-string PATTERN = re.compile("(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)") @@ -45,12 +46,8 @@ Consider using the latest version of Solidity for testing.""" OLD_VERSION_TXT = "allows old versions" LESS_THAN_TXT = "uses lesser than" - TOO_RECENT_VERSION_TXT = ( - "necessitates a version too recent to be trusted. Consider deploying with 0.6.11" - ) - BUGGY_VERSION_TXT = ( - "is known to contain severe issues (https://solidity.readthedocs.io/en/latest/bugs.html)" - ) + TOO_RECENT_VERSION_TXT = "necessitates a version too recent to be trusted. Consider deploying with 0.6.11" + BUGGY_VERSION_TXT = "is known to contain severe issues (https://solidity.readthedocs.io/en/latest/bugs.html)" # Indicates the allowed versions. Must be formatted in increasing order. ALLOWED_VERSIONS = [ @@ -85,7 +82,9 @@ Consider using the latest version of Solidity for testing.""" return self.LESS_THAN_TXT version_number = ".".join(version[2:]) if version_number not in self.ALLOWED_VERSIONS: - if list(map(int, version[2:])) > list(map(int, self.ALLOWED_VERSIONS[-1].split("."))): + if list(map(int, version[2:])) > list( + map(int, self.ALLOWED_VERSIONS[-1].split(".")) + ): return self.TOO_RECENT_VERSION_TXT return self.OLD_VERSION_TXT return None @@ -97,7 +96,7 @@ Consider using the latest version of Solidity for testing.""" if len(versions) == 1: version = versions[0] return self._check_version(version) - elif len(versions) == 2: + if len(versions) == 2: version_left = versions[0] version_right = versions[1] # Only allow two elements if the second one is @@ -109,8 +108,7 @@ Consider using the latest version of Solidity for testing.""" ]: return self.COMPLEX_PRAGMA_TXT return self._check_version(version_left) - else: - return self.COMPLEX_PRAGMA_TXT + return self.COMPLEX_PRAGMA_TXT def _detect(self): """ @@ -121,14 +119,13 @@ Consider using the latest version of Solidity for testing.""" results = [] pragma = self.slither.pragma_directives disallowed_pragmas = [] - detected_version = False + for p in pragma: # Skip any pragma directives which do not refer to version if len(p.directive) < 1 or p.directive[0] != "solidity": continue # This is version, so we test if this is disallowed. - detected_version = True reason = self._check_pragma(p.version) if reason: disallowed_pragmas.append((reason, p)) @@ -162,4 +159,4 @@ Consider using the latest version of Solidity for testing.""" @staticmethod def _format(slither, result): - format(slither, result) + custom_format(slither, result) diff --git a/slither/detectors/attributes/locked_ether.py b/slither/detectors/attributes/locked_ether.py index cba7d63c4..8438aa20a 100644 --- a/slither/detectors/attributes/locked_ether.py +++ b/slither/detectors/attributes/locked_ether.py @@ -14,9 +14,7 @@ from slither.slithir.operations import ( ) -class LockedEther(AbstractDetector): - """ - """ +class LockedEther(AbstractDetector): # pylint: disable=too-many-nested-blocks ARGUMENT = "locked-ether" HELP = "Contracts that lock ether" @@ -26,7 +24,9 @@ class LockedEther(AbstractDetector): WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#contracts-that-lock-ether" WIKI_TITLE = "Contracts that lock Ether" - WIKI_DESCRIPTION = "Contract with a `payable` function, but without a withdrawal capacity." + WIKI_DESCRIPTION = ( + "Contract with a `payable` function, but without a withdrawal capacity." + ) WIKI_EXPLOIT_SCENARIO = """ ```solidity pragma solidity 0.4.24; @@ -44,7 +44,7 @@ Every Ether sent to `Locked` will be lost.""" functions = contract.all_functions_called to_explore = functions explored = [] - while to_explore: + while to_explore: # pylint: disable=too-many-nested-blocks functions = to_explore explored += to_explore to_explore = [] @@ -55,7 +55,8 @@ Every Ether sent to `Locked` will be lost.""" for node in function.nodes: for ir in node.irs: if isinstance( - ir, (Send, Transfer, HighLevelCall, LowLevelCall, NewContract) + ir, + (Send, Transfer, HighLevelCall, LowLevelCall, NewContract), ): if ir.call_value and ir.call_value != 0: return False @@ -77,13 +78,15 @@ Every Ether sent to `Locked` will be lost.""" for contract in self.slither.contracts_derived: if contract.is_signature_only(): continue - funcs_payable = [function for function in contract.functions if function.payable] + funcs_payable = [ + function for function in contract.functions if function.payable + ] if funcs_payable: if self.do_no_send_ether(contract): info = [f"Contract locking ether found in {self.filename}:\n"] info += ["\tContract ", contract, " has payable functions:\n"] for function in funcs_payable: - info += [f"\t - ", function, "\n"] + info += ["\t - ", function, "\n"] info += "\tBut does not have a function to withdraw the ether\n" json = self.generate_result(info) diff --git a/slither/detectors/erc/incorrect_erc20_interface.py b/slither/detectors/erc/incorrect_erc20_interface.py index 1f8ea7aaa..5f1d10566 100644 --- a/slither/detectors/erc/incorrect_erc20_interface.py +++ b/slither/detectors/erc/incorrect_erc20_interface.py @@ -36,7 +36,11 @@ contract Token{ def incorrect_erc20_interface(signature): (name, parameters, returnVars) = signature - if name == "transfer" and parameters == ["address", "uint256"] and returnVars != ["bool"]: + if ( + name == "transfer" + and parameters == ["address", "uint256"] + and returnVars != ["bool"] + ): return True if ( @@ -46,7 +50,11 @@ contract Token{ ): return True - if name == "approve" and parameters == ["address", "uint256"] and returnVars != ["bool"]: + if ( + name == "approve" + and parameters == ["address", "uint256"] + and returnVars != ["bool"] + ): return True if ( @@ -56,7 +64,11 @@ contract Token{ ): return True - if name == "balanceOf" and parameters == ["address"] and returnVars != ["uint256"]: + if ( + name == "balanceOf" + and parameters == ["address"] + and returnVars != ["uint256"] + ): return True if name == "totalSupply" and parameters == [] and returnVars != ["uint256"]: @@ -98,10 +110,17 @@ contract Token{ """ results = [] for c in self.slither.contracts_derived: - functions = IncorrectERC20InterfaceDetection.detect_incorrect_erc20_interface(c) + functions = IncorrectERC20InterfaceDetection.detect_incorrect_erc20_interface( + c + ) if functions: for function in functions: - info = [c, " has incorrect ERC20 function interface:", function, "\n"] + info = [ + c, + " has incorrect ERC20 function interface:", + function, + "\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/erc/incorrect_erc721_interface.py b/slither/detectors/erc/incorrect_erc721_interface.py index 53703a2cf..6f3ccff06 100644 --- a/slither/detectors/erc/incorrect_erc721_interface.py +++ b/slither/detectors/erc/incorrect_erc721_interface.py @@ -14,9 +14,7 @@ class IncorrectERC721InterfaceDetection(AbstractDetector): IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.HIGH - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#incorrect-erc721-interface" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#incorrect-erc721-interface" WIKI_TITLE = "Incorrect erc721 interface" WIKI_DESCRIPTION = "Incorrect return values for `ERC721` functions. A contract compiled with solidity > 0.4.22 interacting with these functions will fail to execute them, as the return value is missing." @@ -29,18 +27,24 @@ contract Token{ ``` `Token.ownerOf` does not return an address like `ERC721` expects. Bob deploys the token. Alice creates a contract that interacts with it but assumes a correct `ERC721` interface implementation. Alice's contract is unable to interact with Bob's contract.""" - WIKI_RECOMMENDATION = ( - "Set the appropriate return values and vtypes for the defined `ERC721` functions." - ) + WIKI_RECOMMENDATION = "Set the appropriate return values and vtypes for the defined `ERC721` functions." @staticmethod def incorrect_erc721_interface(signature): (name, parameters, returnVars) = signature # ERC721 - if name == "balanceOf" and parameters == ["address"] and returnVars != ["uint256"]: + if ( + name == "balanceOf" + and parameters == ["address"] + and returnVars != ["uint256"] + ): return True - if name == "ownerOf" and parameters == ["uint256"] and returnVars != ["address"]: + if ( + name == "ownerOf" + and parameters == ["uint256"] + and returnVars != ["address"] + ): return True if ( name == "safeTransferFrom" @@ -60,11 +64,23 @@ contract Token{ and returnVars != [] ): return True - if name == "approve" and parameters == ["address", "uint256"] and returnVars != []: + if ( + name == "approve" + and parameters == ["address", "uint256"] + and returnVars != [] + ): return True - if name == "setApprovalForAll" and parameters == ["address", "bool"] and returnVars != []: + if ( + name == "setApprovalForAll" + and parameters == ["address", "bool"] + and returnVars != [] + ): return True - if name == "getApproved" and parameters == ["uint256"] and returnVars != ["address"]: + if ( + name == "getApproved" + and parameters == ["uint256"] + and returnVars != ["address"] + ): return True if ( name == "isApprovedForAll" @@ -74,7 +90,11 @@ contract Token{ return True # ERC165 (dependency) - if name == "supportsInterface" and parameters == ["bytes4"] and returnVars != ["bool"]: + if ( + name == "supportsInterface" + and parameters == ["bytes4"] + and returnVars != ["bool"] + ): return True return False @@ -107,10 +127,17 @@ contract Token{ """ results = [] for c in self.slither.contracts_derived: - functions = IncorrectERC721InterfaceDetection.detect_incorrect_erc721_interface(c) + functions = IncorrectERC721InterfaceDetection.detect_incorrect_erc721_interface( + c + ) if functions: for function in functions: - info = [c, " has incorrect ERC721 function interface:", function, "\n"] + info = [ + c, + " has incorrect ERC721 function interface:", + function, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/erc/unindexed_event_parameters.py b/slither/detectors/erc/unindexed_event_parameters.py index 757017946..755ddeeff 100644 --- a/slither/detectors/erc/unindexed_event_parameters.py +++ b/slither/detectors/erc/unindexed_event_parameters.py @@ -76,7 +76,11 @@ Failure to include these keywords will exclude the parameter data in the transac # Add each problematic event definition to our result list for (event, parameter) in unindexed_params: - info = ["ERC20 event ", event, f"does not index parameter {parameter}\n"] + info = [ + "ERC20 event ", + event, + f"does not index parameter {parameter}\n", + ] # Add the events to the JSON (note: we do not add the params/vars as they have no source mapping). res = self.generate_result(info) diff --git a/slither/detectors/examples/backdoor.py b/slither/detectors/examples/backdoor.py index 36d8326e4..b57e6828d 100644 --- a/slither/detectors/examples/backdoor.py +++ b/slither/detectors/examples/backdoor.py @@ -6,7 +6,9 @@ class Backdoor(AbstractDetector): Detect function named backdoor """ - ARGUMENT = "backdoor" # slither will launch the detector with slither.py --mydetector + ARGUMENT = ( + "backdoor" # slither will launch the detector with slither.py --mydetector + ) HELP = "Function named backdoor (detector example)" IMPACT = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.HIGH diff --git a/slither/detectors/functions/arbitrary_send.py b/slither/detectors/functions/arbitrary_send.py index 1b63d5e1c..04d23595e 100644 --- a/slither/detectors/functions/arbitrary_send.py +++ b/slither/detectors/functions/arbitrary_send.py @@ -11,7 +11,10 @@ """ from slither.core.declarations import Function from slither.analyses.data_dependency.data_dependency import is_tainted, is_dependent -from slither.core.declarations.solidity_variables import SolidityFunction, SolidityVariableComposed +from slither.core.declarations.solidity_variables import ( + SolidityFunction, + SolidityVariableComposed, +) from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import ( HighLevelCall, @@ -23,10 +26,70 @@ from slither.slithir.operations import ( ) -class ArbitrarySend(AbstractDetector): +# pylint: disable=too-many-nested-blocks,too-many-branches +def arbitrary_send(func): + if func.is_protected(): + return [] + + ret = [] + for node in func.nodes: + for ir in node.irs: + if isinstance(ir, SolidityCall): + if ir.function == SolidityFunction( + "ecrecover(bytes32,uint8,bytes32,bytes32)" + ): + return False + if isinstance(ir, Index): + if ir.variable_right == SolidityVariableComposed("msg.sender"): + return False + if is_dependent( + ir.variable_right, + SolidityVariableComposed("msg.sender"), + func.contract, + ): + return False + if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): + if isinstance(ir, (HighLevelCall)): + if isinstance(ir.function, Function): + if ( + ir.function.full_name + == "transferFrom(address,address,uint256)" + ): + return False + if ir.call_value is None: + continue + if ir.call_value == SolidityVariableComposed("msg.value"): + continue + if is_dependent( + ir.call_value, + SolidityVariableComposed("msg.value"), + func.contract, + ): + continue + + if is_tainted(ir.destination, func.contract): + ret.append(node) + + return ret + + +def detect_arbitrary_send(contract): """ + Detect arbitrary send + Args: + contract (Contract) + Returns: + list((Function), (list (Node))) """ + ret = [] + for f in [f for f in contract.functions if f.contract_declarer == contract]: + nodes = arbitrary_send(f) + if nodes: + ret.append((f, nodes)) + return ret + +class ArbitrarySend(AbstractDetector): ARGUMENT = "arbitrary-send" HELP = "Functions that send Ether to arbitrary destinations" IMPACT = DetectorClassification.HIGH @@ -35,7 +98,9 @@ class ArbitrarySend(AbstractDetector): WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#functions-that-send-ether-to-arbitrary-destinations" WIKI_TITLE = "Functions that send Ether to arbitrary destinations" - WIKI_DESCRIPTION = "Unprotected call to a function sending Ether to an arbitrary address." + WIKI_DESCRIPTION = ( + "Unprotected call to a function sending Ether to an arbitrary address." + ) WIKI_EXPLOIT_SCENARIO = """ ```solidity contract ArbitrarySend{ @@ -51,60 +116,9 @@ contract ArbitrarySend{ ``` Bob calls `setDestination` and `withdraw`. As a result he withdraws the contract's balance.""" - WIKI_RECOMMENDATION = "Ensure that an arbitrary user cannot withdraw unauthorized funds." - - def arbitrary_send(self, func): - """ - """ - if func.is_protected(): - return [] - - ret = [] - for node in func.nodes: - for ir in node.irs: - if isinstance(ir, SolidityCall): - if ir.function == SolidityFunction("ecrecover(bytes32,uint8,bytes32,bytes32)"): - return False - if isinstance(ir, Index): - if ir.variable_right == SolidityVariableComposed("msg.sender"): - return False - if is_dependent( - ir.variable_right, SolidityVariableComposed("msg.sender"), func.contract - ): - return False - if isinstance(ir, (HighLevelCall, LowLevelCall, Transfer, Send)): - if isinstance(ir, (HighLevelCall)): - if isinstance(ir.function, Function): - if ir.function.full_name == "transferFrom(address,address,uint256)": - return False - if ir.call_value is None: - continue - if ir.call_value == SolidityVariableComposed("msg.value"): - continue - if is_dependent( - ir.call_value, SolidityVariableComposed("msg.value"), func.contract - ): - continue - - if is_tainted(ir.destination, func.contract): - ret.append(node) - - return ret - - def detect_arbitrary_send(self, contract): - """ - Detect arbitrary send - Args: - contract (Contract) - Returns: - list((Function), (list (Node))) - """ - ret = [] - for f in [f for f in contract.functions if f.contract_declarer == contract]: - nodes = self.arbitrary_send(f) - if nodes: - ret.append((f, nodes)) - return ret + WIKI_RECOMMENDATION = ( + "Ensure that an arbitrary user cannot withdraw unauthorized funds." + ) def _detect(self): """ @@ -112,8 +126,8 @@ Bob calls `setDestination` and `withdraw`. As a result he withdraws the contract results = [] for c in self.contracts: - arbitrary_send = self.detect_arbitrary_send(c) - for (func, nodes) in arbitrary_send: + arbitrary_send_result = detect_arbitrary_send(c) + for (func, nodes) in arbitrary_send_result: info = [func, " sends eth to arbitrary user\n"] info += ["\tDangerous calls:\n"] diff --git a/slither/detectors/functions/complex_function.py b/slither/detectors/functions/complex_function.py deleted file mode 100644 index c9d7067a9..000000000 --- a/slither/detectors/functions/complex_function.py +++ /dev/null @@ -1,105 +0,0 @@ -from slither.core.declarations.solidity_variables import SolidityFunction, SolidityVariableComposed -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.slithir.operations import HighLevelCall, LowLevelCall, LibraryCall -from slither.utils.code_complexity import compute_cyclomatic_complexity - - -class ComplexFunction(AbstractDetector): - """ - Module detecting complex functions - A complex function is defined by: - - high cyclomatic complexity - - numerous writes to state variables - - numerous external calls - """ - - ARGUMENT = "complex-function" - HELP = "Complex functions" - IMPACT = DetectorClassification.INFORMATIONAL - CONFIDENCE = DetectorClassification.MEDIUM - - MAX_STATE_VARIABLES = 10 - MAX_EXTERNAL_CALLS = 5 - MAX_CYCLOMATIC_COMPLEXITY = 7 - - CAUSE_CYCLOMATIC = "cyclomatic" - CAUSE_EXTERNAL_CALL = "external_calls" - CAUSE_STATE_VARS = "state_vars" - - STANDARD_JSON = True - - @staticmethod - def detect_complex_func(func): - """Detect the cyclomatic complexity of the contract functions - shouldn't be greater than 7 - """ - result = [] - code_complexity = compute_cyclomatic_complexity(func) - - if code_complexity > ComplexFunction.MAX_CYCLOMATIC_COMPLEXITY: - result.append({"func": func, "cause": ComplexFunction.CAUSE_CYCLOMATIC}) - - """Detect the number of external calls in the func - shouldn't be greater than 5 - """ - count = 0 - for node in func.nodes: - for ir in node.irs: - if isinstance(ir, (HighLevelCall, LowLevelCall, LibraryCall)): - count += 1 - - if count > ComplexFunction.MAX_EXTERNAL_CALLS: - result.append({"func": func, "cause": ComplexFunction.CAUSE_EXTERNAL_CALL}) - - """Checks the number of the state variables written - shouldn't be greater than 10 - """ - if len(func.state_variables_written) > ComplexFunction.MAX_STATE_VARIABLES: - result.append({"func": func, "cause": ComplexFunction.CAUSE_STATE_VARS}) - - return result - - def detect_complex(self, contract): - ret = [] - - for func in contract.all_functions_called: - result = self.detect_complex_func(func) - ret.extend(result) - - return ret - - def detect(self): - results = [] - - for contract in self.contracts: - issues = self.detect_complex(contract) - - for issue in issues: - func, cause = issue.values() - - txt = "{} ({}) is a complex function:\n" - - if cause == self.CAUSE_EXTERNAL_CALL: - txt += "\t- Reason: High number of external calls" - if cause == self.CAUSE_CYCLOMATIC: - txt += "\t- Reason: High number of branches" - if cause == self.CAUSE_STATE_VARS: - txt += "\t- Reason: High number of modified state variables" - - info = txt.format(func.canonical_name, func.source_mapping_str) - info = info + "\n" - self.log(info) - - res = self.generate_result(info) - res.add( - func, - { - "high_number_of_external_calls": cause == self.CAUSE_EXTERNAL_CALL, - "high_number_of_branches": cause == self.CAUSE_CYCLOMATIC, - "high_number_of_state_variables": cause == self.CAUSE_STATE_VARS, - }, - ) - - results.append(res) - - return results diff --git a/slither/detectors/functions/external_function.py b/slither/detectors/functions/external_function.py index eb815565b..331a58ee9 100644 --- a/slither/detectors/functions/external_function.py +++ b/slither/detectors/functions/external_function.py @@ -1,7 +1,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import SolidityCall from slither.slithir.operations import InternalCall, InternalDynamicCall -from slither.formatters.functions.external_function import format +from slither.formatters.functions.external_function import custom_format class ExternalFunction(AbstractDetector): @@ -81,7 +81,9 @@ class ExternalFunction(AbstractDetector): # Somehow we couldn't resolve it, which shouldn't happen, as the provided function should be found if we could # not find some any more basic. - raise Exception("Could not resolve the base-most function for the provided function.") + raise Exception( + "Could not resolve the base-most function for the provided function." + ) @staticmethod def get_all_function_definitions(base_most_function): @@ -105,7 +107,7 @@ class ExternalFunction(AbstractDetector): def function_parameters_written(function): return any(p in function.variables_written for p in function.parameters) - def _detect(self): + def _detect(self): # pylint: disable=too-many-locals,too-many-branches results = [] # Create a set to track contracts with dynamic calls. All contracts with dynamic calls could potentially be @@ -157,20 +159,22 @@ class ExternalFunction(AbstractDetector): all_function_definitions = set( self.get_all_function_definitions(base_most_function) ) - completed_functions = completed_functions.union(all_function_definitions) + completed_functions = completed_functions.union( + all_function_definitions + ) # Filter false-positives: Determine if any of these sources have dynamic calls, if so, flag all of these # function definitions, and then flag all functions in all contracts that make dynamic calls. - sources_with_dynamic_calls = set(all_possible_sources) & dynamic_call_contracts + sources_with_dynamic_calls = ( + set(all_possible_sources) & dynamic_call_contracts + ) if sources_with_dynamic_calls: - functions_in_dynamic_call_sources = set( - [ - f - for dyn_contract in sources_with_dynamic_calls - for f in dyn_contract.functions - if not f.is_constructor - ] - ) + functions_in_dynamic_call_sources = { + f + for dyn_contract in sources_with_dynamic_calls + for f in dyn_contract.functions + if not f.is_constructor + } completed_functions = completed_functions.union( functions_in_dynamic_call_sources ) @@ -200,10 +204,12 @@ class ExternalFunction(AbstractDetector): function_definition = all_function_definitions[0] all_function_definitions = all_function_definitions[1:] - info = [f"{function_definition.full_name} should be declared external:\n"] - info += [f"\t- ", function_definition, "\n"] + info = [ + f"{function_definition.full_name} should be declared external:\n" + ] + info += ["\t- ", function_definition, "\n"] for other_function_definition in all_function_definitions: - info += [f"\t- ", other_function_definition, "\n"] + info += ["\t- ", other_function_definition, "\n"] res = self.generate_result(info) @@ -213,4 +219,4 @@ class ExternalFunction(AbstractDetector): @staticmethod def _format(slither, result): - format(slither, result) + custom_format(slither, result) diff --git a/slither/detectors/functions/suicidal.py b/slither/detectors/functions/suicidal.py index 44c4aa919..0afafa2b5 100644 --- a/slither/detectors/functions/suicidal.py +++ b/slither/detectors/functions/suicidal.py @@ -20,7 +20,9 @@ class Suicidal(AbstractDetector): WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#suicidal" WIKI_TITLE = "Suicidal" - WIKI_DESCRIPTION = "Unprotected call to a function executing `selfdestruct`/`suicide`." + WIKI_DESCRIPTION = ( + "Unprotected call to a function executing `selfdestruct`/`suicide`." + ) WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Suicidal{ diff --git a/slither/detectors/naming_convention/naming_convention.py b/slither/detectors/naming_convention/naming_convention.py index 24e944c6b..1ddee0854 100644 --- a/slither/detectors/naming_convention/naming_convention.py +++ b/slither/detectors/naming_convention/naming_convention.py @@ -1,6 +1,6 @@ import re from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.formatters.naming_convention.naming_convention import format +from slither.formatters.naming_convention.naming_convention import custom_format class NamingConvention(AbstractDetector): @@ -54,7 +54,7 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2 def should_avoid_name(name): return re.search("^[lOI]$", name) is not None - def _detect(self): + def _detect(self): # pylint: disable=too-many-branches,too-many-statements results = [] for contract in self.contracts: @@ -91,7 +91,9 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2 "private", ] and self.is_mixed_case_with_underscore(func.name): continue - if func.name.startswith("echidna_") or func.name.startswith("crytic_"): + if func.name.startswith("echidna_") or func.name.startswith( + "crytic_" + ): continue info = ["Function ", func, " is not in mixedCase\n"] @@ -106,22 +108,34 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2 if argument in func.variables_read_or_written: correct_naming = self.is_mixed_case(argument.name) else: - correct_naming = self.is_mixed_case_with_underscore(argument.name) + correct_naming = self.is_mixed_case_with_underscore( + argument.name + ) if not correct_naming: info = ["Parameter ", argument, " is not in mixedCase\n"] res = self.generate_result(info) - res.add(argument, {"target": "parameter", "convention": "mixedCase"}) + res.add( + argument, {"target": "parameter", "convention": "mixedCase"} + ) results.append(res) for var in contract.state_variables_declared: if self.should_avoid_name(var.name): if not self.is_upper_case_with_underscores(var.name): - info = ["Variable ", var, " used l, O, I, which should not be used\n"] + info = [ + "Variable ", + var, + " used l, O, I, which should not be used\n", + ] res = self.generate_result(info) res.add( - var, {"target": "variable", "convention": "l_O_I_should_not_be_used"} + var, + { + "target": "variable", + "convention": "l_O_I_should_not_be_used", + }, ) results.append(res) @@ -131,7 +145,11 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2 continue if not self.is_upper_case_with_underscores(var.name): - info = ["Constant ", var, " is not in UPPER_CASE_WITH_UNDERSCORES\n"] + info = [ + "Constant ", + var, + " is not in UPPER_CASE_WITH_UNDERSCORES\n", + ] res = self.generate_result(info) res.add( @@ -175,4 +193,4 @@ Solidity defines a [naming convention](https://solidity.readthedocs.io/en/v0.4.2 @staticmethod def _format(slither, result): - format(slither, result) + custom_format(slither, result) diff --git a/slither/detectors/operations/block_timestamp.py b/slither/detectors/operations/block_timestamp.py index cf65bf5b3..9a68a6bd5 100644 --- a/slither/detectors/operations/block_timestamp.py +++ b/slither/detectors/operations/block_timestamp.py @@ -7,7 +7,10 @@ from typing import List, Tuple from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node from slither.core.declarations import Function, Contract -from slither.core.declarations.solidity_variables import SolidityVariableComposed, SolidityVariable +from slither.core.declarations.solidity_variables import ( + SolidityVariableComposed, + SolidityVariable, +) from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import Binary, BinaryType @@ -17,7 +20,9 @@ def _timestamp(func: Function) -> List[Node]: for node in func.nodes: if node.contains_require_or_assert(): for var in node.variables_read: - if is_dependent(var, SolidityVariableComposed("block.timestamp"), func.contract): + if is_dependent( + var, SolidityVariableComposed("block.timestamp"), func.contract + ): ret.add(node) if is_dependent(var, SolidityVariable("now"), func.contract): ret.add(node) @@ -33,7 +38,9 @@ def _timestamp(func: Function) -> List[Node]: return sorted(list(ret), key=lambda x: x.node_id) -def _detect_dangerous_timestamp(contract: Contract) -> List[Tuple[Function, List[Node]]]: +def _detect_dangerous_timestamp( + contract: Contract, +) -> List[Tuple[Function, List[Node]]]: """ Args: contract (Contract) @@ -49,20 +56,18 @@ def _detect_dangerous_timestamp(contract: Contract) -> List[Tuple[Function, List class Timestamp(AbstractDetector): - """ - """ ARGUMENT = "timestamp" HELP = "Dangerous usage of `block.timestamp`" IMPACT = DetectorClassification.LOW CONFIDENCE = DetectorClassification.MEDIUM - WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#block-timestamp" + WIKI = ( + "https://github.com/crytic/slither/wiki/Detector-Documentation#block-timestamp" + ) WIKI_TITLE = "Block timestamp" - WIKI_DESCRIPTION = ( - "Dangerous usage of `block.timestamp`. `block.timestamp` can be manipulated by miners." - ) + WIKI_DESCRIPTION = "Dangerous usage of `block.timestamp`. `block.timestamp` can be manipulated by miners." WIKI_EXPLOIT_SCENARIO = """"Bob's contract relies on `block.timestamp` for its randomness. Eve is a miner and manipulates `block.timestamp` to exploit Bob's contract.""" WIKI_RECOMMENDATION = "Avoid relying on `block.timestamp`." diff --git a/slither/detectors/operations/low_level_calls.py b/slither/detectors/operations/low_level_calls.py index 81dbf5ce1..180120a14 100644 --- a/slither/detectors/operations/low_level_calls.py +++ b/slither/detectors/operations/low_level_calls.py @@ -16,7 +16,9 @@ class LowLevelCalls(AbstractDetector): IMPACT = DetectorClassification.INFORMATIONAL CONFIDENCE = DetectorClassification.HIGH - WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#low-level-calls" + WIKI = ( + "https://github.com/crytic/slither/wiki/Detector-Documentation#low-level-calls" + ) WIKI_TITLE = "Low-level calls" WIKI_DESCRIPTION = "The use of low-level calls is error-prone. Low-level calls do not check for [code existence](https://solidity.readthedocs.io/en/v0.4.25/control-structures.html#error-handling-assert-require-revert-and-exceptions) or call success." diff --git a/slither/detectors/operations/unchecked_low_level_return_values.py b/slither/detectors/operations/unchecked_low_level_return_values.py index a9c2da38d..391077b6f 100644 --- a/slither/detectors/operations/unchecked_low_level_return_values.py +++ b/slither/detectors/operations/unchecked_low_level_return_values.py @@ -2,7 +2,7 @@ Module detecting unused return values from low level """ from slither.detectors.abstract_detector import DetectorClassification -from .unused_return_values import UnusedReturnValues +from slither.detectors.operations.unused_return_values import UnusedReturnValues from slither.slithir.operations import LowLevelCall @@ -32,9 +32,11 @@ The return value of the low-level call is not checked, so if the call fails, the If the low level is used to prevent blocking operations, consider logging failed calls. """ - WIKI_RECOMMENDATION = "Ensure that the return value of a low-level call is checked or logged." + WIKI_RECOMMENDATION = ( + "Ensure that the return value of a low-level call is checked or logged." + ) _txt_description = "low-level calls" - def _is_instance(self, ir): + def _is_instance(self, ir): # pylint: disable=no-self-use return isinstance(ir, LowLevelCall) diff --git a/slither/detectors/operations/unchecked_send_return_value.py b/slither/detectors/operations/unchecked_send_return_value.py index b2437609a..469045a84 100644 --- a/slither/detectors/operations/unchecked_send_return_value.py +++ b/slither/detectors/operations/unchecked_send_return_value.py @@ -3,7 +3,7 @@ Module detecting unused return values from send """ from slither.detectors.abstract_detector import DetectorClassification -from .unused_return_values import UnusedReturnValues +from slither.detectors.operations.unused_return_values import UnusedReturnValues from slither.slithir.operations import Send @@ -17,7 +17,9 @@ class UncheckedSend(UnusedReturnValues): IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM - WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#unchecked-send" + WIKI = ( + "https://github.com/crytic/slither/wiki/Detector-Documentation#unchecked-send" + ) WIKI_TITLE = "Unchecked Send" WIKI_DESCRIPTION = "The return value of a `send` is not checked." @@ -37,5 +39,5 @@ If `send` is used to prevent blocking operations, consider logging the failed `s _txt_description = "send calls" - def _is_instance(self, ir): + def _is_instance(self, ir): # pylint: disable=no-self-use return isinstance(ir, Send) diff --git a/slither/detectors/operations/unused_return_values.py b/slither/detectors/operations/unused_return_values.py index a64212b5e..741993d70 100644 --- a/slither/detectors/operations/unused_return_values.py +++ b/slither/detectors/operations/unused_return_values.py @@ -2,9 +2,11 @@ Module detecting unused return values from external calls """ -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.slithir.operations import HighLevelCall, InternalCall, InternalDynamicCall from slither.core.variables.state_variable import StateVariable +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +from slither.slithir.operations import HighLevelCall + + class UnusedReturnValues(AbstractDetector): @@ -20,9 +22,7 @@ class UnusedReturnValues(AbstractDetector): WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#unused-return" WIKI_TITLE = "Unused return" - WIKI_DESCRIPTION = ( - "The return value of an external call is not stored in a local or state variable." - ) + WIKI_DESCRIPTION = "The return value of an external call is not stored in a local or state variable." WIKI_EXPLOIT_SCENARIO = """ ```solidity contract MyConc{ @@ -34,14 +34,16 @@ contract MyConc{ ``` `MyConc` calls `add` of `SafeMath`, but does not store the result in `a`. As a result, the computation has no effect.""" - WIKI_RECOMMENDATION = "Ensure that all the return values of the function calls are used." + WIKI_RECOMMENDATION = ( + "Ensure that all the return values of the function calls are used." + ) _txt_description = "external calls" - def _is_instance(self, ir): + def _is_instance(self, ir): # pylint: disable=no-self-use return isinstance(ir, HighLevelCall) - def detect_unused_return_values(self, f): + def detect_unused_return_values(self, f): # pylint: disable=no-self-use """ Return the nodes where the return value of a call is unused Args: @@ -76,7 +78,7 @@ contract MyConc{ if unused_return: for node in unused_return: - info = [f, f" ignores return value by ", node, "\n"] + info = [f, " ignores return value by ", node, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/operations/void_constructor.py b/slither/detectors/operations/void_constructor.py index 96922c7a7..1d8414d3c 100644 --- a/slither/detectors/operations/void_constructor.py +++ b/slither/detectors/operations/void_constructor.py @@ -9,7 +9,9 @@ class VoidConstructor(AbstractDetector): IMPACT = DetectorClassification.LOW CONFIDENCE = DetectorClassification.HIGH - WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#void-constructor" + WIKI = ( + "https://github.com/crytic/slither/wiki/Detector-Documentation#void-constructor" + ) WIKI_TITLE = "Void constructor" WIKI_DESCRIPTION = "Detect the call to a constructor that is not implemented" diff --git a/slither/detectors/reentrancy/reentrancy.py b/slither/detectors/reentrancy/reentrancy.py index eb238ea96..89318e9f0 100644 --- a/slither/detectors/reentrancy/reentrancy.py +++ b/slither/detectors/reentrancy/reentrancy.py @@ -16,7 +16,10 @@ from slither.slithir.operations import Call, EventCall def union_dict(d1, d2): - d3 = {k: d1.get(k, set()) | d2.get(k, set()) for k in set(list(d1.keys()) + list(d2.keys()))} + d3 = { + k: d1.get(k, set()) | d2.get(k, set()) + for k in set(list(d1.keys()) + list(d2.keys())) + } return defaultdict(set, d3) @@ -40,7 +43,8 @@ def is_subset( def to_hashable(d: Dict[Node, Set[Node]]): list_tuple = list( - tuple((k, tuple(sorted(values, key=lambda x: x.node_id)))) for k, values in d.items() + tuple((k, tuple(sorted(values, key=lambda x: x.node_id)))) + for k, values in d.items() ) return tuple(sorted(list_tuple, key=lambda x: x[0].node_id)) @@ -125,9 +129,12 @@ class AbstractState: if key != skip_father }, ) - self._reads = union_dict(self._reads, father.context[detector.KEY].reads) + self._reads = union_dict( + self._reads, father.context[detector.KEY].reads + ) self._reads_prior_calls = union_dict( - self.reads_prior_calls, father.context[detector.KEY].reads_prior_calls + self.reads_prior_calls, + father.context[detector.KEY].reads_prior_calls, ) def analyze_node(self, node, detector): @@ -178,14 +185,36 @@ class AbstractState: self._send_eth = union_dict(self._send_eth, fathers.send_eth) self._calls = union_dict(self._calls, fathers.calls) self._reads = union_dict(self._reads, fathers.reads) - self._reads_prior_calls = union_dict(self._reads_prior_calls, fathers.reads_prior_calls) + self._reads_prior_calls = union_dict( + self._reads_prior_calls, fathers.reads_prior_calls + ) def does_not_bring_new_info(self, new_info): if is_subset(new_info.calls, self.calls): if is_subset(new_info.send_eth, self.send_eth): if is_subset(new_info.reads, self.reads): - if dict_are_equal(new_info.reads_prior_calls, self.reads_prior_calls): + if dict_are_equal( + new_info.reads_prior_calls, self.reads_prior_calls + ): return True + return False + + +def _filter_if(node): + """ + Check if the node is a condtional node where + there is an external call checked + Heuristic: + - The call is a IF node + - It contains a, external call + - The condition is the negation (!) + + This will work only on naive implementation + """ + return ( + isinstance(node.expression, UnaryOperation) + and node.expression.type == UnaryOperationType.BANG + ) class Reentrancy(AbstractDetector): @@ -215,22 +244,6 @@ class Reentrancy(AbstractDetector): """ return isinstance(ir, Call) and ir.can_send_eth() - def _filter_if(self, node): - """ - Check if the node is a condtional node where - there is an external call checked - Heuristic: - - The call is a IF node - - It contains a, external call - - The condition is the negation (!) - - This will work only on naive implementation - """ - return ( - isinstance(node.expression, UnaryOperation) - and node.expression.type == UnaryOperationType.BANG - ) - def _explore(self, node, visited, skip_father=None): """ Explore the CFG and look for re-entrancy @@ -266,7 +279,7 @@ class Reentrancy(AbstractDetector): sons = node.sons if contains_call and node.type in [NodeType.IF, NodeType.IFLOOP]: - if self._filter_if(node): + if _filter_if(node): son = sons[0] self._explore(son, visited, node) sons = sons[1:] @@ -279,8 +292,6 @@ class Reentrancy(AbstractDetector): self._explore(son, visited) def detect_reentrancy(self, contract): - """ - """ for function in contract.functions_and_modifiers_declared: if function.is_implemented: if self.KEY in function.context: @@ -296,7 +307,7 @@ class Reentrancy(AbstractDetector): # new variables written # This speedup the exploration through a light fixpoint # Its particular useful on 'complex' functions with several loops and conditions - self.visited_all_paths = {} + self.visited_all_paths = {} # pylint: disable=attribute-defined-outside-init for c in self.contracts: self.detect_reentrancy(c) diff --git a/slither/detectors/reentrancy/reentrancy_benign.py b/slither/detectors/reentrancy/reentrancy_benign.py index 05e6dd3f9..79308bd0e 100644 --- a/slither/detectors/reentrancy/reentrancy_benign.py +++ b/slither/detectors/reentrancy/reentrancy_benign.py @@ -20,9 +20,7 @@ class ReentrancyBenign(Reentrancy): IMPACT = DetectorClassification.LOW CONFIDENCE = DetectorClassification.MEDIUM - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-2" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-2" WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_DESCRIPTION = """ @@ -62,13 +60,15 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr for v in node.context[self.KEY].written if v in node.context[self.KEY].reads_prior_calls[c] ] - not_read_then_written = set( - [ - FindingValue(v, node, tuple(sorted(nodes, key=lambda x: x.node_id))) - for (v, nodes) in node.context[self.KEY].written.items() - if v not in read_then_written - ] - ) + not_read_then_written = { + FindingValue( + v, + node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + for (v, nodes) in node.context[self.KEY].written.items() + if v not in read_then_written + } if not_read_then_written: # calls are ordered finding_key = FindingKey( @@ -79,7 +79,7 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr result[finding_key] |= not_read_then_written return result - def _detect(self): + def _detect(self): # pylint: disable=too-many-branches """ """ @@ -88,12 +88,16 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr results = [] - result_sorted = sorted(list(reentrancies.items()), key=lambda x: x[0].function.name) + result_sorted = sorted( + list(reentrancies.items()), key=lambda x: x[0].function.name + ) varsWritten: List[FindingValue] for (func, calls, send_eth), varsWritten in result_sorted: calls = sorted(list(set(calls)), key=lambda x: x[0].node_id) send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id) - varsWritten = sorted(varsWritten, key=lambda x: (x.variable.name, x.node.node_id)) + varsWritten = sorted( + varsWritten, key=lambda x: (x.variable.name, x.node.node_id) + ) info = ["Reentrancy in ", func, ":\n"] @@ -128,18 +132,24 @@ Only report reentrancy that acts as a double call (see `reentrancy-eth`, `reentr res.add(call_info, {"underlying_type": "external_calls"}) for call_list_info in calls_list: if call_list_info != call_info: - res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, + ) # # If the calls are not the same ones that send eth, add the eth sending nodes. if calls != send_eth: for (call_info, calls_list) in calls: - res.add(call_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_info, {"underlying_type": "external_calls_sending_eth"} + ) for call_list_info in calls_list: if call_list_info != call_info: res.add( - call_list_info, {"underlying_type": "external_calls_sending_eth"} + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, ) # Add all variables written via nodes which write them. diff --git a/slither/detectors/reentrancy/reentrancy_eth.py b/slither/detectors/reentrancy/reentrancy_eth.py index 729f55443..7c3c2e198 100644 --- a/slither/detectors/reentrancy/reentrancy_eth.py +++ b/slither/detectors/reentrancy/reentrancy_eth.py @@ -20,9 +20,7 @@ class ReentrancyEth(Reentrancy): IMPACT = DetectorClassification.HIGH CONFIDENCE = DetectorClassification.MEDIUM - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities" WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_DESCRIPTION = """ @@ -48,7 +46,7 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m def find_reentrancies(self): result = defaultdict(set) - for contract in self.contracts: + for contract in self.contracts: # pylint: disable=too-many-nested-blocks for f in contract.functions_and_modifiers_declared: for node in f.nodes: # dead code @@ -61,15 +59,17 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m for c in node.context[self.KEY].calls: if c == node: continue - read_then_written |= set( - [ - FindingValue( - v, node, tuple(sorted(nodes, key=lambda x: x.node_id)) - ) - for (v, nodes) in node.context[self.KEY].written.items() - if v in node.context[self.KEY].reads_prior_calls[c] - ] - ) + read_then_written |= { + FindingValue( + v, + node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + for (v, nodes) in node.context[ + self.KEY + ].written.items() + if v in node.context[self.KEY].reads_prior_calls[c] + } if read_then_written: # calls are ordered @@ -82,7 +82,7 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m result[finding_key] |= set(read_then_written) return result - def _detect(self): + def _detect(self): # pylint: disable=too-many-branches """ """ super()._detect() @@ -91,12 +91,16 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m results = [] - result_sorted = sorted(list(reentrancies.items()), key=lambda x: x[0].function.name) + result_sorted = sorted( + list(reentrancies.items()), key=lambda x: x[0].function.name + ) varsWritten: List[FindingValue] for (func, calls, send_eth), varsWritten in result_sorted: calls = sorted(list(set(calls)), key=lambda x: x[0].node_id) send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id) - varsWritten = sorted(varsWritten, key=lambda x: (x.variable.name, x.node.node_id)) + varsWritten = sorted( + varsWritten, key=lambda x: (x.variable.name, x.node.node_id) + ) info = ["Reentrancy in ", func, ":\n"] info += ["\tExternal calls:\n"] @@ -130,16 +134,22 @@ Bob uses the re-entrancy bug to call `withdrawBalance` two times, and withdraw m res.add(call_info, {"underlying_type": "external_calls"}) for call_list_info in calls_list: if call_list_info != call_info: - res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, + ) # If the calls are not the same ones that send eth, add the eth sending nodes. if calls != send_eth: for (call_info, calls_list) in send_eth: - res.add(call_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_info, {"underlying_type": "external_calls_sending_eth"} + ) for call_list_info in calls_list: if call_list_info != call_info: res.add( - call_list_info, {"underlying_type": "external_calls_sending_eth"} + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, ) # Add all variables written via nodes which write them. diff --git a/slither/detectors/reentrancy/reentrancy_events.py b/slither/detectors/reentrancy/reentrancy_events.py index c23f3110d..822a88886 100644 --- a/slither/detectors/reentrancy/reentrancy_events.py +++ b/slither/detectors/reentrancy/reentrancy_events.py @@ -19,9 +19,7 @@ class ReentrancyEvent(Reentrancy): IMPACT = DetectorClassification.LOW CONFIDENCE = DetectorClassification.MEDIUM - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-3" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-3" WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_DESCRIPTION = """ @@ -60,19 +58,19 @@ If `d.()` re-enters, the `Counter` events will be shown in an incorrect order, w calls=to_hashable(node.context[self.KEY].calls), send_eth=to_hashable(node.context[self.KEY].send_eth), ) - finding_vars = set( - [ - FindingValue( - e, e.node, tuple(sorted(nodes, key=lambda x: x.node_id)) - ) - for (e, nodes) in node.context[self.KEY].events.items() - ] - ) + finding_vars = { + FindingValue( + e, + e.node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + for (e, nodes) in node.context[self.KEY].events.items() + } if finding_vars: result[finding_key] |= finding_vars return result - def _detect(self): + def _detect(self): # pylint: disable=too-many-branches """ """ super()._detect() @@ -85,7 +83,9 @@ If `d.()` re-enters, the `Counter` events will be shown in an incorrect order, w for (func, calls, send_eth), events in result_sorted: calls = sorted(list(set(calls)), key=lambda x: x[0].node_id) send_eth = sorted(list(set(send_eth)), key=lambda x: x[0].node_id) - events = sorted(events, key=lambda x: (str(x.variable.name), x.node.node_id)) + events = sorted( + events, key=lambda x: (str(x.variable.name), x.node.node_id) + ) info = ["Reentrancy in ", func, ":\n"] info += ["\tExternal calls:\n"] @@ -119,18 +119,24 @@ If `d.()` re-enters, the `Counter` events will be shown in an incorrect order, w res.add(call_info, {"underlying_type": "external_calls"}) for call_list_info in calls_list: if call_list_info != call_info: - res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, + ) # # If the calls are not the same ones that send eth, add the eth sending nodes. if calls != send_eth: for (call_info, calls_list) in send_eth: - res.add(call_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_info, {"underlying_type": "external_calls_sending_eth"} + ) for call_list_info in calls_list: if call_list_info != call_info: res.add( - call_list_info, {"underlying_type": "external_calls_sending_eth"} + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, ) for finding_value in events: diff --git a/slither/detectors/reentrancy/reentrancy_no_gas.py b/slither/detectors/reentrancy/reentrancy_no_gas.py index 02dda3da2..d4b261f51 100644 --- a/slither/detectors/reentrancy/reentrancy_no_gas.py +++ b/slither/detectors/reentrancy/reentrancy_no_gas.py @@ -23,9 +23,7 @@ class ReentrancyNoGas(Reentrancy): IMPACT = DetectorClassification.INFORMATIONAL CONFIDENCE = DetectorClassification.MEDIUM - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-4" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-4" WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_DESCRIPTION = """ @@ -71,25 +69,27 @@ Only report reentrancy that is based on `transfer` or `send`.""" calls=to_hashable(node.context[self.KEY].calls), send_eth=to_hashable(node.context[self.KEY].send_eth), ) - finding_vars = set( - [ - FindingValue(v, node, tuple(sorted(nodes, key=lambda x: x.node_id))) - for (v, nodes) in node.context[self.KEY].written.items() - ] - ) - finding_vars |= set( - [ - FindingValue( - e, e.node, tuple(sorted(nodes, key=lambda x: x.node_id)) - ) - for (e, nodes) in node.context[self.KEY].events.items() - ] - ) + finding_vars = { + FindingValue( + v, + node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + for (v, nodes) in node.context[self.KEY].written.items() + } + finding_vars |= { + FindingValue( + e, + e.node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + for (e, nodes) in node.context[self.KEY].events.items() + } if finding_vars: result[finding_key] |= finding_vars return result - def _detect(self): + def _detect(self): # pylint: disable=too-many-branches,too-many-locals """ """ @@ -123,7 +123,9 @@ Only report reentrancy that is based on `transfer` or `send`.""" for (v, node, nodes) in varsWrittenOrEvent if isinstance(v, Variable) ] - varsWritten = sorted(varsWritten, key=lambda x: (x.variable.name, x.node.node_id)) + varsWritten = sorted( + varsWritten, key=lambda x: (x.variable.name, x.node.node_id) + ) if varsWritten: info += ["\tState variables written after the call(s):\n"] for finding_value in varsWritten: @@ -157,18 +159,24 @@ Only report reentrancy that is based on `transfer` or `send`.""" res.add(call_info, {"underlying_type": "external_calls"}) for call_list_info in calls_list: if call_list_info != call_info: - res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, + ) # # If the calls are not the same ones that send eth, add the eth sending nodes. if calls != send_eth: for (call_info, calls_list) in send_eth: - res.add(call_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_info, {"underlying_type": "external_calls_sending_eth"} + ) for call_list_info in calls_list: if call_list_info != call_info: res.add( - call_list_info, {"underlying_type": "external_calls_sending_eth"} + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, ) # Add all variables written via nodes which write them. diff --git a/slither/detectors/reentrancy/reentrancy_read_before_write.py b/slither/detectors/reentrancy/reentrancy_read_before_write.py index e08dcd79e..41bc6cc7f 100644 --- a/slither/detectors/reentrancy/reentrancy_read_before_write.py +++ b/slither/detectors/reentrancy/reentrancy_read_before_write.py @@ -19,9 +19,7 @@ class ReentrancyReadBeforeWritten(Reentrancy): IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-1" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#reentrancy-vulnerabilities-1" WIKI_TITLE = "Reentrancy vulnerabilities" WIKI_DESCRIPTION = """ @@ -45,26 +43,31 @@ Do not report reentrancies that involve Ether (see `reentrancy-eth`).""" def find_reentrancies(self): result = defaultdict(set) - for contract in self.contracts: + for contract in self.contracts: # pylint: disable=too-many-nested-blocks for f in contract.functions_and_modifiers_declared: for node in f.nodes: # dead code if self.KEY not in node.context: continue - if node.context[self.KEY].calls and not node.context[self.KEY].send_eth: + if ( + node.context[self.KEY].calls + and not node.context[self.KEY].send_eth + ): read_then_written = set() for c in node.context[self.KEY].calls: if c == node: continue - read_then_written |= set( - [ - FindingValue( - v, node, tuple(sorted(nodes, key=lambda x: x.node_id)) - ) - for (v, nodes) in node.context[self.KEY].written.items() - if v in node.context[self.KEY].reads_prior_calls[c] - ] - ) + read_then_written |= { + FindingValue( + v, + node, + tuple(sorted(nodes, key=lambda x: x.node_id)), + ) + for (v, nodes) in node.context[ + self.KEY + ].written.items() + if v in node.context[self.KEY].reads_prior_calls[c] + } # We found a potential re-entrancy bug if read_then_written: @@ -76,7 +79,7 @@ Do not report reentrancies that involve Ether (see `reentrancy-eth`).""" result[finding_key] |= read_then_written return result - def _detect(self): + def _detect(self): # pylint: disable=too-many-branches """ """ @@ -85,10 +88,14 @@ Do not report reentrancies that involve Ether (see `reentrancy-eth`).""" results = [] - result_sorted = sorted(list(reentrancies.items()), key=lambda x: x[0].function.name) + result_sorted = sorted( + list(reentrancies.items()), key=lambda x: x[0].function.name + ) for (func, calls), varsWritten in result_sorted: calls = sorted(list(set(calls)), key=lambda x: x[0].node_id) - varsWritten = sorted(varsWritten, key=lambda x: (x.variable.name, x.node.node_id)) + varsWritten = sorted( + varsWritten, key=lambda x: (x.variable.name, x.node.node_id) + ) info = ["Reentrancy in ", func, ":\n"] @@ -116,7 +123,10 @@ Do not report reentrancies that involve Ether (see `reentrancy-eth`).""" res.add(call_info, {"underlying_type": "external_calls"}) for call_list_info in calls_list: if call_list_info != call_info: - res.add(call_list_info, {"underlying_type": "external_calls_sending_eth"}) + res.add( + call_list_info, + {"underlying_type": "external_calls_sending_eth"}, + ) # Add all variables written via nodes which write them. for finding_value in varsWritten: diff --git a/slither/detectors/shadowing/abstract.py b/slither/detectors/shadowing/abstract.py index e8fc20e52..47a596e77 100644 --- a/slither/detectors/shadowing/abstract.py +++ b/slither/detectors/shadowing/abstract.py @@ -6,6 +6,20 @@ Recursively check the called functions from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +def detect_shadowing(contract): + ret = [] + variables_fathers = [] + for father in contract.inheritance: + if all(not f.is_implemented for f in father.functions + father.modifiers): + variables_fathers += father.state_variables_declared + + for var in contract.state_variables_declared: + shadow = [v for v in variables_fathers if v.name == var.name] + if shadow: + ret.append([var] + shadow) + return ret + + class ShadowingAbstractDetection(AbstractDetector): """ Shadowing detection @@ -34,19 +48,6 @@ contract DerivedContract is BaseContract{ WIKI_RECOMMENDATION = "Remove the state variable shadowing." - def detect_shadowing(self, contract): - ret = [] - variables_fathers = [] - for father in contract.inheritance: - if all(not f.is_implemented for f in father.functions + father.modifiers): - variables_fathers += father.state_variables_declared - - for var in contract.state_variables_declared: - shadow = [v for v in variables_fathers if v.name == var.name] - if shadow: - ret.append([var] + shadow) - return ret - def _detect(self): """ Detect shadowing @@ -56,8 +57,8 @@ contract DerivedContract is BaseContract{ """ results = [] - for c in self.contracts: - shadowing = self.detect_shadowing(c) + for contract in self.contracts: + shadowing = detect_shadowing(contract) if shadowing: for all_variables in shadowing: shadow = all_variables[0] diff --git a/slither/detectors/shadowing/builtin_symbols.py b/slither/detectors/shadowing/builtin_symbols.py index e6a06dc1e..924388d86 100644 --- a/slither/detectors/shadowing/builtin_symbols.py +++ b/slither/detectors/shadowing/builtin_symbols.py @@ -179,7 +179,10 @@ contract Bug { shadow_type = shadow[0] shadow_object = shadow[1] - info = [shadow_object, f' ({shadow_type}) shadows built-in symbol"\n'] + info = [ + shadow_object, + f' ({shadow_type}) shadows built-in symbol"\n', + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/shadowing/local.py b/slither/detectors/shadowing/local.py index 9323c4cf4..a71334ecd 100644 --- a/slither/detectors/shadowing/local.py +++ b/slither/detectors/shadowing/local.py @@ -47,7 +47,7 @@ contract Bug { OVERSHADOWED_STATE_VARIABLE = "state variable" OVERSHADOWED_EVENT = "event" - def detect_shadowing_definitions(self, contract): + def detect_shadowing_definitions(self, contract): # pylint: disable=too-many-branches """ Detects if functions, access modifiers, events, state variables, and local variables are named after reserved keywords. Any such definitions are returned in a list. @@ -68,11 +68,15 @@ contract Bug { # Check functions for scope_function in scope_contract.functions_declared: if variable.name == scope_function.name: - overshadowed.append((self.OVERSHADOWED_FUNCTION, scope_function)) + overshadowed.append( + (self.OVERSHADOWED_FUNCTION, scope_function) + ) # Check modifiers for scope_modifier in scope_contract.modifiers_declared: if variable.name == scope_modifier.name: - overshadowed.append((self.OVERSHADOWED_MODIFIER, scope_modifier)) + overshadowed.append( + (self.OVERSHADOWED_MODIFIER, scope_modifier) + ) # Check events for scope_event in scope_contract.events_declared: if variable.name == scope_event.name: @@ -108,7 +112,11 @@ contract Bug { overshadowed = shadow[1] info = [local_variable, " shadows:\n"] for overshadowed_entry in overshadowed: - info += ["\t- ", overshadowed_entry[1], f" ({overshadowed_entry[0]})\n"] + info += [ + "\t- ", + overshadowed_entry[1], + f" ({overshadowed_entry[0]})\n", + ] # Generate relevant JSON data for this shadowing definition. res = self.generate_result(info) diff --git a/slither/detectors/shadowing/state.py b/slither/detectors/shadowing/state.py index ea56169a1..4a664d0b6 100644 --- a/slither/detectors/shadowing/state.py +++ b/slither/detectors/shadowing/state.py @@ -5,6 +5,20 @@ Module detecting shadowing of state variables from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification +def detect_shadowing(contract): + ret = [] + variables_fathers = [] + for father in contract.inheritance: + if any(f.is_implemented for f in father.functions + father.modifiers): + variables_fathers += father.state_variables_declared + + for var in contract.state_variables_declared: + shadow = [v for v in variables_fathers if v.name == var.name] + if shadow: + ret.append([var] + shadow) + return ret + + class StateShadowing(AbstractDetector): """ Shadowing of state variable @@ -47,19 +61,6 @@ contract DerivedContract is BaseContract{ WIKI_RECOMMENDATION = "Remove the state variable shadowing." - def detect_shadowing(self, contract): - ret = [] - variables_fathers = [] - for father in contract.inheritance: - if any(f.is_implemented for f in father.functions + father.modifiers): - variables_fathers += father.state_variables_declared - - for var in contract.state_variables_declared: - shadow = [v for v in variables_fathers if v.name == var.name] - if shadow: - ret.append([var] + shadow) - return ret - def _detect(self): """ Detect shadowing @@ -70,7 +71,7 @@ contract DerivedContract is BaseContract{ """ results = [] for c in self.contracts: - shadowing = self.detect_shadowing(c) + shadowing = detect_shadowing(c) if shadowing: for all_variables in shadowing: shadow = all_variables[0] diff --git a/slither/detectors/slither/name_reused.py b/slither/detectors/slither/name_reused.py index 9f0b1c3c3..989261418 100644 --- a/slither/detectors/slither/name_reused.py +++ b/slither/detectors/slither/name_reused.py @@ -42,14 +42,16 @@ As a result, the second contract cannot be analyzed. """ WIKI_RECOMMENDATION = "Rename the contract." - def _detect(self): + def _detect(self): # pylint: disable=too-many-locals,too-many-branches results = [] names_reused = self.slither.contract_name_collisions # First show the contracts that we know are missing incorrectly_constructed = [ - contract for contract in self.contracts if contract.is_incorrectly_constructed + contract + for contract in self.contracts + if contract.is_incorrectly_constructed ] inheritance_corrupted = defaultdict(list) @@ -66,7 +68,9 @@ As a result, the second contract cannot be analyzed. info += ["\t- ", file, "\n"] if contract_name in inheritance_corrupted: - info += ["\tAs a result, the inherited contracts are not correctly analyzed:\n"] + info += [ + "\tAs a result, the inherited contracts are not correctly analyzed:\n" + ] for corrupted in inheritance_corrupted[contract_name]: info += ["\t\t- ", corrupted, "\n"] res = self.generate_result(info) @@ -79,14 +83,18 @@ As a result, the second contract cannot be analyzed. for b in most_base_with_missing_inheritance: info = [b, " inherits from a contract for which the name is reused.\n"] if b.inheritance: - info += ["\t- Slither could not determine which contract has a duplicate name:\n"] + info += [ + "\t- Slither could not determine which contract has a duplicate name:\n" + ] for inheritance in b.inheritance: info += ["\t\t-", inheritance, "\n"] info += ["\t- Check if:\n"] info += ["\t\t- A inherited contract is missing from this list,\n"] info += ["\t\t- The contract are imported from the correct files.\n"] if b.derived_contracts: - info += [f"\t- This issue impacts the contracts inheriting from {b.name}:\n"] + info += [ + f"\t- This issue impacts the contracts inheriting from {b.name}:\n" + ] for derived in b.derived_contracts: info += ["\t\t-", derived, "\n"] res = self.generate_result(info) diff --git a/slither/detectors/source/rtlo.py b/slither/detectors/source/rtlo.py index 56d66cfa4..03e640586 100644 --- a/slither/detectors/source/rtlo.py +++ b/slither/detectors/source/rtlo.py @@ -1,5 +1,5 @@ -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification import re +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification class RightToLeftOverride(AbstractDetector): @@ -65,25 +65,27 @@ contract Token # If we couldn't find the character in the remainder of source, stop. if result_index == -1: break - else: - # We found another instance of the character, define our output - idx = start_index + result_index - - relative = self.slither.crytic_compile.filename_lookup(filename).relative - info = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n" - - # We have a patch, so pattern.find will return at least one result - - info += f"\t- {pattern.findall(source_encoded)[0]}\n" - res = self.generate_result(info) - res.add_other( - "rtlo-character", - (filename, idx, len(self.RTLO_CHARACTER_ENCODED)), - self.slither, - ) - results.append(res) - - # Advance the start index for the next iteration - start_index = result_index + 1 + + # We found another instance of the character, define our output + idx = start_index + result_index + + relative = self.slither.crytic_compile.filename_lookup( + filename + ).relative + info = f"{relative} contains a unicode right-to-left-override character at byte offset {idx}:\n" + + # We have a patch, so pattern.find will return at least one result + + info += f"\t- {pattern.findall(source_encoded)[0]}\n" + res = self.generate_result(info) + res.add_other( + "rtlo-character", + (filename, idx, len(self.RTLO_CHARACTER_ENCODED)), + self.slither, + ) + results.append(res) + + # Advance the start index for the next iteration + start_index = result_index + 1 return results diff --git a/slither/detectors/statements/assembly.py b/slither/detectors/statements/assembly.py index a16634d80..0558f266d 100644 --- a/slither/detectors/statements/assembly.py +++ b/slither/detectors/statements/assembly.py @@ -16,7 +16,9 @@ class Assembly(AbstractDetector): IMPACT = DetectorClassification.INFORMATIONAL CONFIDENCE = DetectorClassification.HIGH - WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#assembly-usage" + WIKI = ( + "https://github.com/crytic/slither/wiki/Detector-Documentation#assembly-usage" + ) WIKI_TITLE = "Assembly usage" WIKI_DESCRIPTION = "The use of assembly is error-prone and should be avoided." diff --git a/slither/detectors/statements/boolean_constant_equality.py b/slither/detectors/statements/boolean_constant_equality.py index d82573518..fd5268ecc 100644 --- a/slither/detectors/statements/boolean_constant_equality.py +++ b/slither/detectors/statements/boolean_constant_equality.py @@ -3,7 +3,10 @@ Module detecting misuse of Boolean constants """ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.slithir.operations import Assignment, Call, Return, InitArray, Binary, BinaryType +from slither.slithir.operations import ( + Binary, + BinaryType, +) from slither.slithir.variables import Constant @@ -17,7 +20,9 @@ class BooleanEquality(AbstractDetector): IMPACT = DetectorClassification.INFORMATIONAL CONFIDENCE = DetectorClassification.HIGH - WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#boolean-equality" + WIKI = ( + "https://github.com/crytic/slither/wiki/Detector-Documentation#boolean-equality" + ) WIKI_TITLE = "Boolean equality" WIKI_DESCRIPTION = """Detects the comparison to boolean constants.""" @@ -44,7 +49,7 @@ Boolean constants can be used directly and do not need to be compare to `true` o results = [] # Loop for each function and modifier. - for function in contract.functions_and_modifiers_declared: + for function in contract.functions_and_modifiers_declared: # pylint: disable=too-many-nested-blocks f_results = set() # Loop for every node in this function, looking for boolean constants @@ -54,7 +59,7 @@ Boolean constants can be used directly and do not need to be compare to `true` o if ir.type in [BinaryType.EQUAL, BinaryType.NOT_EQUAL]: for r in ir.read: if isinstance(r, Constant): - if type(r.value) is bool: + if isinstance(r.value, bool): f_results.add(node) results.append((function, f_results)) @@ -71,7 +76,12 @@ Boolean constants can be used directly and do not need to be compare to `true` o if boolean_constant_misuses: for (func, nodes) in boolean_constant_misuses: for node in nodes: - info = [func, " compares to a boolean constant:\n\t-", node, "\n"] + info = [ + func, + " compares to a boolean constant:\n\t-", + node, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/boolean_constant_misuse.py b/slither/detectors/statements/boolean_constant_misuse.py index c4e574a00..110c0745b 100644 --- a/slither/detectors/statements/boolean_constant_misuse.py +++ b/slither/detectors/statements/boolean_constant_misuse.py @@ -26,9 +26,7 @@ class BooleanConstantMisuse(AbstractDetector): IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#misuse-of-a-boolean-constant" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#misuse-of-a-boolean-constant" WIKI_TITLE = "Misuse of a Boolean constant" WIKI_DESCRIPTION = """Detects the misuse of a Boolean constant.""" @@ -56,7 +54,7 @@ Other uses (in complex expressions, as conditionals) indicate either an error or WIKI_RECOMMENDATION = """Verify and simplify the condition.""" @staticmethod - def _detect_boolean_constant_misuses(contract): + def _detect_boolean_constant_misuses(contract): # pylint: disable=too-many-branches """ Detects and returns all nodes which misuse a Boolean constant. :param contract: Contract to detect assignment within. @@ -67,7 +65,7 @@ Other uses (in complex expressions, as conditionals) indicate either an error or results = [] # Loop for each function and modifier. - for function in contract.functions_declared: + for function in contract.functions_declared: # pylint: disable=too-many-nested-blocks f_results = set() # Loop for every node in this function, looking for boolean constants @@ -88,13 +86,17 @@ Other uses (in complex expressions, as conditionals) indicate either an error or # It's ok to use a bare boolean constant in these contexts continue if isinstance(ir, Binary): - if ir.type in [BinaryType.ADDITION, BinaryType.EQUAL, BinaryType.NOT_EQUAL]: + if ir.type in [ + BinaryType.ADDITION, + BinaryType.EQUAL, + BinaryType.NOT_EQUAL, + ]: # Comparing to a Boolean constant is dubious style, but harmless # Equal is catch by another detector (informational severity) continue for r in ir.read: if isinstance(r, Constant): - if type(r.value) is bool: + if isinstance(r.value, bool): f_results.add(node) results.append((function, f_results)) @@ -111,7 +113,12 @@ Other uses (in complex expressions, as conditionals) indicate either an error or if boolean_constant_misuses: for (func, nodes) in boolean_constant_misuses: for node in nodes: - info = [func, " uses a Boolean constant improperly:\n\t-", node, "\n"] + info = [ + func, + " uses a Boolean constant improperly:\n\t-", + node, + "\n", + ] res = self.generate_result(info) results.append(res) diff --git a/slither/detectors/statements/calls_in_loop.py b/slither/detectors/statements/calls_in_loop.py index bcc8e15f5..3dc2f3219 100644 --- a/slither/detectors/statements/calls_in_loop.py +++ b/slither/detectors/statements/calls_in_loop.py @@ -1,13 +1,15 @@ -""" -""" from slither.core.cfg.node import NodeType from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.slithir.operations import HighLevelCall, LibraryCall, LowLevelCall, Send, Transfer +from slither.slithir.operations import ( + HighLevelCall, + LibraryCall, + LowLevelCall, + Send, + Transfer, +) class MultipleCallsInLoop(AbstractDetector): - """ - """ ARGUMENT = "calls-loop" HELP = "Multiple calls in a loop" diff --git a/slither/detectors/statements/controlled_delegatecall.py b/slither/detectors/statements/controlled_delegatecall.py index 784301c8c..9c1a1f6f1 100644 --- a/slither/detectors/statements/controlled_delegatecall.py +++ b/slither/detectors/statements/controlled_delegatecall.py @@ -3,9 +3,20 @@ from slither.slithir.operations import LowLevelCall from slither.analyses.data_dependency.data_dependency import is_tainted +def controlled_delegatecall(function): + ret = [] + for node in function.nodes: + for ir in node.irs: + if isinstance(ir, LowLevelCall) and ir.function_name in [ + "delegatecall", + "callcode", + ]: + if is_tainted(ir.destination, function.contract): + ret.append(node) + return ret + + class ControlledDelegateCall(AbstractDetector): - """ - """ ARGUMENT = "controlled-delegatecall" HELP = "Controlled delegatecall destination" @@ -15,7 +26,9 @@ class ControlledDelegateCall(AbstractDetector): WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#controlled-delegatecall" WIKI_TITLE = "Controlled Delegatecall" - WIKI_DESCRIPTION = "`Delegatecall` or `callcode` to an address controlled by the user." + WIKI_DESCRIPTION = ( + "`Delegatecall` or `callcode` to an address controlled by the user." + ) WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Delegatecall{ @@ -28,18 +41,6 @@ Bob calls `delegate` and delegates the execution to his malicious contract. As a WIKI_RECOMMENDATION = "Avoid using `delegatecall`. Use only trusted destinations." - def controlled_delegatecall(self, function): - ret = [] - for node in function.nodes: - for ir in node.irs: - if isinstance(ir, LowLevelCall) and ir.function_name in [ - "delegatecall", - "callcode", - ]: - if is_tainted(ir.destination, function.contract): - ret.append(node) - return ret - def _detect(self): results = [] @@ -49,9 +50,12 @@ Bob calls `delegate` and delegates the execution to his malicious contract. As a # As functions to upgrades the destination lead to too many FPs if contract.is_upgradeable_proxy and f.is_protected(): continue - nodes = self.controlled_delegatecall(f) + nodes = controlled_delegatecall(f) if nodes: - func_info = [f, " uses delegatecall to a input-controlled function id\n"] + func_info = [ + f, + " uses delegatecall to a input-controlled function id\n", + ] for node in nodes: node_info = func_info + ["\t- ", node, "\n"] diff --git a/slither/detectors/statements/deprecated_calls.py b/slither/detectors/statements/deprecated_calls.py index a12d3403d..f1c687d68 100644 --- a/slither/detectors/statements/deprecated_calls.py +++ b/slither/detectors/statements/deprecated_calls.py @@ -2,12 +2,15 @@ Module detecting deprecated standards. """ -from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.visitors.expression.export_values import ExportValues -from slither.core.declarations.solidity_variables import SolidityVariableComposed, SolidityFunction from slither.core.cfg.node import NodeType +from slither.core.declarations.solidity_variables import ( + SolidityVariableComposed, + SolidityFunction, +) +from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification from slither.slithir.operations import LowLevelCall -from slither.solc_parsing.variables.state_variable import StateVariableSolc, StateVariable +from slither.visitors.expression.export_values import ExportValues + # Reference: https://smartcontractsecurity.github.io/SWC-registry/docs/SWC-111 class DeprecatedStandards(AbstractDetector): @@ -124,7 +127,7 @@ contract ContractWithDeprecatedReferences { results.append((state_variable, deprecated_results)) # Loop through all functions + modifiers in this contract. - for function in contract.functions_and_modifiers_declared: + for function in contract.functions_and_modifiers_declared: # pylint: disable=too-many-nested-blocks # Loop through each node in this function. for node in function.nodes: # Detect deprecated references in the node. @@ -153,14 +156,16 @@ contract ContractWithDeprecatedReferences { """ results = [] for contract in self.contracts: - deprecated_references = self.detect_deprecated_references_in_contract(contract) + deprecated_references = self.detect_deprecated_references_in_contract( + contract + ) if deprecated_references: for deprecated_reference in deprecated_references: source_object = deprecated_reference[0] deprecated_entries = deprecated_reference[1] info = ["Deprecated standard detected ", source_object, ":\n"] - for (dep_id, original_desc, recommended_disc) in deprecated_entries: + for (_dep_id, original_desc, recommended_disc) in deprecated_entries: info += [ f'\t- Usage of "{original_desc}" should be replaced with "{recommended_disc}"\n' ] diff --git a/slither/detectors/statements/divide_before_multiply.py b/slither/detectors/statements/divide_before_multiply.py index 25ced7972..ebe99dabf 100644 --- a/slither/detectors/statements/divide_before_multiply.py +++ b/slither/detectors/statements/divide_before_multiply.py @@ -79,7 +79,7 @@ In general, it's usually a good idea to re-arrange arithmetic to perform multipl WIKI_RECOMMENDATION = """Consider ordering multiplication before division.""" - def _explore(self, node, explored, f_results, divisions): + def _explore(self, node, explored, f_results, divisions): # pylint: disable=too-many-branches if node in explored: return explored.add(node) @@ -111,7 +111,9 @@ In general, it's usually a good idea to re-arrange arithmetic to perform multipl if node in divisions[r]: nodes += [n for n in divisions[r] if n not in nodes] else: - nodes += [n for n in divisions[r] + [node] if n not in nodes] + nodes += [ + n for n in divisions[r] + [node] if n not in nodes + ] if nodes: node_results = nodes @@ -170,11 +172,16 @@ In general, it's usually a good idea to re-arrange arithmetic to perform multipl """ results = [] for contract in self.contracts: - divisions_before_multiplications = self.detect_divide_before_multiply(contract) + divisions_before_multiplications = self.detect_divide_before_multiply( + contract + ) if divisions_before_multiplications: for (func, nodes) in divisions_before_multiplications: - info = [func, " performs a multiplication on the result of a division:\n"] + info = [ + func, + " performs a multiplication on the result of a division:\n", + ] for node in nodes: info += ["\t-", node, "\n"] diff --git a/slither/detectors/statements/incorrect_strict_equality.py b/slither/detectors/statements/incorrect_strict_equality.py index d4300534f..a327ffa00 100644 --- a/slither/detectors/statements/incorrect_strict_equality.py +++ b/slither/detectors/statements/incorrect_strict_equality.py @@ -6,12 +6,21 @@ from slither.analyses.data_dependency.data_dependency import is_dependent_ssa from slither.core.declarations import Function from slither.detectors.abstract_detector import AbstractDetector, DetectorClassification -from slither.slithir.operations import Assignment, Balance, Binary, BinaryType, HighLevelCall +from slither.slithir.operations import ( + Assignment, + Balance, + Binary, + BinaryType, + HighLevelCall, +) from slither.core.solidity_types import MappingType, ElementaryType from slither.core.variables.state_variable import StateVariable -from slither.core.declarations.solidity_variables import SolidityVariable, SolidityVariableComposed +from slither.core.declarations.solidity_variables import ( + SolidityVariable, + SolidityVariableComposed, +) class IncorrectStrictEquality(AbstractDetector): @@ -20,12 +29,12 @@ class IncorrectStrictEquality(AbstractDetector): IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.HIGH - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-strict-equalities" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-strict-equalities" WIKI_TITLE = "Dangerous strict equalities" - WIKI_DESCRIPTION = "Use of strict equalities that can be easily manipulated by an attacker." + WIKI_DESCRIPTION = ( + "Use of strict equalities that can be easily manipulated by an attacker." + ) WIKI_EXPLOIT_SCENARIO = """ ```solidity contract Crowdsale{ @@ -36,9 +45,7 @@ contract Crowdsale{ `Crowdsale` relies on `fund_reached` to know when to stop the sale of tokens. `Crowdsale` reaches 100 Ether. Bob sends 0.1 Ether. As a result, `fund_reached` is always false and the `crowdsale` never ends.""" - WIKI_RECOMMENDATION = ( - """Don't use strict equality to determine if an account has enough Ether or tokens.""" - ) + WIKI_RECOMMENDATION = """Don't use strict equality to determine if an account has enough Ether or tokens.""" sources_taint = [ SolidityVariable("now"), @@ -98,7 +105,9 @@ contract Crowdsale{ for ir in node.irs_ssa: # Filter to only tainted equality (==) comparisons - if self.is_direct_comparison(ir) and self.is_any_tainted(ir.used, taints, func): + if self.is_direct_comparison(ir) and self.is_any_tainted( + ir.used, taints, func + ): if func not in results: results[func] = [] results[func].append(node) @@ -133,7 +142,7 @@ contract Crowdsale{ # Output each node with the function info header as a separate result. for node in nodes: - node_info = func_info + [f"\t- ", node, "\n"] + node_info = func_info + ["\t- ", node, "\n"] res = self.generate_result(node_info) results.append(res) diff --git a/slither/detectors/statements/too_many_digits.py b/slither/detectors/statements/too_many_digits.py index ec0d557dc..de39781e0 100644 --- a/slither/detectors/statements/too_many_digits.py +++ b/slither/detectors/statements/too_many_digits.py @@ -16,7 +16,9 @@ class TooManyDigits(AbstractDetector): IMPACT = DetectorClassification.INFORMATIONAL CONFIDENCE = DetectorClassification.MEDIUM - WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#too-many-digits" + WIKI = ( + "https://github.com/crytic/slither/wiki/Detector-Documentation#too-many-digits" + ) WIKI_TITLE = "Too many digits" WIKI_DESCRIPTION = """ Literals with many digits are difficult to read and review. diff --git a/slither/detectors/statements/tx_origin.py b/slither/detectors/statements/tx_origin.py index 88adab962..c67b08bb5 100644 --- a/slither/detectors/statements/tx_origin.py +++ b/slither/detectors/statements/tx_origin.py @@ -15,9 +15,7 @@ class TxOrigin(AbstractDetector): IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.MEDIUM - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-usage-of-txorigin" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#dangerous-usage-of-txorigin" WIKI_TITLE = "Dangerous usage of `tx.origin`" WIKI_DESCRIPTION = "`tx.origin`-based protection can be abused by a malicious contract if a legitimate user interacts with the malicious contract." diff --git a/slither/detectors/statements/type_based_tautology.py b/slither/detectors/statements/type_based_tautology.py index f3bd276db..eea70e844 100644 --- a/slither/detectors/statements/type_based_tautology.py +++ b/slither/detectors/statements/type_based_tautology.py @@ -8,6 +8,56 @@ from slither.slithir.variables import Constant from slither.core.solidity_types.elementary_type import Int, Uint +def typeRange(t): + bits = int(t.split("int")[1]) + if t in Uint: + return 0, (2 ** bits) - 1 + if t in Int: + v = (2 ** (bits - 1)) - 1 + return -v, v + return None + + +def _detect_tautology_or_contradiction(low, high, cval, op): + """ + Return true if "[low high] op cval " is always true or always false + :param low: + :param high: + :param cval: + :param op: + :return: + """ + if op == BinaryType.LESS: + # a < cval + # its a tautology if + # high(a) < cval + # its a contradiction if + # low(a) >= cval + return high < cval or low >= cval + if op == BinaryType.GREATER: + # a > cval + # its a tautology if + # low(a) > cval + # its a contradiction if + # high(a) <= cval + return low > cval or high <= cval + if op == BinaryType.LESS_EQUAL: + # a <= cval + # its a tautology if + # high(a) <= cval + # its a contradiction if + # low(a) > cval + return (high <= cval) or (low > cval) + if op == BinaryType.GREATER_EQUAL: + # a >= cval + # its a tautology if + # low(a) >= cval + # its a contradiction if + # high(a) < cval + return (low >= cval) or (high < cval) + return False + + class TypeBasedTautology(AbstractDetector): """ Type-based tautology or contradiction @@ -18,9 +68,7 @@ class TypeBasedTautology(AbstractDetector): IMPACT = DetectorClassification.MEDIUM CONFIDENCE = DetectorClassification.HIGH - WIKI = ( - "https://github.com/crytic/slither/wiki/Detector-Documentation#tautology-or-contradiction" - ) + WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#tautology-or-contradiction" WIKI_TITLE = "Tautology or contradiction" WIKI_DESCRIPTION = """Detects expressions that are tautologies or contradictions.""" @@ -50,14 +98,6 @@ contract A { """Fix the incorrect comparison by changing the value type or the comparison.""" ) - def typeRange(self, t): - bits = int(t.split("int")[1]) - if t in Uint: - return (0, (2 ** bits) - 1) - if t in Int: - v = (2 ** (bits - 1)) - 1 - return (-v, v) - flip_table = { BinaryType.GREATER: BinaryType.LESS, BinaryType.GREATER_EQUAL: BinaryType.LESS_EQUAL, @@ -65,45 +105,6 @@ contract A { BinaryType.LESS_EQUAL: BinaryType.GREATER_EQUAL, } - def _detect_tautology_or_contradiction(self, low, high, cval, op): - """ - Return true if "[low high] op cval " is always true or always false - :param low: - :param high: - :param cval: - :param op: - :return: - """ - if op == BinaryType.LESS: - # a < cval - # its a tautology if - # high(a) < cval - # its a contradiction if - # low(a) >= cval - return high < cval or low >= cval - elif op == BinaryType.GREATER: - # a > cval - # its a tautology if - # low(a) > cval - # its a contradiction if - # high(a) <= cval - return low > cval or high <= cval - elif op == BinaryType.LESS_EQUAL: - # a <= cval - # its a tautology if - # high(a) <= cval - # its a contradiction if - # low(a) > cval - return (high <= cval) or (low > cval) - elif op == BinaryType.GREATER_EQUAL: - # a >= cval - # its a tautology if - # low(a) >= cval - # its a contradiction if - # high(a) < cval - return (low >= cval) or (high < cval) - return False - def detect_type_based_tautologies(self, contract): """ Detects and returns all nodes with tautology/contradiction comparisons (based on type alone). @@ -116,7 +117,7 @@ contract A { allInts = Int + Uint # Loop for each function and modifier. - for function in contract.functions_declared: + for function in contract.functions_declared: # pylint: disable=too-many-nested-blocks f_results = set() for node in function.nodes: @@ -127,8 +128,8 @@ contract A { cval = ir.variable_left.value rtype = str(ir.variable_right.type) if rtype in allInts: - (low, high) = self.typeRange(rtype) - if self._detect_tautology_or_contradiction( + (low, high) = typeRange(rtype) + if _detect_tautology_or_contradiction( low, high, cval, self.flip_table[ir.type] ): f_results.add(node) @@ -137,8 +138,8 @@ contract A { cval = ir.variable_right.value ltype = str(ir.variable_left.type) if ltype in allInts: - (low, high) = self.typeRange(ltype) - if self._detect_tautology_or_contradiction( + (low, high) = typeRange(ltype) + if _detect_tautology_or_contradiction( low, high, cval, ir.type ): f_results.add(node) diff --git a/slither/detectors/variables/possible_const_state_variables.py b/slither/detectors/variables/possible_const_state_variables.py index 39930ed28..a14db7c3c 100644 --- a/slither/detectors/variables/possible_const_state_variables.py +++ b/slither/detectors/variables/possible_const_state_variables.py @@ -7,7 +7,7 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi from slither.visitors.expression.export_values import ExportValues from slither.core.declarations.solidity_variables import SolidityFunction from slither.core.variables.state_variable import StateVariable -from slither.formatters.variables.possible_const_state_variables import format +from slither.formatters.variables.possible_const_state_variables import custom_format class ConstCandidateStateVars(AbstractDetector): @@ -26,8 +26,12 @@ class ConstCandidateStateVars(AbstractDetector): WIKI = "https://github.com/crytic/slither/wiki/Detector-Documentation#state-variables-that-could-be-declared-constant" WIKI_TITLE = "State variables that could be declared constant" - WIKI_DESCRIPTION = "Constant state variables should be declared constant to save gas." - WIKI_RECOMMENDATION = "Add the `constant` attributes to state variables that never change." + WIKI_DESCRIPTION = ( + "Constant state variables should be declared constant to save gas." + ) + WIKI_RECOMMENDATION = ( + "Add the `constant` attributes to state variables that never change." + ) @staticmethod def _valid_candidate(v): @@ -61,7 +65,10 @@ class ConstCandidateStateVars(AbstractDetector): if not values: return True if all( - (val in self.valid_solidity_function or self._is_constant_var(val) for val in values) + ( + val in self.valid_solidity_function or self._is_constant_var(val) + for val in values + ) ): return True return False @@ -72,18 +79,18 @@ class ConstCandidateStateVars(AbstractDetector): results = [] all_variables = [c.state_variables for c in self.slither.contracts] - all_variables = set([item for sublist in all_variables for item in sublist]) - all_non_constant_elementary_variables = set( - [v for v in all_variables if self._valid_candidate(v)] - ) + all_variables = {item for sublist in all_variables for item in sublist} + all_non_constant_elementary_variables = {v for v in all_variables if self._valid_candidate(v)} all_functions = [c.all_functions_called for c in self.slither.contracts] - all_functions = list(set([item for sublist in all_functions for item in sublist])) + all_functions = list({item for sublist in all_functions for item in sublist}) all_variables_written = [ - f.state_variables_written for f in all_functions if not f.is_constructor_variables + f.state_variables_written + for f in all_functions + if not f.is_constructor_variables ] - all_variables_written = set([item for sublist in all_variables_written for item in sublist]) + all_variables_written = {item for sublist in all_variables_written for item in sublist} constable_variables = [ v @@ -91,7 +98,9 @@ class ConstCandidateStateVars(AbstractDetector): if (not v in all_variables_written) and self._constant_initial_expression(v) ] # Order for deterministic results - constable_variables = sorted(constable_variables, key=lambda x: x.canonical_name) + constable_variables = sorted( + constable_variables, key=lambda x: x.canonical_name + ) # Create a result for each finding for v in constable_variables: @@ -103,4 +112,4 @@ class ConstCandidateStateVars(AbstractDetector): @staticmethod def _format(slither, result): - format(slither, result) + custom_format(slither, result) diff --git a/slither/detectors/variables/uninitialized_local_variables.py b/slither/detectors/variables/uninitialized_local_variables.py index bb663ba7b..4b5cfca45 100644 --- a/slither/detectors/variables/uninitialized_local_variables.py +++ b/slither/detectors/variables/uninitialized_local_variables.py @@ -9,8 +9,6 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi class UninitializedLocalVars(AbstractDetector): - """ - """ ARGUMENT = "uninitialized-local" HELP = "Uninitialized local variables" @@ -55,7 +53,9 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is else: self.visited_all_paths[node] = [] - self.visited_all_paths[node] = list(set(self.visited_all_paths[node] + fathers_context)) + self.visited_all_paths[node] = list( + set(self.visited_all_paths[node] + fathers_context) + ) if self.key in node.context: fathers_context += node.context[self.key] @@ -66,7 +66,9 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is self.results.append((function, uninitialized_local_variable)) # Only save the local variables that are not yet written - uninitialized_local_variables = list(set(fathers_context) - set(node.variables_written)) + uninitialized_local_variables = list( + set(fathers_context) - set(node.variables_written) + ) node.context[self.key] = uninitialized_local_variables for son in node.sons: @@ -81,6 +83,7 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is """ results = [] + # pylint: disable=attribute-defined-outside-init self.results = [] self.visited_all_paths = {} @@ -91,14 +94,21 @@ Bob calls `transfer`. As a result, all Ether is sent to the address `0x0` and is continue # dont consider storage variable, as they are detected by another detector uninitialized_local_variables = [ - v for v in function.local_variables if not v.is_storage and v.uninitialized + v + for v in function.local_variables + if not v.is_storage and v.uninitialized ] - function.entry_point.context[self.key] = uninitialized_local_variables + function.entry_point.context[ + self.key + ] = uninitialized_local_variables self._detect_uninitialized(function, function.entry_point, []) all_results = list(set(self.results)) for (function, uninitialized_local_variable) in all_results: - info = [uninitialized_local_variable, " is a local variable never initialized\n"] + info = [ + uninitialized_local_variable, + " is a local variable never initialized\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/variables/uninitialized_state_variables.py b/slither/detectors/variables/uninitialized_state_variables.py index 97b12e6be..8c5ddef9f 100644 --- a/slither/detectors/variables/uninitialized_state_variables.py +++ b/slither/detectors/variables/uninitialized_state_variables.py @@ -47,11 +47,12 @@ Initialize all the variables. If a variable is meant to be initialized to zero, @staticmethod def _written_variables(contract): ret = [] + # pylint: disable=too-many-nested-blocks for f in contract.all_functions_called + contract.modifiers: for n in f.nodes: ret += n.state_variables_written for ir in n.irs: - if isinstance(ir, LibraryCall) or isinstance(ir, InternalCall): + if isinstance(ir, (LibraryCall, InternalCall)): idx = 0 if ir.function: for param in ir.function.parameters: @@ -69,6 +70,7 @@ Initialize all the variables. If a variable is meant to be initialized to zero, def _variable_written_in_proxy(self): # Hack to memoize without having it define in the init if hasattr(self, "__variables_written_in_proxy"): + # pylint: disable=access-member-before-definition return self.__variables_written_in_proxy variables_written_in_proxy = [] @@ -76,7 +78,8 @@ Initialize all the variables. If a variable is meant to be initialized to zero, if c.is_upgradeable_proxy: variables_written_in_proxy += self._written_variables(c) - self.__variables_written_in_proxy = list(set([v.name for v in variables_written_in_proxy])) + # pylint: disable=attribute-defined-outside-init + self.__variables_written_in_proxy = list({v.name for v in variables_written_in_proxy}) return self.__variables_written_in_proxy def _written_variables_in_proxy(self, contract): diff --git a/slither/detectors/variables/uninitialized_storage_variables.py b/slither/detectors/variables/uninitialized_storage_variables.py index d4cd82081..8fe0633ae 100644 --- a/slither/detectors/variables/uninitialized_storage_variables.py +++ b/slither/detectors/variables/uninitialized_storage_variables.py @@ -9,8 +9,6 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi class UninitializedStorageVars(AbstractDetector): - """ - """ ARGUMENT = "uninitialized-storage" HELP = "Uninitialized storage variables" @@ -63,7 +61,9 @@ Bob calls `func`. As a result, `owner` is overridden to `0`. else: self.visited_all_paths[node] = [] - self.visited_all_paths[node] = list(set(self.visited_all_paths[node] + fathers_context)) + self.visited_all_paths[node] = list( + set(self.visited_all_paths[node] + fathers_context) + ) if self.key in node.context: fathers_context += node.context[self.key] @@ -74,7 +74,9 @@ Bob calls `func`. As a result, `owner` is overridden to `0`. self.results.append((function, uninitialized_storage_variable)) # Only save the storage variables that are not yet written - uninitialized_storage_variables = list(set(fathers_context) - set(node.variables_written)) + uninitialized_storage_variables = list( + set(fathers_context) - set(node.variables_written) + ) node.context[self.key] = uninitialized_storage_variables for son in node.sons: @@ -89,6 +91,7 @@ Bob calls `func`. As a result, `owner` is overridden to `0`. """ results = [] + # pylint: disable=attribute-defined-outside-init self.results = [] self.visited_all_paths = {} @@ -96,13 +99,20 @@ Bob calls `func`. As a result, `owner` is overridden to `0`. for function in contract.functions: if function.is_implemented: uninitialized_storage_variables = [ - v for v in function.local_variables if v.is_storage and v.uninitialized + v + for v in function.local_variables + if v.is_storage and v.uninitialized ] - function.entry_point.context[self.key] = uninitialized_storage_variables + function.entry_point.context[ + self.key + ] = uninitialized_storage_variables self._detect_uninitialized(function, function.entry_point, []) for (function, uninitialized_storage_variable) in self.results: - info = [uninitialized_storage_variable, " is a storage variable never initialized\n"] + info = [ + uninitialized_storage_variable, + " is a storage variable never initialized\n", + ] json = self.generate_result(info) results.append(json) diff --git a/slither/detectors/variables/unused_state_variables.py b/slither/detectors/variables/unused_state_variables.py index 17c2e36dd..ac1410358 100644 --- a/slither/detectors/variables/unused_state_variables.py +++ b/slither/detectors/variables/unused_state_variables.py @@ -6,7 +6,45 @@ from slither.detectors.abstract_detector import AbstractDetector, DetectorClassi from slither.core.solidity_types import ArrayType from slither.visitors.expression.export_values import ExportValues from slither.core.variables.state_variable import StateVariable -from slither.formatters.variables.unused_state_variables import format +from slither.formatters.variables.unused_state_variables import custom_format + + +def detect_unused(contract): + if contract.is_signature_only(): + return None + # Get all the variables read in all the functions and modifiers + + all_functions = contract.all_functions_called + contract.modifiers + variables_used = [x.state_variables_read for x in all_functions] + variables_used += [ + x.state_variables_written + for x in all_functions + if not x.is_constructor_variables + ] + + array_candidates = [x.variables for x in all_functions] + array_candidates = [ + i for sl in array_candidates for i in sl + ] + contract.state_variables + array_candidates = [ + x.type.length + for x in array_candidates + if isinstance(x.type, ArrayType) and x.type.length + ] + array_candidates = [ExportValues(x).result() for x in array_candidates] + array_candidates = [i for sl in array_candidates for i in sl] + array_candidates = [v for v in array_candidates if isinstance(v, StateVariable)] + + # Flat list + variables_used = [item for sublist in variables_used for item in sublist] + variables_used = list(set(variables_used + array_candidates)) + + # Return the variables unused that are not public + return [ + x + for x in contract.variables + if x not in variables_used and x.visibility != "public" + ] class UnusedStateVars(AbstractDetector): @@ -26,43 +64,12 @@ class UnusedStateVars(AbstractDetector): WIKI_EXPLOIT_SCENARIO = "" WIKI_RECOMMENDATION = "Remove unused state variables." - def detect_unused(self, contract): - if contract.is_signature_only(): - return None - # Get all the variables read in all the functions and modifiers - - all_functions = contract.all_functions_called + contract.modifiers - variables_used = [x.state_variables_read for x in all_functions] - variables_used += [ - x.state_variables_written for x in all_functions if not x.is_constructor_variables - ] - - array_candidates = [x.variables for x in all_functions] - array_candidates = [i for sl in array_candidates for i in sl] + contract.state_variables - array_candidates = [ - x.type.length - for x in array_candidates - if isinstance(x.type, ArrayType) and x.type.length - ] - array_candidates = [ExportValues(x).result() for x in array_candidates] - array_candidates = [i for sl in array_candidates for i in sl] - array_candidates = [v for v in array_candidates if isinstance(v, StateVariable)] - - # Flat list - variables_used = [item for sublist in variables_used for item in sublist] - variables_used = list(set(variables_used + array_candidates)) - - # Return the variables unused that are not public - return [ - x for x in contract.variables if x not in variables_used and x.visibility != "public" - ] - def _detect(self): """ Detect unused state variables """ results = [] for c in self.slither.contracts_derived: - unusedVars = self.detect_unused(c) + unusedVars = detect_unused(c) if unusedVars: for var in unusedVars: info = [var, " is never used in ", c, "\n"] @@ -73,4 +80,4 @@ class UnusedStateVars(AbstractDetector): @staticmethod def _format(slither, result): - format(slither, result) + custom_format(slither, result) diff --git a/slither/formatters/attributes/const_functions.py b/slither/formatters/attributes/const_functions.py index baa5d30d9..282e9077c 100644 --- a/slither/formatters/attributes/const_functions.py +++ b/slither/formatters/attributes/const_functions.py @@ -3,7 +3,7 @@ from slither.formatters.exceptions import FormatError from slither.formatters.utils.patches import create_patch -def format(slither, result): +def custom_format(slither, result): elements = result["elements"] for element in elements: if element["type"] != "function": diff --git a/slither/formatters/attributes/constant_pragma.py b/slither/formatters/attributes/constant_pragma.py index 94d0d2663..9f311cb98 100644 --- a/slither/formatters/attributes/constant_pragma.py +++ b/slither/formatters/attributes/constant_pragma.py @@ -5,6 +5,8 @@ from slither.formatters.utils.patches import create_patch # Indicates the recommended versions for replacement REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"] +# pylint: disable=anomalous-backslash-in-string + # group: # 0: ^ > >= < <= (optional) # 1: ' ' (optional) @@ -14,7 +16,7 @@ REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"] PATTERN = re.compile("(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)") -def format(slither, result): +def custom_format(slither, result): elements = result["elements"] versions_used = [] for element in elements: @@ -35,10 +37,11 @@ def _analyse_versions(used_solc_versions): replace_solc_versions = list() for version in used_solc_versions: replace_solc_versions.append(_determine_solc_version_replacement(version)) - if not all(version == replace_solc_versions[0] for version in replace_solc_versions): + if not all( + version == replace_solc_versions[0] for version in replace_solc_versions + ): raise FormatImpossible("Multiple incompatible versions!") - else: - return replace_solc_versions[0] + return replace_solc_versions[0] def _determine_solc_version_replacement(used_solc_version): @@ -48,24 +51,28 @@ def _determine_solc_version_replacement(used_solc_version): minor_version = ".".join(version[2:])[2] if minor_version == "4": return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ";" - elif minor_version == "5": + if minor_version == "5": return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ";" - else: - raise FormatImpossible("Unknown version!") - elif len(versions) == 2: + raise FormatImpossible("Unknown version!") + if len(versions) == 2: version_right = versions[1] minor_version_right = ".".join(version_right[2:])[2] if minor_version_right == "4": # Replace with 0.4.25 return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ";" - elif minor_version_right in ["5", "6"]: + if minor_version_right in ["5", "6"]: # Replace with 0.5.3 return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ";" + raise FormatImpossible("Unknown version!") - -def _patch(slither, result, in_file, pragma, modify_loc_start, modify_loc_end): +def _patch(slither, result, in_file, pragma, modify_loc_start, modify_loc_end): # pylint: disable=too-many-arguments in_file_str = slither.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] create_patch( - result, in_file, int(modify_loc_start), int(modify_loc_end), old_str_of_interest, pragma + result, + in_file, + int(modify_loc_start), + int(modify_loc_end), + old_str_of_interest, + pragma, ) diff --git a/slither/formatters/attributes/incorrect_solc.py b/slither/formatters/attributes/incorrect_solc.py index ccec963ff..6d0727701 100644 --- a/slither/formatters/attributes/incorrect_solc.py +++ b/slither/formatters/attributes/incorrect_solc.py @@ -12,10 +12,12 @@ REPLACEMENT_VERSIONS = ["^0.4.25", "^0.5.3"] # 2: version number # 3: version number # 4: version number + +# pylint: disable=anomalous-backslash-in-string PATTERN = re.compile("(\^|>|>=|<|<=)?([ ]+)?(\d+)\.(\d+)\.(\d+)") -def format(slither, result): +def custom_format(slither, result): elements = result["elements"] for element in elements: solc_version_replace = _determine_solc_version_replacement( @@ -40,22 +42,22 @@ def _determine_solc_version_replacement(used_solc_version): if minor_version == "4": # Replace with 0.4.25 return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ";" - elif minor_version == "5": + if minor_version == "5": # Replace with 0.5.3 return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ";" - else: - raise FormatImpossible(f"Unknown version {versions}") - elif len(versions) == 2: + raise FormatImpossible(f"Unknown version {versions}") + if len(versions) == 2: version_right = versions[1] minor_version_right = ".".join(version_right[2:])[2] if minor_version_right == "4": # Replace with 0.4.25 return "pragma solidity " + REPLACEMENT_VERSIONS[0] + ";" - elif minor_version_right in ["5", "6"]: + if minor_version_right in ["5", "6"]: # Replace with 0.5.3 return "pragma solidity " + REPLACEMENT_VERSIONS[1] + ";" + return None - +#pylint: disable=too-many-arguments def _patch(slither, result, in_file, solc_version, modify_loc_start, modify_loc_end): in_file_str = slither.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] diff --git a/slither/formatters/functions/external_function.py b/slither/formatters/functions/external_function.py index 17f6c1361..2ce9f3e96 100644 --- a/slither/formatters/functions/external_function.py +++ b/slither/formatters/functions/external_function.py @@ -2,7 +2,7 @@ import re from slither.formatters.utils.patches import create_patch -def format(slither, result): +def custom_format(slither, result): elements = result["elements"] for element in elements: target_contract = slither.get_contract_from_name( @@ -27,16 +27,22 @@ def _patch(slither, result, in_file, modify_loc_start, modify_loc_end): old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] # Search for 'public' keyword which is in-between the function name and modifier name (if present) # regex: 'public' could have spaces around or be at the end of the line - m = re.search(r"((\spublic)\s+)|(\spublic)$|(\)public)$", old_str_of_interest.decode("utf-8")) + m = re.search( + r"((\spublic)\s+)|(\spublic)$|(\)public)$", old_str_of_interest.decode("utf-8") + ) if m is None: # No visibility specifier exists; public by default. create_patch( result, in_file, # start after the function definition's closing paranthesis - modify_loc_start + len(old_str_of_interest.decode("utf-8").split(")")[0]) + 1, + modify_loc_start + + len(old_str_of_interest.decode("utf-8").split(")")[0]) + + 1, # end is same as start because we insert the keyword `external` at that location - modify_loc_start + len(old_str_of_interest.decode("utf-8").split(")")[0]) + 1, + modify_loc_start + + len(old_str_of_interest.decode("utf-8").split(")")[0]) + + 1, "", " external", ) # replace_text is `external` diff --git a/slither/formatters/naming_convention/naming_convention.py b/slither/formatters/naming_convention/naming_convention.py index 018920800..719f1517f 100644 --- a/slither/formatters/naming_convention/naming_convention.py +++ b/slither/formatters/naming_convention/naming_convention.py @@ -21,7 +21,9 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger("Slither.Format") -def format(slither, result): +# pylint: disable=anomalous-backslash-in-string + +def custom_format(slither, result): elements = result["elements"] for element in elements: target = element["additional_fields"]["target"] @@ -126,13 +128,15 @@ def _name_already_use(slither, name): if not KEY in slither.context: all_names = set() for contract in slither.contracts_derived: - all_names = all_names.union(set([st.name for st in contract.structures])) - all_names = all_names.union(set([f.name for f in contract.functions_and_modifiers])) - all_names = all_names.union(set([e.name for e in contract.enums])) - all_names = all_names.union(set([s.name for s in contract.state_variables])) + all_names = all_names.union({st.name for st in contract.structures}) + all_names = all_names.union( + {f.name for f in contract.functions_and_modifiers} + ) + all_names = all_names.union({e.name for e in contract.enums}) + all_names = all_names.union({s.name for s in contract.state_variables}) for function in contract.functions: - all_names = all_names.union(set([v.name for v in function.variables])) + all_names = all_names.union({v.name for v in function.variables}) slither.context[KEY] = all_names return name in slither.context[KEY] @@ -144,13 +148,17 @@ def _convert_CapWords(original_name, slither): while "_" in name: offset = name.find("_") if len(name) > offset: - name = name[0:offset] + name[offset + 1].upper() + name[offset + 1 :] + name = name[0:offset] + name[offset + 1].upper() + name[offset + 1:] if _name_already_use(slither, name): - raise FormatImpossible(f"{original_name} cannot be converted to {name} (already used)") + raise FormatImpossible( + f"{original_name} cannot be converted to {name} (already used)" + ) if name in SOLIDITY_KEYWORDS: - raise FormatImpossible(f"{original_name} cannot be converted to {name} (Solidity keyword)") + raise FormatImpossible( + f"{original_name} cannot be converted to {name} (Solidity keyword)" + ) return name @@ -162,21 +170,29 @@ def _convert_mixedCase(original_name, slither): while "_" in name: offset = name.find("_") if len(name) > offset: - name = name[0:offset] + name[offset + 1].upper() + name[offset + 2 :] + name = name[0:offset] + name[offset + 1].upper() + name[offset + 2:] name = name[0].lower() + name[1:] if _name_already_use(slither, name): - raise FormatImpossible(f"{original_name} cannot be converted to {name} (already used)") + raise FormatImpossible( + f"{original_name} cannot be converted to {name} (already used)" + ) if name in SOLIDITY_KEYWORDS: - raise FormatImpossible(f"{original_name} cannot be converted to {name} (Solidity keyword)") + raise FormatImpossible( + f"{original_name} cannot be converted to {name} (Solidity keyword)" + ) return name def _convert_UPPER_CASE_WITH_UNDERSCORES(name, slither): if _name_already_use(slither, name.upper()): - raise FormatImpossible(f"{name} cannot be converted to {name.upper()} (already used)") + raise FormatImpossible( + f"{name} cannot be converted to {name.upper()} (already used)" + ) if name.upper() in SOLIDITY_KEYWORDS: - raise FormatImpossible(f"{name} cannot be converted to {name.upper()} (Solidity keyword)") + raise FormatImpossible( + f"{name} cannot be converted to {name.upper()} (Solidity keyword)" + ) return name.upper() @@ -210,15 +226,18 @@ def _get_from_contract(slither, element, name, getter): def _patch(slither, result, element, _target): - if _target == "contract": target = slither.get_contract_from_name(element["name"]) elif _target == "structure": - target = _get_from_contract(slither, element, element["name"], "get_structure_from_name") + target = _get_from_contract( + slither, element, element["name"], "get_structure_from_name" + ) elif _target == "event": - target = _get_from_contract(slither, element, element["name"], "get_event_from_name") + target = _get_from_contract( + slither, element, element["name"], "get_event_from_name" + ) elif _target == "function": # Avoid constructor (FP?) @@ -230,15 +249,17 @@ def _patch(slither, result, element, _target): elif _target == "modifier": modifier_sig = element["type_specific_fields"]["signature"] - target = _get_from_contract(slither, element, modifier_sig, "get_modifier_from_signature") + target = _get_from_contract( + slither, element, modifier_sig, "get_modifier_from_signature" + ) elif _target == "parameter": - contract_name = element["type_specific_fields"]["parent"]["type_specific_fields"]["parent"][ - "name" - ] - function_sig = element["type_specific_fields"]["parent"]["type_specific_fields"][ - "signature" - ] + contract_name = element["type_specific_fields"]["parent"][ + "type_specific_fields" + ]["parent"]["name"] + function_sig = element["type_specific_fields"]["parent"][ + "type_specific_fields" + ]["signature"] param_name = element["name"] contract = slither.get_contract_from_name(contract_name) function = contract.get_function_from_signature(function_sig) @@ -247,12 +268,12 @@ def _patch(slither, result, element, _target): elif _target in ["variable", "variable_constant"]: # Local variable if element["type_specific_fields"]["parent"] == "function": - contract_name = element["type_specific_fields"]["parent"]["type_specific_fields"][ - "parent" - ]["name"] - function_sig = element["type_specific_fields"]["parent"]["type_specific_fields"][ - "signature" - ] + contract_name = element["type_specific_fields"]["parent"][ + "type_specific_fields" + ]["parent"]["name"] + function_sig = element["type_specific_fields"]["parent"][ + "type_specific_fields" + ]["signature"] var_name = element["name"] contract = slither.get_contract_from_name(contract_name) function = contract.get_function_from_signature(function_sig) @@ -271,7 +292,9 @@ def _patch(slither, result, element, _target): else: raise FormatError("Unknown naming convention! " + _target) - _explore(slither, result, target, conventions[element["additional_fields"]["convention"]]) + _explore( + slither, result, target, conventions[element["additional_fields"]["convention"]] + ) # endregion @@ -288,7 +311,13 @@ def _patch(slither, result, element, _target): RE_MAPPING_FROM = b"([a-zA-Z0-9\._\[\]]*)" RE_MAPPING_TO = b"([\=\>\(\) a-zA-Z0-9\._\[\]\ ]*)" RE_MAPPING = ( - b"[ ]*mapping[ ]*\([ ]*" + RE_MAPPING_FROM + b"[ ]*" + b"=>" + b"[ ]*" + RE_MAPPING_TO + b"\)" + b"[ ]*mapping[ ]*\([ ]*" + + RE_MAPPING_FROM + + b"[ ]*" + + b"=>" + + b"[ ]*" + + RE_MAPPING_TO + + b"\)" ) @@ -301,15 +330,17 @@ def _is_var_declaration(slither, filename, start): :return: """ v = "var " - return slither.source_code[filename][start : start + len(v)] == v + return slither.source_code[filename][start: start + len(v)] == v -def _explore_type(slither, result, target, convert, type, filename_source_code, start, end): - if isinstance(type, UserDefinedType): +def _explore_type( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches + slither, result, target, convert, custom_type, filename_source_code, start, end +): + if isinstance(custom_type, UserDefinedType): # Patch type based on contract/enum - if isinstance(type.type, (Enum, Contract)): - if type.type == target: - old_str = type.type.name + if isinstance(custom_type.type, (Enum, Contract)): + if custom_type.type == target: + old_str = custom_type.type.name new_str = convert(old_str, slither) loc_start = start @@ -318,13 +349,15 @@ def _explore_type(slither, result, target, convert, type, filename_source_code, else: loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) else: # Patch type based on structure - assert isinstance(type.type, Structure) - if type.type == target: - old_str = type.type.name + assert isinstance(custom_type.type, Structure) + if custom_type.type == target: + old_str = custom_type.type.name new_str = convert(old_str, slither) loc_start = start @@ -333,15 +366,17 @@ def _explore_type(slither, result, target, convert, type, filename_source_code, else: loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) # Structure contain a list of elements, that might need patching # .elems return a list of VariableStructure _explore_variables_declaration( - slither, type.type.elems.values(), result, target, convert + slither, custom_type.type.elems.values(), result, target, convert ) - if isinstance(type, MappingType): + if isinstance(custom_type, MappingType): # Mapping has three steps: # Convert the "from" type # Convert the "to" type @@ -350,39 +385,42 @@ def _explore_type(slither, result, target, convert, type, filename_source_code, # Do the comparison twice, so we can factor together the re matching # mapping can only have elementary type in type_from - if isinstance(type.type_to, (UserDefinedType, MappingType)) or target in [ - type.type_from, - type.type_to, + if isinstance(custom_type.type_to, (UserDefinedType, MappingType)) or target in [ + custom_type.type_from, + custom_type.type_to, ]: full_txt_start = start full_txt_end = end full_txt = slither.source_code[filename_source_code].encode("utf8")[ - full_txt_start:full_txt_end - ] + full_txt_start:full_txt_end + ] re_match = re.match(RE_MAPPING, full_txt) assert re_match - if type.type_from == target: - old_str = type.type_from.name + if custom_type.type_from == target: + old_str = custom_type.type_from.name new_str = convert(old_str, slither) loc_start = start + re_match.start(1) loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) - - if type.type_to == target: + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) - old_str = type.type_to.name + if custom_type.type_to == target: + old_str = custom_type.type_to.name new_str = convert(old_str, slither) loc_start = start + re_match.start(2) loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) - if isinstance(type.type_to, (UserDefinedType, MappingType)): + if isinstance(custom_type.type_to, (UserDefinedType, MappingType)): loc_start = start + re_match.start(2) loc_end = start + re_match.end(2) _explore_type( @@ -390,15 +428,15 @@ def _explore_type(slither, result, target, convert, type, filename_source_code, result, target, convert, - type.type_to, + custom_type.type_to, filename_source_code, loc_start, loc_end, ) -def _explore_variables_declaration( - slither, variables, result, target, convert, patch_comment=False +def _explore_variables_declaration( # pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks + slither, variables, result, target, convert, patch_comment=False ): for variable in variables: # First explore the type of the variable @@ -406,8 +444,8 @@ def _explore_variables_declaration( full_txt_start = variable.source_mapping["start"] full_txt_end = full_txt_start + variable.source_mapping["length"] full_txt = slither.source_code[filename_source_code].encode("utf8")[ - full_txt_start:full_txt_end - ] + full_txt_start:full_txt_end + ] _explore_type( slither, @@ -428,23 +466,28 @@ def _explore_variables_declaration( loc_start = full_txt_start + full_txt.find(old_str.encode("utf8")) loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) # Patch comment only makes sense for local variable declaration in the parameter list if patch_comment and isinstance(variable, LocalVariable): - if "lines" in variable.source_mapping and variable.source_mapping["lines"]: + if ( + "lines" in variable.source_mapping + and variable.source_mapping["lines"] + ): func = variable.function end_line = func.source_mapping["lines"][0] if variable in func.parameters: idx = len(func.parameters) - func.parameters.index(variable) + 1 first_line = end_line - idx - 2 - potential_comments = slither.source_code[filename_source_code].encode( - "utf8" - ) - potential_comments = potential_comments.splitlines(keepends=True)[ - first_line : end_line - 1 - ] + potential_comments = slither.source_code[ + filename_source_code + ].encode("utf8") + potential_comments = potential_comments.splitlines( + keepends=True + )[first_line: end_line - 1] idx_beginning = func.source_mapping["start"] idx_beginning += -func.source_mapping["starting_column"] + 1 @@ -475,7 +518,9 @@ def _explore_variables_declaration( def _explore_structures_declaration(slither, structures, result, target, convert): for st in structures: # Explore the variable declared within the structure (VariableStructure) - _explore_variables_declaration(slither, st.elems.values(), result, target, convert) + _explore_variables_declaration( + slither, st.elems.values(), result, target, convert + ) # If the structure is the target if st == target: @@ -486,16 +531,20 @@ def _explore_structures_declaration(slither, structures, result, target, convert full_txt_start = st.source_mapping["start"] full_txt_end = full_txt_start + st.source_mapping["length"] full_txt = slither.source_code[filename_source_code].encode("utf8")[ - full_txt_start:full_txt_end - ] + full_txt_start:full_txt_end + ] # The name is after the space matches = re.finditer(b"struct[ ]*", full_txt) # Look for the end offset of the largest list of ' ' - loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end() + loc_start = ( + full_txt_start + max(matches, key=lambda x: len(x.group())).end() + ) loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) def _explore_events_declaration(slither, events, result, target, convert): @@ -513,47 +562,52 @@ def _explore_events_declaration(slither, events, result, target, convert): loc_start = event.source_mapping["start"] loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) def get_ir_variables(ir): - vars = ir.read + all_vars = ir.read if isinstance(ir, (InternalCall, InternalDynamicCall, HighLevelCall)): - vars += [ir.function] + all_vars += [ir.function] if isinstance(ir, (HighLevelCall, Send, LowLevelCall, Transfer)): - vars += [ir.call_value] + all_vars += [ir.call_value] if isinstance(ir, (HighLevelCall, LowLevelCall)): - vars += [ir.call_gas] + all_vars += [ir.call_gas] if isinstance(ir, OperationWithLValue): - vars += [ir.lvalue] + all_vars += [ir.lvalue] - return [v for v in vars if v] + return [v for v in all_vars if v] def _explore_irs(slither, irs, result, target, convert): + # pylint: disable=too-many-locals if irs is None: return for ir in irs: for v in get_ir_variables(ir): if target == v or ( - isinstance(target, Function) - and isinstance(v, Function) - and v.canonical_name == target.canonical_name + isinstance(target, Function) + and isinstance(v, Function) + and v.canonical_name == target.canonical_name ): source_mapping = ir.expression.source_mapping filename_source_code = source_mapping["filename_absolute"] full_txt_start = source_mapping["start"] full_txt_end = full_txt_start + source_mapping["length"] full_txt = slither.source_code[filename_source_code].encode("utf8")[ - full_txt_start:full_txt_end - ] + full_txt_start:full_txt_end + ] if not target.name.encode("utf8") in full_txt: - raise FormatError(f"{target} not found in {full_txt} ({source_mapping}") + raise FormatError( + f"{target} not found in {full_txt} ({source_mapping}" + ) old_str = target.name.encode("utf8") new_str = convert(old_str, slither) @@ -562,24 +616,37 @@ def _explore_irs(slither, irs, result, target, convert): # Can be found multiple time on the same IR # We patch one by one while old_str in full_txt: - target_found_at = full_txt.find((old_str)) - full_txt = full_txt[target_found_at + 1 :] + full_txt = full_txt[target_found_at + 1:] counter += target_found_at loc_start = full_txt_start + counter loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, + filename_source_code, + loc_start, + loc_end, + old_str, + new_str, + ) def _explore_functions(slither, functions, result, target, convert): for function in functions: - _explore_variables_declaration(slither, function.variables, result, target, convert, True) - _explore_irs(slither, function.all_slithir_operations(), result, target, convert) + _explore_variables_declaration( + slither, function.variables, result, target, convert, True + ) + _explore_irs( + slither, function.all_slithir_operations(), result, target, convert + ) - if isinstance(target, Function) and function.canonical_name == target.canonical_name: + if ( + isinstance(target, Function) + and function.canonical_name == target.canonical_name + ): old_str = function.name new_str = convert(old_str, slither) @@ -587,8 +654,8 @@ def _explore_functions(slither, functions, result, target, convert): full_txt_start = function.source_mapping["start"] full_txt_end = full_txt_start + function.source_mapping["length"] full_txt = slither.source_code[filename_source_code].encode("utf8")[ - full_txt_start:full_txt_end - ] + full_txt_start:full_txt_end + ] # The name is after the space if isinstance(target, Modifier): @@ -596,10 +663,14 @@ def _explore_functions(slither, functions, result, target, convert): else: matches = re.finditer(b"function([ ]*)", full_txt) # Look for the end offset of the largest list of ' ' - loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end() + loc_start = ( + full_txt_start + max(matches, key=lambda x: len(x.group())).end() + ) loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) def _explore_enums(slither, enums, result, target, convert): @@ -612,22 +683,32 @@ def _explore_enums(slither, enums, result, target, convert): full_txt_start = enum.source_mapping["start"] full_txt_end = full_txt_start + enum.source_mapping["length"] full_txt = slither.source_code[filename_source_code].encode("utf8")[ - full_txt_start:full_txt_end - ] + full_txt_start:full_txt_end + ] # The name is after the space matches = re.finditer(b"enum([ ]*)", full_txt) # Look for the end offset of the largest list of ' ' - loc_start = full_txt_start + max(matches, key=lambda x: len(x.group())).end() + loc_start = ( + full_txt_start + max(matches, key=lambda x: len(x.group())).end() + ) loc_end = loc_start + len(old_str) - create_patch(result, filename_source_code, loc_start, loc_end, old_str, new_str) + create_patch( + result, filename_source_code, loc_start, loc_end, old_str, new_str + ) def _explore_contract(slither, contract, result, target, convert): - _explore_variables_declaration(slither, contract.state_variables, result, target, convert) - _explore_structures_declaration(slither, contract.structures, result, target, convert) - _explore_functions(slither, contract.functions_and_modifiers, result, target, convert) + _explore_variables_declaration( + slither, contract.state_variables, result, target, convert + ) + _explore_structures_declaration( + slither, contract.structures, result, target, convert + ) + _explore_functions( + slither, contract.functions_and_modifiers, result, target, convert + ) _explore_enums(slither, contract.enums, result, target, convert) if contract == target: @@ -635,8 +716,8 @@ def _explore_contract(slither, contract, result, target, convert): full_txt_start = contract.source_mapping["start"] full_txt_end = full_txt_start + contract.source_mapping["length"] full_txt = slither.source_code[filename_source_code].encode("utf8")[ - full_txt_start:full_txt_end - ] + full_txt_start:full_txt_end + ] old_str = contract.name new_str = convert(old_str, slither) @@ -655,5 +736,4 @@ def _explore(slither, result, target, convert): for contract in slither.contracts_derived: _explore_contract(slither, contract, result, target, convert) - # endregion diff --git a/slither/formatters/utils/patches.py b/slither/formatters/utils/patches.py index 72b25a97e..323bf6c42 100644 --- a/slither/formatters/utils/patches.py +++ b/slither/formatters/utils/patches.py @@ -3,7 +3,7 @@ import difflib from collections import defaultdict -def create_patch(result, file, start, end, old_str, new_str): +def create_patch(result, file, start, end, old_str, new_str): # pylint: disable=too-many-arguments if isinstance(old_str, bytes): old_str = old_str.decode("utf8") if isinstance(new_str, bytes): diff --git a/slither/formatters/variables/possible_const_state_variables.py b/slither/formatters/variables/possible_const_state_variables.py index da52aa89f..3bb81d4b5 100644 --- a/slither/formatters/variables/possible_const_state_variables.py +++ b/slither/formatters/variables/possible_const_state_variables.py @@ -3,7 +3,7 @@ from slither.formatters.exceptions import FormatError, FormatImpossible from slither.formatters.utils.patches import create_patch -def format(slither, result): +def custom_format(slither, result): elements = result["elements"] for element in elements: @@ -12,7 +12,9 @@ def format(slither, result): contract = slither.get_contract_from_name(contract_name) var = contract.get_state_variable_from_name(element["name"]) if not var.expression: - raise FormatImpossible(f"{var.name} is uninitialized and cannot become constant.") + raise FormatImpossible( + f"{var.name} is uninitialized and cannot become constant." + ) _patch( slither, @@ -25,7 +27,9 @@ def format(slither, result): ) -def _patch(slither, result, in_file, match_text, replace_text, modify_loc_start, modify_loc_end): +def _patch( # pylint: disable=too-many-arguments + slither, result, in_file, match_text, replace_text, modify_loc_start, modify_loc_end +): in_file_str = slither.source_code[in_file].encode("utf8") old_str_of_interest = in_file_str[modify_loc_start:modify_loc_end] # Add keyword `constant` before the variable name diff --git a/slither/formatters/variables/unused_state_variables.py b/slither/formatters/variables/unused_state_variables.py index 8f8262383..332371d8c 100644 --- a/slither/formatters/variables/unused_state_variables.py +++ b/slither/formatters/variables/unused_state_variables.py @@ -1,7 +1,7 @@ from slither.formatters.utils.patches import create_patch -def format(slither, result): +def custom_format(slither, result): elements = result["elements"] for element in elements: if element["type"] == "variable": @@ -26,7 +26,11 @@ def _patch(slither, result, in_file, modify_loc_start): in_file, int(modify_loc_start), # Remove the entire declaration until the semicolon - int(modify_loc_start + len(old_str_of_interest.decode("utf-8").partition(";")[0]) + 1), + int( + modify_loc_start + + len(old_str_of_interest.decode("utf-8").partition(";")[0]) + + 1 + ), old_str, "", ) diff --git a/slither/printers/abstract_printer.py b/slither/printers/abstract_printer.py index 7d6968757..7a60aa1ae 100644 --- a/slither/printers/abstract_printer.py +++ b/slither/printers/abstract_printer.py @@ -41,10 +41,10 @@ class AbstractPrinter(metaclass=abc.ABCMeta): def generate_output(self, info, additional_fields=None): if additional_fields is None: additional_fields = {} - d = output.Output(info, additional_fields) - d.data["printer"] = self.ARGUMENT + printer_output = output.Output(info, additional_fields) + printer_output.data["printer"] = self.ARGUMENT - return d + return printer_output @abc.abstractmethod def output(self, filename): diff --git a/slither/printers/all_printers.py b/slither/printers/all_printers.py index 699c0ae94..833ff6494 100644 --- a/slither/printers/all_printers.py +++ b/slither/printers/all_printers.py @@ -1,3 +1,4 @@ +# pylint: disable=unused-import,relative-beyond-top-level from .summary.function import FunctionSummary from .summary.contract import ContractSummary from .inheritance.inheritance import PrinterInheritance diff --git a/slither/printers/call/call_graph.py b/slither/printers/call/call_graph.py index 3085e1513..d0b1be3b9 100644 --- a/slither/printers/call/call_graph.py +++ b/slither/printers/call/call_graph.py @@ -36,138 +36,175 @@ def _node(node, label=None): return " ".join((f'"{node}"', f'[label="{label}"]' if label is not None else "",)) -class PrinterCallGraph(AbstractPrinter): - ARGUMENT = "call-graph" - HELP = "Export the call-graph of the contracts to a dot file" +# pylint: disable=too-many-arguments +def _process_internal_call( + contract, + function, + internal_call, + contract_calls, + solidity_functions, + solidity_calls, +): + if isinstance(internal_call, (Function)): + contract_calls[contract].add( + _edge( + _function_node(contract, function), + _function_node(contract, internal_call), + ) + ) + elif isinstance(internal_call, (SolidityFunction)): + solidity_functions.add(_node(_solidity_function_node(internal_call)), ) + solidity_calls.add( + _edge( + _function_node(contract, function), + _solidity_function_node(internal_call), + ) + ) - WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph" - def _process_functions(self, functions): +def _render_external_calls(external_calls): + return "\n".join(external_calls) - contract_functions = defaultdict(set) # contract -> contract functions nodes - contract_calls = defaultdict(set) # contract -> contract calls edges - solidity_functions = set() # solidity function nodes - solidity_calls = set() # solidity calls edges - external_calls = set() # external calls edges +def _render_internal_calls(contract, contract_functions, contract_calls): + lines = [] - all_contracts = set() + lines.append(f"subgraph {_contract_subgraph(contract)} {{") + lines.append(f'label = "{contract.name}"') - for function in functions: - all_contracts.add(function.contract_declarer) - for function in functions: - self._process_function( - function.contract_declarer, - function, - contract_functions, - contract_calls, - solidity_functions, - solidity_calls, - external_calls, - all_contracts, - ) + lines.extend(contract_functions[contract]) + lines.extend(contract_calls[contract]) - render_internal_calls = "" - for contract in all_contracts: - render_internal_calls += self._render_internal_calls( - contract, contract_functions, contract_calls - ) + lines.append("}") + + return "\n".join(lines) + + +def _render_solidity_calls(solidity_functions, solidity_calls): + lines = [] - render_solidity_calls = self._render_solidity_calls(solidity_functions, solidity_calls) + lines.append("subgraph cluster_solidity {") + lines.append('label = "[Solidity]"') - render_external_calls = self._render_external_calls(external_calls) + lines.extend(solidity_functions) + lines.extend(solidity_calls) - return render_internal_calls + render_solidity_calls + render_external_calls + lines.append("}") - def _process_function( - self, + return "\n".join(lines) + + +def _process_external_call( contract, function, + external_call, contract_functions, - contract_calls, - solidity_functions, - solidity_calls, external_calls, all_contracts, - ): - contract_functions[contract].add(_node(_function_node(contract, function), function.name),) - - for internal_call in function.internal_calls: - self._process_internal_call( - contract, - function, - internal_call, - contract_calls, - solidity_functions, - solidity_calls, - ) - for external_call in function.high_level_calls: - self._process_external_call( - contract, function, external_call, contract_functions, external_calls, all_contracts - ) - - def _process_internal_call( - self, contract, function, internal_call, contract_calls, solidity_functions, solidity_calls - ): - if isinstance(internal_call, (Function)): - contract_calls[contract].add( - _edge(_function_node(contract, function), _function_node(contract, internal_call),) - ) - elif isinstance(internal_call, (SolidityFunction)): - solidity_functions.add(_node(_solidity_function_node(internal_call)),) - solidity_calls.add( - _edge(_function_node(contract, function), _solidity_function_node(internal_call),) - ) +): + external_contract, external_function = external_call - def _process_external_call( - self, contract, function, external_call, contract_functions, external_calls, all_contracts - ): - external_contract, external_function = external_call - - if not external_contract in all_contracts: - return - - # add variable as node to respective contract - if isinstance(external_function, (Variable)): - contract_functions[external_contract].add( - _node(_function_node(external_contract, external_function), external_function.name) - ) + if not external_contract in all_contracts: + return - external_calls.add( - _edge( - _function_node(contract, function), + # add variable as node to respective contract + if isinstance(external_function, (Variable)): + contract_functions[external_contract].add( + _node( _function_node(external_contract, external_function), + external_function.name, ) ) - def _render_internal_calls(self, contract, contract_functions, contract_calls): - lines = [] + external_calls.add( + _edge( + _function_node(contract, function), + _function_node(external_contract, external_function), + ) + ) + - lines.append(f"subgraph {_contract_subgraph(contract)} {{") - lines.append(f'label = "{contract.name}"') +# pylint: disable=too-many-arguments +def _process_function( + contract, + function, + contract_functions, + contract_calls, + solidity_functions, + solidity_calls, + external_calls, + all_contracts, +): + contract_functions[contract].add( + _node(_function_node(contract, function), function.name), + ) + + for internal_call in function.internal_calls: + _process_internal_call( + contract, + function, + internal_call, + contract_calls, + solidity_functions, + solidity_calls, + ) + for external_call in function.high_level_calls: + _process_external_call( + contract, + function, + external_call, + contract_functions, + external_calls, + all_contracts, + ) - lines.extend(contract_functions[contract]) - lines.extend(contract_calls[contract]) - lines.append("}") +def _process_functions(functions): + contract_functions = defaultdict(set) # contract -> contract functions nodes + contract_calls = defaultdict(set) # contract -> contract calls edges + + solidity_functions = set() # solidity function nodes + solidity_calls = set() # solidity calls edges + external_calls = set() # external calls edges + + all_contracts = set() + + for function in functions: + all_contracts.add(function.contract_declarer) + for function in functions: + _process_function( + function.contract_declarer, + function, + contract_functions, + contract_calls, + solidity_functions, + solidity_calls, + external_calls, + all_contracts, + ) - return "\n".join(lines) + render_internal_calls = "" + for contract in all_contracts: + render_internal_calls += _render_internal_calls( + contract, contract_functions, contract_calls + ) - def _render_solidity_calls(self, solidity_functions, solidity_calls): - lines = [] + render_solidity_calls = _render_solidity_calls( + solidity_functions, solidity_calls + ) - lines.append("subgraph cluster_solidity {") - lines.append('label = "[Solidity]"') + render_external_calls = _render_external_calls(external_calls) - lines.extend(solidity_functions) - lines.extend(solidity_calls) + return render_internal_calls + render_solidity_calls + render_external_calls - lines.append("}") - return "\n".join(lines) +class PrinterCallGraph(AbstractPrinter): + ARGUMENT = "call-graph" + HELP = "Export the call-graph of the contracts to a dot file" - def _render_external_calls(self, external_calls): - return "\n".join(external_calls) + WIKI = ( + "https://github.com/trailofbits/slither/wiki/Printer-documentation#call-graph" + ) def output(self, filename): """ @@ -186,7 +223,9 @@ class PrinterCallGraph(AbstractPrinter): with open(filename, "w", encoding="utf8") as f: info += f"Call Graph: {filename}\n" content = "\n".join( - ["strict digraph {"] + [self._process_functions(self.slither.functions)] + ["}"] + ["strict digraph {"] + + [_process_functions(self.slither.functions)] + + ["}"] ) f.write(content) results.append((filename, content)) @@ -196,7 +235,7 @@ class PrinterCallGraph(AbstractPrinter): info += f"Call Graph: {derived_contract.name}.dot\n" content = "\n".join( ["strict digraph {"] - + [self._process_functions(derived_contract.functions)] + + [_process_functions(derived_contract.functions)] + ["}"] ) f.write(content) @@ -204,7 +243,7 @@ class PrinterCallGraph(AbstractPrinter): self.info(info) res = self.generate_output(info) - for filename, content in results: - res.add_file(filename, content) + for filename_result, content in results: + res.add_file(filename_result, content) return res diff --git a/slither/printers/functions/authorization.py b/slither/printers/functions/authorization.py index af66b3b9d..405bdce49 100644 --- a/slither/printers/functions/authorization.py +++ b/slither/printers/functions/authorization.py @@ -49,10 +49,16 @@ class PrinterWrittenVariablesAndAuthorization(AbstractPrinter): ) for function in contract.functions: - state_variables_written = [v.name for v in function.all_state_variables_written()] + state_variables_written = [ + v.name for v in function.all_state_variables_written() + ] msg_sender_condition = self.get_msg_sender_checks(function) table.add_row( - [function.name, str(state_variables_written), str(msg_sender_condition)] + [ + function.name, + str(state_variables_written), + str(msg_sender_condition), + ] ) all_tables.append((contract.name, table)) txt += str(table) + "\n" diff --git a/slither/printers/functions/cfg.py b/slither/printers/functions/cfg.py index 963c98bba..340656b88 100644 --- a/slither/printers/functions/cfg.py +++ b/slither/printers/functions/cfg.py @@ -1,6 +1,3 @@ -""" -""" - from slither.printers.abstract_printer import AbstractPrinter @@ -11,7 +8,7 @@ class CFG(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#cfg" - def output(self, original_filename): + def output(self, filename): """ _filename is not used Args: @@ -24,9 +21,9 @@ class CFG(AbstractPrinter): if contract.is_top_level: continue for function in contract.functions + contract.modifiers: - if original_filename: + if filename: filename = "{}-{}-{}.dot".format( - original_filename, contract.name, function.full_name + filename, contract.name, function.full_name ) else: filename = "{}-{}.dot".format(contract.name, function.full_name) @@ -39,6 +36,6 @@ class CFG(AbstractPrinter): self.info(info) res = self.generate_output(info) - for filename, content in all_files: - res.add_file(filename, content) + for filename_result, content in all_files: + res.add_file(filename_result, content) return res diff --git a/slither/printers/guidance/echidna.py b/slither/printers/guidance/echidna.py index 46031be45..2afe0a101 100644 --- a/slither/printers/guidance/echidna.py +++ b/slither/printers/guidance/echidna.py @@ -1,9 +1,6 @@ -""" -""" - import json from collections import defaultdict -from typing import Dict, List, Set, Tuple, Union, NamedTuple +from typing import Dict, List, Set, Tuple, NamedTuple from slither.analyses.data_dependency.data_dependency import is_dependent from slither.core.cfg.node import Node @@ -32,20 +29,22 @@ from slither.slithir.operations import ( TypeConversion, Balance, ) -from slither.slithir.operations.binary import Binary, BinaryType +from slither.slithir.operations.binary import Binary from slither.slithir.variables import Constant def _get_name(f: Function) -> str: if f.is_fallback or f.is_receive: - return f"()" + return "()" return f.solidity_signature def _extract_payable(slither: SlitherCore) -> Dict[str, List[str]]: ret: Dict[str, List[str]] = {} for contract in slither.contracts: - payable_functions = [_get_name(f) for f in contract.functions_entry_points if f.payable] + payable_functions = [ + _get_name(f) for f in contract.functions_entry_points if f.payable + ] if payable_functions: ret[contract.name] = payable_functions return ret @@ -67,7 +66,7 @@ def _extract_solidity_variable_usage( return ret -def _is_constant(f: Function) -> bool: +def _is_constant(f: Function) -> bool: # pylint: disable=too-many-branches """ Heuristic: - If view/pure with Solidity >= 0.4 -> Return true @@ -79,7 +78,9 @@ def _is_constant(f: Function) -> bool: :return: """ if f.view or f.pure: - if not f.contract.slither.crytic_compile.compiler_version.version.startswith("0.4"): + if not f.contract.slither.crytic_compile.compiler_version.version.startswith( + "0.4" + ): return True if f.payable: return False @@ -100,9 +101,15 @@ def _is_constant(f: Function) -> bool: ]: return False if isinstance(ir, HighLevelCall): - if isinstance(ir.function, Variable) or ir.function.view or ir.function.pure: + if ( + isinstance(ir.function, Variable) + or ir.function.view + or ir.function.pure + ): # External call to constant functions are ensured to be constant only for solidity >= 0.5 - if f.contract.slither.crytic_compile.compiler_version.version.startswith("0.4"): + if f.contract.slither.crytic_compile.compiler_version.version.startswith( + "0.4" + ): return False else: return False @@ -116,9 +123,13 @@ def _is_constant(f: Function) -> bool: def _extract_constant_functions(slither: SlitherCore) -> Dict[str, List[str]]: ret: Dict[str, List[str]] = {} for contract in slither.contracts: - cst_functions = [_get_name(f) for f in contract.functions_entry_points if _is_constant(f)] + cst_functions = [ + _get_name(f) for f in contract.functions_entry_points if _is_constant(f) + ] cst_functions += [ - v.function_name for v in contract.state_variables if v.visibility in ["public"] + v.function_name + for v in contract.state_variables + if v.visibility in ["public"] ] if cst_functions: ret[contract.name] = cst_functions @@ -141,8 +152,12 @@ def _extract_assert(slither: SlitherCore) -> Dict[str, List[str]]: # Create a named tuple that is serialization in json def json_serializable(cls): + # pylint: disable=unnecessary-comprehension def as_dict(self): - yield {name: value for name, value in zip(self._fields, iter(super(cls, self).__iter__()))} + yield { + name: value + for name, value in zip(self._fields, iter(super(cls, self).__iter__())) + } cls.__iter__ = as_dict return cls @@ -157,7 +172,7 @@ class ConstantValue(NamedTuple): type: str -def _extract_constants_from_irs( +def _extract_constants_from_irs( # pylint: disable=too-many-branches,too-many-nested-blocks irs: List[Operation], all_cst_used: List[ConstantValue], all_cst_used_in_binary: Dict[str, List[ConstantValue]], @@ -185,14 +200,13 @@ def _extract_constants_from_irs( if r.node_initialization.irs: if r.node_initialization in context_explored: continue - else: - context_explored.add(r.node_initialization) - _extract_constants_from_irs( - r.node_initialization.irs, - all_cst_used, - all_cst_used_in_binary, - context_explored, - ) + context_explored.add(r.node_initialization) + _extract_constants_from_irs( + r.node_initialization.irs, + all_cst_used, + all_cst_used_in_binary, + context_explored, + ) def _extract_constants( @@ -201,7 +215,9 @@ def _extract_constants( # contract -> function -> [ {"value": value, "type": type} ] ret_cst_used: Dict[str, Dict[str, List[ConstantValue]]] = defaultdict(dict) # contract -> function -> binary_operand -> [ {"value": value, "type": type ] - ret_cst_used_in_binary: Dict[str, Dict[str, Dict[str, List[ConstantValue]]]] = defaultdict(dict) + ret_cst_used_in_binary: Dict[ + str, Dict[str, Dict[str, List[ConstantValue]]] + ] = defaultdict(dict) for contract in slither.contracts: for function in contract.functions_entry_points: all_cst_used: List = [] @@ -219,7 +235,9 @@ def _extract_constants( # Note: use list(set()) instead of set # As this is meant to be serialized in JSON, and JSON does not support set if all_cst_used: - ret_cst_used[contract.name][_get_name(function)] = list(set(all_cst_used)) + ret_cst_used[contract.name][_get_name(function)] = list( + set(all_cst_used) + ) if all_cst_used_in_binary: ret_cst_used_in_binary[contract.name][_get_name(function)] = { k: list(set(v)) for k, v in all_cst_used_in_binary.items() @@ -227,7 +245,9 @@ def _extract_constants( return ret_cst_used, ret_cst_used_in_binary -def _extract_function_relations(slither: SlitherCore) -> Dict[str, Dict[str, Dict[str, List[str]]]]: +def _extract_function_relations( + slither: SlitherCore, +) -> Dict[str, Dict[str, Dict[str, List[str]]]]: # contract -> function -> [functions] ret: Dict[str, Dict[str, Dict[str, List[str]]]] = defaultdict(dict) for contract in slither.contracts: @@ -241,10 +261,15 @@ def _extract_function_relations(slither: SlitherCore) -> Dict[str, Dict[str, Dic for function in contract.functions_entry_points } for function in contract.functions_entry_points: - ret[contract.name][_get_name(function)] = {"impacts": [], "is_impacted_by": []} + ret[contract.name][_get_name(function)] = { + "impacts": [], + "is_impacted_by": [], + } for candidate, varsWritten in written.items(): if any((r in varsWritten for r in function.all_state_variables_read())): - ret[contract.name][_get_name(function)]["is_impacted_by"].append(candidate) + ret[contract.name][_get_name(function)]["is_impacted_by"].append( + candidate + ) for candidate, varsRead in read.items(): if any((r in varsRead for r in function.all_state_variables_written())): ret[contract.name][_get_name(function)]["impacts"].append(candidate) @@ -292,7 +317,7 @@ def _call_a_parameter(slither: SlitherCore) -> Dict[str, List[Dict]]: """ # contract -> [ (function, idx, interface_called) ] ret: Dict[str, List[Dict]] = defaultdict(list) - for contract in slither.contracts: + for contract in slither.contracts: # pylint: disable=too-many-nested-blocks for function in contract.functions_entry_points: for ir in function.all_slithir_operations(): if isinstance(ir, HighLevelCall): @@ -324,7 +349,7 @@ class Echidna(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#echidna" - def output(self, filename): + def output(self, filename): # pylint: disable=too-many-locals """ Output the inheritance relation diff --git a/slither/printers/inheritance/inheritance.py b/slither/printers/inheritance/inheritance.py index 79d009f83..6ad63ddf8 100644 --- a/slither/printers/inheritance/inheritance.py +++ b/slither/printers/inheritance/inheritance.py @@ -12,7 +12,9 @@ class PrinterInheritance(AbstractPrinter): ARGUMENT = "inheritance" HELP = "Print the inheritance relations between contracts" - WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#inheritance" + WIKI = ( + "https://github.com/trailofbits/slither/wiki/Printer-documentation#inheritance" + ) def _get_child_contracts(self, base): # Generate function to get all child contracts of a base contract @@ -31,7 +33,7 @@ class PrinterInheritance(AbstractPrinter): info = "Inheritance\n" if not self.contracts: - return + return [] info += blue("Child_Contract -> ") + green("Immediate_Base_Contracts") info += green(" [Not_Immediate_Base_Contracts]") @@ -49,14 +51,18 @@ class PrinterInheritance(AbstractPrinter): not_immediate = [i for i in child.inheritance if i not in immediate] info += " -> " + green(", ".join(map(str, immediate))) + "\n" - result["child_to_base"][child.name]["immediate"] = list(map(str, immediate)) + result["child_to_base"][child.name]["immediate"] = list( + map(str, immediate) + ) if not_immediate: info += ", [" + green(", ".join(map(str, not_immediate))) + "]\n" result["child_to_base"][child.name]["not_immediate"] = list( map(str, not_immediate) ) - info += green("\n\nBase_Contract -> ") + blue("Immediate_Child_Contracts") + "\n" + info += ( + green("\n\nBase_Contract -> ") + blue("Immediate_Child_Contracts") + "\n" + ) info += blue(" [Not_Immediate_Child_Contracts]") + "\n" result["base_to_child"] = {} @@ -68,14 +74,22 @@ class PrinterInheritance(AbstractPrinter): result["base_to_child"][base.name] = {"immediate": [], "not_immediate": []} if children: - immediate = [child for child in children if base in child.immediate_inheritance] + immediate = [ + child for child in children if base in child.immediate_inheritance + ] not_immediate = [child for child in children if not child in immediate] info += " -> " + blue(", ".join(map(str, immediate))) + "\n" - result["base_to_child"][base.name]["immediate"] = list(map(str, immediate)) + result["base_to_child"][base.name]["immediate"] = list( + map(str, immediate) + ) if not_immediate: - info += ", [" + blue(", ".join(map(str, not_immediate))) + "]" + "\n" - result["base_to_child"][base.name]["not_immediate"] = list(map(str, immediate)) + info += ( + ", [" + blue(", ".join(map(str, not_immediate))) + "]" + "\n" + ) + result["base_to_child"][base.name]["not_immediate"] = list( + map(str, immediate) + ) self.info(info) res = self.generate_output(info, additional_fields=result) diff --git a/slither/printers/inheritance/inheritance_graph.py b/slither/printers/inheritance/inheritance_graph.py index a560abe25..624467917 100644 --- a/slither/printers/inheritance/inheritance_graph.py +++ b/slither/printers/inheritance/inheritance_graph.py @@ -15,6 +15,22 @@ from slither.utils.inheritance_analysis import ( ) +def _get_pattern_func(func): + # Html pattern, each line is a row in a table + func_name = func.full_name + pattern = ' %s' + pattern_shadow = ( + ' %s' + ) + if func.shadows: + return pattern_shadow % func_name + return pattern % func_name + + +def _get_port_id(var, contract): + return "%s%s" % (var.name, contract.name) + + class PrinterInheritanceGraph(AbstractPrinter): ARGUMENT = "inheritance-graph" HELP = "Export the inheritance graph of each contract to a dot file" @@ -25,7 +41,7 @@ class PrinterInheritanceGraph(AbstractPrinter): super(PrinterInheritanceGraph, self).__init__(slither, logger) inheritance = [x.inheritance for x in slither.contracts] - self.inheritance = set([item for sublist in inheritance for item in sublist]) + self.inheritance = {item for sublist in inheritance for item in sublist} self.overshadowing_state_variables = {} shadows = detect_state_variable_shadowing(slither.contracts) @@ -36,37 +52,30 @@ class PrinterInheritanceGraph(AbstractPrinter): # Add overshadowing variable entry. if overshadowing_state_var not in self.overshadowing_state_variables: self.overshadowing_state_variables[overshadowing_state_var] = set() - self.overshadowing_state_variables[overshadowing_state_var].add(overshadowed_state_var) - - def _get_pattern_func(self, func, contract): - # Html pattern, each line is a row in a table - func_name = func.full_name - pattern = ' %s' - pattern_shadow = ' %s' - if func.shadows: - return pattern_shadow % func_name - return pattern % func_name + self.overshadowing_state_variables[overshadowing_state_var].add( + overshadowed_state_var + ) - def _get_pattern_var(self, var, contract): + def _get_pattern_var(self, var): # Html pattern, each line is a row in a table var_name = var.name pattern = ' %s' - pattern_contract = ( - ' %s (%s)' + pattern_contract = ' %s (%s)' + pattern_shadow = ( + ' %s' ) - pattern_shadow = ' %s' pattern_contract_shadow = ' %s (%s)' - if isinstance(var.type, UserDefinedType) and isinstance(var.type.type, Contract): + if isinstance(var.type, UserDefinedType) and isinstance( + var.type.type, Contract + ): if var in self.overshadowing_state_variables: return pattern_contract_shadow % (var_name, var.type.type.name) - else: - return pattern_contract % (var_name, var.type.type.name) - else: - if var in self.overshadowing_state_variables: - return pattern_shadow % var_name - else: - return pattern % var_name + return pattern_contract % (var_name, var.type.type.name) + + if var in self.overshadowing_state_variables: + return pattern_shadow % var_name + return pattern % var_name @staticmethod def _get_indirect_shadowing_information(contract): @@ -89,9 +98,6 @@ class PrinterInheritanceGraph(AbstractPrinter): ) return "\n".join(result) - def _get_port_id(self, var, contract): - return "%s%s" % (var.name, contract.name) - def _summary(self, contract): """ Build summary using HTML @@ -112,7 +118,7 @@ class PrinterInheritanceGraph(AbstractPrinter): # Functions visibilities = ["public", "external"] public_functions = [ - self._get_pattern_func(f, contract) + _get_pattern_func(f) for f in contract.functions if not f.is_constructor and not f.is_constructor_variables @@ -121,7 +127,7 @@ class PrinterInheritanceGraph(AbstractPrinter): ] public_functions = "".join(public_functions) private_functions = [ - self._get_pattern_func(f, contract) + _get_pattern_func(f) for f in contract.functions if not f.is_constructor and not f.is_constructor_variables @@ -132,7 +138,7 @@ class PrinterInheritanceGraph(AbstractPrinter): # Modifiers modifiers = [ - self._get_pattern_func(m, contract) + _get_pattern_func(m) for m in contract.modifiers if m.contract_declarer == contract ] @@ -140,21 +146,23 @@ class PrinterInheritanceGraph(AbstractPrinter): # Public variables public_variables = [ - self._get_pattern_var(v, contract) + self._get_pattern_var(v) for v in contract.state_variables_declared if v.visibility in visibilities ] public_variables = "".join(public_variables) private_variables = [ - self._get_pattern_var(v, contract) + self._get_pattern_var(v) for v in contract.state_variables_declared if v.visibility not in visibilities ] private_variables = "".join(private_variables) # Obtain any indirect shadowing information for this node. - indirect_shadowing_information = self._get_indirect_shadowing_information(contract) + indirect_shadowing_information = self._get_indirect_shadowing_information( + contract + ) # Build the node label ret += '%s[shape="box"' % contract.name diff --git a/slither/printers/summary/constructor_calls.py b/slither/printers/summary/constructor_calls.py index c6205b02c..38c0c0089 100644 --- a/slither/printers/summary/constructor_calls.py +++ b/slither/printers/summary/constructor_calls.py @@ -6,7 +6,9 @@ from slither.utils import output class ConstructorPrinter(AbstractPrinter): - WIKI = "https://github.com/crytic/slither/wiki/Printer-documentation#constructor-calls" + WIKI = ( + "https://github.com/crytic/slither/wiki/Printer-documentation#constructor-calls" + ) ARGUMENT = "constructor-calls" HELP = "Print the constructors executed" diff --git a/slither/printers/summary/contract.py b/slither/printers/summary/contract.py index 7196ad7ae..36bbc482b 100644 --- a/slither/printers/summary/contract.py +++ b/slither/printers/summary/contract.py @@ -13,7 +13,7 @@ class ContractSummary(AbstractPrinter): WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#contract-summary" - def output(self, _filename): + def output(self, _filename): # pylint: disable=too-many-locals """ _filename is not used Args: @@ -71,14 +71,25 @@ class ContractSummary(AbstractPrinter): for function in functions: if function.visibility in ["external", "public"]: txt += green( - " - {} ({})\n".format(function.full_name, function.visibility) + " - {} ({})\n".format( + function.full_name, function.visibility + ) ) if function.visibility in ["internal", "private"]: txt += magenta( - " - {} ({})\n".format(function.full_name, function.visibility) + " - {} ({})\n".format( + function.full_name, function.visibility + ) + ) + if function.visibility not in [ + "external", + "public", + "internal", + "private", + ]: + txt += " - {}  ({})\n".format( + function.full_name, function.visibility ) - if function.visibility not in ["external", "public", "internal", "private"]: - txt += " - {}  ({})\n".format(function.full_name, function.visibility) additional_fields.add( function, additional_fields={"visibility": function.visibility} diff --git a/slither/printers/summary/data_depenency.py b/slither/printers/summary/data_depenency.py index cef518119..78dace88e 100644 --- a/slither/printers/summary/data_depenency.py +++ b/slither/printers/summary/data_depenency.py @@ -9,15 +9,11 @@ from slither.utils.myprettytable import MyPrettyTable def _get(v, c): - return list( - set( - [ - d.name - for d in get_dependencies(v, c) - if not isinstance(d, (TemporaryVariable, ReferenceVariable)) - ] - ) - ) + return list({ + d.name + for d in get_dependencies(v, c) + if not isinstance(d, (TemporaryVariable, ReferenceVariable)) + }) class DataDependency(AbstractPrinter): diff --git a/slither/printers/summary/evm.py b/slither/printers/summary/evm.py index c7eb472f0..540f9bd38 100644 --- a/slither/printers/summary/evm.py +++ b/slither/printers/summary/evm.py @@ -2,10 +2,57 @@ Module printing evm mapping of the contract """ from slither.printers.abstract_printer import AbstractPrinter -from slither.analyses.evm import generate_source_to_evm_ins_mapping, load_evm_cfg_builder +from slither.analyses.evm import ( + generate_source_to_evm_ins_mapping, + load_evm_cfg_builder, +) from slither.utils.colors import blue, green, magenta, red +def _extract_evm_info(slither): + """ + Extract evm information for all derived contracts using evm_cfg_builder + + Returns: evm CFG and Solidity source to Program Counter (pc) mapping + """ + + evm_info = {} + + CFG = load_evm_cfg_builder() + + for contract in slither.contracts_derived: + contract_bytecode_runtime = slither.crytic_compile.bytecode_runtime( + contract.name + ) + contract_srcmap_runtime = slither.crytic_compile.srcmap_runtime( + contract.name + ) + cfg = CFG(contract_bytecode_runtime) + evm_info["cfg", contract.name] = cfg + evm_info["mapping", contract.name] = generate_source_to_evm_ins_mapping( + cfg.instructions, + contract_srcmap_runtime, + slither, + contract.source_mapping["filename_absolute"], + ) + + contract_bytecode_init = slither.crytic_compile.bytecode_init(contract.name) + contract_srcmap_init = slither.crytic_compile.srcmap_init(contract.name) + cfg_init = CFG(contract_bytecode_init) + + evm_info["cfg_init", contract.name] = cfg_init + evm_info[ + "mapping_init", contract.name + ] = generate_source_to_evm_ins_mapping( + cfg_init.instructions, + contract_srcmap_init, + slither, + contract.source_mapping["filename_absolute"], + ) + + return evm_info + + class PrinterEVM(AbstractPrinter): ARGUMENT = "evm" HELP = "Print the evm instructions of nodes in functions" @@ -26,7 +73,7 @@ class PrinterEVM(AbstractPrinter): self.info(red(txt)) res = self.generate_output(txt) return res - evm_info = self._extract_evm_info(self.slither) + evm_info = _extract_evm_info(self.slither) for contract in self.slither.contracts_derived: txt += blue("Contract {}\n".format(contract.name)) @@ -55,12 +102,15 @@ class PrinterEVM(AbstractPrinter): for node in function.nodes: txt += green("\t\tNode: " + str(node) + "\n") node_source_line = ( - contract_file[0 : node.source_mapping["start"]].count("\n".encode("utf-8")) + contract_file[0 : node.source_mapping["start"]].count( + "\n".encode("utf-8") + ) + 1 ) txt += green( "\t\tSource line {}: {}\n".format( - node_source_line, contract_file_lines[node_source_line - 1].rstrip() + node_source_line, + contract_file_lines[node_source_line - 1].rstrip(), ) ) txt += magenta("\t\tEVM Instructions:\n") @@ -77,12 +127,15 @@ class PrinterEVM(AbstractPrinter): for node in modifier.nodes: txt += green("\t\tNode: " + str(node) + "\n") node_source_line = ( - contract_file[0 : node.source_mapping["start"]].count("\n".encode("utf-8")) + contract_file[0 : node.source_mapping["start"]].count( + "\n".encode("utf-8") + ) + 1 ) txt += green( "\t\tSource line {}: {}\n".format( - node_source_line, contract_file_lines[node_source_line - 1].rstrip() + node_source_line, + contract_file_lines[node_source_line - 1].rstrip(), ) ) txt += magenta("\t\tEVM Instructions:\n") @@ -97,40 +150,3 @@ class PrinterEVM(AbstractPrinter): self.info(txt) res = self.generate_output(txt) return res - - def _extract_evm_info(self, slither): - """ - Extract evm information for all derived contracts using evm_cfg_builder - - Returns: evm CFG and Solidity source to Program Counter (pc) mapping - """ - - evm_info = {} - - CFG = load_evm_cfg_builder() - - for contract in slither.contracts_derived: - contract_bytecode_runtime = slither.crytic_compile.bytecode_runtime(contract.name) - contract_srcmap_runtime = slither.crytic_compile.srcmap_runtime(contract.name) - cfg = CFG(contract_bytecode_runtime) - evm_info["cfg", contract.name] = cfg - evm_info["mapping", contract.name] = generate_source_to_evm_ins_mapping( - cfg.instructions, - contract_srcmap_runtime, - slither, - contract.source_mapping["filename_absolute"], - ) - - contract_bytecode_init = slither.crytic_compile.bytecode_init(contract.name) - contract_srcmap_init = slither.crytic_compile.srcmap_init(contract.name) - cfg_init = CFG(contract_bytecode_init) - - evm_info["cfg_init", contract.name] = cfg_init - evm_info["mapping_init", contract.name] = generate_source_to_evm_ins_mapping( - cfg_init.instructions, - contract_srcmap_init, - slither, - contract.source_mapping["filename_absolute"], - ) - - return evm_info diff --git a/slither/printers/summary/function.py b/slither/printers/summary/function.py index 31d6ce75f..5c14d0681 100644 --- a/slither/printers/summary/function.py +++ b/slither/printers/summary/function.py @@ -22,7 +22,7 @@ class FunctionSummary(AbstractPrinter): return "\n".join(l) return str(l) - def output(self, _filename): + def output(self, _filename): # pylint: disable=too-many-locals """ _filename is not used Args: @@ -65,11 +65,26 @@ class FunctionSummary(AbstractPrinter): internal_calls = self._convert(internal_calls) external_calls = self._convert(external_calls) table.add_row( - [f_name, visi, modifiers, read, write, internal_calls, external_calls] + [ + f_name, + visi, + modifiers, + read, + write, + internal_calls, + external_calls, + ] ) txt += "\n \n" + str(table) table = MyPrettyTable( - ["Modifiers", "Visibility", "Read", "Write", "Internal Calls", "External Calls"] + [ + "Modifiers", + "Visibility", + "Read", + "Write", + "Internal Calls", + "External Calls", + ] ) for ( _c_name, @@ -85,7 +100,9 @@ class FunctionSummary(AbstractPrinter): write = self._convert(write) internal_calls = self._convert(internal_calls) external_calls = self._convert(external_calls) - table.add_row([f_name, visi, read, write, internal_calls, external_calls]) + table.add_row( + [f_name, visi, read, write, internal_calls, external_calls] + ) txt += "\n\n" + str(table) txt += "\n" self.info(txt) diff --git a/slither/printers/summary/function_ids.py b/slither/printers/summary/function_ids.py index 50dbafd87..ad5788260 100644 --- a/slither/printers/summary/function_ids.py +++ b/slither/printers/summary/function_ids.py @@ -11,7 +11,9 @@ class FunctionIds(AbstractPrinter): ARGUMENT = "function-id" HELP = "Print the keccack256 signature of the functions" - WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#function-id" + WIKI = ( + "https://github.com/trailofbits/slither/wiki/Printer-documentation#function-id" + ) def output(self, _filename): """ @@ -28,7 +30,9 @@ class FunctionIds(AbstractPrinter): for function in contract.functions: if function.visibility in ["public", "external"]: function_id = get_function_id(function.solidity_signature) - table.add_row([function.solidity_signature, f"{function_id:#0{10}x}"]) + table.add_row( + [function.solidity_signature, f"{function_id:#0{10}x}"] + ) for variable in contract.state_variables: if variable.visibility in ["public"]: sig = variable.function_name diff --git a/slither/printers/summary/human_summary.py b/slither/printers/summary/human_summary.py index 77fc25581..78b867566 100644 --- a/slither/printers/summary/human_summary.py +++ b/slither/printers/summary/human_summary.py @@ -8,7 +8,13 @@ from typing import Tuple, List, Dict from slither.core.declarations import SolidityFunction, Function from slither.core.variables.state_variable import StateVariable from slither.printers.abstract_printer import AbstractPrinter -from slither.slithir.operations import LowLevelCall, HighLevelCall, Transfer, Send, SolidityCall +from slither.slithir.operations import ( + LowLevelCall, + HighLevelCall, + Transfer, + Send, + SolidityCall, +) from slither.utils import output from slither.utils.code_complexity import compute_cyclomatic_complexity from slither.utils.colors import green, red, yellow @@ -41,7 +47,8 @@ class PrinterHumanSummary(AbstractPrinter): mint_unlimited = None # no minting race_condition_mitigated = ( - "increaseApproval" in functions_name or "safeIncreaseAllowance" in functions_name + "increaseApproval" in functions_name + or "safeIncreaseAllowance" in functions_name ) return pause, mint_unlimited, race_condition_mitigated @@ -49,7 +56,9 @@ class PrinterHumanSummary(AbstractPrinter): def get_summary_erc20(self, contract): txt = "" - pause, mint_unlimited, race_condition_mitigated = self._get_summary_erc20(contract) + pause, mint_unlimited, race_condition_mitigated = self._get_summary_erc20( + contract + ) if pause: txt += yellow("Pausable") + "\n" @@ -80,11 +89,15 @@ class PrinterHumanSummary(AbstractPrinter): issues_optimization = [c.detect() for c in checks_optimization] issues_optimization = [c for c in issues_optimization if c] - issues_optimization = [item for sublist in issues_optimization for item in sublist] + issues_optimization = [ + item for sublist in issues_optimization for item in sublist + ] issues_informational = [c.detect() for c in checks_informational] issues_informational = [c for c in issues_informational if c] - issues_informational = [item for sublist in issues_informational for item in sublist] + issues_informational = [ + item for sublist in issues_informational for item in sublist + ] issues_low = [c.detect() for c in checks_low] issues_low = [c for c in issues_low if c] @@ -99,7 +112,11 @@ class PrinterHumanSummary(AbstractPrinter): issues_high = [item for sublist in issues_high for item in sublist] all_results = ( - issues_optimization + issues_informational + issues_low + issues_medium + issues_high + issues_optimization + + issues_informational + + issues_low + + issues_medium + + issues_high ) return ( @@ -112,7 +129,14 @@ class PrinterHumanSummary(AbstractPrinter): ) def get_detectors_result(self) -> Tuple[str, List[Dict], int, int, int, int, int]: - all_results, optimization, informational, low, medium, high = self._get_detectors_result() + ( + all_results, + optimization, + informational, + low, + medium, + high, + ) = self._get_detectors_result() txt = "Number of optimization issues: {}\n".format(green(optimization)) txt += "Number of informational issues: {}\n".format(green(informational)) txt += "Number of low issues: {}\n".format(green(low)) @@ -191,7 +215,7 @@ class PrinterHumanSummary(AbstractPrinter): def _number_contracts(self): if self.slither.crytic_compile is None: - len(self.slither.contracts), 0 + return len(self.slither.contracts), 0, 0 contracts = [c for c in self.slither.contracts if not c.is_top_level] deps = [c for c in contracts if c.is_from_dependency()] tests = [c for c in contracts if c.is_test] @@ -212,7 +236,7 @@ class PrinterHumanSummary(AbstractPrinter): ercs += contract.ercs() return list(set(ercs)) - def _get_features(self, contract): + def _get_features(self, contract): # pylint: disable=too-many-branches has_payable = False can_send_eth = False @@ -241,7 +265,10 @@ class PrinterHumanSummary(AbstractPrinter): has_assembly = True for ir in function.slithir_operations: - if isinstance(ir, (LowLevelCall, HighLevelCall, Send, Transfer)) and ir.call_value: + if ( + isinstance(ir, (LowLevelCall, HighLevelCall, Send, Transfer)) + and ir.call_value + ): can_send_eth = True if isinstance(ir, SolidityCall) and ir.function in [ SolidityFunction("suicide(address)"), @@ -277,7 +304,7 @@ class PrinterHumanSummary(AbstractPrinter): "Proxy": contract.is_upgradeable_proxy, } - def output(self, _filename): + def output(self, _filename): # pylint: disable=too-many-locals,too-many-statements """ _filename is not used Args: @@ -308,7 +335,11 @@ class PrinterHumanSummary(AbstractPrinter): txt += f"Number of assembly lines: {total_asm_lines}\n" results["number_lines_assembly"] = total_asm_lines - number_contracts, number_contracts_deps, number_contracts_tests = self._number_contracts() + ( + number_contracts, + number_contracts_deps, + number_contracts_tests, + ) = self._number_contracts() txt += f"Number of contracts: {number_contracts} (+ {number_contracts_deps} in dependencies, + {number_contracts_tests} tests) \n\n" ( @@ -358,10 +389,23 @@ class PrinterHumanSummary(AbstractPrinter): erc20_info += self.get_summary_erc20(contract) features = "\n".join( - [name for name, to_print in self._get_features(contract).items() if to_print] + [ + name + for name, to_print in self._get_features(contract).items() + if to_print + ] ) - table.add_row([contract.name, number_functions, ercs, erc20_info, is_complex, features]) + table.add_row( + [ + contract.name, + number_functions, + ercs, + erc20_info, + is_complex, + features, + ] + ) self.info(txt + "\n" + str(table)) @@ -376,11 +420,15 @@ class PrinterHumanSummary(AbstractPrinter): "is_erc20": contract.is_erc20(), "number_functions": self._number_functions(contract), "features": [ - name for name, to_print in self._get_features(contract).items() if to_print + name + for name, to_print in self._get_features(contract).items() + if to_print ], } if contract_d["is_erc20"]: - pause, mint_limited, race_condition_mitigated = self._get_summary_erc20(contract) + pause, mint_limited, race_condition_mitigated = self._get_summary_erc20( + contract + ) contract_d["erc20_pause"] = pause if mint_limited is not None: contract_d["erc20_can_mint"] = True diff --git a/slither/printers/summary/require_calls.py b/slither/printers/summary/require_calls.py index 6b68f99bb..c5bf49217 100644 --- a/slither/printers/summary/require_calls.py +++ b/slither/printers/summary/require_calls.py @@ -46,7 +46,10 @@ class RequireOrAssert(AbstractPrinter): ] require = [ir.node for ir in require] table.add_row( - [function.name, self._convert([str(m.expression) for m in set(require)])] + [ + function.name, + self._convert([str(m.expression) for m in set(require)]), + ] ) txt += "\n" + str(table) self.info(txt) diff --git a/slither/printers/summary/slithir_ssa.py b/slither/printers/summary/slithir_ssa.py index 8c01c78ec..15566fa69 100644 --- a/slither/printers/summary/slithir_ssa.py +++ b/slither/printers/summary/slithir_ssa.py @@ -10,7 +10,9 @@ class PrinterSlithIRSSA(AbstractPrinter): ARGUMENT = "slithir-ssa" HELP = "Print the slithIR representation of the functions" - WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#slithir-ssa" + WIKI = ( + "https://github.com/trailofbits/slither/wiki/Printer-documentation#slithir-ssa" + ) def output(self, _filename): """ diff --git a/slither/printers/summary/variable_order.py b/slither/printers/summary/variable_order.py index 3cb26cfa5..4f633ee44 100644 --- a/slither/printers/summary/variable_order.py +++ b/slither/printers/summary/variable_order.py @@ -30,7 +30,9 @@ class VariableOrder(AbstractPrinter): for variable in contract.state_variables_ordered: if not variable.is_constant: slot, offset = self.slither.storage_layout_of(contract, variable) - table.add_row([variable.canonical_name, str(variable.type), slot, offset]) + table.add_row( + [variable.canonical_name, str(variable.type), slot, offset] + ) all_tables.append((contract.name, table)) txt += str(table) + "\n" diff --git a/slither/slither.py b/slither/slither.py index b44ad66d8..bbc854bd0 100644 --- a/slither/slither.py +++ b/slither/slither.py @@ -16,7 +16,20 @@ logger_detector = logging.getLogger("Detectors") logger_printer = logging.getLogger("Printers") -class Slither(SlitherCore): +def _check_common_things(thing_name, cls, base_cls, instances_list): + + if not issubclass(cls, base_cls) or cls is base_cls: + raise Exception( + "You can't register {!r} as a {}. You need to pass a class that inherits from {}".format( + cls, thing_name, base_cls.__name__ + ) + ) + + if any(type(obj) == cls for obj in instances_list): # pylint: disable=unidiomatic-typecheck + raise Exception("You can't register {!r} twice.".format(cls)) + + +class Slither(SlitherCore): # pylint: disable=too-many-instance-attributes def __init__(self, target, **kwargs): """ Args: @@ -93,7 +106,8 @@ class Slither(SlitherCore): with open(filename, encoding="utf8") as astFile: stdout = astFile.read() if not stdout: - raise SlitherError("Empty AST file: %s", filename) + to_log = f"Empty AST file: {filename}" + raise SlitherError(to_log) contracts_json = stdout.split("\n=") self._parser = SlitherSolc(filename, self) @@ -128,17 +142,25 @@ class Slither(SlitherCore): @property def detectors_informational(self): - return [d for d in self.detectors if d.IMPACT == DetectorClassification.INFORMATIONAL] + return [ + d + for d in self.detectors + if d.IMPACT == DetectorClassification.INFORMATIONAL + ] @property def detectors_optimization(self): - return [d for d in self.detectors if d.IMPACT == DetectorClassification.OPTIMIZATION] + return [ + d for d in self.detectors if d.IMPACT == DetectorClassification.OPTIMIZATION + ] def register_detector(self, detector_class): """ :param detector_class: Class inheriting from `AbstractDetector`. """ - self._check_common_things("detector", detector_class, AbstractDetector, self._detectors) + _check_common_things( + "detector", detector_class, AbstractDetector, self._detectors + ) instance = detector_class(self, logger_detector) self._detectors.append(instance) @@ -147,7 +169,9 @@ class Slither(SlitherCore): """ :param printer_class: Class inheriting from `AbstractPrinter`. """ - self._check_common_things("printer", printer_class, AbstractPrinter, self._printers) + _check_common_things( + "printer", printer_class, AbstractPrinter, self._printers + ) instance = printer_class(self, logger_printer) self._printers.append(instance) @@ -169,32 +193,6 @@ class Slither(SlitherCore): return [p.output(self.filename).data for p in self._printers] - def _check_common_things(self, thing_name, cls, base_cls, instances_list): - - if not issubclass(cls, base_cls) or cls is base_cls: - raise Exception( - "You can't register {!r} as a {}. You need to pass a class that inherits from {}".format( - cls, thing_name, base_cls.__name__ - ) - ) - - if any(type(obj) == cls for obj in instances_list): - raise Exception("You can't register {!r} twice.".format(cls)) - - def _run_solc(self, filename, solc, disable_solc_warnings, solc_arguments, ast_format): - if not os.path.isfile(filename): - raise SlitherError( - "{} does not exist (are you in the correct directory?)".format(filename) - ) - assert filename.endswith("json") - with open(filename, encoding="utf8") as astFile: - stdout = astFile.read() - if not stdout: - raise SlitherError("Empty AST file: %s", filename) - stdout = stdout.split("\n=") - - return stdout - @property def triage_mode(self): return self._triage_mode diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 5149a6331..16a404d24 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -1,6 +1,7 @@ import logging from typing import List +# pylint: disable= too-many-lines,import-outside-toplevel,too-many-branches,too-many-statements,too-many-nested-blocks from slither.core.declarations import ( Contract, Enum, @@ -87,7 +88,10 @@ def convert_expression(expression, node): cond.set_node(node) result = [cond] return result - if isinstance(expression, Identifier) and node.type in [NodeType.IF, NodeType.IFLOOP]: + if isinstance(expression, Identifier) and node.type in [ + NodeType.IF, + NodeType.IFLOOP, + ]: cond = Condition(expression.value) cond.set_expression(expression) cond.set_node(node) @@ -169,7 +173,9 @@ def get_canonical_names(ir, function_name, contract_name): # list of list of arguments argss = convert_arguments(ir.arguments) - return [sig.format(f"{contract_name}.{function_name}", ",".join(args)) for args in argss] + return [ + sig.format(f"{contract_name}.{function_name}", ",".join(args)) for args in argss + ] def convert_arguments(arguments): @@ -202,7 +208,8 @@ def convert_arguments(arguments): def is_temporary(ins): return isinstance( - ins, (Argument, TmpNewElementaryType, TmpNewContract, TmpNewArray, TmpNewStructure) + ins, + (Argument, TmpNewElementaryType, TmpNewContract, TmpNewArray, TmpNewStructure), ) @@ -240,8 +247,7 @@ def integrate_value_gas(result): variable_to_replace = {} # Replace call to value, gas to an argument of the real call - for idx in range(len(result)): - ins = result[idx] + for idx, ins in enumerate(result): # value can be shadowed, so we check that the prev ins # is an Argument if is_value(ins) and isinstance(result[idx - 1], Argument): @@ -276,7 +282,7 @@ def integrate_value_gas(result): was_changed = True ins.call_id = variable_to_replace[ins.call_id].name - calls = list(set([str(c) for c in calls])) + calls = list({str(c) for c in calls}) idx = 0 calls_d = {} for call in calls: @@ -395,7 +401,9 @@ def _convert_type_contract(ir, slither): "The codebase uses type(x).creationCode, but crytic-compile was not used. As a result, the bytecode cannot be found" ) bytecode = "MISSING_BYTECODE" - assignment = Assignment(ir.lvalue, Constant(str(bytecode)), ElementaryType("bytes")) + assignment = Assignment( + ir.lvalue, Constant(str(bytecode)), ElementaryType("bytes") + ) assignment.set_expression(ir.expression) assignment.set_node(ir.node) assignment.lvalue.set_type(ElementaryType("bytes")) @@ -408,7 +416,9 @@ def _convert_type_contract(ir, slither): "The codebase uses type(x).runtimeCode, but crytic-compile was not used. As a result, the bytecode cannot be found" ) bytecode = "MISSING_BYTECODE" - assignment = Assignment(ir.lvalue, Constant(str(bytecode)), ElementaryType("bytes")) + assignment = Assignment( + ir.lvalue, Constant(str(bytecode)), ElementaryType("bytes") + ) assignment.set_expression(ir.expression) assignment.set_node(ir.node) assignment.lvalue.set_type(ElementaryType("bytes")) @@ -420,7 +430,7 @@ def _convert_type_contract(ir, slither): interfaceId = interfaceId ^ get_function_id(entry_point.full_name) assignment = Assignment( ir.lvalue, - Constant(str(interfaceId), type=ElementaryType("bytes4")), + Constant(str(interfaceId), constant_type=ElementaryType("bytes4")), ElementaryType("bytes4"), ) assignment.set_expression(ir.expression) @@ -429,7 +439,9 @@ def _convert_type_contract(ir, slither): return assignment if ir.variable_right == "name": - assignment = Assignment(ir.lvalue, Constant(contract.name), ElementaryType("string")) + assignment = Assignment( + ir.lvalue, Constant(contract.name), ElementaryType("string") + ) assignment.set_expression(ir.expression) assignment.set_node(ir.node) assignment.lvalue.set_type(ElementaryType("string")) @@ -438,7 +450,7 @@ def _convert_type_contract(ir, slither): raise SlithIRError(f"type({contract.name}).{ir.variable_right} is unknown") -def propagate_types(ir, node): +def propagate_types(ir, node): # pylint: disable=too-many-locals # propagate the type using_for = node.function.contract.using_for if isinstance(ir, OperationWithLValue): @@ -461,7 +473,7 @@ def propagate_types(ir, node): # Temporary operation (they are removed later) if t is None: - return + return None if isinstance(t, ElementaryType) and t.name == "address": if can_be_solidity_func(ir): @@ -478,7 +490,9 @@ def propagate_types(ir, node): t_type = t.type if isinstance(t_type, Contract): contract = node.slither.get_contract_from_name(t_type.name) - return convert_type_of_high_and_internal_level_call(ir, contract) + return convert_type_of_high_and_internal_level_call( + ir, contract + ) # Convert HighLevelCall to LowLevelCall if isinstance(t, ElementaryType) and t.name == "address": @@ -513,7 +527,9 @@ def propagate_types(ir, node): elif isinstance(ir, InternalCall): # if its not a tuple, return a singleton if ir.function is None: - convert_type_of_high_and_internal_level_call(ir, node.function.contract) + convert_type_of_high_and_internal_level_call( + ir, node.function.contract + ) return_type = ir.function.return_type if return_type: if len(return_type) == 1: @@ -566,7 +582,9 @@ def propagate_types(ir, node): b.set_expression(ir.expression) b.set_node(ir.node) return b - if ir.variable_right == "selector" and isinstance(ir.variable_left.type, Function): + if ir.variable_right == "selector" and isinstance( + ir.variable_left.type, Function + ): assignment = Assignment( ir.lvalue, Constant(str(get_function_id(ir.variable_left.type.full_name))), @@ -586,7 +604,8 @@ def propagate_types(ir, node): if ( left == SolidityVariable("this") and isinstance(ir.variable_right, Constant) - and str(ir.variable_right) in [x.name for x in ir.function.contract.functions] + and str(ir.variable_right) + in [x.name for x in ir.function.contract.functions] ): # Assumption that this.function_name can only compile if # And the contract does not have two functions starting with function_name @@ -635,7 +654,12 @@ def propagate_types(ir, node): # We dont need to check for function collision, as solc prevents the use of selector # if there are multiple functions with the same name f = next( - (f for f in type_t.functions if f.name == ir.variable_right), None + ( + f + for f in type_t.functions + if f.name == ir.variable_right + ), + None, ) if f: ir.lvalue.set_type(f) @@ -697,8 +721,10 @@ def propagate_types(ir, node): # temporary operation; they will be removed pass else: - raise SlithIRError("Not handling {} during type propgation".format(type(ir))) - + raise SlithIRError( + "Not handling {} during type propgation".format(type(ir)) + ) + return None def extract_tmp_call(ins, contract): assert isinstance(ins, TmpCall) @@ -769,7 +795,10 @@ def extract_tmp_call(ins, contract): ins.called = SolidityFunction("blockhash(uint256)") elif str(ins.called) == "this.balance": s = SolidityCall( - SolidityFunction("this.balance()"), ins.nbr_arguments, ins.lvalue, ins.type_call + SolidityFunction("this.balance()"), + ins.nbr_arguments, + ins.lvalue, + ins.type_call, ) s.set_expression(ins.expression) return s @@ -867,7 +896,7 @@ def convert_to_low_level(ir): ir.set_expression(prev_ir.expression) ir.set_node(prev_ir.node) return ir - elif ir.function_name == "send": + if ir.function_name == "send": assert len(ir.arguments) == 1 prev_ir = ir ir = Send(ir.destination, ir.arguments[0], ir.lvalue) @@ -875,7 +904,7 @@ def convert_to_low_level(ir): ir.set_node(prev_ir.node) ir.lvalue.set_type(ElementaryType("bool")) return ir - elif ir.function_name in ["call", "delegatecall", "callcode", "staticcall"]: + if ir.function_name in ["call", "delegatecall", "callcode", "staticcall"]: new_ir = LowLevelCall( ir.destination, ir.function_name, ir.nbr_arguments, ir.lvalue, ir.type_call ) @@ -921,7 +950,7 @@ def convert_to_solidity_func(ir): and len(new_ir.arguments) == 2 and isinstance(new_ir.arguments[1], list) ): - types = [x for x in new_ir.arguments[1]] + types = list(new_ir.arguments[1]) new_ir.lvalue.set_type(types) # abi.decode where the type to decode is a singleton # abi.decode(a, (uint)) @@ -1021,13 +1050,17 @@ def convert_to_pop(ir, node): val = TemporaryVariable(node) - ir_sub_1 = Binary(val, length, Constant("1", ElementaryType("uint256")), BinaryType.SUBTRACTION) + ir_sub_1 = Binary( + val, length, Constant("1", ElementaryType("uint256")), BinaryType.SUBTRACTION + ) ir_sub_1.set_expression(ir.expression) ir_sub_1.set_node(ir.node) ret.append(ir_sub_1) element_to_delete = ReferenceVariable(node) - ir_assign_element_to_delete = Index(element_to_delete, arr, val, ElementaryType("uint256")) + ir_assign_element_to_delete = Index( + element_to_delete, arr, val, ElementaryType("uint256") + ) ir_length.lvalue.points_to = arr element_to_delete.set_type(ElementaryType("uint256")) ir_assign_element_to_delete.set_expression(ir.expression) @@ -1055,12 +1088,16 @@ def convert_to_pop(ir, node): return ret -def look_for_library(contract, ir, node, using_for, t): +def look_for_library(contract, ir, using_for, t): for destination in using_for[t]: lib_contract = contract.slither.get_contract_from_name(str(destination)) if lib_contract: lib_call = LibraryCall( - lib_contract, ir.function_name, ir.nbr_arguments, ir.lvalue, ir.type_call + lib_contract, + ir.function_name, + ir.nbr_arguments, + ir.lvalue, + ir.type_call, ) lib_call.set_expression(ir.expression) lib_call.set_node(ir.node) @@ -1080,12 +1117,12 @@ def convert_to_library(ir, node, using_for): contract = node.function.contract_declarer t = ir.destination.type if t in using_for: - new_ir = look_for_library(contract, ir, node, using_for, t) + new_ir = look_for_library(contract, ir, using_for, t) if new_ir: return new_ir if "*" in using_for: - new_ir = look_for_library(contract, ir, node, using_for, "*") + new_ir = look_for_library(contract, ir, using_for, "*") if new_ir: return new_ir @@ -1120,7 +1157,9 @@ def convert_type_library_call(ir, lib_contract): # myFunc(uint) # can be called with an uint8 for function in lib_contract.functions: - if function.name == ir.function_name and len(function.parameters) == len(ir.arguments): + if function.name == ir.function_name and len(function.parameters) == len( + ir.arguments + ): func = function break if not func: @@ -1149,7 +1188,9 @@ def _convert_to_structure_to_list(return_type: Type) -> List[Type]: :param return_type: :return: """ - if isinstance(return_type, UserDefinedType) and isinstance(return_type.type, Structure): + if isinstance(return_type, UserDefinedType) and isinstance( + return_type.type, Structure + ): ret = [] for v in return_type.type.elems_ordered: ret += _convert_to_structure_to_list(v.type) @@ -1207,7 +1248,9 @@ def convert_type_of_high_and_internal_level_call(ir, contract): # myFunc(uint) # can be called with an uint8 for function in contract.functions: - if function.name == ir.function_name and len(function.parameters) == len(ir.arguments): + if function.name == ir.function_name and len(function.parameters) == len( + ir.arguments + ): func = function break # lowlelvel lookup needs to be done at last step @@ -1217,7 +1260,8 @@ def convert_type_of_high_and_internal_level_call(ir, contract): if can_be_solidity_func(ir): return convert_to_solidity_func(ir) if not func: - logger.error("Function not found {}".format(sig)) + to_log = "Function not found {}".format(sig) + logger.error(to_log) ir.function = func if isinstance(func, Function): return_type = func.return_type @@ -1292,7 +1336,14 @@ def remove_temporary(result): ins for ins in result if not isinstance( - ins, (Argument, TmpNewElementaryType, TmpNewContract, TmpNewArray, TmpNewStructure) + ins, + ( + Argument, + TmpNewElementaryType, + TmpNewContract, + TmpNewArray, + TmpNewStructure, + ), ) ] @@ -1318,7 +1369,9 @@ def remove_unused(result): # and reference that are written for ins in result: to_keep += [str(x) for x in ins.read] - if isinstance(ins, OperationWithLValue) and not isinstance(ins, (Index, Member)): + if isinstance(ins, OperationWithLValue) and not isinstance( + ins, (Index, Member) + ): if isinstance(ins.lvalue, ReferenceVariable): to_keep += [str(ins.lvalue)] @@ -1356,13 +1409,12 @@ def convert_constant_types(irs): if ir.lvalue.type.type in ElementaryTypeInt: if isinstance(ir.rvalue, Function): continue - elif isinstance(ir.rvalue, TupleVariable): + if isinstance(ir.rvalue, TupleVariable): # TODO: fix missing Unpack conversion continue - else: - if ir.rvalue.type.type != "int256": - ir.rvalue.set_type(ElementaryType("int256")) - was_changed = True + if ir.rvalue.type.type != "int256": + ir.rvalue.set_type(ElementaryType("int256")) + was_changed = True if isinstance(ir, Binary): if isinstance(ir.lvalue.type, ElementaryType): if ir.lvalue.type.type in ElementaryTypeInt: diff --git a/slither/slithir/operations/assignment.py b/slither/slithir/operations/assignment.py index 66d6acd6b..7954dcf7d 100644 --- a/slither/slithir/operations/assignment.py +++ b/slither/slithir/operations/assignment.py @@ -41,7 +41,9 @@ class Assignment(OperationWithLValue): points = self.lvalue.points_to while isinstance(points, ReferenceVariable): points = points.points_to - return "{} (->{}) := {}({})".format(self.lvalue, points, self.rvalue, self.rvalue.type) + return "{} (->{}) := {}({})".format( + self.lvalue, points, self.rvalue, self.rvalue.type + ) return "{}({}) := {}({})".format( self.lvalue, self.lvalue.type, self.rvalue, self.rvalue.type ) diff --git a/slither/slithir/operations/balance.py b/slither/slithir/operations/balance.py index 45a3b0313..a8ca4c3c4 100644 --- a/slither/slithir/operations/balance.py +++ b/slither/slithir/operations/balance.py @@ -5,6 +5,7 @@ from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue class Balance(OperationWithLValue): def __init__(self, value, lvalue): + super().__init__() assert is_valid_rvalue(value) assert is_valid_lvalue(lvalue) self._value = value diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index 553d0e8fe..262919222 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -45,7 +45,7 @@ class BinaryType(Enum): ] @staticmethod - def get_type(operation_type): + def get_type(operation_type): # pylint: disable=too-many-branches if operation_type == "**": return BinaryType.POWER if operation_type == "*": @@ -85,9 +85,11 @@ class BinaryType(Enum): if operation_type == "||": return BinaryType.OROR - raise SlithIRError("get_type: Unknown operation type {})".format(operation_type)) + raise SlithIRError( + "get_type: Unknown operation type {})".format(operation_type) + ) - def __str__(self): + def __str__(self): # pylint: disable=too-many-branches if self == BinaryType.POWER: return "**" if self == BinaryType.MULTIPLICATION: @@ -126,7 +128,9 @@ class BinaryType(Enum): return "&&" if self == BinaryType.OROR: return "||" - raise SlithIRError("str: Unknown operation type {} {})".format(self, type(self))) + raise SlithIRError( + "str: Unknown operation type {} {})".format(self, type(self)) + ) class Binary(OperationWithLValue): @@ -174,7 +178,11 @@ class Binary(OperationWithLValue): while isinstance(points, ReferenceVariable): points = points.points_to return "{}(-> {}) = {} {} {}".format( - str(self.lvalue), points, self.variable_left, self.type_str, self.variable_right + str(self.lvalue), + points, + self.variable_left, + self.type_str, + self.variable_right, ) return "{}({}) = {} {} {}".format( str(self.lvalue), diff --git a/slither/slithir/operations/call.py b/slither/slithir/operations/call.py index bc57eab8f..8f851d831 100644 --- a/slither/slithir/operations/call.py +++ b/slither/slithir/operations/call.py @@ -14,14 +14,14 @@ class Call(Operation): def arguments(self, v): self._arguments = v - def can_reenter(self, callstack=None): + def can_reenter(self, _callstack=None): # pylint: disable=no-self-use """ Must be called after slithIR analysis pass :return: bool """ return False - def can_send_eth(self): + def can_send_eth(self): # pylint: disable=no-self-use """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/codesize.py b/slither/slithir/operations/codesize.py index 10e221509..c1cab450f 100644 --- a/slither/slithir/operations/codesize.py +++ b/slither/slithir/operations/codesize.py @@ -5,6 +5,7 @@ from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue class CodeSize(OperationWithLValue): def __init__(self, value, lvalue): + super().__init__() assert is_valid_rvalue(value) assert is_valid_lvalue(lvalue) self._value = value diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index de42140a8..bb3d7b12c 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -15,6 +15,7 @@ class HighLevelCall(Call, OperationWithLValue): High level message call """ + # pylint: disable=too-many-arguments,too-many-instance-attributes def __init__(self, destination, function_name, nbr_arguments, result, type_call): assert isinstance(function_name, Constant) assert is_valid_lvalue(result) or result is None @@ -33,7 +34,7 @@ class HighLevelCall(Call, OperationWithLValue): # Development function, to be removed once the code is stable # It is ovveride by LbraryCall - def _check_destination(self, destination): + def _check_destination(self, destination): # pylint: disable=no-self-use assert isinstance(destination, (Variable, SolidityVariable)) @property @@ -62,7 +63,9 @@ class HighLevelCall(Call, OperationWithLValue): @property def read(self): - all_read = [self.destination, self.call_gas, self.call_value] + self._unroll(self.arguments) + all_read = [self.destination, self.call_gas, self.call_value] + self._unroll( + self.arguments + ) # remove None return [x for x in all_read if x] + [self.destination] @@ -106,7 +109,9 @@ class HighLevelCall(Call, OperationWithLValue): """ # If solidity >0.5, STATICCALL is used if self.slither.solc_version and self.slither.solc_version >= "0.5.0": - if isinstance(self.function, Function) and (self.function.view or self.function.pure): + if isinstance(self.function, Function) and ( + self.function.view or self.function.pure + ): return False if isinstance(self.function, Variable): return False @@ -154,7 +159,9 @@ class HighLevelCall(Call, OperationWithLValue): if not self.lvalue: lvalue = "" elif isinstance(self.lvalue.type, (list,)): - lvalue = "{}({}) = ".format(self.lvalue, ",".join(str(x) for x in self.lvalue.type)) + lvalue = "{}({}) = ".format( + self.lvalue, ",".join(str(x) for x in self.lvalue.type) + ) else: lvalue = "{}({}) = ".format(self.lvalue, self.lvalue.type) return txt.format( diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py index 31093632f..c294cb6f9 100644 --- a/slither/slithir/operations/index.py +++ b/slither/slithir/operations/index.py @@ -7,9 +7,9 @@ from slither.slithir.variables.reference import ReferenceVariable class Index(OperationWithLValue): def __init__(self, result, left_variable, right_variable, index_type): super(Index, self).__init__() - assert is_valid_lvalue(left_variable) or left_variable == SolidityVariableComposed( - "msg.data" - ) + assert is_valid_lvalue( + left_variable + ) or left_variable == SolidityVariableComposed("msg.data") assert is_valid_rvalue(right_variable) assert isinstance(result, ReferenceVariable) self._variables = [left_variable, right_variable] diff --git a/slither/slithir/operations/internal_call.py b/slither/slithir/operations/internal_call.py index e2bcfbb14..af151a096 100644 --- a/slither/slithir/operations/internal_call.py +++ b/slither/slithir/operations/internal_call.py @@ -4,7 +4,7 @@ from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue -class InternalCall(Call, OperationWithLValue): +class InternalCall(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes def __init__(self, function, nbr_arguments, result, type_call): super(InternalCall, self).__init__() if isinstance(function, Function): @@ -60,7 +60,9 @@ class InternalCall(Call, OperationWithLValue): if not self.lvalue: lvalue = "" elif isinstance(self.lvalue.type, (list,)): - lvalue = "{}({}) = ".format(self.lvalue, ",".join(str(x) for x in self.lvalue.type)) + lvalue = "{}({}) = ".format( + self.lvalue, ",".join(str(x) for x in self.lvalue.type) + ) else: lvalue = "{}({}) = ".format(self.lvalue, self.lvalue.type) if self.is_modifier_call: diff --git a/slither/slithir/operations/internal_dynamic_call.py b/slither/slithir/operations/internal_dynamic_call.py index 509996ced..c401465be 100644 --- a/slither/slithir/operations/internal_dynamic_call.py +++ b/slither/slithir/operations/internal_dynamic_call.py @@ -5,7 +5,7 @@ from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue -class InternalDynamicCall(Call, OperationWithLValue): +class InternalDynamicCall(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes def __init__(self, lvalue, function, function_type): assert isinstance(function_type, FunctionType) assert isinstance(function, Variable) @@ -66,7 +66,9 @@ class InternalDynamicCall(Call, OperationWithLValue): if not self.lvalue: lvalue = "" elif isinstance(self.lvalue.type, (list,)): - lvalue = "{}({}) = ".format(self.lvalue, ",".join(str(x) for x in self.lvalue.type)) + lvalue = "{}({}) = ".format( + self.lvalue, ",".join(str(x) for x in self.lvalue.type) + ) else: lvalue = "{}({}) = ".format(self.lvalue, self.lvalue.type) txt = "{}INTERNAL_DYNAMIC_CALL {}({}) {} {}" diff --git a/slither/slithir/operations/length.py b/slither/slithir/operations/length.py index 7f134fb4d..ebc559451 100644 --- a/slither/slithir/operations/length.py +++ b/slither/slithir/operations/length.py @@ -5,6 +5,7 @@ from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue class Length(OperationWithLValue): def __init__(self, value, lvalue): + super().__init__() assert is_valid_rvalue(value) assert is_valid_lvalue(lvalue) self._value = value diff --git a/slither/slithir/operations/library_call.py b/slither/slithir/operations/library_call.py index 7947c8ae3..4cae00b60 100644 --- a/slither/slithir/operations/library_call.py +++ b/slither/slithir/operations/library_call.py @@ -33,10 +33,16 @@ class LibraryCall(HighLevelCall): if not self.lvalue: lvalue = "" elif isinstance(self.lvalue.type, (list,)): - lvalue = "{}({}) = ".format(self.lvalue, ",".join(str(x) for x in self.lvalue.type)) + lvalue = "{}({}) = ".format( + self.lvalue, ",".join(str(x) for x in self.lvalue.type) + ) else: lvalue = "{}({}) = ".format(self.lvalue, self.lvalue.type) txt = "{}LIBRARY_CALL, dest:{}, function:{}, arguments:{} {}" return txt.format( - lvalue, self.destination, self.function_name, [str(x) for x in arguments], gas + lvalue, + self.destination, + self.function_name, + [str(x) for x in arguments], + gas, ) diff --git a/slither/slithir/operations/low_level_call.py b/slither/slithir/operations/low_level_call.py index f36601a8c..eea585275 100644 --- a/slither/slithir/operations/low_level_call.py +++ b/slither/slithir/operations/low_level_call.py @@ -6,12 +6,13 @@ from slither.core.declarations.solidity_variables import SolidityVariable from slither.slithir.variables.constant import Constant -class LowLevelCall(Call, OperationWithLValue): +class LowLevelCall(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes """ High level message call """ def __init__(self, destination, function_name, nbr_arguments, result, type_call): + # pylint: disable=too-many-arguments assert isinstance(destination, (Variable, SolidityVariable)) assert isinstance(function_name, Constant) super(LowLevelCall, self).__init__() @@ -55,7 +56,7 @@ class LowLevelCall(Call, OperationWithLValue): # remove None return self._unroll([x for x in all_read if x]) - def can_reenter(self, callstack=None): + def can_reenter(self, _callstack=None): """ Must be called after slithIR analysis pass :return: bool diff --git a/slither/slithir/operations/member.py b/slither/slithir/operations/member.py index 656aaa442..06592b0d4 100644 --- a/slither/slithir/operations/member.py +++ b/slither/slithir/operations/member.py @@ -8,7 +8,9 @@ from slither.slithir.variables.reference import ReferenceVariable class Member(OperationWithLValue): def __init__(self, variable_left, variable_right, result): - assert is_valid_rvalue(variable_left) or isinstance(variable_left, (Contract, Enum)) + assert is_valid_rvalue(variable_left) or isinstance( + variable_left, (Contract, Enum) + ) assert isinstance(variable_right, Constant) assert isinstance(result, ReferenceVariable) super(Member, self).__init__() diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py index f535e6e02..2ca6dfc2f 100644 --- a/slither/slithir/operations/new_contract.py +++ b/slither/slithir/operations/new_contract.py @@ -3,7 +3,7 @@ from slither.slithir.utils.utils import is_valid_lvalue from slither.slithir.variables.constant import Constant -class NewContract(Call, OperationWithLValue): +class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instance-attributes def __init__(self, contract_name, lvalue): assert isinstance(contract_name, Constant) assert is_valid_lvalue(lvalue) @@ -92,4 +92,6 @@ class NewContract(Call, OperationWithLValue): if self.call_salt: options += "salt:{} ".format(self.call_salt) args = [str(a) for a in self.arguments] - return "{} = new {}({}) {}".format(self.lvalue, self.contract_name, ",".join(args), options) + return "{} = new {}({}) {}".format( + self.lvalue, self.contract_name, ",".join(args), options + ) diff --git a/slither/slithir/operations/new_structure.py b/slither/slithir/operations/new_structure.py index f17acb47a..640d366ad 100644 --- a/slither/slithir/operations/new_structure.py +++ b/slither/slithir/operations/new_structure.py @@ -29,4 +29,6 @@ class NewStructure(Call, OperationWithLValue): def __str__(self): args = [str(a) for a in self.arguments] - return "{} = new {}({})".format(self.lvalue, self.structure_name, ",".join(args)) + return "{} = new {}({})".format( + self.lvalue, self.structure_name, ",".join(args) + ) diff --git a/slither/slithir/operations/operation.py b/slither/slithir/operations/operation.py index e0df640c1..34a7ab935 100644 --- a/slither/slithir/operations/operation.py +++ b/slither/slithir/operations/operation.py @@ -12,7 +12,7 @@ class AbstractOperation(abc.ABC): """ Return the list of variables READ """ - pass + pass # pylint: disable=unnecessary-pass @property @abc.abstractmethod @@ -20,7 +20,7 @@ class AbstractOperation(abc.ABC): """ Return the list of variables used """ - pass + pass # pylint: disable=unnecessary-pass class Operation(Context, ChildExpression, ChildNode, AbstractOperation): diff --git a/slither/slithir/operations/push.py b/slither/slithir/operations/push.py index dc836c5c0..db3865be7 100644 --- a/slither/slithir/operations/push.py +++ b/slither/slithir/operations/push.py @@ -1,9 +1,11 @@ +from slither.core.declarations import Function from slither.slithir.operations.lvalue import OperationWithLValue from slither.slithir.utils.utils import is_valid_lvalue, is_valid_rvalue class Push(OperationWithLValue): def __init__(self, array, value): + super().__init__() assert is_valid_rvalue(value) or isinstance(value, Function) assert is_valid_lvalue(array) self._value = value diff --git a/slither/slithir/operations/return_operation.py b/slither/slithir/operations/return_operation.py index 9297820ec..01667d42f 100644 --- a/slither/slithir/operations/return_operation.py +++ b/slither/slithir/operations/return_operation.py @@ -39,7 +39,9 @@ class Return(Operation): if isinstance(value, list): assert all(self._valid_value(v) for v in value) else: - assert is_valid_rvalue(value) or isinstance(value, (TupleVariable, Function)) + assert is_valid_rvalue(value) or isinstance( + value, (TupleVariable, Function) + ) return True @property diff --git a/slither/slithir/operations/solidity_call.py b/slither/slithir/operations/solidity_call.py index 3823bccf8..4dce69234 100644 --- a/slither/slithir/operations/solidity_call.py +++ b/slither/slithir/operations/solidity_call.py @@ -35,7 +35,10 @@ class SolidityCall(Call, OperationWithLValue): and isinstance(self.arguments[1], list) ): args = ( - str(self.arguments[0]) + "(" + ",".join([str(a) for a in self.arguments[1]]) + ")" + str(self.arguments[0]) + + "(" + + ",".join([str(a) for a in self.arguments[1]]) + + ")" ) else: args = ",".join([str(a) for a in self.arguments]) @@ -43,7 +46,9 @@ class SolidityCall(Call, OperationWithLValue): lvalue = "" if self.lvalue: if isinstance(self.lvalue.type, (list,)): - lvalue = "{}({}) = ".format(self.lvalue, ",".join(str(x) for x in self.lvalue.type)) + lvalue = "{}({}) = ".format( + self.lvalue, ",".join(str(x) for x in self.lvalue.type) + ) else: lvalue = "{}({}) = ".format(self.lvalue, self.lvalue.type) return lvalue + "SOLIDITY_CALL {}({})".format(self.function.full_name, args) diff --git a/slither/slithir/tmp_operations/tmp_call.py b/slither/slithir/tmp_operations/tmp_call.py index f56631a6b..313032b76 100644 --- a/slither/slithir/tmp_operations/tmp_call.py +++ b/slither/slithir/tmp_operations/tmp_call.py @@ -9,11 +9,18 @@ from slither.core.variables.variable import Variable from slither.slithir.operations.lvalue import OperationWithLValue -class TmpCall(OperationWithLValue): +class TmpCall(OperationWithLValue): # pylint: disable=too-many-instance-attributes def __init__(self, called, nbr_arguments, result, type_call): assert isinstance( called, - (Contract, Variable, SolidityVariableComposed, SolidityFunction, Structure, Event), + ( + Contract, + Variable, + SolidityVariableComposed, + SolidityFunction, + Structure, + Event, + ), ) super(TmpCall, self).__init__() self._called = called @@ -86,4 +93,8 @@ class TmpCall(OperationWithLValue): self._ori = ori def __str__(self): - return str(self.lvalue) + " = TMPCALL{} ".format(self.nbr_arguments) + str(self._called) + return ( + str(self.lvalue) + + " = TMPCALL{} ".format(self.nbr_arguments) + + str(self._called) + ) diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 063f7d96b..90d85ba96 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -79,13 +79,13 @@ def transform_slithir_vars_to_ssa(function): variables += [ir.lvalue] tmp_variables = [v for v in variables if isinstance(v, TemporaryVariable)] - for idx in range(len(tmp_variables)): + for idx, _ in enumerate(tmp_variables): tmp_variables[idx].index = idx ref_variables = [v for v in variables if isinstance(v, ReferenceVariable)] - for idx in range(len(ref_variables)): + for idx, _ in enumerate(ref_variables): ref_variables[idx].index = idx tuple_variables = [v for v in variables if isinstance(v, TupleVariable)] - for idx in range(len(tuple_variables)): + for idx, _ in enumerate(tuple_variables): tuple_variables[idx].index = idx @@ -95,6 +95,7 @@ def transform_slithir_vars_to_ssa(function): ################################################################################### ################################################################################### +# pylint: disable=too-many-arguments,too-many-locals,too-many-nested-blocks,too-many-statements,too-many-branches def add_ssa_ir(function, all_state_variables_instances): """ @@ -123,7 +124,9 @@ def add_ssa_ir(function, all_state_variables_instances): for (_, variable_instance) in all_state_variables_instances.items(): if is_used_later(function.entry_point, variable_instance): # rvalues are fixed in solc_parsing.declaration.function - function.entry_point.add_ssa_ir(Phi(StateIRVariable(variable_instance), set())) + function.entry_point.add_ssa_ir( + Phi(StateIRVariable(variable_instance), set()) + ) add_phi_origins(function.entry_point, init_definition, dict()) @@ -171,7 +174,9 @@ def add_ssa_ir(function, all_state_variables_instances): init_state_variables_instances = dict(all_state_variables_instances) initiate_all_local_variables_instances( - function.nodes, init_local_variables_instances, all_init_local_variables_instances + function.nodes, + init_local_variables_instances, + all_init_local_variables_instances, ) generate_ssa_irs( @@ -259,14 +264,18 @@ def generate_ssa_irs( node.add_ssa_ir(new_ir) - if isinstance(ir, (InternalCall, HighLevelCall, InternalDynamicCall, LowLevelCall)): + if isinstance( + ir, (InternalCall, HighLevelCall, InternalDynamicCall, LowLevelCall) + ): if isinstance(ir, LibraryCall): continue for variable in all_state_variables_instances.values(): if not is_used_later(node, variable): continue new_var = StateIRVariable(variable) - new_var.index = all_state_variables_instances[variable.canonical_name].index + 1 + new_var.index = ( + all_state_variables_instances[variable.canonical_name].index + 1 + ) all_state_variables_instances[variable.canonical_name] = new_var state_variables_instances[variable.canonical_name] = new_var phi_ir = PhiCallback(new_var, {node}, new_ir, variable) @@ -402,7 +411,9 @@ def update_lvalue( local_variables_instances[lvalue.name] = new_var else: new_var = StateIRVariable(lvalue) - new_var.index = all_state_variables_instances[lvalue.canonical_name].index + 1 + new_var.index = ( + all_state_variables_instances[lvalue.canonical_name].index + 1 + ) all_state_variables_instances[lvalue.canonical_name] = new_var state_variables_instances[lvalue.canonical_name] = new_var if update_through_ref: @@ -457,7 +468,8 @@ def fix_phi_rvalues_and_storage_ref( for ir in node.irs_ssa: if isinstance(ir, (Phi)) and not ir.rvalues: variables = [ - last_name(dst, ir.lvalue, init_local_variables_instances) for dst in ir.nodes + last_name(dst, ir.lvalue, init_local_variables_instances) + for dst in ir.nodes ] ir.rvalues = variables if isinstance(ir, (Phi, PhiCallback)): @@ -504,7 +516,8 @@ def add_phi_origins(node, local_variables_definition, state_variables_definition # We keep the instance as we want to avoid to add __hash__ on v.name in Variable # That might work for this used, but could create collision for other uses local_variables_definition = dict( - local_variables_definition, **{v.name: (v, node) for v in node.local_variables_written} + local_variables_definition, + **{v.name: (v, node) for v in node.local_variables_written}, ) state_variables_definition = dict( state_variables_definition, @@ -564,9 +577,12 @@ def get( local_variables_instances[variable.name] = new_var all_local_variables_instances[variable.name] = new_var return new_var - if isinstance(variable, StateVariable) and variable.canonical_name in state_variables_instances: + if ( + isinstance(variable, StateVariable) + and variable.canonical_name in state_variables_instances + ): return state_variables_instances[variable.canonical_name] - elif isinstance(variable, ReferenceVariable): + if isinstance(variable, ReferenceVariable): if not variable.index in reference_variables_instances: new_variable = ReferenceVariableSSA(variable) if variable.points_to: @@ -582,13 +598,13 @@ def get( new_variable.set_type(variable.type) reference_variables_instances[variable.index] = new_variable return reference_variables_instances[variable.index] - elif isinstance(variable, TemporaryVariable): + if isinstance(variable, TemporaryVariable): if not variable.index in temporary_variables_instances: new_variable = TemporaryVariableSSA(variable) new_variable.set_type(variable.type) temporary_variables_instances[variable.index] = new_variable return temporary_variables_instances[variable.index] - elif isinstance(variable, TupleVariable): + if isinstance(variable, TupleVariable): if not variable.index in tuple_variables_instances: new_variable = TupleVariableSSA(variable) new_variable.set_type(variable.type) @@ -596,12 +612,22 @@ def get( return tuple_variables_instances[variable.index] assert isinstance( variable, - (Constant, SolidityVariable, Contract, Enum, SolidityFunction, Structure, Function, Type), + ( + Constant, + SolidityVariable, + Contract, + Enum, + SolidityFunction, + Structure, + Function, + Type, + ), ) # type for abi.decode(.., t) return variable def get_variable(ir, f, *instances): + # pylint: disable=no-value-for-parameter variable = f(ir) variable = get(variable, *instances) return variable @@ -609,6 +635,7 @@ def get_variable(ir, f, *instances): def _get_traversal(values, *instances): ret = [] + # pylint: disable=no-value-for-parameter for v in values: if isinstance(v, list): v = _get_traversal(v, *instances) @@ -646,57 +673,61 @@ def copy_ir(ir, *instances): rvalue = get_variable(ir, lambda x: x.rvalue, *instances) variable_return_type = ir.variable_return_type return Assignment(lvalue, rvalue, variable_return_type) - elif isinstance(ir, Balance): + if isinstance(ir, Balance): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) value = get_variable(ir, lambda x: x.value, *instances) return Balance(value, lvalue) - elif isinstance(ir, Binary): + if isinstance(ir, Binary): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) variable_left = get_variable(ir, lambda x: x.variable_left, *instances) variable_right = get_variable(ir, lambda x: x.variable_right, *instances) operation_type = ir.type return Binary(lvalue, variable_left, variable_right, operation_type) - elif isinstance(ir, CodeSize): + if isinstance(ir, CodeSize): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) value = get_variable(ir, lambda x: x.value, *instances) return CodeSize(value, lvalue) - elif isinstance(ir, Condition): + if isinstance(ir, Condition): val = get_variable(ir, lambda x: x.value, *instances) return Condition(val) - elif isinstance(ir, Delete): + if isinstance(ir, Delete): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) variable = get_variable(ir, lambda x: x.variable, *instances) return Delete(lvalue, variable) - elif isinstance(ir, EventCall): + if isinstance(ir, EventCall): name = ir.name return EventCall(name) - elif isinstance(ir, HighLevelCall): # include LibraryCall + if isinstance(ir, HighLevelCall): # include LibraryCall destination = get_variable(ir, lambda x: x.destination, *instances) function_name = ir.function_name nbr_arguments = ir.nbr_arguments lvalue = get_variable(ir, lambda x: x.lvalue, *instances) type_call = ir.type_call if isinstance(ir, LibraryCall): - new_ir = LibraryCall(destination, function_name, nbr_arguments, lvalue, type_call) + new_ir = LibraryCall( + destination, function_name, nbr_arguments, lvalue, type_call + ) else: - new_ir = HighLevelCall(destination, function_name, nbr_arguments, lvalue, type_call) + new_ir = HighLevelCall( + destination, function_name, nbr_arguments, lvalue, type_call + ) new_ir.call_id = ir.call_id new_ir.call_value = get_variable(ir, lambda x: x.call_value, *instances) new_ir.call_gas = get_variable(ir, lambda x: x.call_gas, *instances) new_ir.arguments = get_arguments(ir, *instances) new_ir.function = ir.function return new_ir - elif isinstance(ir, Index): + if isinstance(ir, Index): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) variable_left = get_variable(ir, lambda x: x.variable_left, *instances) variable_right = get_variable(ir, lambda x: x.variable_right, *instances) index_type = ir.index_type return Index(lvalue, variable_left, variable_right, index_type) - elif isinstance(ir, InitArray): + if isinstance(ir, InitArray): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) init_values = get_rec_values(ir, lambda x: x.init_values, *instances) return InitArray(init_values, lvalue) - elif isinstance(ir, InternalCall): + if isinstance(ir, InternalCall): function = ir.function nbr_arguments = ir.nbr_arguments lvalue = get_variable(ir, lambda x: x.lvalue, *instances) @@ -704,44 +735,46 @@ def copy_ir(ir, *instances): new_ir = InternalCall(function, nbr_arguments, lvalue, type_call) new_ir.arguments = get_arguments(ir, *instances) return new_ir - elif isinstance(ir, InternalDynamicCall): + if isinstance(ir, InternalDynamicCall): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) function = get_variable(ir, lambda x: x.function, *instances) function_type = ir.function_type new_ir = InternalDynamicCall(lvalue, function, function_type) new_ir.arguments = get_arguments(ir, *instances) return new_ir - elif isinstance(ir, LowLevelCall): + if isinstance(ir, LowLevelCall): destination = get_variable(ir, lambda x: x.destination, *instances) function_name = ir.function_name nbr_arguments = ir.nbr_arguments lvalue = get_variable(ir, lambda x: x.lvalue, *instances) type_call = ir.type_call - new_ir = LowLevelCall(destination, function_name, nbr_arguments, lvalue, type_call) + new_ir = LowLevelCall( + destination, function_name, nbr_arguments, lvalue, type_call + ) new_ir.call_id = ir.call_id new_ir.call_value = get_variable(ir, lambda x: x.call_value, *instances) new_ir.call_gas = get_variable(ir, lambda x: x.call_gas, *instances) new_ir.arguments = get_arguments(ir, *instances) return new_ir - elif isinstance(ir, Member): + if isinstance(ir, Member): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) variable_left = get_variable(ir, lambda x: x.variable_left, *instances) variable_right = get_variable(ir, lambda x: x.variable_right, *instances) return Member(variable_left, variable_right, lvalue) - elif isinstance(ir, NewArray): + if isinstance(ir, NewArray): depth = ir.depth array_type = ir.array_type lvalue = get_variable(ir, lambda x: x.lvalue, *instances) new_ir = NewArray(depth, array_type, lvalue) new_ir.arguments = get_rec_values(ir, lambda x: x.arguments, *instances) return new_ir - elif isinstance(ir, NewElementaryType): + if isinstance(ir, NewElementaryType): new_type = ir.type lvalue = get_variable(ir, lambda x: x.lvalue, *instances) new_ir = NewElementaryType(new_type, lvalue) new_ir.arguments = get_arguments(ir, *instances) return new_ir - elif isinstance(ir, NewContract): + if isinstance(ir, NewContract): contract_name = ir.contract_name lvalue = get_variable(ir, lambda x: x.lvalue, *instances) new_ir = NewContract(contract_name, lvalue) @@ -749,27 +782,27 @@ def copy_ir(ir, *instances): new_ir.call_value = get_variable(ir, lambda x: x.call_value, *instances) new_ir.call_salt = get_variable(ir, lambda x: x.call_salt, *instances) return new_ir - elif isinstance(ir, NewStructure): + if isinstance(ir, NewStructure): structure = ir.structure lvalue = get_variable(ir, lambda x: x.lvalue, *instances) new_ir = NewStructure(structure, lvalue) new_ir.arguments = get_arguments(ir, *instances) return new_ir - elif isinstance(ir, Nop): + if isinstance(ir, Nop): return Nop() - elif isinstance(ir, Push): + if isinstance(ir, Push): array = get_variable(ir, lambda x: x.array, *instances) lvalue = get_variable(ir, lambda x: x.lvalue, *instances) return Push(array, lvalue) - elif isinstance(ir, Return): + if isinstance(ir, Return): values = get_rec_values(ir, lambda x: x.values, *instances) return Return(values) - elif isinstance(ir, Send): + if isinstance(ir, Send): destination = get_variable(ir, lambda x: x.destination, *instances) value = get_variable(ir, lambda x: x.call_value, *instances) lvalue = get_variable(ir, lambda x: x.lvalue, *instances) return Send(destination, value, lvalue) - elif isinstance(ir, SolidityCall): + if isinstance(ir, SolidityCall): function = ir.function nbr_arguments = ir.nbr_arguments lvalue = get_variable(ir, lambda x: x.lvalue, *instances) @@ -777,26 +810,26 @@ def copy_ir(ir, *instances): new_ir = SolidityCall(function, nbr_arguments, lvalue, type_call) new_ir.arguments = get_arguments(ir, *instances) return new_ir - elif isinstance(ir, Transfer): + if isinstance(ir, Transfer): destination = get_variable(ir, lambda x: x.destination, *instances) value = get_variable(ir, lambda x: x.call_value, *instances) return Transfer(destination, value) - elif isinstance(ir, TypeConversion): + if isinstance(ir, TypeConversion): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) variable = get_variable(ir, lambda x: x.variable, *instances) variable_type = ir.type return TypeConversion(lvalue, variable, variable_type) - elif isinstance(ir, Unary): + if isinstance(ir, Unary): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) rvalue = get_variable(ir, lambda x: x.rvalue, *instances) operation_type = ir.type return Unary(lvalue, rvalue, operation_type) - elif isinstance(ir, Unpack): + if isinstance(ir, Unpack): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) tuple_var = get_variable(ir, lambda x: x.tuple, *instances) idx = ir.index return Unpack(lvalue, tuple_var, idx) - elif isinstance(ir, Length): + if isinstance(ir, Length): lvalue = get_variable(ir, lambda x: x.lvalue, *instances) value = get_variable(ir, lambda x: x.value, *instances) return Length(value, lvalue) diff --git a/slither/slithir/utils/utils.py b/slither/slithir/utils/utils.py index 31a637d19..796bb822c 100644 --- a/slither/slithir/utils/utils.py +++ b/slither/slithir/utils/utils.py @@ -25,5 +25,12 @@ def is_valid_rvalue(v): def is_valid_lvalue(v): return isinstance( - v, (StateVariable, LocalVariable, TemporaryVariable, ReferenceVariable, TupleVariable) + v, + ( + StateVariable, + LocalVariable, + TemporaryVariable, + ReferenceVariable, + TupleVariable, + ), ) diff --git a/slither/slithir/variables/constant.py b/slither/slithir/variables/constant.py index 5e47b6677..dadfbc99b 100644 --- a/slither/slithir/variables/constant.py +++ b/slither/slithir/variables/constant.py @@ -1,15 +1,15 @@ from functools import total_ordering from decimal import Decimal -from .variable import SlithIRVariable +from slither.slithir.variables.variable import SlithIRVariable +from slither.slithir.exceptions import SlithIRError from slither.core.solidity_types.elementary_type import ElementaryType, Int, Uint from slither.utils.arithmetic import convert_subdenomination -from ..exceptions import SlithIRError @total_ordering class Constant(SlithIRVariable): - def __init__(self, val, type=None, subdenomination=None): + def __init__(self, val, constant_type=None, subdenomination=None): # pylint: disable=too-many-branches super(Constant, self).__init__() assert isinstance(val, str) @@ -19,10 +19,10 @@ class Constant(SlithIRVariable): if subdenomination: val = str(convert_subdenomination(val, subdenomination)) - if type: - assert isinstance(type, ElementaryType) - self._type = type - if type.type in Int + Uint + ["address"]: + if constant_type: # pylint: disable=too-many-nested-blocks + assert isinstance(constant_type, ElementaryType) + self._type = constant_type + if constant_type.type in Int + Uint + ["address"]: if val.startswith("0x") or val.startswith("0X"): self._val = int(val, 16) else: @@ -38,13 +38,12 @@ class Constant(SlithIRVariable): raise SlithIRError( f"{base}e{expo} is too large to fit in any Solidity integer size" ) - else: - self._val = 0 + self._val = 0 else: self._val = int(Decimal(base) * Decimal(10 ** expo)) else: self._val = int(Decimal(val)) - elif type.type == "bool": + elif constant_type.type == "bool": self._val = (val == "true") | (val == "True") else: self._val = val diff --git a/slither/slithir/variables/local_variable.py b/slither/slithir/variables/local_variable.py index 5ee9e6660..2710b8bbb 100644 --- a/slither/slithir/variables/local_variable.py +++ b/slither/slithir/variables/local_variable.py @@ -1,10 +1,9 @@ -from .variable import SlithIRVariable -from .temporary import TemporaryVariable from slither.core.variables.local_variable import LocalVariable -from slither.core.children.child_node import ChildNode +from slither.slithir.variables.temporary import TemporaryVariable +from slither.slithir.variables.variable import SlithIRVariable -class LocalIRVariable(LocalVariable, SlithIRVariable): +class LocalIRVariable(LocalVariable, SlithIRVariable): # pylint: disable=too-many-instance-attributes def __init__(self, local_variable): assert isinstance(local_variable, LocalVariable) @@ -68,5 +67,7 @@ class LocalIRVariable(LocalVariable, SlithIRVariable): @property def ssa_name(self): if self.is_storage: - return "{}_{} (-> {})".format(self._name, self.index, [v.name for v in self.refers_to]) + return "{}_{} (-> {})".format( + self._name, self.index, [v.name for v in self.refers_to] + ) return "{}_{}".format(self._name, self.index) diff --git a/slither/slithir/variables/reference.py b/slither/slithir/variables/reference.py index 457c9e660..9c7a637c4 100644 --- a/slither/slithir/variables/reference.py +++ b/slither/slithir/variables/reference.py @@ -1,7 +1,6 @@ -from .variable import SlithIRVariable from slither.core.children.child_node import ChildNode -from slither.core.variables.variable import Variable from slither.core.declarations import Contract, Enum, SolidityVariable, Function +from slither.core.variables.variable import Variable class ReferenceVariable(ChildNode, Variable): @@ -45,7 +44,7 @@ class ReferenceVariable(ChildNode, Variable): def points_to(self, points_to): # Can only be a rvalue of # Member or Index operator - from slither.slithir.utils.utils import is_valid_lvalue + from slither.slithir.utils.utils import is_valid_lvalue # pylint: disable=import-outside-toplevel assert is_valid_lvalue(points_to) or isinstance( points_to, (SolidityVariable, Contract, Enum) diff --git a/slither/slithir/variables/reference_ssa.py b/slither/slithir/variables/reference_ssa.py index 4f67ad327..a5cec139e 100644 --- a/slither/slithir/variables/reference_ssa.py +++ b/slither/slithir/variables/reference_ssa.py @@ -4,7 +4,6 @@ as the ReferenceVariable are in SSA form in both version """ from .reference import ReferenceVariable -from .variable import SlithIRVariable class ReferenceVariableSSA(ReferenceVariable): diff --git a/slither/slithir/variables/state_variable.py b/slither/slithir/variables/state_variable.py index bbabce205..43a344f4b 100644 --- a/slither/slithir/variables/state_variable.py +++ b/slither/slithir/variables/state_variable.py @@ -1,13 +1,12 @@ -from .variable import SlithIRVariable from slither.core.variables.state_variable import StateVariable -from slither.core.children.child_node import ChildNode +from slither.slithir.variables.variable import SlithIRVariable -class StateIRVariable(StateVariable, SlithIRVariable): +class StateIRVariable(StateVariable, SlithIRVariable): # pylint: disable=too-many-instance-attributes def __init__(self, state_variable): assert isinstance(state_variable, StateVariable) - super(StateVariable, self).__init__() + super(StateIRVariable, self).__init__() # initiate ChildContract self.set_contract(state_variable.contract) diff --git a/slither/slithir/variables/temporary.py b/slither/slithir/variables/temporary.py index 6940a303e..b304d73af 100644 --- a/slither/slithir/variables/temporary.py +++ b/slither/slithir/variables/temporary.py @@ -1,6 +1,5 @@ -from .variable import SlithIRVariable -from slither.core.variables.variable import Variable from slither.core.children.child_node import ChildNode +from slither.core.variables.variable import Variable class TemporaryVariable(ChildNode, Variable): diff --git a/slither/slithir/variables/tuple.py b/slither/slithir/variables/tuple.py index 57ae3b6ad..6e2a4b039 100644 --- a/slither/slithir/variables/tuple.py +++ b/slither/slithir/variables/tuple.py @@ -1,8 +1,5 @@ -from .variable import SlithIRVariable -from slither.core.variables.variable import Variable from slither.core.children.child_node import ChildNode - -from slither.core.solidity_types.type import Type +from slither.slithir.variables.variable import SlithIRVariable class TupleVariable(ChildNode, SlithIRVariable): diff --git a/slither/slithir/variables/tuple_ssa.py b/slither/slithir/variables/tuple_ssa.py index 3db2ad2f4..fab0906cb 100644 --- a/slither/slithir/variables/tuple_ssa.py +++ b/slither/slithir/variables/tuple_ssa.py @@ -4,7 +4,6 @@ as the TupleVariable are in SSA form in both version """ from .tuple import TupleVariable -from .variable import SlithIRVariable class TupleVariableSSA(TupleVariable): diff --git a/slither/solc_parsing/cfg/node.py b/slither/solc_parsing/cfg/node.py index 2d2bec0f3..5e9b42480 100644 --- a/slither/solc_parsing/cfg/node.py +++ b/slither/solc_parsing/cfg/node.py @@ -44,7 +44,9 @@ class NodeSolc: AssignmentOperationType.ASSIGN, self._node.variable_declaration.type, ) - _expression.set_offset(self._node.expression.source_mapping, self._node.slither) + _expression.set_offset( + self._node.expression.source_mapping, self._node.slither + ) self._node.add_expression(_expression, bypass_verif_empty=True) expression = self._node.expression @@ -57,8 +59,12 @@ class NodeSolc: find_call = FindCalls(expression) self._node.calls_as_expression = find_call.result() self._node.external_calls_as_expressions = [ - c for c in self._node.calls_as_expression if not isinstance(c.called, Identifier) + c + for c in self._node.calls_as_expression + if not isinstance(c.called, Identifier) ] self._node.internal_calls_as_expressions = [ - c for c in self._node.calls_as_expression if isinstance(c.called, Identifier) + c + for c in self._node.calls_as_expression + if isinstance(c.called, Identifier) ] diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index ad170f91f..a49c7cf24 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from slither.solc_parsing.slitherSolc import SlitherSolc from slither.core.slither_core import SlitherCore +# pylint: disable=too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks,too-many-public-methods class ContractSolc: def __init__(self, slither_parser: "SlitherSolc", contract: Contract, data): @@ -103,6 +104,14 @@ class ContractSolc: def modifiers_parser(self) -> List["ModifierSolc"]: return self._modifiers_parser + @property + def structures_not_parsed(self) -> List[Dict]: + return self._structuresNotParsed + + @property + def enums_not_parsed(self) -> List[Dict]: + return self._enumsNotParsed + ################################################################################### ################################################################################### # region AST @@ -153,11 +162,11 @@ class ContractSolc: if "baseContracts" in self._data: for elem in self._data["baseContracts"]: if elem["nodeType"] == "InheritanceSpecifier": - self._remapping[elem["baseName"]["referencedDeclaration"]] = elem["baseName"][ - "name" - ] + self._remapping[elem["baseName"]["referencedDeclaration"]] = elem[ + "baseName" + ]["name"] - def _parse_base_contract_info(self): + def _parse_base_contract_info(self): # pylint: disable=too-many-branches # Parse base contracts (immediate, non-linearized) if self.is_compact_ast: # Parse base contracts + constructors in compact-ast @@ -172,21 +181,31 @@ class ContractSolc: continue # Obtain our contract reference and add it to our base contract list - referencedDeclaration = base_contract["baseName"]["referencedDeclaration"] + referencedDeclaration = base_contract["baseName"][ + "referencedDeclaration" + ] self.baseContracts.append(referencedDeclaration) # If we have defined arguments in our arguments object, this is a constructor invocation. # (note: 'arguments' can be [], which is not the same as None. [] implies a constructor was # called with no arguments, while None implies no constructor was called). - if "arguments" in base_contract and base_contract["arguments"] is not None: - self.baseConstructorContractsCalled.append(referencedDeclaration) + if ( + "arguments" in base_contract + and base_contract["arguments"] is not None + ): + self.baseConstructorContractsCalled.append( + referencedDeclaration + ) else: # Parse base contracts + constructors in legacy-ast if "children" in self._data: for base_contract in self._data["children"]: if base_contract["name"] != "InheritanceSpecifier": continue - if "children" not in base_contract or len(base_contract["children"]) == 0: + if ( + "children" not in base_contract + or len(base_contract["children"]) == 0 + ): continue # Obtain all items for this base contract specification (base contract, followed by arguments) base_contract_items = base_contract["children"] @@ -197,7 +216,8 @@ class ContractSolc: continue if ( "attributes" not in base_contract_items[0] - or "referencedDeclaration" not in base_contract_items[0]["attributes"] + or "referencedDeclaration" + not in base_contract_items[0]["attributes"] ): continue @@ -213,7 +233,9 @@ class ContractSolc: or "arguments" not in base_contract["attributes"] or base_contract["attributes"]["arguments"] is not None ): - self.baseConstructorContractsCalled.append(referencedDeclaration) + self.baseConstructorContractsCalled.append( + referencedDeclaration + ) def _parse_contract_items(self): if not self.get_children() in self._data: # empty contract @@ -344,9 +366,8 @@ class ContractSolc: def log_incorrect_parsing(self, error): if self._contract.slither.disallow_partial: raise ParsingError(error) - else: - LOGGER.error(error) - self._contract.is_incorrectly_parsed = True + LOGGER.error(error) + self._contract.is_incorrectly_parsed = True def analyze_content_modifiers(self): try: @@ -361,7 +382,6 @@ class ContractSolc: function_parser.analyze_content() except (VariableNotFound, KeyError, ParsingError) as e: self.log_incorrect_parsing(f"Missing function {e}") - return def analyze_params_modifiers(self): try: @@ -403,7 +423,7 @@ class ContractSolc: self.log_incorrect_parsing(f"Missing params {e}") self._functions_no_params = [] - def _analyze_params_elements( + def _analyze_params_elements( # pylint: disable=too-many-arguments,too-many-locals self, elements_no_params: List[FunctionSolc], getter: Callable[["ContractSolc"], List[FunctionSolc]], @@ -427,16 +447,23 @@ class ContractSolc: try: for father in self._contract.inheritance: - father_parser = self._slither_parser.underlying_contract_to_parser[father] + father_parser = self._slither_parser.underlying_contract_to_parser[ + father + ] for element_parser in getter(father_parser): elem = Cls() elem.set_contract(self._contract) - elem.set_contract_declarer(element_parser.underlying_function.contract_declarer) + elem.set_contract_declarer( + element_parser.underlying_function.contract_declarer + ) elem.set_offset( - element_parser.function_not_parsed["src"], self._contract.slither + element_parser.function_not_parsed["src"], + self._contract.slither, ) - elem_parser = Cls_parser(elem, element_parser.function_not_parsed, self,) + elem_parser = Cls_parser( + elem, element_parser.function_not_parsed, self, + ) elem_parser.analyze_params() if isinstance(elem, Modifier): self._contract.slither.add_modifier(elem) @@ -465,7 +492,9 @@ class ContractSolc: if has_constructor: _accessible_functions = { - k: v for (k, v) in accessible_elements.items() if not v.is_constructor + k: v + for (k, v) in accessible_elements.items() + if not v.is_constructor } for element_parser in elements_no_params: @@ -477,7 +506,10 @@ class ContractSolc: ] = element_parser.underlying_function for element in all_elements.values(): - if accessible_elements[element.full_name] != all_elements[element.canonical_name]: + if ( + accessible_elements[element.full_name] + != all_elements[element.canonical_name] + ): element.is_shadowed = True accessible_elements[element.full_name].shadows = True except (VariableNotFound, KeyError) as e: @@ -492,7 +524,6 @@ class ContractSolc: var_parser.analyze(self) except (VariableNotFound, KeyError) as e: LOGGER.error(e) - pass def analyze_state_variables(self): try: @@ -571,7 +602,7 @@ class ContractSolc: new_enum.set_offset(enum["src"], self._contract.slither) self._contract.enums_as_dict[canonicalName] = new_enum - def _analyze_struct(self, struct: StructureSolc): + def _analyze_struct(self, struct: StructureSolc): # pylint: disable=no-self-use struct.analyze() def analyze_structs(self): diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 905095888..8e20a776e 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -1,15 +1,19 @@ -""" -""" import logging from typing import Dict, Optional, Union, List, TYPE_CHECKING from slither.core.cfg.node import NodeType, link_nodes, insert_node, Node from slither.core.declarations.contract import Contract -from slither.core.declarations.function import Function, ModifierStatements, FunctionType +from slither.core.declarations.function import ( + Function, + ModifierStatements, + FunctionType, +) from slither.core.expressions import AssignmentOperation from slither.core.variables.local_variable import LocalVariable -from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.core.variables.local_variable_init_from_tuple import ( + LocalVariableInitFromTuple, +) from slither.solc_parsing.cfg.node import NodeSolc from slither.solc_parsing.expressions.expression_parsing import parse_expression @@ -17,7 +21,9 @@ from slither.solc_parsing.variables.local_variable import LocalVariableSolc from slither.solc_parsing.variables.local_variable_init_from_tuple import ( LocalVariableInitFromTupleSolc, ) -from slither.solc_parsing.variables.variable_declaration import MultipleVariablesDeclaration +from slither.solc_parsing.variables.variable_declaration import ( + MultipleVariablesDeclaration, +) from slither.solc_parsing.yul.parse_yul import YulBlock from slither.utils.expression_manipulations import SplitTernaryExpression from slither.visitors.expression.export_values import ExportValues @@ -38,14 +44,14 @@ def link_underlying_nodes(node1: NodeSolc, node2: NodeSolc): link_nodes(node1.underlying_node, node2.underlying_node) +# pylint: disable=too-many-lines,too-many-branches,too-many-locals,too-many-statements,too-many-instance-attributes + class FunctionSolc: - """ - """ # elems = [(type, name)] def __init__( - self, function: Function, function_data: Dict, contract_parser: "ContractSolc", + self, function: Function, function_data: Dict, contract_parser: "ContractSolc", ): self._slither_parser: "SlitherSolc" = contract_parser.slither_parser self._contract_parser = contract_parser @@ -137,12 +143,12 @@ class FunctionSolc: @property def variables_renamed( - self, + self, ) -> Dict[int, Union[LocalVariableSolc, LocalVariableInitFromTupleSolc]]: return self._variables_renamed def _add_local_variable( - self, local_var_parser: Union[LocalVariableSolc, LocalVariableInitFromTupleSolc] + self, local_var_parser: Union[LocalVariableSolc, LocalVariableInitFromTupleSolc] ): # If two local variables have the same name # We add a suffix to the new variable @@ -348,9 +354,13 @@ class FunctionSolc: condition_node = self._new_node(NodeType.IF, condition["src"]) condition_node.add_unparsed_expression(condition) link_underlying_nodes(node, condition_node) - trueStatement = self._parse_statement(if_statement["trueBody"], condition_node) + trueStatement = self._parse_statement( + if_statement["trueBody"], condition_node + ) if if_statement["falseBody"]: - falseStatement = self._parse_statement(if_statement["falseBody"], condition_node) + falseStatement = self._parse_statement( + if_statement["falseBody"], condition_node + ) else: children = if_statement[self.get_children("children")] condition = children[0] @@ -378,7 +388,9 @@ class FunctionSolc: node_startWhile = self._new_node(NodeType.STARTLOOP, whilte_statement["src"]) if self.is_compact_ast: - node_condition = self._new_node(NodeType.IFLOOP, whilte_statement["condition"]["src"]) + node_condition = self._new_node( + NodeType.IFLOOP, whilte_statement["condition"]["src"] + ) node_condition.add_unparsed_expression(whilte_statement["condition"]) statement = self._parse_statement(whilte_statement["body"], node_condition) else: @@ -534,7 +546,9 @@ class FunctionSolc: if hasLoopExpression: if len(children) > 2: if children[-2][self.get_key()] == "ExpressionStatement": - node_LoopExpression = self._parse_statement(children[-2], node_statement) + node_LoopExpression = self._parse_statement( + children[-2], node_statement + ) if not hasCondition: link_underlying_nodes(node_LoopExpression, node_endLoop) @@ -547,12 +561,18 @@ class FunctionSolc: def _parse_dowhile(self, do_while_statement: Dict, node: NodeSolc) -> NodeSolc: - node_startDoWhile = self._new_node(NodeType.STARTLOOP, do_while_statement["src"]) + node_startDoWhile = self._new_node( + NodeType.STARTLOOP, do_while_statement["src"] + ) if self.is_compact_ast: - node_condition = self._new_node(NodeType.IFLOOP, do_while_statement["condition"]["src"]) + node_condition = self._new_node( + NodeType.IFLOOP, do_while_statement["condition"]["src"] + ) node_condition.add_unparsed_expression(do_while_statement["condition"]) - statement = self._parse_statement(do_while_statement["body"], node_condition) + statement = self._parse_statement( + do_while_statement["body"], node_condition + ) else: children = do_while_statement[self.get_children("children")] # same order in the AST as while @@ -568,7 +588,10 @@ class FunctionSolc: if not node_condition.underlying_node.sons: link_underlying_nodes(node_startDoWhile, node_condition) else: - link_nodes(node_startDoWhile.underlying_node, node_condition.underlying_node.sons[0]) + link_nodes( + node_startDoWhile.underlying_node, + node_condition.underlying_node.sons[0], + ) link_underlying_nodes(statement, node_condition) link_underlying_nodes(node_condition, node_endDoWhile) return node_endDoWhile @@ -577,7 +600,9 @@ class FunctionSolc: externalCall = statement.get("externalCall", None) if externalCall is None: - raise ParsingError("Try/Catch not correctly parsed by Slither %s" % statement) + raise ParsingError( + "Try/Catch not correctly parsed by Slither %s" % statement + ) new_node = self._new_node(NodeType.TRY, statement["src"]) new_node.add_unparsed_expression(externalCall) @@ -629,8 +654,8 @@ class FunctionSolc: count = len(variables) if ( - statement["initialValue"]["nodeType"] == "TupleExpression" - and len(statement["initialValue"]["components"]) == count + statement["initialValue"]["nodeType"] == "TupleExpression" + and len(statement["initialValue"]["components"]) == count ): inits = statement["initialValue"]["components"] i = 0 @@ -648,7 +673,9 @@ class FunctionSolc: "declarations": [variable], "initialValue": init, } - new_node = self._parse_variable_definition(new_statement, new_node) + new_node = self._parse_variable_definition( + new_statement, new_node + ) else: # If we have @@ -681,7 +708,9 @@ class FunctionSolc: "nodeType": "Identifier", "src": v["src"], "name": v["name"], - "typeDescriptions": {"typeString": v["typeDescriptions"]["typeString"]}, + "typeDescriptions": { + "typeString": v["typeDescriptions"]["typeString"] + }, } var_identifiers.append(identifier) @@ -732,7 +761,9 @@ class FunctionSolc: self.get_children("children"): [variable, init], } - new_node = self._parse_variable_definition(new_statement, new_node) + new_node = self._parse_variable_definition( + new_statement, new_node + ) else: # If we have # var (a, b) = f() @@ -788,7 +819,7 @@ class FunctionSolc: return new_node def _parse_variable_definition_init_tuple( - self, statement: Dict, index: int, node: NodeSolc + self, statement: Dict, index: int, node: NodeSolc ) -> NodeSolc: local_var = LocalVariableInitFromTuple() local_var.set_function(self._function) @@ -838,7 +869,7 @@ class FunctionSolc: node = exitpoint else: asm_node = self._new_node(NodeType.ASSEMBLY, statement["src"]) - self._function._contains_assembly = True + self._function.contains_assembly = True # Added with solc 0.4.12 if "operations" in statement: asm_node.underlying_node.add_inline_asm(statement["operations"]) @@ -864,8 +895,8 @@ class FunctionSolc: return_node.add_unparsed_expression(statement["expression"]) else: if ( - self.get_children("children") in statement - and statement[self.get_children("children")] + self.get_children("children") in statement + and statement[self.get_children("children")] ): assert len(statement[self.get_children("children")]) == 1 expression = statement[self.get_children("children")][0] @@ -951,7 +982,9 @@ class FunctionSolc: ################################################################################### ################################################################################### - def _find_end_loop(self, node: Node, visited: List[Node], counter: int) -> Optional[Node]: + def _find_end_loop( + self, node: Node, visited: List[Node], counter: int + ) -> Optional[Node]: # counter allows to explore nested loop if node in visited: return None @@ -1092,7 +1125,9 @@ class FunctionSolc: node_parser.add_unparsed_expression(modifier) # The latest entry point is the entry point, or the latest modifier call if self._function.modifiers: - latest_entry_point = self._function.modifiers_statements[-1].nodes[-1] + latest_entry_point = self._function.modifiers_statements[-1].nodes[ + -1 + ] else: latest_entry_point = self._function.entry_point insert_node(latest_entry_point, node_parser.underlying_node) @@ -1111,7 +1146,9 @@ class FunctionSolc: if self._function.explicit_base_constructor_calls_statements: latest_entry_point = self._function.explicit_base_constructor_calls_statements[ -1 - ].nodes[-1] + ].nodes[ + -1 + ] else: latest_entry_point = self._function.entry_point insert_node(latest_entry_point, node_parser.underlying_node) @@ -1131,7 +1168,7 @@ class FunctionSolc: ################################################################################### def _remove_incorrect_edges(self): - for node in self._node_to_nodesolc.keys(): + for node in self._node_to_nodesolc: if node.type in [NodeType.RETURN, NodeType.THROW]: for son in node.sons: son.remove_father(node) @@ -1166,13 +1203,15 @@ class FunctionSolc: while set(prev_nodes) != set(self._node_to_nodesolc.keys()): prev_nodes = self._node_to_nodesolc.keys() to_remove: List[Node] = [] - for node in self._node_to_nodesolc.keys(): + for node in self._node_to_nodesolc: if node.type == NodeType.ENDIF and not node.fathers: for son in node.sons: son.remove_father(node) node.set_sons([]) to_remove.append(node) - self._function.nodes = [n for n in self._function.nodes if n not in to_remove] + self._function.nodes = [ + n for n in self._function.nodes if n not in to_remove + ] for remove in to_remove: if remove in self._node_to_nodesolc: del self._node_to_nodesolc[remove] @@ -1189,7 +1228,7 @@ class FunctionSolc: updated = False while ternary_found: ternary_found = False - for node in self._node_to_nodesolc.keys(): + for node in self._node_to_nodesolc: has_cond = HasConditional(node.expression) if has_cond.result(): st = SplitTernaryExpression(node.expression) @@ -1207,18 +1246,20 @@ class FunctionSolc: return updated def _split_ternary_node( - self, - node: Node, - condition: "Expression", - true_expr: "Expression", - false_expr: "Expression", + self, + node: Node, + condition: "Expression", + true_expr: "Expression", + false_expr: "Expression", ): condition_node = self._new_node(NodeType.IF, node.source_mapping) condition_node.underlying_node.add_expression(condition) condition_node.analyze_expressions(self) if node.type == NodeType.VARIABLE: - condition_node.underlying_node.add_variable_declaration(node.variable_declaration) + condition_node.underlying_node.add_variable_declaration( + node.variable_declaration + ) true_node_parser = self._new_node(NodeType.EXPRESSION, node.source_mapping) if node.type == NodeType.VARIABLE: @@ -1253,12 +1294,20 @@ class FunctionSolc: link_underlying_nodes(condition_node, true_node_parser) link_underlying_nodes(condition_node, false_node_parser) - if true_node_parser.underlying_node.type not in [NodeType.THROW, NodeType.RETURN]: + if true_node_parser.underlying_node.type not in [ + NodeType.THROW, + NodeType.RETURN, + ]: link_underlying_nodes(true_node_parser, endif_node) - if false_node_parser.underlying_node.type not in [NodeType.THROW, NodeType.RETURN]: + if false_node_parser.underlying_node.type not in [ + NodeType.THROW, + NodeType.RETURN, + ]: link_underlying_nodes(false_node_parser, endif_node) - self._function.nodes = [n for n in self._function.nodes if n.node_id != node.node_id] + self._function.nodes = [ + n for n in self._function.nodes if n.node_id != node.node_id + ] del self._node_to_nodesolc[node] # endregion diff --git a/slither/solc_parsing/declarations/modifier.py b/slither/solc_parsing/declarations/modifier.py index 99268804a..7256a32e3 100644 --- a/slither/solc_parsing/declarations/modifier.py +++ b/slither/solc_parsing/declarations/modifier.py @@ -14,7 +14,9 @@ if TYPE_CHECKING: class ModifierSolc(FunctionSolc): - def __init__(self, modifier: Modifier, function_data: Dict, contract_parser: "ContractSolc"): + def __init__( + self, modifier: Modifier, function_data: Dict, contract_parser: "ContractSolc" + ): super().__init__(modifier, function_data, contract_parser) # _modifier is equal to _function, but keep it here to prevent # confusion for mypy in underlying_function diff --git a/slither/solc_parsing/declarations/structure.py b/slither/solc_parsing/declarations/structure.py index 513d2a5b8..0fa2a36a0 100644 --- a/slither/solc_parsing/declarations/structure.py +++ b/slither/solc_parsing/declarations/structure.py @@ -11,14 +11,14 @@ if TYPE_CHECKING: from slither.solc_parsing.declarations.contract import ContractSolc -class StructureSolc: +class StructureSolc: # pylint: disable=too-few-public-methods """ Structure class """ # elems = [(type, name)] - def __init__( + def __init__( # pylint: disable=too-many-arguments self, st: Structure, name: str, diff --git a/slither/solc_parsing/expressions/expression_parsing.py b/slither/solc_parsing/expressions/expression_parsing.py index 7d14a0026..1326d2515 100644 --- a/slither/solc_parsing/expressions/expression_parsing.py +++ b/slither/solc_parsing/expressions/expression_parsing.py @@ -17,10 +17,15 @@ from slither.core.expressions.assignment_operation import ( AssignmentOperation, AssignmentOperationType, ) -from slither.core.expressions.binary_operation import BinaryOperation, BinaryOperationType +from slither.core.expressions.binary_operation import ( + BinaryOperation, + BinaryOperationType, +) from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.conditional_expression import ConditionalExpression -from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.elementary_type_name_expression import ( + ElementaryTypeNameExpression, +) from slither.core.expressions.identifier import Identifier from slither.core.expressions.index_access import IndexAccess from slither.core.expressions.literal import Literal @@ -33,7 +38,12 @@ from slither.core.expressions.super_identifier import SuperIdentifier from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.type_conversion import TypeConversion from slither.core.expressions.unary_operation import UnaryOperation, UnaryOperationType -from slither.core.solidity_types import ArrayType, ElementaryType, FunctionType, MappingType +from slither.core.solidity_types import ( + ArrayType, + ElementaryType, + FunctionType, + MappingType, +) from slither.core.variables.variable import Variable from slither.solc_parsing.exceptions import ParsingError, VariableNotFound from slither.solc_parsing.solidity_types.type_parsing import UnknownType, parse_type @@ -45,6 +55,8 @@ if TYPE_CHECKING: logger = logging.getLogger("ExpressionParsing") +# pylint: disable=anomalous-backslash-in-string,import-outside-toplevel,too-many-branches,too-many-locals + ################################################################################### ################################################################################### # region Helpers @@ -68,13 +80,20 @@ def get_pointer_name(variable: Variable): return None -def find_variable( +def find_variable( # pylint: disable=too-many-locals,too-many-statements var_name: str, caller_context: CallerContext, referenced_declaration: Optional[int] = None, is_super=False, ) -> Union[ - Variable, Function, Contract, SolidityVariable, SolidityFunction, Event, Enum, Structure + Variable, + Function, + Contract, + SolidityVariable, + SolidityFunction, + Event, + Enum, + Structure, ]: from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.function import FunctionSolc @@ -218,7 +237,9 @@ def find_variable( if function_candidate.referenced_declaration == referenced_declaration: return function_candidate.underlying_function - raise VariableNotFound("Variable not found: {} (context {})".format(var_name, caller_context)) + raise VariableNotFound( + "Variable not found: {} (context {})".format(var_name, caller_context) + ) # endregion @@ -273,7 +294,7 @@ def filter_name(value: str) -> str: ################################################################################### -def parse_call(expression: Dict, caller_context): +def parse_call(expression: Dict, caller_context): # pylint: disable=too-many-statements src = expression["src"] if caller_context.is_compact_ast: attributes = expression @@ -323,7 +344,9 @@ def parse_call(expression: Dict, caller_context): if expression["expression"][caller_context.get_key()] == "FunctionCallOptions": call_with_options = expression["expression"] for idx, name in enumerate(call_with_options.get("names", [])): - option = parse_expression(call_with_options["options"][idx], caller_context) + option = parse_expression( + call_with_options["options"][idx], caller_context + ) if name == "value": call_value = option if name == "gas": @@ -332,7 +355,9 @@ def parse_call(expression: Dict, caller_context): call_salt = option arguments = [] if expression["arguments"]: - arguments = [parse_expression(a, caller_context) for a in expression["arguments"]] + arguments = [ + parse_expression(a, caller_context) for a in expression["arguments"] + ] else: children = expression["children"] called = parse_expression(children[0], caller_context) @@ -394,6 +419,7 @@ def _parse_elementary_type_name_expression( def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expression": + # pylint: disable=too-many-nested-blocks,too-many-statements """ Returns: @@ -432,7 +458,9 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres else: attributes = expression["attributes"] assert "prefix" in attributes - operation_type = UnaryOperationType.get_type(attributes["operator"], attributes["prefix"]) + operation_type = UnaryOperationType.get_type( + attributes["operator"], attributes["prefix"] + ) if is_compact_ast: expression = parse_expression(expression["subExpression"], caller_context) @@ -443,7 +471,7 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres unary_op.set_offset(src, caller_context.slither) return unary_op - elif name == "BinaryOperation": + if name == "BinaryOperation": if is_compact_ast: attributes = expression else: @@ -451,50 +479,60 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres operation_type = BinaryOperationType.get_type(attributes["operator"]) if is_compact_ast: - left_expression = parse_expression(expression["leftExpression"], caller_context) - right_expression = parse_expression(expression["rightExpression"], caller_context) + left_expression = parse_expression( + expression["leftExpression"], caller_context + ) + right_expression = parse_expression( + expression["rightExpression"], caller_context + ) else: assert len(expression["children"]) == 2 - left_expression = parse_expression(expression["children"][0], caller_context) - right_expression = parse_expression(expression["children"][1], caller_context) + left_expression = parse_expression( + expression["children"][0], caller_context + ) + right_expression = parse_expression( + expression["children"][1], caller_context + ) binary_op = BinaryOperation(left_expression, right_expression, operation_type) binary_op.set_offset(src, caller_context.slither) return binary_op - elif name in "FunctionCall": + if name in "FunctionCall": return parse_call(expression, caller_context) - elif name == "FunctionCallOptions": + if name == "FunctionCallOptions": # call/gas info are handled in parse_call called = parse_expression(expression["expression"], caller_context) assert isinstance(called, (MemberAccess, NewContract)) return called - elif name == "TupleExpression": - """ - For expression like - (a,,c) = (1,2,3) - the AST provides only two children in the left side - We check the type provided (tuple(uint256,,uint256)) - To determine that there is an empty variable - Otherwhise we would not be able to determine that - a = 1, c = 3, and 2 is lost - - Note: this is only possible with Solidity >= 0.4.12 - """ + if name == "TupleExpression": + # For expression like + # (a,,c) = (1,2,3) + # the AST provides only two children in the left side + # We check the type provided (tuple(uint256,,uint256)) + # To determine that there is an empty variable + # Otherwhise we would not be able to determine that + # a = 1, c = 3, and 2 is lost + # + # Note: this is only possible with Solidity >= 0.4.12 if is_compact_ast: expressions = [ - parse_expression(e, caller_context) if e else None for e in expression["components"] + parse_expression(e, caller_context) if e else None + for e in expression["components"] ] else: if "children" not in expression: attributes = expression["attributes"] components = attributes["components"] expressions = [ - parse_expression(c, caller_context) if c else None for c in components + parse_expression(c, caller_context) if c else None + for c in components ] else: - expressions = [parse_expression(e, caller_context) for e in expression["children"]] + expressions = [ + parse_expression(e, caller_context) for e in expression["children"] + ] # Add none for empty tuple items if "attributes" in expression: if "type" in expression["attributes"]: @@ -502,32 +540,42 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres if ",," in t or "(," in t or ",)" in t: t = t[len("tuple(") : -1] elems = t.split(",") - for idx in range(len(elems)): + for idx, _ in enumerate(elems): if elems[idx] == "": expressions.insert(idx, None) t = TupleExpression(expressions) t.set_offset(src, caller_context.slither) return t - elif name == "Conditional": + if name == "Conditional": if is_compact_ast: if_expression = parse_expression(expression["condition"], caller_context) - then_expression = parse_expression(expression["trueExpression"], caller_context) - else_expression = parse_expression(expression["falseExpression"], caller_context) + then_expression = parse_expression( + expression["trueExpression"], caller_context + ) + else_expression = parse_expression( + expression["falseExpression"], caller_context + ) else: children = expression["children"] assert len(children) == 3 if_expression = parse_expression(children[0], caller_context) then_expression = parse_expression(children[1], caller_context) else_expression = parse_expression(children[2], caller_context) - conditional = ConditionalExpression(if_expression, then_expression, else_expression) + conditional = ConditionalExpression( + if_expression, then_expression, else_expression + ) conditional.set_offset(src, caller_context.slither) return conditional - elif name == "Assignment": + if name == "Assignment": if is_compact_ast: - left_expression = parse_expression(expression["leftHandSide"], caller_context) - right_expression = parse_expression(expression["rightHandSide"], caller_context) + left_expression = parse_expression( + expression["leftHandSide"], caller_context + ) + right_expression = parse_expression( + expression["rightHandSide"], caller_context + ) operation_type = AssignmentOperationType.get_type(expression["operator"]) @@ -548,7 +596,7 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres assignement.set_offset(src, caller_context.slither) return assignement - elif name == "Literal": + if name == "Literal": subdenomination = None @@ -599,7 +647,7 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres literal.set_offset(src, caller_context.slither) return literal - elif name == "Identifier": + if name == "Identifier": assert "children" not in expression t = None @@ -613,7 +661,9 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres t = expression["attributes"]["type"] if t: - found = re.findall("[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)", t) + found = re.findall( + "[struct|enum|function|modifier] \(([\[\] ()a-zA-Z0-9\.,_]*)\)", t + ) assert len(found) <= 1 if found: value = value + "(" + found[0] + ")" @@ -630,7 +680,7 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres identifier.set_offset(src, caller_context.slither) return identifier - elif name == "IndexAccess": + if name == "IndexAccess": if is_compact_ast: index_type = expression["typeDescriptions"]["typeString"] left = expression["baseExpression"] @@ -658,11 +708,13 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres index.set_offset(src, caller_context.slither) return index - elif name == "MemberAccess": + if name == "MemberAccess": if caller_context.is_compact_ast: member_name = expression["memberName"] member_type = expression["typeDescriptions"]["typeString"] - member_expression = parse_expression(expression["expression"], caller_context) + member_expression = parse_expression( + expression["expression"], caller_context + ) else: member_name = expression["attributes"]["member_name"] member_type = expression["attributes"]["type"] @@ -685,11 +737,13 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres return idx return member_access - elif name == "ElementaryTypeNameExpression": - return _parse_elementary_type_name_expression(expression, is_compact_ast, caller_context) + if name == "ElementaryTypeNameExpression": + return _parse_elementary_type_name_expression( + expression, is_compact_ast, caller_context + ) # NewExpression is not a root expression, it's always the child of another expression - elif name == "NewExpression": + if name == "NewExpression": if is_compact_ast: type_name = expression["typeName"] @@ -715,7 +769,9 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres array_type = ElementaryType(type_name["attributes"]["name"]) elif type_name[caller_context.get_key()] == "UserDefinedTypeName": if is_compact_ast: - array_type = parse_type(UnknownType(type_name["name"]), caller_context) + array_type = parse_type( + UnknownType(type_name["name"]), caller_context + ) else: array_type = parse_type( UnknownType(type_name["attributes"]["name"]), caller_context @@ -747,13 +803,15 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres new.set_offset(src, caller_context.slither) return new - elif name == "ModifierInvocation": + if name == "ModifierInvocation": if is_compact_ast: called = parse_expression(expression["modifierName"], caller_context) arguments = [] if expression["arguments"]: - arguments = [parse_expression(a, caller_context) for a in expression["arguments"]] + arguments = [ + parse_expression(a, caller_context) for a in expression["arguments"] + ] else: children = expression["children"] called = parse_expression(children[0], caller_context) @@ -763,7 +821,7 @@ def parse_expression(expression: Dict, caller_context: CallerContext) -> "Expres call.set_offset(src, caller_context.slither) return call - elif name == "IndexRangeAccess": + if name == "IndexRangeAccess": # For now, we convert array slices to a direct array access # As a result the generated IR will lose the slices information # As far as I understand, array slice are only used in abi.decode diff --git a/slither/solc_parsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py index 584e4dbd4..3b0751ce8 100644 --- a/slither/solc_parsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -7,10 +7,6 @@ from typing import List, Dict from slither.core.declarations import Contract from slither.exceptions import SlitherException -logging.basicConfig() -logger = logging.getLogger("SlitherSolcParsing") -logger.setLevel(logging.INFO) - from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.function import FunctionSolc from slither.core.slither_core import SlitherCore @@ -18,8 +14,12 @@ from slither.core.declarations.pragma_directive import Pragma from slither.core.declarations.import_directive import Import from slither.analyses.data_dependency.data_dependency import compute_dependency +logging.basicConfig() +logger = logging.getLogger("SlitherSolcParsing") +logger.setLevel(logging.INFO) class SlitherSolc: + # pylint: disable=no-self-use def __init__(self, filename: str, core: SlitherCore): super(SlitherSolc, self).__init__() core.filename = filename @@ -82,16 +82,17 @@ class SlitherSolc: data_loaded = json.loads(json_data) # Truffle AST if "ast" in data_loaded: - self.parse_contracts_from_loaded_json(data_loaded["ast"], data_loaded["sourcePath"]) + self.parse_contracts_from_loaded_json( + data_loaded["ast"], data_loaded["sourcePath"] + ) return True # solc AST, where the non-json text was removed + if "attributes" in data_loaded: + filename = data_loaded["attributes"]["absolutePath"] else: - if "attributes" in data_loaded: - filename = data_loaded["attributes"]["absolutePath"] - else: - filename = data_loaded["absolutePath"] - self.parse_contracts_from_loaded_json(data_loaded, filename) - return True + filename = data_loaded["absolutePath"] + self.parse_contracts_from_loaded_json(data_loaded, filename) + return True except ValueError: first = json_data.find("{") @@ -105,7 +106,7 @@ class SlitherSolc: return True return False - def parse_contracts_from_loaded_json(self, data_loaded: Dict, filename: str): + def parse_contracts_from_loaded_json(self, data_loaded: Dict, filename: str): # pylint: disable=too-many-branches if "nodeType" in data_loaded: self._is_compact_ast = True @@ -115,11 +116,11 @@ class SlitherSolc: self._core.add_source_code(sourcePath) if data_loaded[self.get_key()] == "root": - self._core._solc_version = "0.3" + self._core.solc_version = "0.3" logger.error("solc <0.4 is not supported") return - elif data_loaded[self.get_key()] == "SourceUnit": - self._core._solc_version = "0.4" + if data_loaded[self.get_key()] == "SourceUnit": + self._core.solc_version = "0.4" self._parse_source_unit(data_loaded, filename) else: logger.error("solc version is not supported") @@ -153,18 +154,25 @@ class SlitherSolc: if self.is_compact_ast: import_directive = Import(contract_data["absolutePath"]) else: - import_directive = Import(contract_data["attributes"]["absolutePath"]) + import_directive = Import( + contract_data["attributes"]["absolutePath"] + ) import_directive.set_offset(contract_data["src"], self._core) self._core.import_directives.append(import_directive) - elif contract_data[self.get_key()] in ["StructDefinition", "EnumDefinition"]: + elif contract_data[self.get_key()] in [ + "StructDefinition", + "EnumDefinition", + ]: # This can only happen for top-level structure and enum # They were introduced with 0.6.5 - assert self._is_compact_ast # Do not support top level definition for legacy AST + assert ( + self._is_compact_ast + ) # Do not support top level definition for legacy AST fake_contract_data = { "name": f"SlitherInternalTopLevelContract{self._top_level_contracts_counter}", "id": -1000 - + self._top_level_contracts_counter, # TODO: determine if collission possible + + self._top_level_contracts_counter, # TODO: determine if collission possible "linearizedBaseContracts": [], "fullyImplemented": True, "contractKind": "SLitherInternal", @@ -176,11 +184,11 @@ class SlitherSolc: contract.set_offset(contract_data["src"], self._core) if contract_data[self.get_key()] == "StructDefinition": - top_level_contract._structuresNotParsed.append( + top_level_contract.structures_not_parsed.append( contract_data ) # Todo add proper setters else: - top_level_contract._enumsNotParsed.append( + top_level_contract.enums_not_parsed.append( contract_data ) # Todo add proper setters @@ -188,7 +196,7 @@ class SlitherSolc: def _parse_source_unit(self, data: Dict, filename: str): if data[self.get_key()] != "SourceUnit": - return -1 # handle solc prior 0.3.6 + return # handle solc prior 0.3.6 # match any char for filename # filename can contain space, /, -, .. @@ -232,7 +240,7 @@ class SlitherSolc: def analyzed(self) -> bool: return self._analyzed - def analyze_contracts(self): + def analyze_contracts(self): # pylint: disable=too-many-statements,too-many-branches if not self._underlying_contract_to_parser: logger.info( f"No contract were found in {self._core.filename}, check the correct compilation" @@ -242,10 +250,10 @@ class SlitherSolc: # First we save all the contracts in a dict # the key is the contractid - for contract in self._underlying_contract_to_parser.keys(): + for contract in self._underlying_contract_to_parser: if ( - contract.name.startswith("SlitherInternalTopLevelContract") - and not contract.is_top_level + contract.name.startswith("SlitherInternalTopLevelContract") + and not contract.is_top_level ): raise SlitherException( """Your codebase has a contract named 'SlitherInternalTopLevelContract'. @@ -286,7 +294,9 @@ Please rename it, this name is reserved for Slither's internals""" # Resolve immediate base contracts for i in contract_parser.baseContracts: if i in contract_parser.remapping: - fathers.append(self._core.get_contract_from_name(contract_parser.remapping[i])) + fathers.append( + self._core.get_contract_from_name(contract_parser.remapping[i]) + ) elif i in self._contracts_by_id: fathers.append(self._contracts_by_id[i]) else: @@ -311,7 +321,9 @@ Please rename it, this name is reserved for Slither's internals""" self._core.contracts_with_missing_inheritance.add( contract_parser.underlying_contract ) - contract_parser.log_incorrect_parsing(f"Missing inheritance {contract_parser}") + contract_parser.log_incorrect_parsing( + f"Missing inheritance {contract_parser}" + ) contract_parser.set_is_analyzed(True) contract_parser.delete_content() @@ -319,17 +331,23 @@ Please rename it, this name is reserved for Slither's internals""" # Any contract can refer another contract enum without need for inheritance self._analyze_all_enums(contracts_to_be_analyzed) + # pylint: disable=expression-not-assigned [c.set_is_analyzed(False) for c in self._underlying_contract_to_parser.values()] libraries = [ - c for c in contracts_to_be_analyzed if c.underlying_contract.contract_kind == "library" + c + for c in contracts_to_be_analyzed + if c.underlying_contract.contract_kind == "library" ] contracts_to_be_analyzed = [ - c for c in contracts_to_be_analyzed if c.underlying_contract.contract_kind != "library" + c + for c in contracts_to_be_analyzed + if c.underlying_contract.contract_kind != "library" ] # We first parse the struct/variables/functions/contract self._analyze_first_part(contracts_to_be_analyzed, libraries) + # pylint: disable=expression-not-assigned [c.set_is_analyzed(False) for c in self._underlying_contract_to_parser.values()] # We analyze the struct and parse and analyze the events @@ -362,10 +380,11 @@ Please rename it, this name is reserved for Slither's internals""" self._analyze_enums(contract) else: contracts_to_be_analyzed += [contract] - return def _analyze_first_part( - self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc] + self, + contracts_to_be_analyzed: List[ContractSolc], + libraries: List[ContractSolc], ): for lib in libraries: self._parse_struct_var_modifiers_functions(lib) @@ -388,10 +407,11 @@ Please rename it, this name is reserved for Slither's internals""" else: contracts_to_be_analyzed += [contract] - return def _analyze_second_part( - self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc] + self, + contracts_to_be_analyzed: List[ContractSolc], + libraries: List[ContractSolc], ): for lib in libraries: self._analyze_struct_events(lib) @@ -414,10 +434,11 @@ Please rename it, this name is reserved for Slither's internals""" else: contracts_to_be_analyzed += [contract] - return def _analyze_third_part( - self, contracts_to_be_analyzed: List[ContractSolc], libraries: List[ContractSolc] + self, + contracts_to_be_analyzed: List[ContractSolc], + libraries: List[ContractSolc], ): for lib in libraries: self._analyze_variables_modifiers_functions(lib) @@ -440,7 +461,6 @@ Please rename it, this name is reserved for Slither's internals""" else: contracts_to_be_analyzed += [contract] - return def _analyze_enums(self, contract: ContractSolc): # Enum must be analyzed first diff --git a/slither/solc_parsing/solidity_types/type_parsing.py b/slither/solc_parsing/solidity_types/type_parsing.py index 4ec8f5112..16d9b23b6 100644 --- a/slither/solc_parsing/solidity_types/type_parsing.py +++ b/slither/solc_parsing/solidity_types/type_parsing.py @@ -1,7 +1,11 @@ import logging +import re from typing import List, TYPE_CHECKING, Union, Dict -from slither.core.solidity_types.elementary_type import ElementaryType, ElementaryTypeName +from slither.core.solidity_types.elementary_type import ( + ElementaryType, + ElementaryTypeName, +) from slither.core.solidity_types.type import Type from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.core.solidity_types.array_type import ArrayType @@ -15,15 +19,15 @@ from slither.core.declarations.contract import Contract from slither.core.expressions.literal import Literal from slither.solc_parsing.exceptions import ParsingError -import re if TYPE_CHECKING: from slither.core.declarations import Structure, Enum logger = logging.getLogger("TypeParsing") +# pylint: disable=anomalous-backslash-in-string -class UnknownType: +class UnknownType: # pylint: disable=too-few-public-methods def __init__(self, name): self._name = name @@ -32,7 +36,7 @@ class UnknownType: return self._name -def _find_from_type_name( +def _find_from_type_name( # pylint: disable=too-many-locals,too-many-branches,too-many-statements name: str, contract: Contract, contracts: List[Contract], @@ -46,8 +50,7 @@ def _find_from_type_name( depth = name.count("[") if depth: return ArrayType(ElementaryType(name_elementary), Literal(depth, "uint256")) - else: - return ElementaryType(name_elementary) + return ElementaryType(name_elementary) # We first look for contract # To avoid collision # Ex: a structure with the name of a contract @@ -71,25 +74,33 @@ def _find_from_type_name( all_enums = [item for sublist in all_enums for item in sublist] var_type = next((e for e in all_enums if e.name == enum_name), None) if not var_type: - var_type = next((e for e in all_enums if e.canonical_name == enum_name), None) + var_type = next( + (e for e in all_enums if e.canonical_name == enum_name), None + ) if not var_type: # any contract can refer to another contract's structure name_struct = name if name_struct.startswith("struct "): name_struct = name_struct[len("struct ") :] - name_struct = name_struct.split(" ")[0] # remove stuff like storage pointer at the end + name_struct = name_struct.split(" ")[ + 0 + ] # remove stuff like storage pointer at the end all_structures = [c.structures for c in contracts] all_structures = [item for sublist in all_structures for item in sublist] var_type = next((st for st in all_structures if st.name == name_struct), None) if not var_type: - var_type = next((st for st in all_structures if st.canonical_name == name_struct), None) + var_type = next( + (st for st in all_structures if st.canonical_name == name_struct), None + ) # case where struct xxx.xx[] where not well formed in the AST if not var_type: depth = 0 while name_struct.endswith("[]"): name_struct = name_struct[0:-2] depth += 1 - var_type = next((st for st in all_structures if st.canonical_name == name_struct), None) + var_type = next( + (st for st in all_structures if st.canonical_name == name_struct), None + ) if var_type: return ArrayType(UserDefinedType(var_type), Literal(depth, "uint256")) @@ -104,7 +115,8 @@ def _find_from_type_name( params = found[0][0].split(",") return_values = found[0][1].split(",") params = [ - _find_from_type_name(p, contract, contracts, structures, enums) for p in params + _find_from_type_name(p, contract, contracts, structures, enums) + for p in params ] return_values = [ _find_from_type_name(r, contract, contracts, structures, enums) @@ -125,16 +137,21 @@ def _find_from_type_name( if name.startswith("mapping("): # nested mapping declared with var if name.count("mapping(") == 1: - found = re.findall("mapping\(([a-zA-Z0-9\.]*) => ([a-zA-Z0-9\.\[\]]*)\)", name) + found = re.findall( + "mapping\(([a-zA-Z0-9\.]*) => ([a-zA-Z0-9\.\[\]]*)\)", name + ) else: found = re.findall( - "mapping\(([a-zA-Z0-9\.]*) => (mapping\([=> a-zA-Z0-9\.\[\]]*\))\)", name + "mapping\(([a-zA-Z0-9\.]*) => (mapping\([=> a-zA-Z0-9\.\[\]]*\))\)", + name, ) assert len(found) == 1 from_ = found[0][0] to_ = found[0][1] - from_type = _find_from_type_name(from_, contract, contracts, structures, enums) + from_type = _find_from_type_name( + from_, contract, contracts, structures, enums + ) to_type = _find_from_type_name(to_, contract, contracts, structures, enums) return MappingType(from_type, to_type) @@ -146,8 +163,12 @@ def _find_from_type_name( def parse_type(t: Union[Dict, UnknownType], caller_context): # local import to avoid circular dependency + # pylint: disable=too-many-locals,too-many-branches,too-many-statements + # pylint: disable=import-outside-toplevel from slither.solc_parsing.expressions.expression_parsing import parse_expression - from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc + from slither.solc_parsing.variables.function_type_variable import ( + FunctionTypeVariableSolc, + ) from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.function import FunctionSolc @@ -174,15 +195,19 @@ def parse_type(t: Union[Dict, UnknownType], caller_context): if isinstance(t, UnknownType): return _find_from_type_name(t.name, contract, contracts, structures, enums) - elif t[key] == "ElementaryTypeName": + if t[key] == "ElementaryTypeName": if is_compact_ast: return ElementaryType(t["name"]) return ElementaryType(t["attributes"][key]) - elif t[key] == "UserDefinedTypeName": + if t[key] == "UserDefinedTypeName": if is_compact_ast: return _find_from_type_name( - t["typeDescriptions"]["typeString"], contract, contracts, structures, enums + t["typeDescriptions"]["typeString"], + contract, + contracts, + structures, + enums, ) # Determine if we have a type node (otherwise we use the name node, as some older solc did not have 'type'). @@ -191,7 +216,7 @@ def parse_type(t: Union[Dict, UnknownType], caller_context): t["attributes"][type_name_key], contract, contracts, structures, enums ) - elif t[key] == "ArrayTypeName": + if t[key] == "ArrayTypeName": length = None if is_compact_ast: if t["length"]: @@ -205,7 +230,7 @@ def parse_type(t: Union[Dict, UnknownType], caller_context): array_type = parse_type(t["children"][0], contract_parser) return ArrayType(array_type, length) - elif t[key] == "Mapping": + if t[key] == "Mapping": if is_compact_ast: mappingFrom = parse_type(t["keyType"], contract_parser) @@ -218,7 +243,7 @@ def parse_type(t: Union[Dict, UnknownType], caller_context): return MappingType(mappingFrom, mappingTo) - elif t[key] == "FunctionTypeName": + if t[key] == "FunctionTypeName": if is_compact_ast: params = t["parameterTypes"] diff --git a/slither/solc_parsing/variables/event_variable.py b/slither/solc_parsing/variables/event_variable.py index 6d743ba11..038479dd8 100644 --- a/slither/solc_parsing/variables/event_variable.py +++ b/slither/solc_parsing/variables/event_variable.py @@ -1,6 +1,6 @@ from typing import Dict -from .variable_declaration import VariableDeclarationSolc +from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc from slither.core.variables.event_variable import EventVariable diff --git a/slither/solc_parsing/variables/local_variable.py b/slither/solc_parsing/variables/local_variable.py index c2594c938..ba2645ea1 100644 --- a/slither/solc_parsing/variables/local_variable.py +++ b/slither/solc_parsing/variables/local_variable.py @@ -1,6 +1,6 @@ from typing import Dict -from .variable_declaration import VariableDeclarationSolc +from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc from slither.core.variables.local_variable import LocalVariable diff --git a/slither/solc_parsing/variables/local_variable_init_from_tuple.py b/slither/solc_parsing/variables/local_variable_init_from_tuple.py index 3384482c1..63c9cdde3 100644 --- a/slither/solc_parsing/variables/local_variable_init_from_tuple.py +++ b/slither/solc_parsing/variables/local_variable_init_from_tuple.py @@ -1,11 +1,15 @@ from typing import Dict -from .variable_declaration import VariableDeclarationSolc -from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc +from slither.core.variables.local_variable_init_from_tuple import ( + LocalVariableInitFromTuple, +) class LocalVariableInitFromTupleSolc(VariableDeclarationSolc): - def __init__(self, variable: LocalVariableInitFromTuple, variable_data: Dict, index: int): + def __init__( + self, variable: LocalVariableInitFromTuple, variable_data: Dict, index: int + ): super(LocalVariableInitFromTupleSolc, self).__init__(variable, variable_data) variable.tuple_index = index diff --git a/slither/solc_parsing/variables/state_variable.py b/slither/solc_parsing/variables/state_variable.py index 398b8ff3c..8990d8e86 100644 --- a/slither/solc_parsing/variables/state_variable.py +++ b/slither/solc_parsing/variables/state_variable.py @@ -1,6 +1,6 @@ from typing import Dict -from .variable_declaration import VariableDeclarationSolc +from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc from slither.core.variables.state_variable import StateVariable diff --git a/slither/solc_parsing/variables/structure_variable.py b/slither/solc_parsing/variables/structure_variable.py index 750778678..a41c47e35 100644 --- a/slither/solc_parsing/variables/structure_variable.py +++ b/slither/solc_parsing/variables/structure_variable.py @@ -1,6 +1,6 @@ from typing import Dict -from .variable_declaration import VariableDeclarationSolc +from slither.solc_parsing.variables.variable_declaration import VariableDeclarationSolc from slither.core.variables.structure_variable import StructureVariable diff --git a/slither/solc_parsing/variables/variable_declaration.py b/slither/solc_parsing/variables/variable_declaration.py index 61cc8d403..2c0654cd6 100644 --- a/slither/solc_parsing/variables/variable_declaration.py +++ b/slither/solc_parsing/variables/variable_declaration.py @@ -7,7 +7,10 @@ from slither.core.variables.variable import Variable from slither.solc_parsing.solidity_types.type_parsing import parse_type, UnknownType -from slither.core.solidity_types.elementary_type import ElementaryType, NonElementaryType +from slither.core.solidity_types.elementary_type import ( + ElementaryType, + NonElementaryType, +) from slither.solc_parsing.exceptions import ParsingError logger = logging.getLogger("VariableDeclarationSolcParsing") @@ -19,12 +22,12 @@ class MultipleVariablesDeclaration(Exception): var (a,b) = ... It should occur only on local variable definition """ - + # pylint: disable=unnecessary-pass pass class VariableDeclarationSolc: - def __init__(self, variable: Variable, variable_data: Dict): + def __init__(self, variable: Variable, variable_data: Dict): # pylint: disable=too-many-branches """ A variable can be declared through a statement, or directly. If it is through a statement, the following children may contain @@ -45,7 +48,10 @@ class VariableDeclarationSolc: if "nodeType" in variable_data: self._is_compact_ast = True nodeType = variable_data["nodeType"] - if nodeType in ["VariableDeclarationStatement", "VariableDefinitionStatement"]: + if nodeType in [ + "VariableDeclarationStatement", + "VariableDefinitionStatement", + ]: if len(variable_data["declarations"]) > 1: raise MultipleVariablesDeclaration init = None @@ -55,12 +61,17 @@ class VariableDeclarationSolc: elif nodeType == "VariableDeclaration": self._init_from_declaration(variable_data, variable_data["value"]) else: - raise ParsingError("Incorrect variable declaration type {}".format(nodeType)) + raise ParsingError( + "Incorrect variable declaration type {}".format(nodeType) + ) else: nodeType = variable_data["name"] - if nodeType in ["VariableDeclarationStatement", "VariableDefinitionStatement"]: + if nodeType in [ + "VariableDeclarationStatement", + "VariableDefinitionStatement", + ]: if len(variable_data["children"]) == 2: init = variable_data["children"][1] elif len(variable_data["children"]) == 1: @@ -76,7 +87,9 @@ class VariableDeclarationSolc: elif nodeType == "VariableDeclaration": self._init_from_declaration(variable_data, False) else: - raise ParsingError("Incorrect variable declaration type {}".format(nodeType)) + raise ParsingError( + "Incorrect variable declaration type {}".format(nodeType) + ) @property def underlying_variable(self) -> Variable: @@ -96,7 +109,7 @@ class VariableDeclarationSolc: else: self._variable.visibility = "internal" - def _init_from_declaration(self, var: Dict, init: bool): + def _init_from_declaration(self, var: Dict, init: bool): # pylint: disable=too-many-branches if self._is_compact_ast: attributes = var self._typeName = attributes["typeDescriptions"]["typeString"] @@ -170,5 +183,7 @@ class VariableDeclarationSolc: self._elem_to_parse = None if self._variable.initialized: - self._variable.expression = parse_expression(self._initializedNotParsed, caller_context) + self._variable.expression = parse_expression( + self._initializedNotParsed, caller_context + ) self._initializedNotParsed = None diff --git a/slither/solc_parsing/yul/evm_functions.py b/slither/solc_parsing/yul/evm_functions.py index 56f1e7f99..0276d4bf7 100644 --- a/slither/solc_parsing/yul/evm_functions.py +++ b/slither/solc_parsing/yul/evm_functions.py @@ -263,7 +263,7 @@ binary_ops = { } -class YulBuiltin: +class YulBuiltin: # pylint: disable=too-few-public-methods def __init__(self, name): self._name = name diff --git a/slither/solc_parsing/yul/parse_yul.py b/slither/solc_parsing/yul/parse_yul.py index d663e48ed..527d90777 100644 --- a/slither/solc_parsing/yul/parse_yul.py +++ b/slither/solc_parsing/yul/parse_yul.py @@ -3,7 +3,12 @@ import json from typing import Optional, Dict, List, Union from slither.core.cfg.node import NodeType, Node, link_nodes -from slither.core.declarations import Function, SolidityFunction, SolidityVariable, Contract +from slither.core.declarations import ( + Function, + SolidityFunction, + SolidityVariable, + Contract, +) from slither.core.expressions import ( Literal, AssignmentOperation, @@ -19,7 +24,8 @@ from slither.core.slither_core import SlitherCore from slither.core.solidity_types import ElementaryType from slither.core.variables.local_variable import LocalVariable from slither.exceptions import SlitherException -from slither.solc_parsing.yul.evm_functions import * +from slither.solc_parsing.yul.evm_functions import format_function_descriptor, builtins, YulBuiltin, unary_ops, \ + binary_ops from slither.visitors.expression.find_calls import FindCalls from slither.visitors.expression.read_var import ReadVar from slither.visitors.expression.write_var import WriteVar @@ -55,7 +61,9 @@ class YulNode: AssignmentOperationType.ASSIGN, self._node.variable_declaration.type, ) - _expression.set_offset(self._node.expression.source_mapping, self._node.slither) + _expression.set_offset( + self._node.expression.source_mapping, self._node.slither + ) self._node.add_expression(_expression, bypass_verif_empty=True) expression = self._node.expression @@ -68,10 +76,14 @@ class YulNode: find_call = FindCalls(expression) self._node.calls_as_expression = find_call.result() self._node.external_calls_as_expressions = [ - c for c in self._node.calls_as_expression if not isinstance(c.called, Identifier) + c + for c in self._node.calls_as_expression + if not isinstance(c.called, Identifier) ] self._node.internal_calls_as_expressions = [ - c for c in self._node.calls_as_expression if isinstance(c.called, Identifier) + c + for c in self._node.calls_as_expression + if isinstance(c.called, Identifier) ] @@ -80,11 +92,17 @@ def link_underlying_nodes(node1: YulNode, node2: YulNode): class YulScope(metaclass=abc.ABCMeta): - __slots__ = ["_contract", "_id", "_yul_local_variables", "_yul_local_functions", "_parent_func"] - - def __init__(self, contract: Contract, id: List[str], parent_func: Function = None): + __slots__ = [ + "_contract", + "_id", + "_yul_local_variables", + "_yul_local_functions", + "_parent_func", + ] + + def __init__(self, contract: Contract, yul_id: List[str], parent_func: Function = None): self._contract = contract - self._id: List[str] = id + self._id: List[str] = yul_id self._yul_local_variables: List[YulLocalVariable] = [] self._yul_local_functions: List[YulFunction] = [] self._parent_func = parent_func @@ -119,17 +137,25 @@ class YulScope(metaclass=abc.ABCMeta): def get_yul_local_variable_from_name(self, variable_name): return next( - (v for v in self._yul_local_variables if v.underlying.name == variable_name), None + ( + v + for v in self._yul_local_variables + if v.underlying.name == variable_name + ), + None, ) def add_yul_local_function(self, func): self._yul_local_functions.append(func) def get_yul_local_function_from_name(self, func_name): - return next((v for v in self._yul_local_functions if v.underlying.name == func_name), None) + return next( + (v for v in self._yul_local_functions if v.underlying.name == func_name), + None, + ) -class YulLocalVariable: +class YulLocalVariable: # pylint: disable=too-few-public-methods __slots__ = ["_variable", "_root"] def __init__(self, var: LocalVariable, root: YulScope, ast: Dict): @@ -155,7 +181,9 @@ class YulFunction(YulScope): __slots__ = ["_function", "_root", "_ast", "_nodes", "_entrypoint"] def __init__(self, func: Function, root: YulScope, ast: Dict): - super().__init__(root.contract, root.id + [ast["name"]], parent_func=root.parent_func) + super().__init__( + root.contract, root.id + [ast["name"]], parent_func=root.parent_func + ) assert ast["nodeType"] == "YulFunctionDefinition" @@ -203,7 +231,9 @@ class YulFunction(YulScope): for ret in self._ast.get("returnVariables", []): node = convert_yul(self, node, ret) - self._function.add_return(self.get_yul_local_variable_from_name(ret["name"]).underlying) + self._function.add_return( + self.get_yul_local_variable_from_name(ret["name"]).underlying + ) convert_yul(self, node, self._ast["body"]) @@ -231,8 +261,8 @@ class YulBlock(YulScope): __slots__ = ["_entrypoint", "_parent_func", "_nodes"] - def __init__(self, contract: Contract, entrypoint: Node, id: List[str], **kwargs): - super().__init__(contract, id, **kwargs) + def __init__(self, contract: Contract, entrypoint: Node, yul_id: List[str], **kwargs): + super().__init__(contract, yul_id, **kwargs) self._entrypoint: YulNode = YulNode(entrypoint, self) self._nodes: List[YulNode] = [] @@ -269,22 +299,22 @@ class YulBlock(YulScope): ################################################################################### ################################################################################### -""" -The functions in this region, at a high level, will extract the control flow -structures and metadata from the input AST. These include things like function -definitions and local variables. -Each function takes three parameters: - 1) root is the current YulScope, where you can find things like local variables - 2) parent is the previous YulNode, which you'll have to link to - 3) ast is a dictionary and is the current node in the Yul ast being converted - -Each function must return a single parameter: - 1) the new YulNode that the CFG ends at +# The functions in this region, at a high level, will extract the control flow +# structures and metadata from the input AST. These include things like function +# definitions and local variables. +# +# Each function takes three parameters: +# 1) root is the current YulScope, where you can find things like local variables +# 2) parent is the previous YulNode, which you'll have to link to +# 3) ast is a dictionary and is the current node in the Yul ast being converted +# +# Each function must return a single parameter: +# 1) the new YulNode that the CFG ends at +# +# The entrypoint is the function at the end of this region, `convert_yul`, which +# dispatches to a specialized function based on a lookup dictionary. -The entrypoint is the function at the end of this region, `convert_yul`, which -dispatches to a specialized function based on a lookup dictionary. -""" def convert_yul_block(root: YulScope, parent: YulNode, ast: Dict) -> YulNode: @@ -293,11 +323,13 @@ def convert_yul_block(root: YulScope, parent: YulNode, ast: Dict) -> YulNode: return parent -def convert_yul_function_definition(root: YulScope, parent: YulNode, ast: Dict) -> YulNode: +def convert_yul_function_definition( + root: YulScope, parent: YulNode, ast: Dict +) -> YulNode: func = Function() yul_function = YulFunction(func, root, ast) - root.contract._functions[func.canonical_name] = func + root.contract.add_function(func) root.slither.add_function(func) root.add_yul_local_function(yul_function) @@ -307,7 +339,9 @@ def convert_yul_function_definition(root: YulScope, parent: YulNode, ast: Dict) return parent -def convert_yul_variable_declaration(root: YulScope, parent: YulNode, ast: Dict) -> YulNode: +def convert_yul_variable_declaration( + root: YulScope, parent: YulNode, ast: Dict +) -> YulNode: for variable_ast in ast["variables"]: parent = convert_yul(root, parent, variable_ast) @@ -325,7 +359,9 @@ def convert_yul_assignment(root: YulScope, parent: YulNode, ast: Dict) -> YulNod return node -def convert_yul_expression_statement(root: YulScope, parent: YulNode, ast: Dict) -> YulNode: +def convert_yul_expression_statement( + root: YulScope, parent: YulNode, ast: Dict +) -> YulNode: src = ast["src"] expression_ast = ast["expression"] @@ -394,7 +430,7 @@ def convert_yul_switch(root: YulScope, parent: YulNode, ast: Dict) -> YulNode: ], } - last_if = None + last_if: Optional[Dict] = None default_ast = None @@ -418,7 +454,11 @@ def convert_yul_switch(root: YulScope, parent: YulNode, ast: Dict) -> YulNode: "name": "eq", }, "arguments": [ - {"nodeType": "YulIdentifier", "src": case_ast["src"], "name": switch_expr_var,}, + { + "nodeType": "YulIdentifier", + "src": case_ast["src"], + "name": switch_expr_var, + }, value_ast, ], }, @@ -426,7 +466,7 @@ def convert_yul_switch(root: YulScope, parent: YulNode, ast: Dict) -> YulNode: } if last_if: - last_if["false_body"] = current_if + last_if["false_body"] = current_if # pylint: disable=unsupported-assignment-operation else: rewritten_switch["statements"].append(current_if) @@ -545,7 +585,7 @@ Each function takes three parameters: 1) root is the same root as above 2) node is the CFG node which stores this expression 3) ast is the same ast as above - + Each function must return a single parameter: 1) The operation that was parsed, or None @@ -579,11 +619,15 @@ def parse_yul_variable_declaration( return _parse_yul_assignment_common(root, node, ast, "variables") -def parse_yul_assignment(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: +def parse_yul_assignment( + root: YulScope, node: YulNode, ast: Dict +) -> Optional[Expression]: return _parse_yul_assignment_common(root, node, ast, "variableNames") -def parse_yul_function_call(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: +def parse_yul_function_call( + root: YulScope, node: YulNode, ast: Dict +) -> Optional[Expression]: args = [parse_yul(root, node, arg) for arg in ast["arguments"]] ident = parse_yul(root, node, ast["functionName"]) @@ -602,17 +646,23 @@ def parse_yul_function_call(root: YulScope, node: YulNode, ast: Dict) -> Optiona if name in unary_ops: return UnaryOperation(args[0], unary_ops[name]) - ident = Identifier(SolidityFunction(format_function_descriptor(ident.value.name))) + ident = Identifier( + SolidityFunction(format_function_descriptor(ident.value.name)) + ) if isinstance(ident.value, Function): return CallExpression(ident, args, vars_to_typestr(ident.value.returns)) - elif isinstance(ident.value, SolidityFunction): + if isinstance(ident.value, SolidityFunction): return CallExpression(ident, args, vars_to_typestr(ident.value.return_type)) - else: - raise SlitherException(f"unexpected function call target type {str(type(ident.value))}") + + raise SlitherException( + f"unexpected function call target type {str(type(ident.value))}" + ) -def parse_yul_identifier(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: +def parse_yul_identifier( + root: YulScope, _node: YulNode, ast: Dict +) -> Optional[Expression]: name = ast["name"] if name in builtins: @@ -654,7 +704,7 @@ def parse_yul_identifier(root: YulScope, node: YulNode, ast: Dict) -> Optional[E raise SlitherException(f"unresolved reference to identifier {name}") -def parse_yul_literal(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: +def parse_yul_literal(_root: YulScope, _node: YulNode, ast: Dict) -> Optional[Expression]: type_ = ast["type"] value = ast["value"] @@ -664,16 +714,22 @@ def parse_yul_literal(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expr return Literal(value, ElementaryType(type_)) -def parse_yul_typed_name(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: +def parse_yul_typed_name( + root: YulScope, _node: YulNode, ast: Dict +) -> Optional[Expression]: var = root.get_yul_local_variable_from_name(ast["name"]) i = Identifier(var.underlying) - i._type = var.underlying.type + i.type = var.underlying.type return i -def parse_yul_unsupported(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: - raise SlitherException(f"no parser available for {ast['nodeType']} {json.dumps(ast, indent=2)}") +def parse_yul_unsupported( + _root: YulScope, _node: YulNode, ast: Dict +) -> Optional[Expression]: + raise SlitherException( + f"no parser available for {ast['nodeType']} {json.dumps(ast, indent=2)}" + ) def parse_yul(root: YulScope, node: YulNode, ast: Dict) -> Optional[Expression]: @@ -706,7 +762,7 @@ def vars_to_typestr(rets: List[Expression]) -> str: return "tuple({})".format(",".join(str(ret.type) for ret in rets)) -def vars_to_val(vars): - if len(vars) == 1: - return vars[0] - return TupleExpression(vars) +def vars_to_val(vars_to_convert): + if len(vars_to_convert) == 1: + return vars_to_convert[0] + return TupleExpression(vars_to_convert) diff --git a/slither/tools/demo/__main__.py b/slither/tools/demo/__main__.py index 03dc984b6..d5c363ae2 100644 --- a/slither/tools/demo/__main__.py +++ b/slither/tools/demo/__main__.py @@ -1,4 +1,3 @@ -import os import argparse import logging from slither import Slither @@ -31,7 +30,7 @@ def main(): args = parse_args() # Perform slither analysis on the given filename - slither = Slither(args.filename, **vars(args)) + _slither = Slither(args.filename, **vars(args)) logger.info("Analysis done!") diff --git a/slither/tools/erc_conformance/__main__.py b/slither/tools/erc_conformance/__main__.py index bb98842e2..9eb977ff6 100644 --- a/slither/tools/erc_conformance/__main__.py +++ b/slither/tools/erc_conformance/__main__.py @@ -2,8 +2,8 @@ import argparse import logging from collections import defaultdict -from slither import Slither from crytic_compile import cryticparser +from slither import Slither from slither.utils.erc import ERCS from slither.utils.output import output_to_json from .erc.ercs import generic_erc_checks @@ -31,7 +31,8 @@ def parse_args(): :return: Returns the arguments for the program. """ parser = argparse.ArgumentParser( - description="Check the ERC 20 conformance", usage="slither-check-erc project contractName" + description="Check the ERC 20 conformance", + usage="slither-check-erc project contractName", ) parser.add_argument("project", help="The codebase to be tested.") diff --git a/slither/tools/erc_conformance/erc/erc20.py b/slither/tools/erc_conformance/erc/erc20.py index 720b08322..d5c69a036 100644 --- a/slither/tools/erc_conformance/erc/erc20.py +++ b/slither/tools/erc_conformance/erc/erc20.py @@ -6,7 +6,9 @@ logger = logging.getLogger("Slither-conformance") def approval_race_condition(contract, ret): - increaseAllowance = contract.get_function_from_signature("increaseAllowance(address,uint256)") + increaseAllowance = contract.get_function_from_signature( + "increaseAllowance(address,uint256)" + ) if not increaseAllowance: increaseAllowance = contract.get_function_from_signature( diff --git a/slither/tools/erc_conformance/erc/ercs.py b/slither/tools/erc_conformance/erc/ercs.py index e297d7918..5ba0c1db7 100644 --- a/slither/tools/erc_conformance/erc/ercs.py +++ b/slither/tools/erc_conformance/erc/ercs.py @@ -2,11 +2,15 @@ import logging from slither.slithir.operations import EventCall from slither.utils import output -from slither.utils.type import export_nested_types_from_variable, export_return_type_from_variable +from slither.utils.type import ( + export_nested_types_from_variable, + export_return_type_from_variable, +) logger = logging.getLogger("Slither-conformance") +# pylint: disable=too-many-locals,too-many-branches,too-many-statements def _check_signature(erc_function, contract, ret): name = erc_function.name parameters = erc_function.parameters @@ -22,10 +26,10 @@ def _check_signature(erc_function, contract, ret): # The check on state variable is needed until we have a better API to handle state variable getters state_variable_as_function = contract.get_state_variable_from_name(name) - if not state_variable_as_function or not state_variable_as_function.visibility in [ - "public", - "external", - ]: + if ( + not state_variable_as_function + or not state_variable_as_function.visibility in ["public", "external",] + ): txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' logger.info(txt) missing_func = output.Output( @@ -35,7 +39,10 @@ def _check_signature(erc_function, contract, ret): ret["missing_function"].append(missing_func.data) return - types = [str(x) for x in export_nested_types_from_variable(state_variable_as_function)] + types = [ + str(x) + for x in export_nested_types_from_variable(state_variable_as_function) + ] if types != parameters: txt = f'[ ] {sig} is missing {"" if required else "(optional)"}' @@ -47,13 +54,15 @@ def _check_signature(erc_function, contract, ret): ret["missing_function"].append(missing_func.data) return - function_return_type = [export_return_type_from_variable(state_variable_as_function)] + function_return_type = [ + export_return_type_from_variable(state_variable_as_function) + ] function = state_variable_as_function function_view = True else: - function_return_type = function.return_type - function_view = function.view + function_return_type = function.return_type # pylint: disable=no-member + function_view = function.view # pylint: disable=no-member txt = f"[✓] {sig} is present" logger.info(txt) @@ -106,7 +115,7 @@ def _check_signature(erc_function, contract, ret): should_be_view.add(function) ret["should_be_view"].append(should_be_view.data) - if events: + if events: # pylint: disable=too-many-nested-blocks for event in events: event_sig = f'{event.name}({",".join(event.parameters)})' @@ -171,7 +180,9 @@ def _check_events(erc_event, contract, ret): txt = f"\t[ ] parameter {i} should be indexed" logger.info(txt) - missing_event_index = output.Output(txt, additional_fields={"missing_index": i}) + missing_event_index = output.Output( + txt, additional_fields={"missing_index": i} + ) missing_event_index.add_event(event) ret["missing_event_index"].append(missing_event_index.data) @@ -185,10 +196,10 @@ def generic_erc_checks(contract, erc_functions, erc_events, ret, explored=None): logger.info(f"# Check {contract.name}\n") - logger.info(f"## Check functions") + logger.info("## Check functions") for erc_function in erc_functions: _check_signature(erc_function, contract, ret) - logger.info(f"\n## Check events") + logger.info("\n## Check events") for erc_event in erc_events: _check_events(erc_event, contract, ret) diff --git a/slither/tools/flattening/__main__.py b/slither/tools/flattening/__main__.py index 0639c481b..4487dea27 100644 --- a/slither/tools/flattening/__main__.py +++ b/slither/tools/flattening/__main__.py @@ -28,20 +28,24 @@ def parse_args(): usage="slither-flat filename", ) - parser.add_argument("filename", help="The filename of the contract or project to analyze.") + parser.add_argument( + "filename", help="The filename of the contract or project to analyze." + ) parser.add_argument("--contract", help="Flatten one contract.", default=None) parser.add_argument( "--strategy", help=f"Flatenning strategy: {STRATEGIES_NAMES} (default: MostDerived).", - default=Strategy.MostDerived.name, + default=Strategy.MostDerived.name, # pylint: disable=no-member ) group_export = parser.add_argument_group("Export options") group_export.add_argument( - "--dir", help=f"Export directory (default: {DEFAULT_EXPORT_PATH}).", default=None + "--dir", + help=f"Export directory (default: {DEFAULT_EXPORT_PATH}).", + default=None, ) group_export.add_argument( @@ -52,7 +56,10 @@ def parse_args(): ) parser.add_argument( - "--zip", help="Export all the files to a zip file", action="store", default=None, + "--zip", + help="Export all the files to a zip file", + action="store", + default=None, ) parser.add_argument( @@ -69,7 +76,9 @@ def parse_args(): ) group_patching.add_argument( - "--convert-private", help="Convert private variables to internal.", action="store_true" + "--convert-private", + help="Convert private variables to internal.", + action="store_true", ) group_patching.add_argument( @@ -109,9 +118,8 @@ def main(): try: strategy = Strategy[args.strategy] except KeyError: - logger.error( - f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)" - ) + to_log = f"{args.strategy} is not a valid strategy, use: {STRATEGIES_NAMES} (default MostDerived)" + logger.error(to_log) return flat.export( strategy=strategy, diff --git a/slither/tools/flattening/export/export.py b/slither/tools/flattening/export/export.py index 94e13d78a..75b9ac941 100644 --- a/slither/tools/flattening/export/export.py +++ b/slither/tools/flattening/export/export.py @@ -24,7 +24,9 @@ def save_to_zip(files: List[Export], zip_filename: str, zip_type: str = "lzma"): """ logger.info(f"Export {zip_filename}") with zipfile.ZipFile( - zip_filename, "w", compression=ZIP_TYPES_ACCEPTED.get(zip_type, zipfile.ZIP_LZMA) + zip_filename, + "w", + compression=ZIP_TYPES_ACCEPTED.get(zip_type, zipfile.ZIP_LZMA), ) as file_desc: for f in files: file_desc.writestr(str(f.filename), f.content) diff --git a/slither/tools/flattening/flattening.py b/slither/tools/flattening/flattening.py index 1ddcf42d5..2f99700f9 100644 --- a/slither/tools/flattening/flattening.py +++ b/slither/tools/flattening/flattening.py @@ -12,7 +12,12 @@ from slither.core.solidity_types import MappingType, ArrayType from slither.core.solidity_types.user_defined_type import UserDefinedType from slither.exceptions import SlitherException from slither.slithir.operations import NewContract, TypeConversion, SolidityCall -from slither.tools.flattening.export.export import Export, export_as_json, save_to_zip, save_to_disk +from slither.tools.flattening.export.export import ( + Export, + export_as_json, + save_to_zip, + save_to_disk, +) logger = logging.getLogger("Slither-flattening") @@ -36,6 +41,7 @@ DEFAULT_EXPORT_PATH = Path("crytic-export/flattening") class Flattening: + # pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-locals,too-few-public-methods def __init__( self, slither, @@ -53,7 +59,9 @@ class Flattening: self._private_to_internal = private_to_internal self._pragma_solidity = pragma_solidity - self._export_path: Path = DEFAULT_EXPORT_PATH if export_path is None else Path(export_path) + self._export_path: Path = DEFAULT_EXPORT_PATH if export_path is None else Path( + export_path + ) self._check_abi_encoder_v2() @@ -71,7 +79,7 @@ class Flattening: self._use_abi_encoder_v2 = True return - def _get_source_code(self, contract: Contract): + def _get_source_code(self, contract: Contract): # pylint: disable=too-many-branches,too-many-statements """ Save the source code of the contract in self._source_codes Patch the source code @@ -79,7 +87,9 @@ class Flattening: :return: """ src_mapping = contract.source_mapping - content = self._slither.source_code[src_mapping["filename_absolute"]].encode("utf8") + content = self._slither.source_code[src_mapping["filename_absolute"]].encode( + "utf8" + ) start = src_mapping["start"] end = src_mapping["start"] + src_mapping["length"] @@ -97,21 +107,33 @@ class Flattening: ) attributes_end = f.returns_src.source_mapping["start"] attributes = content[attributes_start:attributes_end] - regex = re.search(r"((\sexternal)\s+)|(\sexternal)$|(\)external)$", attributes) + regex = re.search( + r"((\sexternal)\s+)|(\sexternal)$|(\)external)$", attributes + ) if regex: to_patch.append( - Patch(attributes_start + regex.span()[0] + 1, "public_to_external") + Patch( + attributes_start + regex.span()[0] + 1, + "public_to_external", + ) ) else: - raise SlitherException(f"External keyword not found {f.name} {attributes}") + raise SlitherException( + f"External keyword not found {f.name} {attributes}" + ) for var in f.parameters: if var.location == "calldata": calldata_start = var.source_mapping["start"] calldata_end = calldata_start + var.source_mapping["length"] - calldata_idx = content[calldata_start:calldata_end].find(" calldata ") + calldata_idx = content[calldata_start:calldata_end].find( + " calldata " + ) to_patch.append( - Patch(calldata_start + calldata_idx + 1, "calldata_to_memory") + Patch( + calldata_start + calldata_idx + 1, + "calldata_to_memory", + ) ) if self._private_to_internal: @@ -119,25 +141,34 @@ class Flattening: if variable.visibility == "private": print(variable.source_mapping) attributes_start = variable.source_mapping["start"] - attributes_end = attributes_start + variable.source_mapping["length"] + attributes_end = ( + attributes_start + variable.source_mapping["length"] + ) attributes = content[attributes_start:attributes_end] print(attributes) regex = re.search(r" private ", attributes) if regex: to_patch.append( - Patch(attributes_start + regex.span()[0] + 1, "private_to_internal") + Patch( + attributes_start + regex.span()[0] + 1, + "private_to_internal", + ) ) else: - raise SlitherException(f"private keyword not found {v.name} {attributes}") + raise SlitherException( + f"private keyword not found {variable.name} {attributes}" + ) if self._remove_assert: for function in contract.functions_and_modifiers_declared: for node in function.nodes: for ir in node.irs: - if isinstance(ir, SolidityCall) and ir.function == SolidityFunction( - "assert(bool)" - ): - to_patch.append(Patch(node.source_mapping["start"], "line_removal")) + if isinstance( + ir, SolidityCall + ) and ir.function == SolidityFunction("assert(bool)"): + to_patch.append( + Patch(node.source_mapping["start"], "line_removal") + ) logger.info( f"Code commented: {node.expression} ({node.source_mapping_str})" ) @@ -150,11 +181,17 @@ class Flattening: index = patch.index index = index - start if patch_type == "public_to_external": - content = content[:index] + "public" + content[index + len("external") :] + content = ( + content[:index] + "public" + content[index + len("external") :] + ) if patch_type == "private_to_internal": - content = content[:index] + "internal" + content[index + len("private") :] + content = ( + content[:index] + "internal" + content[index + len("private") :] + ) elif patch_type == "calldata_to_memory": - content = content[:index] + "memory" + content[index + len("calldata") :] + content = ( + content[:index] + "memory" + content[index + len("calldata") :] + ) else: assert patch_type == "line_removal" content = content[:index] + " // " + content[index:] @@ -180,7 +217,9 @@ class Flattening: if isinstance(t, UserDefinedType): if isinstance(t.type, (Enum, Structure)): if t.type.contract != contract and t.type.contract not in exported: - self._export_list_used_contracts(t.type.contract, exported, list_contract) + self._export_list_used_contracts( + t.type.contract, exported, list_contract + ) else: assert isinstance(t.type, Contract) if t.type != contract and t.type not in exported: @@ -191,7 +230,7 @@ class Flattening: elif isinstance(t, ArrayType): self._export_from_type(t.type, contract, exported, list_contract) - def _export_list_used_contracts( + def _export_list_used_contracts( # pylint: disable=too-many-branches self, contract: Contract, exported: Set[str], list_contract: List[Contract] ): if contract.name in exported: @@ -204,7 +243,7 @@ class Flattening: externals = contract.all_library_calls + contract.all_high_level_calls # externals is a list of (contract, function) # We also filter call to itself to avoid infilite loop - externals = list(set([e[0] for e in externals if e[0] != contract])) + externals = list({e[0] for e in externals if e[0] != contract}) for inherited in externals: self._export_list_used_contracts(inherited, exported, list_contract) @@ -225,7 +264,10 @@ class Flattening: for f in contract.functions_declared: for ir in f.slithir_operations: if isinstance(ir, NewContract): - if ir.contract_created != contract and not ir.contract_created in exported: + if ( + ir.contract_created != contract + and not ir.contract_created in exported + ): self._export_list_used_contracts( ir.contract_created, exported, list_contract ) @@ -242,8 +284,8 @@ class Flattening: content = "" content += self._pragmas() - for contract in list_contracts: - content += self._source_codes[contract] + for listed_contract in list_contracts: + content += self._source_codes[listed_contract] content += "\n" return Export(filename=path, content=content) @@ -255,7 +297,7 @@ class Flattening: return ret def _export_all(self) -> List[Export]: - path = Path(self._export_path, f"export.sol") + path = Path(self._export_path, "export.sol") content = "" content += self._pragmas() @@ -266,17 +308,17 @@ class Flattening: # We only need the inheritance order here, as solc can compile # a contract that use another contract type (ex: state variable) that he has not seen yet while contract_to_explore: - next = contract_to_explore.pop(0) + next_to_explore = contract_to_explore.pop(0) - if not next.inheritance or all( - (father in contract_seen for father in next.inheritance) + if not next_to_explore.inheritance or all( + (father in contract_seen for father in next_to_explore.inheritance) ): content += "\n" - content += self._source_codes[next] + content += self._source_codes[next_to_explore] content += "\n" - contract_seen.add(next) + contract_seen.add(next_to_explore) else: - contract_to_explore.append(next) + contract_to_explore.append(next_to_explore) return [Export(filename=path, content=content)] @@ -299,12 +341,12 @@ class Flattening: exports.append(Export(filename=path, content=content)) return exports - def export( + def export( # pylint: disable=too-many-arguments,too-few-public-methods self, strategy: Strategy, target: Optional[str] = None, json: Optional[str] = None, - zip: Optional[str] = None, + zip: Optional[str] = None, # pylint: disable=redefined-builtin zip_type: Optional[str] = None, ): diff --git a/slither/tools/kspec_coverage/__main__.py b/slither/tools/kspec_coverage/__main__.py index 33bd3a162..b6ce0f81b 100644 --- a/slither/tools/kspec_coverage/__main__.py +++ b/slither/tools/kspec_coverage/__main__.py @@ -1,9 +1,8 @@ import sys import logging import argparse -from slither import Slither -from .kspec_coverage import kspec_coverage from crytic_compile import cryticparser +from slither.tools.kspec_coverage.kspec_coverage import kspec_coverage logging.basicConfig() logger = logging.getLogger("Slither.kspec") @@ -23,18 +22,23 @@ def parse_args(): :return: Returns the arguments for the program. """ parser = argparse.ArgumentParser( - description="slither-kspec-coverage", usage="slither-kspec-coverage contract.sol kspec.md" + description="slither-kspec-coverage", + usage="slither-kspec-coverage contract.sol kspec.md", ) parser.add_argument( "contract", help="The filename of the contract or truffle directory to analyze." ) parser.add_argument( - "kspec", help="The filename of the Klab spec markdown for the analyzed contract(s)" + "kspec", + help="The filename of the Klab spec markdown for the analyzed contract(s)", ) parser.add_argument( - "--version", help="displays the current version", version="0.1.0", action="version" + "--version", + help="displays the current version", + version="0.1.0", + action="version", ) parser.add_argument( "--json", diff --git a/slither/tools/kspec_coverage/analysis.py b/slither/tools/kspec_coverage/analysis.py index 08e42ad76..ab8151203 100755 --- a/slither/tools/kspec_coverage/analysis.py +++ b/slither/tools/kspec_coverage/analysis.py @@ -10,8 +10,10 @@ logging.basicConfig(level=logging.WARNING) logger = logging.getLogger("Slither.kspec") -def _refactor_type(type): - return {"uint": "uint256", "int": "int256"}.get(type, type) +# pylint: disable=anomalous-backslash-in-string + +def _refactor_type(targeted_type): + return {"uint": "uint256", "int": "int256"}.get(targeted_type, targeted_type) def _get_all_covered_kspec_functions(target): @@ -35,12 +37,18 @@ def _get_all_covered_kspec_functions(target): match = INTERFACE_PATTERN.match(lines[i + 1]) if match: function_full_name = match.groups()[0] - start, end = function_full_name.index("(") + 1, function_full_name.index(")") + start, end = ( + function_full_name.index("(") + 1, + function_full_name.index(")"), + ) function_arguments = function_full_name[start:end].split(",") function_arguments = [ - _refactor_type(arg.strip().split(" ")[0]) for arg in function_arguments + _refactor_type(arg.strip().split(" ")[0]) + for arg in function_arguments ] - function_full_name = function_full_name[:start] + ",".join(function_arguments) + ")" + function_full_name = ( + function_full_name[:start] + ",".join(function_arguments) + ")" + ) covered_functions.add((contract_name, function_full_name)) i += 1 i += 1 @@ -53,16 +61,20 @@ def _get_slither_functions(slither): f for f in slither.functions if ( - f.contract == f.contract_declarer - and f.is_implemented - and not f.is_constructor - and not f.is_constructor_variables + f.contract == f.contract_declarer + and f.is_implemented + and not f.is_constructor + and not f.is_constructor_variables ) ] # Use list(set()) because same state variable instances can be shared accross contracts # TODO: integrate state variables all_functions_declared += list( - set([s for s in slither.state_variables if s.visibility in ["public", "external"]]) + { + s + for s in slither.state_variables + if s.visibility in ["public", "external"] + } ) slither_functions = { (function.contract.name, function.full_name): function @@ -95,7 +107,9 @@ def _generate_output_unresolved(kspec, message, color, generate_json): logger.info(color(info)) if generate_json: - json_kspec_present = output.Output(info, additional_fields={"signatures": kspec}) + json_kspec_present = output.Output( + info, additional_fields={"signatures": kspec} + ) return json_kspec_present.data return None @@ -152,15 +166,15 @@ def _run_coverage_analysis(args, slither, kspec_functions): ) -def run_analysis(args, slither, kspec): +def run_analysis(args, slither, kspec_arg): # Get all of our kspec'd functions (tuple(contract_name, function_name)). - if "," in kspec: - kspecs = kspec.split(",") + if "," in kspec_arg: + kspecs = kspec_arg.split(",") kspec_functions = set() for kspec in kspecs: kspec_functions |= _get_all_covered_kspec_functions(kspec) else: - kspec_functions = _get_all_covered_kspec_functions(kspec) + kspec_functions = _get_all_covered_kspec_functions(kspec_arg) # Run coverage analysis _run_coverage_analysis(args, slither, kspec_functions) diff --git a/slither/tools/possible_paths/__main__.py b/slither/tools/possible_paths/__main__.py index c13a3f390..e6940fcab 100644 --- a/slither/tools/possible_paths/__main__.py +++ b/slither/tools/possible_paths/__main__.py @@ -1,10 +1,15 @@ -import os import argparse -from slither import Slither -from slither.utils.colors import red +import sys + import logging -from .possible_paths import find_target_paths, resolve_functions, ResolveFunctionException from crytic_compile import cryticparser +from slither import Slither +from slither.utils.colors import red +from slither.tools.possible_paths.possible_paths import ( + find_target_paths, + resolve_functions, + ResolveFunctionException, +) logging.basicConfig() logging.getLogger("Slither").setLevel(logging.INFO) @@ -16,7 +21,8 @@ def parse_args(): :return: Returns the arguments for the program. """ parser = argparse.ArgumentParser( - description="PossiblePaths", usage="possible_paths.py filename [contract.function targets]" + description="PossiblePaths", + usage="possible_paths.py filename [contract.function targets]", ) parser.add_argument( @@ -44,22 +50,22 @@ def main(): try: targets = resolve_functions(slither, args.targets) - except ResolveFunctionException as r: - print(red(r)) - exit(-1) + except ResolveFunctionException as resolvefunction: + print(red(resolvefunction)) + sys.exit(-1) # Print out all target functions. - print(f"Target functions:") + print("Target functions:") for target in targets: print(f"- {target.contract_declarer.name}.{target.full_name}") print("\n") # Obtain all paths which reach the target functions. reaching_paths = find_target_paths(slither, targets) - reaching_functions = set([y for x in reaching_paths for y in x if y not in targets]) + reaching_functions = {y for x in reaching_paths for y in x if y not in targets} # Print out all function names which can reach the targets. - print(f"The following functions reach the specified targets:") + print("The following functions reach the specified targets:") for function_desc in sorted([f"{f.canonical_name}" for f in reaching_functions]): print(f"- {function_desc}") print("\n") @@ -71,7 +77,7 @@ def main(): ] # Print a sorted list of all function paths which can reach the targets. - print(f"The following paths reach the specified targets:") + print("The following paths reach the specified targets:") for reaching_path in sorted(reaching_paths_str): print(f"{reaching_path}\n") diff --git a/slither/tools/possible_paths/possible_paths.py b/slither/tools/possible_paths/possible_paths.py index 8137f3d1b..558b97e3a 100644 --- a/slither/tools/possible_paths/possible_paths.py +++ b/slither/tools/possible_paths/possible_paths.py @@ -14,11 +14,14 @@ def resolve_function(slither, contract_name, function_name): # Verify the contract was resolved successfully if contract is None: - raise ResolveFunctionException(f"Could not resolve target contract: {contract_name}") + raise ResolveFunctionException( + f"Could not resolve target contract: {contract_name}" + ) # Obtain the target function target_function = next( - (function for function in contract.functions if function.name == function_name), None + (function for function in contract.functions if function.name == function_name), + None, ) # Verify we have resolved the function specified. @@ -43,7 +46,9 @@ def resolve_functions(slither, functions): # Verify that the provided argument is a list. if not isinstance(functions, list): - raise ResolveFunctionException("Provided functions to resolve must be a list type.") + raise ResolveFunctionException( + "Provided functions to resolve must be a list type." + ) # Loop for each item in the list. for item in functions: @@ -85,8 +90,8 @@ def all_function_definitions(function): ] -def __find_target_paths(slither, target_function, current_path=[]): - +def __find_target_paths(slither, target_function, current_path=None): + current_path = current_path if current_path else [] # Create our results list results = set() @@ -105,13 +110,17 @@ def __find_target_paths(slither, target_function, current_path=[]): continue # Find all function calls in this function (except for low level) - called_functions = [f for (_, f) in function.high_level_calls + function.library_calls] + called_functions = [ + f for (_, f) in function.high_level_calls + function.library_calls + ] called_functions += function.internal_calls called_functions = set(called_functions) # If any of our target functions are reachable from this function, it's a result. if all_target_functions.intersection(called_functions): - path_results = __find_target_paths(slither, function, current_path.copy()) + path_results = __find_target_paths( + slither, function, current_path.copy() + ) if path_results: results = results.union(path_results) diff --git a/slither/tools/properties/__main__.py b/slither/tools/properties/__main__.py index 25685fb6b..9df407151 100644 --- a/slither/tools/properties/__main__.py +++ b/slither/tools/properties/__main__.py @@ -2,12 +2,16 @@ import argparse import logging import sys -from slither import Slither from crytic_compile import cryticparser -from slither.tools.properties.addresses.address import Addresses +from slither import Slither from slither.tools.properties.properties.erc20 import generate_erc20, ERC20_PROPERTIES -from slither.tools.properties.addresses.address import OWNER_ADDRESS, USER_ADDRESS, ATTACKER_ADDRESS +from slither.tools.properties.addresses.address import ( + Addresses, + OWNER_ADDRESS, + USER_ADDRESS, + ATTACKER_ADDRESS, +) from slither.utils.myprettytable import MyPrettyTable logging.basicConfig() @@ -41,14 +45,14 @@ def _all_properties(): return table -class ListScenarios(argparse.Action): - def __call__(self, parser, *args, **kwargs): +class ListScenarios(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs logger.info(_all_scenarios()) parser.exit() -class ListProperties(argparse.Action): - def __call__(self, parser, *args, **kwargs): +class ListProperties(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs logger.info(_all_properties()) parser.exit() @@ -72,7 +76,7 @@ def parse_args(): parser.add_argument( "--scenario", - help=f"Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable", + help="Test a specific scenario. Use --list-scenarios to see the available scenarios. Default Transferable", default="Transferable", ) @@ -101,7 +105,9 @@ def parse_args(): ) parser.add_argument( - "--address-attacker", help=f"Attacker address. Default {ATTACKER_ADDRESS}", default=None + "--address-attacker", + help=f"Attacker address. Default {ATTACKER_ADDRESS}", + default=None, ) # Add default arguments from crytic-compile @@ -126,9 +132,10 @@ def main(): contract = slither.contracts[0] else: if args.contract is None: - logger.error(f"Specify the target: --contract ContractName") + to_log = "Specify the target: --contract ContractName" else: - logger.error(f"{args.contract} not found") + to_log = f"{args.contract} not found" + logger.error(to_log) return addresses = Addresses(args.address_owner, args.address_user, args.address_attacker) diff --git a/slither/tools/properties/addresses/address.py b/slither/tools/properties/addresses/address.py index 2068bca23..cf5aec2ee 100644 --- a/slither/tools/properties/addresses/address.py +++ b/slither/tools/properties/addresses/address.py @@ -7,7 +7,7 @@ USER_ADDRESS = "0xf17f52151EbEF6C7334FAD080c5704D77216b732" ATTACKER_ADDRESS = "0xC5fdf4076b8F3A5357c5E395ab970B5B54098Fef" -class Addresses: +class Addresses: # pylint: disable=too-few-public-methods def __init__( self, owner: Optional[str] = None, diff --git a/slither/tools/properties/platforms/truffle.py b/slither/tools/properties/platforms/truffle.py index 48627fab4..028694628 100644 --- a/slither/tools/properties/platforms/truffle.py +++ b/slither/tools/properties/platforms/truffle.py @@ -4,7 +4,11 @@ from pathlib import Path from typing import List from slither.tools.properties.addresses.address import Addresses -from slither.tools.properties.properties.properties import PropertyReturn, Property, PropertyCaller +from slither.tools.properties.properties.properties import ( + PropertyReturn, + Property, + PropertyCaller, +) from slither.tools.properties.utils import write_file PATTERN_TRUFFLE_MIGRATION = re.compile("^[0-9]*_") @@ -64,7 +68,7 @@ async function catchRevertThrow(promise) { """ -def generate_unit_test( +def generate_unit_test( # pylint: disable=too-many-arguments,too-many-branches test_contract: str, filename: str, unit_tests: List[Property], @@ -99,14 +103,18 @@ def generate_unit_test( for caller in callers: content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n" if assert_message: - content += f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n' + content += ( + f'\t\tassert.equal(test_{caller}, true, "{assert_message}");\n' + ) else: content += f"\t\tassert.equal(test_{caller}, true);\n" elif unit_test.return_type == PropertyReturn.FAIL: for caller in callers: content += f"\t\tlet test_{caller} = await instance.{unit_test.name[:-2]}.call({{from: {caller}}});\n" if assert_message: - content += f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n' + content += ( + f'\t\tassert.equal(test_{caller}, false, "{assert_message}");\n' + ) else: content += f"\t\tassert.equal(test_{caller}, false);\n" elif unit_test.return_type == PropertyReturn.FAIL_OR_THROW: diff --git a/slither/tools/properties/properties/erc20.py b/slither/tools/properties/properties/erc20.py index 31f6ff8cc..3b94e7387 100644 --- a/slither/tools/properties/properties/erc20.py +++ b/slither/tools/properties/properties/erc20.py @@ -9,9 +9,15 @@ from crytic_compile.platform import Type as PlatformType from slither.core.declarations import Contract from slither.tools.properties.addresses.address import Addresses from slither.tools.properties.platforms.echidna import generate_echidna_config -from slither.tools.properties.properties.ercs.erc20.properties.burn import ERC20_NotBurnable -from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG -from slither.tools.properties.properties.ercs.erc20.properties.mint import ERC20_NotMintable +from slither.tools.properties.properties.ercs.erc20.properties.burn import ( + ERC20_NotBurnable, +) +from slither.tools.properties.properties.ercs.erc20.properties.initialization import ( + ERC20_CONFIG, +) +from slither.tools.properties.properties.ercs.erc20.properties.mint import ( + ERC20_NotMintable, +) from slither.tools.properties.properties.ercs.erc20.properties.mint_and_burn import ( ERC20_NotMintableNotBurnable, ) @@ -19,8 +25,13 @@ from slither.tools.properties.properties.ercs.erc20.properties.transfer import ( ERC20_Transferable, ERC20_Pausable, ) -from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import generate_truffle_test -from slither.tools.properties.properties.properties import property_to_solidity, Property +from slither.tools.properties.properties.ercs.erc20.unit_tests.truffle import ( + generate_truffle_test, +) +from slither.tools.properties.properties.properties import ( + property_to_solidity, + Property, +) from slither.tools.properties.solidity.generate_properties import ( generate_solidity_properties, generate_test_contract, @@ -33,15 +44,22 @@ logger = logging.getLogger("Slither") PropertyDescription = namedtuple("PropertyDescription", ["properties", "description"]) ERC20_PROPERTIES = { - "Transferable": PropertyDescription(ERC20_Transferable, "Test the correct tokens transfer"), + "Transferable": PropertyDescription( + ERC20_Transferable, "Test the correct tokens transfer" + ), "Pausable": PropertyDescription(ERC20_Pausable, "Test the pausable functionality"), - "NotMintable": PropertyDescription(ERC20_NotMintable, "Test that no one can mint tokens"), + "NotMintable": PropertyDescription( + ERC20_NotMintable, "Test that no one can mint tokens" + ), "NotMintableNotBurnable": PropertyDescription( ERC20_NotMintableNotBurnable, "Test that no one can mint or burn tokens" ), - "NotBurnable": PropertyDescription(ERC20_NotBurnable, "Test that no one can burn tokens"), + "NotBurnable": PropertyDescription( + ERC20_NotBurnable, "Test that no one can burn tokens" + ), "Burnable": PropertyDescription( - ERC20_NotBurnable, 'Test the burn of tokens. Require the "burn(address) returns()" function' + ERC20_NotBurnable, + 'Test the burn of tokens. Require the "burn(address) returns()" function', ), } @@ -63,8 +81,13 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) :param type_property: One of ERC20_PROPERTIES.keys() :return: """ - if contract.slither.crytic_compile.type not in [PlatformType.TRUFFLE, PlatformType.SOLC]: - logging.error(f"{contract.slither.crytic_compile.type} not yet supported by slither-prop") + if contract.slither.crytic_compile.type not in [ + PlatformType.TRUFFLE, + PlatformType.SOLC, + ]: + logging.error( + f"{contract.slither.crytic_compile.type} not yet supported by slither-prop" + ) return # Check if the contract is an ERC20 contract and if the functions have the correct visibility @@ -76,7 +99,7 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) properties = ERC20_PROPERTIES.get(type_property, None) if properties is None: logger.error( - f"{type_property} unknown. Types available {[x for x in ERC20_PROPERTIES.keys()]}" + f"{type_property} unknown. Types available {ERC20_PROPERTIES.keys()}" ) return properties = properties.properties @@ -97,7 +120,11 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) # Generate the Test contract initialization_recommendation = _initialization_recommendation(type_property) contract_filename, contract_name = generate_test_contract( - contract, type_property, output_dir, property_file, initialization_recommendation + contract, + type_property, + output_dir, + property_file, + initialization_recommendation, ) # Generate Echidna config file @@ -109,10 +136,14 @@ def generate_erc20(contract: Contract, type_property: str, addresses: Addresses) # If truffle, generate unit tests if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: - unit_test_info = generate_truffle_test(contract, type_property, unit_tests, addresses) + unit_test_info = generate_truffle_test( + contract, type_property, unit_tests, addresses + ) logger.info("################################################") - logger.info(green(f"Update the constructor in {Path(output_dir, contract_filename)}")) + logger.info( + green(f"Update the constructor in {Path(output_dir, contract_filename)}") + ) if unit_test_info: logger.info(green(unit_test_info)) @@ -147,6 +178,7 @@ def _platform_to_output_dir(platform: AbstractPlatform) -> Path: return Path(platform.target, "contracts", "crytic") if platform.TYPE == PlatformType.SOLC: return Path(platform.target).parent + return Path() def _check_compatibility(contract): @@ -160,7 +192,9 @@ def _check_compatibility(contract): if transfer.visibility != "public": errors = f"slither-prop requires {transfer.canonical_name} to be public. Please change the visibility" - transfer_from = contract.get_function_from_signature("transferFrom(address,address,uint256)") + transfer_from = contract.get_function_from_signature( + "transferFrom(address,address,uint256)" + ) if transfer_from.visibility != "public": if errors: errors += "\n" @@ -179,7 +213,9 @@ def _get_properties(contract, properties: List[Property]) -> Tuple[str, List[Pro solidity_properties = "" if contract.slither.crytic_compile.type == PlatformType.TRUFFLE: - solidity_properties += "\n".join([property_to_solidity(p) for p in ERC20_CONFIG]) + solidity_properties += "\n".join( + [property_to_solidity(p) for p in ERC20_CONFIG] + ) solidity_properties += "\n".join([property_to_solidity(p) for p in properties]) unit_tests = [p for p in properties if p.is_unit_test] diff --git a/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py b/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py index 614447b4d..a44d7161a 100644 --- a/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py +++ b/slither/tools/properties/properties/ercs/erc20/unit_tests/truffle.py @@ -4,15 +4,23 @@ from typing import List from slither.core.declarations import Contract from slither.tools.properties.addresses.address import Addresses -from slither.tools.properties.platforms.truffle import generate_migration, generate_unit_test -from slither.tools.properties.properties.ercs.erc20.properties.initialization import ERC20_CONFIG +from slither.tools.properties.platforms.truffle import ( + generate_migration, + generate_unit_test, +) +from slither.tools.properties.properties.ercs.erc20.properties.initialization import ( + ERC20_CONFIG, +) from slither.tools.properties.properties.properties import Property logger = logging.getLogger("Slither") def generate_truffle_test( - contract: Contract, type_property: str, unit_tests: List[Property], addresses: Addresses + contract: Contract, + type_property: str, + unit_tests: List[Property], + addresses: Addresses, ) -> str: test_contract = f"Test{contract.name}{type_property}" filename_init = f"Initialization{test_contract}.js" diff --git a/slither/tools/properties/solidity/generate_properties.py b/slither/tools/properties/solidity/generate_properties.py index 006ab24c3..fcc7aa1de 100644 --- a/slither/tools/properties/solidity/generate_properties.py +++ b/slither/tools/properties/solidity/generate_properties.py @@ -13,14 +13,12 @@ def generate_solidity_properties( contract: Contract, type_property: str, solidity_properties: str, output_dir: Path ) -> Path: - solidity_import = f'import "./interfaces.sol";\n' + solidity_import = 'import "./interfaces.sol";\n' solidity_import += f'import "../{contract.source_mapping["filename_short"]}";' test_contract_name = f"Properties{contract.name}{type_property}" - solidity_content = ( - f"{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}" - ) + solidity_content = f"{solidity_import}\ncontract {test_contract_name} is CryticInterface,{contract.name}" solidity_content += f"{{\n\n{solidity_properties}\n}}\n" filename = f"{test_contract_name}.sol" @@ -44,7 +42,9 @@ def generate_test_contract( content += f"contract {test_contract_name} is {properties_name} {{\n" content += "\tconstructor() public{\n" content += "\t\t// Existing addresses:\n" - content += "\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n" + content += ( + "\t\t// - crytic_owner: If the contract has an owner, it must be crytic_owner\n" + ) content += "\t\t// - crytic_user: Legitimate user\n" content += "\t\t// - crytic_attacker: Attacker\n" content += "\t\t// \n" diff --git a/slither/tools/properties/utils.py b/slither/tools/properties/utils.py index 541d85712..ef70aca2c 100644 --- a/slither/tools/properties/utils.py +++ b/slither/tools/properties/utils.py @@ -27,7 +27,9 @@ def write_file( if discard_if_exist: return if not allow_overwrite: - logger.info(yellow(f"{file_to_write} already exist and will not be overwritten")) + logger.info( + yellow(f"{file_to_write} already exist and will not be overwritten") + ) return logger.info(yellow(f"Overwrite {file_to_write}")) else: diff --git a/slither/tools/similarity/__main__.py b/slither/tools/similarity/__main__.py index 85f837115..90000c989 100755 --- a/slither/tools/similarity/__main__.py +++ b/slither/tools/similarity/__main__.py @@ -3,15 +3,13 @@ import argparse import logging import sys -import traceback -import operator from crytic_compile import cryticparser -from .info import info -from .test import test -from .train import train -from .plot import plot +from slither.tools.similarity.info import info +from slither.tools.similarity.test import test +from slither.tools.similarity.train import train +from slither.tools.similarity.plot import plot logging.basicConfig() logger = logging.getLogger("Slither-simil") @@ -28,11 +26,15 @@ def parse_args(): parser.add_argument("model", help="model.bin") - parser.add_argument("--filename", action="store", dest="filename", help="contract.sol") + parser.add_argument( + "--filename", action="store", dest="filename", help="contract.sol" + ) parser.add_argument("--fname", action="store", dest="fname", help="Target function") - parser.add_argument("--ext", action="store", dest="ext", help="Extension to filter contracts") + parser.add_argument( + "--ext", action="store", dest="ext", help="Extension to filter contracts" + ) parser.add_argument( "--nsamples", @@ -56,7 +58,10 @@ def parse_args(): ) parser.add_argument( - "--version", help="displays the current version", version="0.0", action="version" + "--version", + help="displays the current version", + version="0.0", + action="version", ) cryticparser.init(parser) @@ -94,7 +99,8 @@ def main(): elif mode == "plot": plot(args) else: - logger.error("Invalid mode!. It should be one of these: %s" % ", ".join(modes)) + to_log = "Invalid mode!. It should be one of these: %s" % ", ".join(modes) + logger.error(to_log) sys.exit(-1) diff --git a/slither/tools/similarity/encode.py b/slither/tools/similarity/encode.py index 615ed98ca..f83a9636c 100644 --- a/slither/tools/similarity/encode.py +++ b/slither/tools/similarity/encode.py @@ -9,9 +9,16 @@ from slither.core.declarations import ( SolidityVariable, Function, ) -from slither.core.solidity_types import ElementaryType, ArrayType, MappingType, UserDefinedType +from slither.core.solidity_types import ( + ElementaryType, + ArrayType, + MappingType, + UserDefinedType, +) from slither.core.variables.local_variable import LocalVariable -from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.core.variables.local_variable_init_from_tuple import ( + LocalVariableInitFromTuple, +) from slither.core.variables.state_variable import StateVariable from slither.slithir.operations import ( Assignment, @@ -42,7 +49,12 @@ from slither.slithir.operations import ( InitArray, InternalCall, ) -from slither.slithir.variables import TemporaryVariable, TupleVariable, Constant, ReferenceVariable +from slither.slithir.variables import ( + TemporaryVariable, + TupleVariable, + Constant, + ReferenceVariable, +) from .cache import load_cache simil_logger = logging.getLogger("Slither-simil") @@ -59,10 +71,12 @@ def parse_target(target): parts = target.split(".") if len(parts) == 1: return None, parts[0] - elif len(parts) == 2: + if len(parts) == 2: return parts - else: - simil_logger.error("Invalid target. It should be 'function' or 'Contract.function'") + simil_logger.error( + "Invalid target. It should be 'function' or 'Contract.function'" + ) + return None def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): @@ -80,7 +94,7 @@ def load_and_encode(infile, vmodel, ext=None, nsamples=None, **kwargs): return r -def load_contracts(dirname, ext=None, nsamples=None, **kwargs): +def load_contracts(dirname, ext=None, nsamples=None): r = [] walk = list(os.walk(dirname)) for x, y, files in walk: @@ -90,12 +104,12 @@ def load_contracts(dirname, ext=None, nsamples=None, **kwargs): if nsamples is None: return r - else: - # TODO: shuffle - return r[:nsamples] + + # TODO: shuffle + return r[:nsamples] -def ntype(_type): +def ntype(_type): # pylint: disable=too-many-branches if isinstance(_type, ElementaryType): _type = str(_type) elif isinstance(_type, ArrayType): @@ -119,24 +133,23 @@ def ntype(_type): if "struct" in _type: return "struct" - elif "enum" in _type: + if "enum" in _type: return "enum" - elif "tuple" in _type: + if "tuple" in _type: return "tuple" - elif "contract" in _type: + if "contract" in _type: return "contract" - elif "mapping" in _type: + if "mapping" in _type: return "mapping" - else: - return _type.replace(" ", "_") + return _type.replace(" ", "_") -def encode_ir(ir): +def encode_ir(ir): # pylint: disable=too-many-branches # operations if isinstance(ir, Assignment): return "({}):=({})".format(encode_ir(ir.lvalue), encode_ir(ir.rvalue)) if isinstance(ir, Index): - return "index({})".format(ntype(ir._type)) + return "index({})".format(ntype(ir.index_type)) if isinstance(ir, Member): return "member" # .format(ntype(ir._type)) if isinstance(ir, Length): @@ -154,9 +167,9 @@ def encode_ir(ir): if isinstance(ir, NewContract): return "new_contract" if isinstance(ir, NewArray): - return "new_array({})".format(ntype(ir._array_type)) + return "new_array({})".format(ntype(ir.array_type)) if isinstance(ir, NewElementaryType): - return "new_elementary({})".format(ntype(ir._type)) + return "new_elementary({})".format(ntype(ir.type)) if isinstance(ir, Push): return "push({},{})".format(encode_ir(ir.value), encode_ir(ir.lvalue)) if isinstance(ir, Delete): @@ -164,7 +177,7 @@ def encode_ir(ir): if isinstance(ir, SolidityCall): return "solidity_call({})".format(ir.function.full_name) if isinstance(ir, InternalCall): - return "internal_call({})".format(ntype(ir._type_call)) + return "internal_call({})".format(ntype(ir.type_call)) if isinstance(ir, EventCall): # is this useful? return "event" if isinstance(ir, LibraryCall): @@ -192,7 +205,7 @@ def encode_ir(ir): # variables if isinstance(ir, Constant): - return "constant({})".format(ntype(ir._type)) + return "constant({})".format(ntype(ir.type)) if isinstance(ir, SolidityVariableComposed): return "solidity_variable_composed({})".format(ir.name) if isinstance(ir, SolidityVariable): @@ -200,20 +213,19 @@ def encode_ir(ir): if isinstance(ir, TemporaryVariable): return "temporary_variable" if isinstance(ir, ReferenceVariable): - return "reference({})".format(ntype(ir._type)) + return "reference({})".format(ntype(ir.type)) if isinstance(ir, LocalVariable): - return "local_solc_variable({})".format(ir._location) + return "local_solc_variable({})".format(ir.location) if isinstance(ir, StateVariable): - return "state_solc_variable({})".format(ntype(ir._type)) + return "state_solc_variable({})".format(ntype(ir.type)) if isinstance(ir, LocalVariableInitFromTuple): return "local_variable_init_tuple" if isinstance(ir, TupleVariable): return "tuple_variable" # default - else: - simil_logger.error(type(ir), "is missing encoding!") - return "" + simil_logger.error(type(ir), "is missing encoding!") + return "" def encode_contract(cfilename, **kwargs): @@ -222,8 +234,10 @@ def encode_contract(cfilename, **kwargs): # Init slither try: slither = Slither(cfilename, **kwargs) - except: - simil_logger.error("Compilation failed for %s using %s", cfilename, kwargs["solc"]) + except Exception: # pylint: disable=broad-except + simil_logger.error( + "Compilation failed for %s using %s", cfilename, kwargs["solc"] + ) return r # Iterate over all the contracts diff --git a/slither/tools/similarity/info.py b/slither/tools/similarity/info.py index b577bfd93..f39cf1773 100644 --- a/slither/tools/similarity/info.py +++ b/slither/tools/similarity/info.py @@ -22,7 +22,6 @@ def info(args): filename = args.filename contract, fname = parse_target(args.fname) - solc = args.solc if filename is None and contract is None and fname is None: logger.info("%s uses the following words:", args.model) @@ -31,7 +30,9 @@ def info(args): sys.exit(0) if filename is None or contract is None or fname is None: - logger.error("The encode mode requires filename, contract and fname parameters.") + logger.error( + "The encode mode requires filename, contract and fname parameters." + ) sys.exit(-1) irs = encode_contract(filename, **vars(args)) @@ -41,13 +42,15 @@ def info(args): x = (filename, contract, fname) y = " ".join(irs[x]) - logger.info("Function {} in contract {} is encoded as:".format(fname, contract)) + to_log = "Function {} in contract {} is encoded as:".format(fname, contract) + logger.info(to_log) logger.info(y) if model is not None: fvector = model.get_sentence_vector(y) logger.info(fvector) - except Exception: - logger.error("Error in %s" % args.filename) + except Exception: # pylint: disable=broad-except + to_log = "Error in %s" % args.filename + logger.error(to_log) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/similarity/model.py b/slither/tools/similarity/model.py index 4f3412113..88f06c7b3 100644 --- a/slither/tools/similarity/model.py +++ b/slither/tools/similarity/model.py @@ -1,9 +1,12 @@ import sys try: + # pylint: disable=unused-import from fastText import load_model from fastText import train_unsupervised except ImportError: print("ERROR: in order to use slither-simil, you need to install fastText 0.2.0:") - print("$ pip3 install https://github.com/facebookresearch/fastText/archive/0.2.0.zip --user\n") + print( + "$ pip3 install https://github.com/facebookresearch/fastText/archive/0.2.0.zip --user\n" + ) sys.exit(-1) diff --git a/slither/tools/similarity/plot.py b/slither/tools/similarity/plot.py index 75ef90c15..efdf8add9 100644 --- a/slither/tools/similarity/plot.py +++ b/slither/tools/similarity/plot.py @@ -1,12 +1,19 @@ import logging +import random import sys import traceback -import operator -import numpy as np -import random -from .model import load_model -from .encode import load_and_encode, parse_target +try: + import numpy as np +except ImportError: + print("ERROR: in order to use slither-simil, you need to install numpy:") + print( + "$ pip3 install numpy --user\n" + ) + sys.exit(-1) + +from slither.tools.similarity.encode import load_and_encode, parse_target +from slither.tools.similarity.model import load_model try: from sklearn import decomposition @@ -18,7 +25,7 @@ except ImportError: logger = logging.getLogger("Slither-simil") -def plot(args): +def plot(args): # pylint: disable=too-many-locals if decomposition is None or plt is None: logger.error( @@ -31,7 +38,6 @@ def plot(args): model = args.model model = load_model(model) - filename = args.filename # contract = args.contract contract, fname = parse_target(args.fname) # solc = args.solc @@ -75,7 +81,7 @@ def plot(args): logger.info("Saving figure to plot.png..") plt.savefig("plot.png", bbox_inches="tight") - except Exception: + except Exception: # pylint: disable=broad-except logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/similarity/similarity.py b/slither/tools/similarity/similarity.py index 3cf30acda..df9bf3580 100644 --- a/slither/tools/similarity/similarity.py +++ b/slither/tools/similarity/similarity.py @@ -1,4 +1,13 @@ -import numpy as np +import sys + +try: + import numpy as np +except ImportError: + print("ERROR: in order to use slither-simil, you need to install numpy:") + print( + "$ pip3 install numpy --user\n" + ) + sys.exit(-1) def similarity(v1, v2): diff --git a/slither/tools/similarity/test.py b/slither/tools/similarity/test.py index 89043a5a1..56a0e9e92 100755 --- a/slither/tools/similarity/test.py +++ b/slither/tools/similarity/test.py @@ -1,14 +1,11 @@ -import argparse import logging +import operator import sys import traceback -import operator -import numpy as np -from .model import load_model -from .encode import encode_contract, load_and_encode, parse_target -from .cache import save_cache -from .similarity import similarity +from slither.tools.similarity.encode import encode_contract, load_and_encode, parse_target +from slither.tools.similarity.model import load_model +from slither.tools.similarity.similarity import similarity logger = logging.getLogger("Slither-simil") @@ -24,7 +21,9 @@ def test(args): ntop = args.ntop if filename is None or contract is None or fname is None or infile is None: - logger.error("The test mode requires filename, contract, fname and input parameters.") + logger.error( + "The test mode requires filename, contract, fname and input parameters." + ) sys.exit(-1) irs = encode_contract(filename, **vars(args)) @@ -42,14 +41,16 @@ def test(args): r[x] = similarity(fvector, y) r = sorted(r.items(), key=operator.itemgetter(1), reverse=True) - logger.info("Reviewed %d functions, listing the %d most similar ones:", len(r), ntop) + logger.info( + "Reviewed %d functions, listing the %d most similar ones:", len(r), ntop + ) format_table = "{: <65} {: <20} {: <20} {: <10}" logger.info(format_table.format(*["filename", "contract", "function", "score"])) for x, score in r[:ntop]: score = str(round(score, 3)) logger.info(format_table.format(*(list(x) + [score]))) - except Exception: + except Exception: # pylint: disable=broad-except logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/similarity/train.py b/slither/tools/similarity/train.py index 3052ae6c5..23ce9e4d1 100755 --- a/slither/tools/similarity/train.py +++ b/slither/tools/similarity/train.py @@ -1,24 +1,21 @@ -import argparse import logging +import os import sys import traceback -import operator -import os -from .model import train_unsupervised -from .encode import encode_contract, load_contracts -from .cache import save_cache +from slither.tools.similarity.cache import save_cache +from slither.tools.similarity.encode import encode_contract, load_contracts +from slither.tools.similarity.model import train_unsupervised logger = logging.getLogger("Slither-simil") -def train(args): +def train(args): # pylint: disable=too-many-locals try: last_data_train_filename = "last_data_train.txt" model_filename = args.model dirname = args.input - nsamples = args.nsamples if dirname is None: logger.error("The train mode requires the input parameter.") @@ -30,13 +27,15 @@ def train(args): with open(last_data_train_filename, "w") as f: for filename in contracts: # cache[filename] = dict() - for (filename, contract, function), ir in encode_contract( + for (filename_inner, contract, function), ir in encode_contract( filename, **vars(args) ).items(): if ir != []: x = " ".join(ir) f.write(x + "\n") - cache.append((os.path.split(filename)[-1], contract, function, x)) + cache.append( + (os.path.split(filename_inner)[-1], contract, function, x) + ) logger.info("Starting training") model = train_unsupervised(input=last_data_train_filename, model="skipgram") @@ -51,7 +50,7 @@ def train(args): save_cache(cache, "cache.npz") logger.info("Done!") - except Exception: + except Exception: # pylint: disable=broad-except logger.error("Error in %s" % args.filename) logger.error(traceback.format_exc()) sys.exit(-1) diff --git a/slither/tools/slither_format/__main__.py b/slither/tools/slither_format/__main__.py index b99485236..26952bcc4 100644 --- a/slither/tools/slither_format/__main__.py +++ b/slither/tools/slither_format/__main__.py @@ -1,10 +1,11 @@ import sys import argparse -from slither import Slither -from slither.utils.command_line import read_config_file import logging -from .slither_format import slither_format from crytic_compile import cryticparser +from slither import Slither +from slither.utils.command_line import read_config_file +from slither.tools.slither_format.slither_format import slither_format + logging.basicConfig() logger = logging.getLogger("Slither").setLevel(logging.INFO) @@ -29,7 +30,9 @@ def parse_args(): Parse the underlying arguments for the program. :return: Returns the arguments for the program. """ - parser = argparse.ArgumentParser(description="slither_format", usage="slither_format filename") + parser = argparse.ArgumentParser( + description="slither_format", usage="slither_format filename" + ) parser.add_argument( "filename", help="The filename of the contract or truffle directory to analyze." @@ -42,10 +45,17 @@ def parse_args(): default=False, ) parser.add_argument( - "--verbose-json", "-j", help="verbose json output", action="store_true", default=False + "--verbose-json", + "-j", + help="verbose json output", + action="store_true", + default=False, ) parser.add_argument( - "--version", help="displays the current version", version="0.1.0", action="version" + "--version", + help="displays the current version", + version="0.1.0", + action="version", ) parser.add_argument( diff --git a/slither/tools/slither_format/slither_format.py b/slither/tools/slither_format/slither_format.py index 659b69557..9a972a146 100644 --- a/slither/tools/slither_format/slither_format.py +++ b/slither/tools/slither_format/slither_format.py @@ -5,7 +5,9 @@ from slither.detectors.attributes.incorrect_solc import IncorrectSolc from slither.detectors.attributes.constant_pragma import ConstantPragma from slither.detectors.naming_convention.naming_convention import NamingConvention from slither.detectors.functions.external_function import ExternalFunction -from slither.detectors.variables.possible_const_state_variables import ConstCandidateStateVars +from slither.detectors.variables.possible_const_state_variables import ( + ConstCandidateStateVars, +) from slither.detectors.attributes.const_functions_asm import ConstantFunctionsAsm from slither.detectors.attributes.const_functions_state import ConstantFunctionsState from slither.utils.colors import yellow @@ -25,7 +27,7 @@ all_detectors = { } -def slither_format(slither, **kwargs): +def slither_format(slither, **kwargs): # pylint: disable=too-many-locals """' Keyword Args: detectors_to_run (str): Comma-separated list of detectors, defaults to all @@ -42,7 +44,9 @@ def slither_format(slither, **kwargs): detector_results = slither.run_detectors() detector_results = [x for x in detector_results if x] # remove empty results - detector_results = [item for sublist in detector_results for item in sublist] # flatten + detector_results = [ + item for sublist in detector_results for item in sublist + ] # flatten export = Path("crytic-export", "patches") @@ -50,7 +54,11 @@ def slither_format(slither, **kwargs): counter_result = 0 - logger.info(yellow("slither-format is in beta, carefully review each patch before merging it.")) + logger.info( + yellow( + "slither-format is in beta, carefully review each patch before merging it." + ) + ) for result in detector_results: if not "patches" in result: @@ -65,7 +73,7 @@ def slither_format(slither, **kwargs): logger.info(f"Issue: {one_line_description}") logger.info(f"Generated: ({export_result})") - for file, diff, in result["patches_diff"].items(): + for _, diff, in result["patches_diff"].items(): filename = f"fix_{counter}.patch" path = Path(export_result, filename) logger.info(f"\t- {filename}") @@ -139,8 +147,8 @@ def print_patches_json(number_of_slither_results, patches): print('"Patch file":' + '"' + file + '",') print('"Number of patches":' + '"' + str(len(patches[file])) + '"', ",") print('"Patches":' + "[") - for index, patch in enumerate(patches[file]): - if index > 0: + for inner_index, patch in enumerate(patches[file]): + if inner_index > 0: print(",") print("{", end="") print('"Detector":' + '"' + patch["detector"] + '",') diff --git a/slither/tools/upgradeability/__main__.py b/slither/tools/upgradeability/__main__.py index 8903c01c5..8e78e7794 100644 --- a/slither/tools/upgradeability/__main__.py +++ b/slither/tools/upgradeability/__main__.py @@ -10,9 +10,9 @@ from slither import Slither from slither.exceptions import SlitherException from slither.utils.colors import red from slither.utils.output import output_to_json -from .checks import all_checks -from .checks.abstract_checks import AbstractCheck -from .utils.command_line import ( +from slither.tools.upgradeability.checks import all_checks +from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck +from slither.tools.upgradeability.utils.command_line import ( output_detectors_json, output_wiki, output_detectors, @@ -57,7 +57,10 @@ def parse_args(): ) parser.add_argument( - "--markdown-root", help="URL for markdown generation", action="store", default="" + "--markdown-root", + help="URL for markdown generation", + action="store", + default="", ) parser.add_argument( @@ -72,7 +75,9 @@ def parse_args(): default=False, ) - parser.add_argument("--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False) + parser.add_argument( + "--markdown", help=argparse.SUPPRESS, action=OutputMarkdown, default=False + ) cryticparser.init(parser) @@ -92,34 +97,36 @@ def parse_args(): def _get_checks(): detectors = [getattr(all_checks, name) for name in dir(all_checks)] - detectors = [c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck)] + detectors = [ + c for c in detectors if inspect.isclass(c) and issubclass(c, AbstractCheck) + ] return detectors -class ListDetectors(argparse.Action): - def __call__(self, parser, *args, **kwargs): +class ListDetectors(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs checks = _get_checks() output_detectors(checks) parser.exit() -class ListDetectorsJson(argparse.Action): - def __call__(self, parser, *args, **kwargs): +class ListDetectorsJson(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, *args, **kwargs): # pylint: disable=signature-differs checks = _get_checks() detector_types_json = output_detectors_json(checks) print(json.dumps(detector_types_json)) parser.exit() -class OutputMarkdown(argparse.Action): - def __call__(self, parser, args, values, option_string=None): +class OutputMarkdown(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, args, values, option_string=None): # pylint: disable=signature-differs checks = _get_checks() output_to_markdown(checks, values) parser.exit() -class OutputWiki(argparse.Action): - def __call__(self, parser, args, values, option_string=None): +class OutputWiki(argparse.Action): # pylint: disable=too-few-public-methods + def __call__(self, parser, args, values, option_string=None): # pylint: disable=signature-differs checks = _get_checks() output_wiki(checks, values) parser.exit() @@ -143,7 +150,9 @@ def _checks_on_contract(detectors, contract): def _checks_on_contract_update(detectors, contract_v1, contract_v2): detectors = [ - d(logger, contract_v1, contract_v2=contract_v2) for d in detectors if d.REQUIRE_CONTRACT_V2 + d(logger, contract_v1, contract_v2=contract_v2) + for d in detectors + if d.REQUIRE_CONTRACT_V2 ] return _run_checks(detectors), len(detectors) @@ -160,9 +169,13 @@ def _checks_on_contract_and_proxy(detectors, contract, proxy): ################################################################################### ################################################################################### - +# pylint: disable=too-many-statements,too-many-branches,too-many-locals def main(): - json_results = {"proxy-present": False, "contract_v2-present": False, "detectors": []} + json_results = { + "proxy-present": False, + "contract_v2-present": False, + "detectors": [], + } args = parse_args() @@ -170,19 +183,21 @@ def main(): number_detectors_run = 0 detectors = _get_checks() try: - v1 = Slither(v1_filename, **vars(args)) + variable1 = Slither(v1_filename, **vars(args)) # Analyze logic contract v1_name = args.ContractName - v1_contract = v1.get_contract_from_name(v1_name) + v1_contract = variable1.get_contract_from_name(v1_name) if v1_contract is None: - info = "Contract {} not found in {}".format(v1_name, v1.filename) + info = "Contract {} not found in {}".format(v1_name, variable1.filename) logger.error(red(info)) if args.json: output_to_json(args.json, str(info), json_results) return - detectors_results, number_detectors = _checks_on_contract(detectors, v1_contract) + detectors_results, number_detectors = _checks_on_contract( + detectors, v1_contract + ) json_results["detectors"] += detectors_results number_detectors_run += number_detectors @@ -192,11 +207,13 @@ def main(): if args.proxy_filename: proxy = Slither(args.proxy_filename, **vars(args)) else: - proxy = v1 + proxy = variable1 proxy_contract = proxy.get_contract_from_name(args.proxy_name) if proxy_contract is None: - info = "Proxy {} not found in {}".format(args.proxy_name, proxy.filename) + info = "Proxy {} not found in {}".format( + args.proxy_name, proxy.filename + ) logger.error(red(info)) if args.json: output_to_json(args.json, str(info), json_results) @@ -211,14 +228,14 @@ def main(): # Analyze new version if args.new_contract_name: if args.new_contract_filename: - v2 = Slither(args.new_contract_filename, **vars(args)) + variable2 = Slither(args.new_contract_filename, **vars(args)) else: - v2 = v1 + variable2 = variable1 - v2_contract = v2.get_contract_from_name(args.new_contract_name) + v2_contract = variable2.get_contract_from_name(args.new_contract_name) if v2_contract is None: info = "New logic contract {} not found in {}".format( - args.new_contract_name, v2.filename + args.new_contract_name, variable2.filename ) logger.error(red(info)) if args.json: @@ -244,16 +261,15 @@ def main(): json_results["detectors"] += detectors_results number_detectors_run += number_detectors - logger.info( - f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run' - ) + to_log = f'{len(json_results["detectors"])} findings, {number_detectors_run} detectors run' + logger.info(to_log) if args.json: output_to_json(args.json, None, json_results) - except SlitherException as e: - logger.error(str(e)) + except SlitherException as slither_exception: + logger.error(str(slither_exception)) if args.json: - output_to_json(args.json, str(e), json_results) + output_to_json(args.json, str(slither_exception), json_results) return diff --git a/slither/tools/upgradeability/checks/abstract_checks.py b/slither/tools/upgradeability/checks/abstract_checks.py index 05a0d1182..19f456e89 100644 --- a/slither/tools/upgradeability/checks/abstract_checks.py +++ b/slither/tools/upgradeability/checks/abstract_checks.py @@ -8,7 +8,7 @@ class IncorrectCheckInitialization(Exception): pass -class CheckClassification: +class CheckClassification: # pylint: disable=too-few-public-methods HIGH = 0 MEDIUM = 1 LOW = 2 @@ -81,12 +81,16 @@ class AbstractCheck(metaclass=abc.ABCMeta): CheckClassification.INFORMATIONAL ]: raise IncorrectCheckInitialization( - "WIKI_EXPLOIT_SCENARIO is not initialized {}".format(self.__class__.__name__) + "WIKI_EXPLOIT_SCENARIO is not initialized {}".format( + self.__class__.__name__ + ) ) if not self.WIKI_RECOMMENDATION: raise IncorrectCheckInitialization( - "WIKI_RECOMMENDATION is not initialized {}".format(self.__class__.__name__) + "WIKI_RECOMMENDATION is not initialized {}".format( + self.__class__.__name__ + ) ) if self.REQUIRE_PROXY and self.REQUIRE_CONTRACT_V2: @@ -129,14 +133,16 @@ class AbstractCheck(metaclass=abc.ABCMeta): if all_results: if self.logger: info = "\n" - for idx, result in enumerate(all_results): + for result in all_results: info += result["description"] info += "Reference: {}".format(self.WIKI) self._log(info) return all_results def generate_result(self, info, additional_fields=None): - output = Output(info, additional_fields, markdown_root=self.contract.slither.markdown_root) + output = Output( + info, additional_fields, markdown_root=self.contract.slither.markdown_root + ) output.data["check"] = self.ARGUMENT diff --git a/slither/tools/upgradeability/checks/all_checks.py b/slither/tools/upgradeability/checks/all_checks.py index fcedb69f8..2289c3808 100644 --- a/slither/tools/upgradeability/checks/all_checks.py +++ b/slither/tools/upgradeability/checks/all_checks.py @@ -1,4 +1,5 @@ -from .initialization import ( +# pylint: disable=unused-import +from slither.tools.upgradeability.checks.initialization import ( InitializablePresent, InitializableInherited, InitializableInitializer, @@ -8,11 +9,11 @@ from .initialization import ( InitializeTarget, ) -from .functions_ids import IDCollision, FunctionShadowing +from slither.tools.upgradeability.checks.functions_ids import IDCollision, FunctionShadowing -from .variable_initialization import VariableWithInit +from slither.tools.upgradeability.checks.variable_initialization import VariableWithInit -from .variables_order import ( +from slither.tools.upgradeability.checks.variables_order import ( MissingVariable, DifferentVariableContractProxy, DifferentVariableContractNewContract, @@ -20,4 +21,4 @@ from .variables_order import ( ExtraVariablesNewContract, ) -from .constant import WereConstant, BecameConstant +from slither.tools.upgradeability.checks.constant import WereConstant, BecameConstant diff --git a/slither/tools/upgradeability/checks/constant.py b/slither/tools/upgradeability/checks/constant.py index e1d547e28..467bc2776 100644 --- a/slither/tools/upgradeability/checks/constant.py +++ b/slither/tools/upgradeability/checks/constant.py @@ -1,4 +1,7 @@ -from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck, CheckClassification +from slither.tools.upgradeability.checks.abstract_checks import ( + AbstractCheck, + CheckClassification, +) class WereConstant(AbstractCheck): @@ -67,7 +70,8 @@ Do not remove `constant` from a state variables during an update. if not state_v2.is_constant: # If v2 has additional non constant variables, we need to skip them if ( - state_v1.name != state_v2.name or state_v1.type != state_v2.type + state_v1.name != state_v2.name + or state_v1.type != state_v2.type ) and v2_additional_variables > 0: v2_additional_variables -= 1 idx_v2 += 1 @@ -149,7 +153,8 @@ Do not make an existing state variable `constant`. if not state_v2.is_constant: # If v2 has additional non constant variables, we need to skip them if ( - state_v1.name != state_v2.name or state_v1.type != state_v2.type + state_v1.name != state_v2.name + or state_v1.type != state_v2.type ) and v2_additional_variables > 0: v2_additional_variables -= 1 idx_v2 += 1 diff --git a/slither/tools/upgradeability/checks/functions_ids.py b/slither/tools/upgradeability/checks/functions_ids.py index cbc822d20..5ffc62ff3 100644 --- a/slither/tools/upgradeability/checks/functions_ids.py +++ b/slither/tools/upgradeability/checks/functions_ids.py @@ -1,5 +1,8 @@ from slither.exceptions import SlitherError -from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck, CheckClassification +from slither.tools.upgradeability.checks.abstract_checks import ( + AbstractCheck, + CheckClassification, +) from slither.utils.function import get_function_id @@ -8,12 +11,16 @@ def get_signatures(c): functions = [ f.full_name for f in functions - if f.visibility in ["public", "external"] and not f.is_constructor and not f.is_fallback + if f.visibility in ["public", "external"] + and not f.is_constructor + and not f.is_fallback ] variables = c.state_variables variables = [ - variable.name + "()" for variable in variables if variable.visibility in ["public"] + variable.name + "()" + for variable in variables + if variable.visibility in ["public"] ] return list(set(functions + variables)) @@ -85,7 +92,9 @@ Rename the function. Avoid public functions in the proxy. implem_function = _get_function_or_variable( self.contract, signatures_ids_implem[k] ) - proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k]) + proxy_function = _get_function_or_variable( + self.proxy, signatures_ids_proxy[k] + ) info = [ "Function id collision found: ", @@ -151,7 +160,9 @@ Rename the function. Avoid public functions in the proxy. implem_function = _get_function_or_variable( self.contract, signatures_ids_implem[k] ) - proxy_function = _get_function_or_variable(self.proxy, signatures_ids_proxy[k]) + proxy_function = _get_function_or_variable( + self.proxy, signatures_ids_proxy[k] + ) info = [ "Function shadowing found: ", diff --git a/slither/tools/upgradeability/checks/initialization.py b/slither/tools/upgradeability/checks/initialization.py index 2e37457dc..b8fdb12b6 100644 --- a/slither/tools/upgradeability/checks/initialization.py +++ b/slither/tools/upgradeability/checks/initialization.py @@ -1,9 +1,11 @@ import logging from slither.slithir.operations import InternalCall -from slither.tools.upgradeability.checks.abstract_checks import AbstractCheck, CheckClassification -from slither.utils.output import Output -from slither.utils.colors import red, yellow, green +from slither.tools.upgradeability.checks.abstract_checks import ( + AbstractCheck, + CheckClassification, +) +from slither.utils.colors import red logger = logging.getLogger("Slither-check-upgradeability") @@ -13,7 +15,9 @@ class MultipleInitTarget(Exception): def _get_initialize_functions(contract): - return [f for f in contract.functions if f.name == "initialize" and f.is_implemented] + return [ + f for f in contract.functions if f.name == "initialize" and f.is_implemented + ] def _get_all_internal_calls(function): @@ -26,7 +30,9 @@ def _get_all_internal_calls(function): def _get_most_derived_init(contract): - init_functions = [f for f in contract.functions if not f.is_shadowed and f.name == "initialize"] + init_functions = [ + f for f in contract.functions if not f.is_shadowed and f.name == "initialize" + ] if len(init_functions) > 1: if len([f for f in init_functions if f.contract_declarer == contract]) == 1: return next((f for f in init_functions if f.contract_declarer == contract)) @@ -120,7 +126,9 @@ Review manually the contract's initialization. Consider inheriting a `Initializa if initializable not in self.contract.inheritance: return [] - initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()") + initializer = self.contract.get_modifier_from_canonical_name( + "Initializable.initializer()" + ) if initializer is None: info = ["Initializable.initializer() does not exist.\n"] json = self.generate_result(info) @@ -165,7 +173,9 @@ Use `Initializable.initializer()`. # See InitializableInherited if initializable not in self.contract.inheritance: return [] - initializer = self.contract.get_modifier_from_canonical_name("Initializable.initializer()") + initializer = self.contract.get_modifier_from_canonical_name( + "Initializable.initializer()" + ) # InitializableInitializer if initializer is None: return [] @@ -228,8 +238,12 @@ Ensure all the initialize functions are reached by the most derived initialize f return [] all_init_functions = _get_initialize_functions(self.contract) - all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init] - missing_calls = [f for f in all_init_functions if not f in all_init_functions_called] + all_init_functions_called = _get_all_internal_calls(most_derived_init) + [ + most_derived_init + ] + missing_calls = [ + f for f in all_init_functions if not f in all_init_functions_called + ] for f in missing_calls: info = ["Missing call to ", f, " in ", most_derived_init, ".\n"] json = self.generate_result(info) @@ -292,10 +306,14 @@ Call only one time every initialize function. if most_derived_init is None: return [] - all_init_functions_called = _get_all_internal_calls(most_derived_init) + [most_derived_init] - double_calls = list( - set([f for f in all_init_functions_called if all_init_functions_called.count(f) > 1]) - ) + all_init_functions_called = _get_all_internal_calls(most_derived_init) + [ + most_derived_init + ] + double_calls = list({ + f + for f in all_init_functions_called + if all_init_functions_called.count(f) > 1 + }) for f in double_calls: info = [f, " is called multiple times in ", most_derived_init, ".\n"] json = self.generate_result(info) @@ -337,6 +355,11 @@ Ensure that the function is called at deployment. if most_derived_init is None: return [] - info = [self.contract, f" needs to be initialized by ", most_derived_init, ".\n"] + info = [ + self.contract, + " needs to be initialized by ", + most_derived_init, + ".\n", + ] json = self.generate_result(info) return [json] diff --git a/slither/tools/upgradeability/checks/variable_initialization.py b/slither/tools/upgradeability/checks/variable_initialization.py index 7b9316ef7..cab71d018 100644 --- a/slither/tools/upgradeability/checks/variable_initialization.py +++ b/slither/tools/upgradeability/checks/variable_initialization.py @@ -1,4 +1,7 @@ -from slither.tools.upgradeability.checks.abstract_checks import CheckClassification, AbstractCheck +from slither.tools.upgradeability.checks.abstract_checks import ( + CheckClassification, + AbstractCheck, +) class VariableWithInit(AbstractCheck): diff --git a/slither/tools/upgradeability/checks/variables_order.py b/slither/tools/upgradeability/checks/variables_order.py index 735a8a1e1..a38b0ad06 100644 --- a/slither/tools/upgradeability/checks/variables_order.py +++ b/slither/tools/upgradeability/checks/variables_order.py @@ -1,4 +1,7 @@ -from slither.tools.upgradeability.checks.abstract_checks import CheckClassification, AbstractCheck +from slither.tools.upgradeability.checks.abstract_checks import ( + CheckClassification, + AbstractCheck, +) class MissingVariable(AbstractCheck): @@ -6,7 +9,9 @@ class MissingVariable(AbstractCheck): IMPACT = CheckClassification.MEDIUM HELP = "Variable missing in the v2" - WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables" + WIKI = ( + "https://github.com/crytic/slither/wiki/Upgradeability-Checks#missing-variables" + ) WIKI_TITLE = "Missing variables" WIKI_DESCRIPTION = """ Detect variables that were present in the original contracts but are not in the updated one. @@ -37,11 +42,19 @@ Do not change the order of the state variables in the updated contract. def _check(self): contract1 = self.contract contract2 = self.contract_v2 - order1 = [variable for variable in contract1.state_variables if not variable.is_constant] - order2 = [variable for variable in contract2.state_variables if not variable.is_constant] + order1 = [ + variable + for variable in contract1.state_variables + if not variable.is_constant + ] + order2 = [ + variable + for variable in contract2.state_variables + if not variable.is_constant + ] results = [] - for idx in range(0, len(order1)): + for idx, _ in enumerate(order1): variable1 = order1[idx] if len(order2) <= idx: info = ["Variable missing in ", contract2, ": ", variable1, "\n"] @@ -92,11 +105,19 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s def _check(self): contract1 = self._contract1() contract2 = self._contract2() - order1 = [variable for variable in contract1.state_variables if not variable.is_constant] - order2 = [variable for variable in contract2.state_variables if not variable.is_constant] + order1 = [ + variable + for variable in contract1.state_variables + if not variable.is_constant + ] + order2 = [ + variable + for variable in contract2.state_variables + if not variable.is_constant + ] results = [] - for idx in range(0, len(order1)): + for idx, _ in enumerate(order1): if len(order2) <= idx: # Handle by MissingVariable return results @@ -104,9 +125,15 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s variable1 = order1[idx] variable2 = order2[idx] if (variable1.name != variable2.name) or (variable1.type != variable2.type): - info = ["Different variables between ", contract1, " and ", contract2, "\n"] - info += [f"\t ", variable1, "\n"] - info += [f"\t ", variable2, "\n"] + info = [ + "Different variables between ", + contract1, + " and ", + contract2, + "\n", + ] + info += ["\t ", variable1, "\n"] + info += ["\t ", variable2, "\n"] json = self.generate_result(info) results.append(json) @@ -154,9 +181,7 @@ class ExtraVariablesProxy(AbstractCheck): IMPACT = CheckClassification.MEDIUM HELP = "Extra vars in the proxy" - WIKI = ( - "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy" - ) + WIKI = "https://github.com/crytic/slither/wiki/Upgradeability-Checks#extra-variables-in-the-proxy" WIKI_TITLE = "Extra variables in the proxy" WIKI_DESCRIPTION = """ @@ -193,8 +218,16 @@ Avoid variables in the proxy. If a variable is in the proxy, ensure it has the s def _check(self): contract1 = self._contract1() contract2 = self._contract2() - order1 = [variable for variable in contract1.state_variables if not variable.is_constant] - order2 = [variable for variable in contract2.state_variables if not variable.is_constant] + order1 = [ + variable + for variable in contract1.state_variables + if not variable.is_constant + ] + order2 = [ + variable + for variable in contract2.state_variables + if not variable.is_constant + ] results = [] diff --git a/slither/tools/upgradeability/utils/command_line.py b/slither/tools/upgradeability/utils/command_line.py index 57a6ef88d..e02c1624f 100644 --- a/slither/tools/upgradeability/utils/command_line.py +++ b/slither/tools/upgradeability/utils/command_line.py @@ -40,10 +40,14 @@ def output_detectors(detector_classes): require_proxy = detector.REQUIRE_PROXY require_v2 = detector.REQUIRE_CONTRACT_V2 detectors_list.append((argument, help_info, impact, require_proxy, require_v2)) - table = MyPrettyTable(["Num", "Check", "What it Detects", "Impact", "Proxy", "Contract V2"]) + table = MyPrettyTable( + ["Num", "Check", "What it Detects", "Impact", "Proxy", "Contract V2"] + ) # Sort by impact, confidence, and name - detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) + detectors_list = sorted( + detectors_list, key=lambda element: (element[2], element[0]) + ) idx = 1 for (argument, help_info, impact, proxy, v2) in detectors_list: table.add_row( @@ -60,7 +64,7 @@ def output_detectors(detector_classes): print(table) -def output_to_markdown(detector_classes, filter_wiki): +def output_to_markdown(detector_classes, _filter_wiki): def extract_help(cls): if cls.WIKI == "": return cls.HELP @@ -76,7 +80,9 @@ def output_to_markdown(detector_classes, filter_wiki): detectors_list.append((argument, help_info, impact, require_proxy, require_v2)) # Sort by impact, confidence, and name - detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) + detectors_list = sorted( + detectors_list, key=lambda element: (element[2], element[0]) + ) idx = 1 for (argument, help_info, impact, proxy, v2) in detectors_list: print( @@ -115,7 +121,9 @@ def output_detectors_json(detector_classes): ) # Sort by impact, confidence, and name - detectors_list = sorted(detectors_list, key=lambda element: (element[2], element[0])) + detectors_list = sorted( + detectors_list, key=lambda element: (element[2], element[0]) + ) idx = 1 table = [] for ( diff --git a/slither/utils/arithmetic.py b/slither/utils/arithmetic.py index 03b2c1de1..1206764b9 100644 --- a/slither/utils/arithmetic.py +++ b/slither/utils/arithmetic.py @@ -3,7 +3,7 @@ from decimal import Decimal from slither.exceptions import SlitherException -def convert_subdenomination(value: str, sub: str) -> int: +def convert_subdenomination(value: str, sub: str) -> int: # pylint: disable=too-many-return-statements # to allow 0.1 ether conversion if value[0:2] == "0x": @@ -31,4 +31,6 @@ def convert_subdenomination(value: str, sub: str) -> int: if sub == "years": return int(decimal_value * 60 * 60 * 24 * 7 * 365) - raise SlitherException(f"Subdemonination conversion impossible {decimal_value} {sub}") + raise SlitherException( + f"Subdemonination conversion impossible {decimal_value} {sub}" + ) diff --git a/slither/utils/colors.py b/slither/utils/colors.py index c1a0e73af..0f2585501 100644 --- a/slither/utils/colors.py +++ b/slither/utils/colors.py @@ -2,7 +2,7 @@ from functools import partial import platform -class Colors: +class Colors: # pylint: disable=too-few-public-methods COLORIZATION_ENABLED = True RED = "\033[91m" GREEN = "\033[92m" @@ -15,8 +15,7 @@ class Colors: def colorize(color: Colors, txt: str) -> str: if Colors.COLORIZATION_ENABLED: return "{}{}{}".format(color, txt, Colors.END) - else: - return txt + return txt def enable_windows_virtual_terminal_sequences() -> bool: @@ -26,6 +25,7 @@ def enable_windows_virtual_terminal_sequences() -> bool: """ try: + # pylint: disable=import-outside-toplevel from ctypes import windll, byref from ctypes.wintypes import DWORD, HANDLE @@ -55,7 +55,7 @@ def enable_windows_virtual_terminal_sequences() -> bool: current_handle, current_mode.value | virtual_terminal_flag ): return False - except: + except Exception: # pylint: disable=broad-except # Any generic failure (possibly from calling these methods on older Windows builds where they do not exist) # will fall back onto disabling colorization. return False diff --git a/slither/utils/command_line.py b/slither/utils/command_line.py index 3e0659772..15e2d049e 100644 --- a/slither/utils/command_line.py +++ b/slither/utils/command_line.py @@ -7,8 +7,8 @@ from crytic_compile.cryticparser.defaults import ( ) from slither.detectors.abstract_detector import classification_txt -from .colors import yellow, red -from .myprettytable import MyPrettyTable +from slither.utils.colors import yellow, red +from slither.utils.myprettytable import MyPrettyTable logger = logging.getLogger("Slither") @@ -56,7 +56,9 @@ def read_config_file(args): if key not in defaults_flag_in_config: logger.info( yellow( - "{} has an unknown key: {} : {}".format(args.config_file, key, elem) + "{} has an unknown key: {} : {}".format( + args.config_file, key, elem + ) ) ) continue @@ -64,7 +66,11 @@ def read_config_file(args): setattr(args, key, elem) except json.decoder.JSONDecodeError as e: logger.error( - red("Impossible to read {}, please check the file {}".format(args.config_file, e)) + red( + "Impossible to read {}, please check the file {}".format( + args.config_file, e + ) + ) ) @@ -171,7 +177,8 @@ def output_wiki(detector_classes, filter_wiki): # Sort by impact, confidence, and name detectors_list = sorted( - detector_classes, key=lambda element: (element.IMPACT, element.CONFIDENCE, element.ARGUMENT) + detector_classes, + key=lambda element: (element.IMPACT, element.CONFIDENCE, element.ARGUMENT), ) for detector in detectors_list: @@ -222,12 +229,14 @@ def output_detectors(detector_classes): ) idx = 1 for (argument, help_info, impact, confidence) in detectors_list: - table.add_row([idx, argument, help_info, classification_txt[impact], confidence]) + table.add_row( + [idx, argument, help_info, classification_txt[impact], confidence] + ) idx = idx + 1 print(table) -def output_detectors_json(detector_classes): +def output_detectors_json(detector_classes): # pylint: disable=too-many-locals detectors_list = [] for detector in detector_classes: argument = detector.ARGUMENT diff --git a/slither/utils/erc.py b/slither/utils/erc.py index b289c2962..0fb5193d1 100644 --- a/slither/utils/erc.py +++ b/slither/utils/erc.py @@ -1,7 +1,9 @@ from collections import namedtuple -from typing import Union, List +from typing import List -ERC = namedtuple("ERC", ["name", "parameters", "return_type", "view", "required", "events"]) +ERC = namedtuple( + "ERC", ["name", "parameters", "return_type", "view", "required", "events"] +) ERC_EVENT = namedtuple("ERC_EVENT", ["name", "parameters", "indexes"]) @@ -17,14 +19,20 @@ def erc_to_signatures(erc: List[ERC]): # Final # https://eips.ethereum.org/EIPS/eip-20 -ERC20_transfer_event = ERC_EVENT("Transfer", ["address", "address", "uint256"], [True, True, False]) -ERC20_approval_event = ERC_EVENT("Approval", ["address", "address", "uint256"], [True, True, False]) +ERC20_transfer_event = ERC_EVENT( + "Transfer", ["address", "address", "uint256"], [True, True, False] +) +ERC20_approval_event = ERC_EVENT( + "Approval", ["address", "address", "uint256"], [True, True, False] +) ERC20_EVENTS = [ERC20_transfer_event, ERC20_approval_event] ERC20 = [ ERC("totalSupply", [], "uint256", True, True, []), ERC("balanceOf", ["address"], "uint256", True, True, []), - ERC("transfer", ["address", "uint256"], "bool", False, True, [ERC20_transfer_event]), + ERC( + "transfer", ["address", "uint256"], "bool", False, True, [ERC20_transfer_event] + ), ERC( "transferFrom", ["address", "address", "uint256"], @@ -62,8 +70,17 @@ ERC223 = [ ERC("decimals", [], "uint8", True, True, []), ERC("totalSupply", [], "uint256", True, True, []), ERC("balanceOf", ["address"], "uint256", True, True, []), - ERC("transfer", ["address", "uint256"], "bool", False, True, [ERC223_transfer_event]), - ERC("transfer", ["address", "uint256", "bytes"], "bool", False, True, [ERC223_transfer_event]), + ERC( + "transfer", ["address", "uint256"], "bool", False, True, [ERC223_transfer_event] + ), + ERC( + "transfer", + ["address", "uint256", "bytes"], + "bool", + False, + True, + [ERC223_transfer_event], + ), ERC( "transfer", ["address", "uint256", "bytes", "string"], @@ -87,12 +104,20 @@ ERC165_signatures = erc_to_signatures(ERC165) # https://eips.ethereum.org/EIPS/eip-721 # Must have ERC165 -ERC721_transfer_event = ERC_EVENT("Transfer", ["address", "address", "uint256"], [True, True, True]) -ERC721_approval_event = ERC_EVENT("Approval", ["address", "address", "uint256"], [True, True, True]) +ERC721_transfer_event = ERC_EVENT( + "Transfer", ["address", "address", "uint256"], [True, True, True] +) +ERC721_approval_event = ERC_EVENT( + "Approval", ["address", "address", "uint256"], [True, True, True] +) ERC721_approvalforall_event = ERC_EVENT( "ApprovalForAll", ["address", "address", "bool"], [True, True, False] ) -ERC721_EVENTS = [ERC721_transfer_event, ERC721_approval_event, ERC721_approvalforall_event] +ERC721_EVENTS = [ + ERC721_transfer_event, + ERC721_approval_event, + ERC721_approvalforall_event, +] ERC721 = [ ERC("balanceOf", ["address"], "uint256", True, True, []), @@ -114,10 +139,22 @@ ERC721 = [ [ERC721_transfer_event], ), ERC( - "transferFrom", ["address", "address", "uint256"], "", False, True, [ERC721_transfer_event] + "transferFrom", + ["address", "address", "uint256"], + "", + False, + True, + [ERC721_transfer_event], ), ERC("approve", ["address", "uint256"], "", False, True, [ERC721_approval_event]), - ERC("setApprovalForAll", ["address", "bool"], "", False, True, [ERC721_approvalforall_event]), + ERC( + "setApprovalForAll", + ["address", "bool"], + "", + False, + True, + [ERC721_approvalforall_event], + ), ERC("getApproved", ["uint256"], "address", True, True, []), ERC("isApprovedForAll", ["address", "address"], "bool", True, True, []), ] + ERC165 @@ -136,7 +173,14 @@ ERC721_signatures = erc_to_signatures(ERC721) # https://eips.ethereum.org/EIPS/eip-1820 ERC1820_EVENTS = [] ERC1820 = [ - ERC("canImplementInterfaceForAddress", ["bytes32", "address"], "bytes32", True, True, []) + ERC( + "canImplementInterfaceForAddress", + ["bytes32", "address"], + "bytes32", + True, + True, + [], + ) ] ERC1820_signatures = erc_to_signatures(ERC1820) @@ -148,15 +192,21 @@ ERC777_sent_event = ERC_EVENT( [True, True, True, False, False, False], ) ERC777_minted_event = ERC_EVENT( - "Minted", ["address", "address", "uint256", "bytes", "bytes"], [True, True, False, False, False] + "Minted", + ["address", "address", "uint256", "bytes", "bytes"], + [True, True, False, False, False], ) ERC777_burned_event = ERC_EVENT( - "Burned", ["address", "address", "uint256", "bytes", "bytes"], [True, True, False, False, False] + "Burned", + ["address", "address", "uint256", "bytes", "bytes"], + [True, True, False, False, False], ) ERC777_authorizedOperator_event = ERC_EVENT( "AuthorizedOperator", ["address", "address"], [True, True] ) -ERC777_revokedoperator_event = ERC_EVENT("RevokedOperator", ["address", "address"], [True, True]) +ERC777_revokedoperator_event = ERC_EVENT( + "RevokedOperator", ["address", "address"], [True, True] +) ERC777_EVENTS = [ ERC777_sent_event, ERC777_minted_event, @@ -173,7 +223,14 @@ ERC777 = [ ERC("granularity", [], "uint256", True, True, []), ERC("defaultOperators", [], "address[]", True, True, []), ERC("isOperatorFor", ["address", "address"], "bool", True, True, []), - ERC("authorizeOperator", ["address"], "", False, True, [ERC777_authorizedOperator_event]), + ERC( + "authorizeOperator", + ["address"], + "", + False, + True, + [ERC777_authorizedOperator_event], + ), ERC("revokeOperator", ["address"], "", False, True, [ERC777_revokedoperator_event]), ERC("send", ["address", "uint256", "bytes"], "", False, True, [ERC777_sent_event]), ERC( diff --git a/slither/utils/expression_manipulations.py b/slither/utils/expression_manipulations.py index d3a3fd429..aea109391 100644 --- a/slither/utils/expression_manipulations.py +++ b/slither/utils/expression_manipulations.py @@ -19,7 +19,7 @@ from slither.core.expressions.tuple_expression import TupleExpression from slither.core.expressions.type_conversion import TypeConversion from slither.all_exceptions import SlitherException - +# pylint: disable=protected-access def f_expressions(e, x): e._expressions.append(x) @@ -36,7 +36,7 @@ def f_called(e, x): e._called = x -class SplitTernaryExpression(object): +class SplitTernaryExpression: def __init__(self, expression): if isinstance(expression, ConditionalExpression): @@ -47,7 +47,9 @@ class SplitTernaryExpression(object): self.true_expression = copy.copy(expression) self.false_expression = copy.copy(expression) self.condition = None - self.copy_expression(expression, self.true_expression, self.false_expression) + self.copy_expression( + expression, self.true_expression, self.false_expression + ) def apply_copy(self, next_expr, true_expression, false_expression, f): @@ -56,39 +58,49 @@ class SplitTernaryExpression(object): f(false_expression, copy.copy(next_expr.else_expression)) self.condition = copy.copy(next_expr.if_expression) return False - else: - f(true_expression, copy.copy(next_expr)) - f(false_expression, copy.copy(next_expr)) - return True - def copy_expression(self, expression, true_expression, false_expression): + f(true_expression, copy.copy(next_expr)) + f(false_expression, copy.copy(next_expr)) + return True + + def copy_expression(self, expression, true_expression, false_expression): # pylint: disable=too-many-branches if self.condition: return if isinstance(expression, ConditionalExpression): raise SlitherException("Nested ternary operator not handled") - if isinstance(expression, (Literal, Identifier, IndexAccess, NewArray, NewContract)): - return None + if isinstance( + expression, (Literal, Identifier, IndexAccess, NewArray, NewContract) + ): + return # case of lib # (.. ? .. : ..).add if isinstance(expression, MemberAccess): next_expr = expression.expression - if self.apply_copy(next_expr, true_expression, false_expression, f_expression): + if self.apply_copy( + next_expr, true_expression, false_expression, f_expression + ): self.copy_expression( next_expr, true_expression.expression, false_expression.expression ) - elif isinstance(expression, (AssignmentOperation, BinaryOperation, TupleExpression)): + elif isinstance( + expression, (AssignmentOperation, BinaryOperation, TupleExpression) + ): true_expression._expressions = [] false_expression._expressions = [] for next_expr in expression.expressions: - if self.apply_copy(next_expr, true_expression, false_expression, f_expressions): + if self.apply_copy( + next_expr, true_expression, false_expression, f_expressions + ): # always on last arguments added self.copy_expression( - next_expr, true_expression.expressions[-1], false_expression.expressions[-1] + next_expr, + true_expression.expressions[-1], + false_expression.expressions[-1], ) elif isinstance(expression, CallExpression): @@ -97,33 +109,49 @@ class SplitTernaryExpression(object): # case of lib # (.. ? .. : ..).add if self.apply_copy(next_expr, true_expression, false_expression, f_called): - self.copy_expression(next_expr, true_expression.called, false_expression.called) + self.copy_expression( + next_expr, true_expression.called, false_expression.called + ) true_expression._arguments = [] false_expression._arguments = [] for next_expr in expression.arguments: - if self.apply_copy(next_expr, true_expression, false_expression, f_call): + if self.apply_copy( + next_expr, true_expression, false_expression, f_call + ): # always on last arguments added self.copy_expression( - next_expr, true_expression.arguments[-1], false_expression.arguments[-1] + next_expr, + true_expression.arguments[-1], + false_expression.arguments[-1], ) elif isinstance(expression, TypeConversion): next_expr = expression.expression - if self.apply_copy(next_expr, true_expression, false_expression, f_expression): + if self.apply_copy( + next_expr, true_expression, false_expression, f_expression + ): self.copy_expression( - expression.expression, true_expression.expression, false_expression.expression + expression.expression, + true_expression.expression, + false_expression.expression, ) elif isinstance(expression, UnaryOperation): next_expr = expression.expression - if self.apply_copy(next_expr, true_expression, false_expression, f_expression): + if self.apply_copy( + next_expr, true_expression, false_expression, f_expression + ): self.copy_expression( - expression.expression, true_expression.expression, false_expression.expression + expression.expression, + true_expression.expression, + false_expression.expression, ) else: raise SlitherException( - "Ternary operation not handled {}({})".format(expression, type(expression)) + "Ternary operation not handled {}({})".format( + expression, type(expression) + ) ) diff --git a/slither/utils/inheritance_analysis.py b/slither/utils/inheritance_analysis.py index 4678cb384..66b9e8ac3 100644 --- a/slither/utils/inheritance_analysis.py +++ b/slither/utils/inheritance_analysis.py @@ -10,7 +10,9 @@ if TYPE_CHECKING: from slither.core.variables.state_variable import StateVariable -def detect_c3_function_shadowing(contract: "Contract") -> Dict["Function", Set["Function"]]: +def detect_c3_function_shadowing( + contract: "Contract", +) -> Dict["Function", Set["Function"]]: """ Detects and obtains functions which are indirectly shadowed via multiple inheritance by C3 linearization properties, despite not directly inheriting from each other. @@ -48,7 +50,7 @@ def detect_state_variable_shadowing( inherited. The contracts are simply included to denote the immediate inheritance path from which the shadowed variable originates. """ - results: Set[Tuple["Contract", StateVariable, "Contract", "StateVariable"]] = set() + results: Set[Tuple["Contract", "StateVariable", "Contract", "StateVariable"]] = set() for contract in contracts: variables_declared: Dict[str, "StateVariable"] = { variable.name: variable for variable in contract.state_variables_declared diff --git a/slither/utils/output.py b/slither/utils/output.py index d46035a4c..946bd0598 100644 --- a/slither/utils/output.py +++ b/slither/utils/output.py @@ -47,7 +47,9 @@ def output_to_json(filename: str, error, results: Dict): else: # Write json to file if os.path.isfile(filename): - logger.info(yellow(f"{filename} exists already, the overwrite is prevented")) + logger.info( + yellow(f"{filename} exists already, the overwrite is prevented") + ) else: with open(filename, "w", encoding="utf8") as f: json.dump(json_result, f, indent=2) @@ -62,7 +64,9 @@ ZIP_TYPES_ACCEPTED = { } -def output_to_zip(filename: str, error: Optional[str], results: Dict, zip_type: str = "lzma"): +def output_to_zip( + filename: str, error: Optional[str], results: Dict, zip_type: str = "lzma" +): """ Output the results to a zip The file in the zip is named slither_results.json @@ -78,9 +82,13 @@ def output_to_zip(filename: str, error: Optional[str], results: Dict, zip_type: logger.info(yellow(f"{filename} exists already, the overwrite is prevented")) else: with ZipFile( - filename, "w", compression=ZIP_TYPES_ACCEPTED.get(zip_type, zipfile.ZIP_LZMA) + filename, + "w", + compression=ZIP_TYPES_ACCEPTED.get(zip_type, zipfile.ZIP_LZMA), ) as file_desc: - file_desc.writestr("slither_results.json", json.dumps(json_result).encode("utf8")) + file_desc.writestr( + "slither_results.json", json.dumps(json_result).encode("utf8") + ) # endregion @@ -96,13 +104,14 @@ def _convert_to_description(d): return d if not isinstance(d, SourceMapping): - raise SlitherError(f"{d} does not inherit from SourceMapping, conversion impossible") + raise SlitherError( + f"{d} does not inherit from SourceMapping, conversion impossible" + ) if isinstance(d, Node): if d.expression: return f"{d.expression} ({d.source_mapping_str})" - else: - return f"{str(d)} ({d.source_mapping_str})" + return f"{str(d)} ({d.source_mapping_str})" if hasattr(d, "canonical_name"): return f"{d.canonical_name} ({d.source_mapping_str})" @@ -118,13 +127,14 @@ def _convert_to_markdown(d, markdown_root): return d if not isinstance(d, SourceMapping): - raise SlitherError(f"{d} does not inherit from SourceMapping, conversion impossible") + raise SlitherError( + f"{d} does not inherit from SourceMapping, conversion impossible" + ) if isinstance(d, Node): if d.expression: return f"[{d.expression}]({d.source_mapping_to_markdown(markdown_root)})" - else: - return f"[{str(d)}]({d.source_mapping_to_markdown(markdown_root)})" + return f"[{str(d)}]({d.source_mapping_to_markdown(markdown_root)})" if hasattr(d, "canonical_name"): return f"[{d.canonical_name}]({d.source_mapping_to_markdown(markdown_root)})" @@ -145,13 +155,14 @@ def _convert_to_id(d): return d if not isinstance(d, SourceMapping): - raise SlitherError(f"{d} does not inherit from SourceMapping, conversion impossible") + raise SlitherError( + f"{d} does not inherit from SourceMapping, conversion impossible" + ) if isinstance(d, Node): if d.expression: return f"{d.expression} ({d.source_mapping_str})" - else: - return f"{str(d)} ({d.source_mapping_str})" + return f"{str(d)} ({d.source_mapping_str})" if isinstance(d, Pragma): return f"{d} ({d.source_mapping_str})" @@ -174,13 +185,13 @@ def _convert_to_id(d): def _create_base_element( - type, name, source_mapping, type_specific_fields=None, additional_fields=None + custom_type, name, source_mapping, type_specific_fields=None, additional_fields=None ): if additional_fields is None: additional_fields = {} if type_specific_fields is None: type_specific_fields = {} - element = {"type": type, "name": name, "source_mapping": source_mapping} + element = {"type": custom_type, "name": name, "source_mapping": source_mapping} if type_specific_fields: element["type_specific_fields"] = type_specific_fields if additional_fields: @@ -189,6 +200,7 @@ def _create_base_element( def _create_parent_element(element): + # pylint: disable=import-outside-toplevel from slither.core.children.child_contract import ChildContract from slither.core.children.child_function import ChildFunction from slither.core.children.child_inheritance import ChildInheritance @@ -211,7 +223,9 @@ def _create_parent_element(element): return None -SupportedOutput = Union[Variable, Contract, Function, Enum, Event, Structure, Pragma, Node] +SupportedOutput = Union[ + Variable, Contract, Function, Enum, Event, Structure, Pragma, Node +] class Output: @@ -235,7 +249,9 @@ class Output: self._data: Dict[str, Any] = OrderedDict() self._data["elements"] = [] self._data["description"] = "".join(_convert_to_description(d) for d in info) - self._data["markdown"] = "".join(_convert_to_markdown(d, markdown_root) for d in info) + self._data["markdown"] = "".join( + _convert_to_markdown(d, markdown_root) for d in info + ) id_txt = "".join(_convert_to_id(d) for d in info) self._data["id"] = hashlib.sha3_256(id_txt.encode("utf-8")).hexdigest() @@ -284,7 +300,9 @@ class Output: ################################################################################### ################################################################################### - def add_variable(self, variable: Variable, additional_fields: Optional[Dict] = None): + def add_variable( + self, variable: Variable, additional_fields: Optional[Dict] = None + ): if additional_fields is None: additional_fields = {} type_specific_fields = {"parent": _create_parent_element(variable)} @@ -308,7 +326,9 @@ class Output: ################################################################################### ################################################################################### - def add_contract(self, contract: Contract, additional_fields: Optional[Dict] = None): + def add_contract( + self, contract: Contract, additional_fields: Optional[Dict] = None + ): if additional_fields is None: additional_fields = {} element = _create_base_element( @@ -323,7 +343,9 @@ class Output: ################################################################################### ################################################################################### - def add_function(self, function: Function, additional_fields: Optional[Dict] = None): + def add_function( + self, function: Function, additional_fields: Optional[Dict] = None + ): if additional_fields is None: additional_fields = {} type_specific_fields = { @@ -339,7 +361,9 @@ class Output: ) self._data["elements"].append(element) - def add_functions(self, functions: List[Function], additional_fields: Optional[Dict] = None): + def add_functions( + self, functions: List[Function], additional_fields: Optional[Dict] = None + ): if additional_fields is None: additional_fields = {} for function in sorted(functions, key=lambda x: x.name): @@ -357,7 +381,11 @@ class Output: additional_fields = {} type_specific_fields = {"parent": _create_parent_element(enum)} element = _create_base_element( - "enum", enum.name, enum.source_mapping, type_specific_fields, additional_fields + "enum", + enum.name, + enum.source_mapping, + type_specific_fields, + additional_fields, ) self._data["elements"].append(element) @@ -373,7 +401,11 @@ class Output: additional_fields = {} type_specific_fields = {"parent": _create_parent_element(struct)} element = _create_base_element( - "struct", struct.name, struct.source_mapping, type_specific_fields, additional_fields + "struct", + struct.name, + struct.source_mapping, + type_specific_fields, + additional_fields, ) self._data["elements"].append(element) @@ -392,7 +424,11 @@ class Output: "signature": event.full_name, } element = _create_base_element( - "event", event.name, event.source_mapping, type_specific_fields, additional_fields + "event", + event.name, + event.source_mapping, + type_specific_fields, + additional_fields, ) self._data["elements"].append(element) @@ -412,7 +448,11 @@ class Output: } node_name = str(node.expression) if node.expression else "" element = _create_base_element( - "node", node_name, node.source_mapping, type_specific_fields, additional_fields + "node", + node_name, + node.source_mapping, + type_specific_fields, + additional_fields, ) self._data["elements"].append(element) @@ -432,7 +472,11 @@ class Output: additional_fields = {} type_specific_fields = {"directive": pragma.directive} element = _create_base_element( - "pragma", pragma.version, pragma.source_mapping, type_specific_fields, additional_fields + "pragma", + pragma.version, + pragma.source_mapping, + type_specific_fields, + additional_fields, ) self._data["elements"].append(element) @@ -443,7 +487,9 @@ class Output: ################################################################################### ################################################################################### - def add_file(self, filename: str, content: str, additional_fields: Optional[Dict] = None): + def add_file( + self, filename: str, content: str, additional_fields: Optional[Dict] = None + ): if additional_fields is None: additional_fields = {} type_specific_fields = {"filename": filename, "content": content} @@ -459,12 +505,17 @@ class Output: ################################################################################### def add_pretty_table( - self, content: MyPrettyTable, name: str, additional_fields: Optional[Dict] = None + self, + content: MyPrettyTable, + name: str, + additional_fields: Optional[Dict] = None, ): if additional_fields is None: additional_fields = {} type_specific_fields = {"content": content.to_json(), "name": name} - element = _create_base_element("pretty_table", type_specific_fields, additional_fields) + element = _create_base_element( + "pretty_table", type_specific_fields, additional_fields + ) self._data["elements"].append(element) @@ -476,7 +527,11 @@ class Output: ################################################################################### def add_other( - self, name: str, source_mapping, slither, additional_fields: Optional[Dict] = None + self, + name: str, + source_mapping, + slither, + additional_fields: Optional[Dict] = None, ): # If this a tuple with (filename, start, end), convert it to a source mapping. if additional_fields is None: @@ -487,7 +542,10 @@ class Output: source_id = next( ( source_unit_id - for (source_unit_id, source_unit_filename) in slither.source_units.items() + for ( + source_unit_id, + source_unit_filename, + ) in slither.source_units.items() if source_unit_filename == filename ), -1, @@ -507,5 +565,7 @@ class Output: source_mapping = source_mapping.source_mapping # Create the underlying element and add it to our resulting json - element = _create_base_element("other", name, source_mapping, {}, additional_fields) + element = _create_base_element( + "other", name, source_mapping, {}, additional_fields + ) self._data["elements"].append(element) diff --git a/slither/utils/standard_libraries.py b/slither/utils/standard_libraries.py index d4cbdbe5e..ffc24586d 100644 --- a/slither/utils/standard_libraries.py +++ b/slither/utils/standard_libraries.py @@ -26,7 +26,9 @@ libraries = { "AragonOS-UnsafeAragonApp": lambda x: is_aragonos_unsafe_aragon_app(x), "AragonOS-Autopetrified": lambda x: is_aragonos_autopetrified(x), "AragonOS-DelegateProxy": lambda x: is_aragonos_delegate_proxy(x), - "AragonOS-DepositableDelegateProxy": lambda x: is_aragonos_depositable_delegate_proxy(x), + "AragonOS-DepositableDelegateProxy": lambda x: is_aragonos_depositable_delegate_proxy( + x + ), "AragonOS-DepositableStorage": lambda x: is_aragonos_delegate_proxy(x), "AragonOS-Initializable": lambda x: is_aragonos_initializable(x), "AragonOS-IsContract": lambda x: is_aragonos_is_contract(x), @@ -55,7 +57,10 @@ def is_openzepellin(contract: "Contract") -> bool: if not contract.is_from_dependency(): return False path = Path(contract.source_mapping["filename_absolute"]).parts - is_zep = "openzeppelin-solidity" in Path(contract.source_mapping["filename_absolute"]).parts + is_zep = ( + "openzeppelin-solidity" + in Path(contract.source_mapping["filename_absolute"]).parts + ) try: is_zep |= path[path.index("@openzeppelin") + 1] == "contracts" except IndexError: diff --git a/slither/utils/type_helpers.py b/slither/utils/type_helpers.py index 86bdc38a9..522bade5d 100644 --- a/slither/utils/type_helpers.py +++ b/slither/utils/type_helpers.py @@ -1,10 +1,16 @@ from typing import Union, Tuple, TYPE_CHECKING if TYPE_CHECKING: - from slither.core.declarations import Function, SolidityFunction, Contract, SolidityVariable + from slither.core.declarations import ( + Function, + SolidityFunction, + Contract, + SolidityVariable, + ) from slither.core.variables.variable import Variable ### core.declaration +# pylint: disable=used-before-assignment InternalCallType = Union[Function, SolidityFunction] HighLevelCallType = Tuple[Contract, Union[Function, Variable]] LibraryCallType = Tuple[Contract, Function] diff --git a/slither/visitors/expression/constants_folding.py b/slither/visitors/expression/constants_folding.py index 4a82ccfae..7a00cad16 100644 --- a/slither/visitors/expression/constants_folding.py +++ b/slither/visitors/expression/constants_folding.py @@ -1,7 +1,5 @@ -import logging - -from .expression import ExpressionVisitor from slither.core.expressions import BinaryOperationType, Literal +from slither.visitors.expression.expression import ExpressionVisitor class NotConstant(Exception): @@ -23,8 +21,8 @@ def set_val(expression, val): class ConstantFolding(ExpressionVisitor): - def __init__(self, expression, type): - self._type = type + def __init__(self, expression, custom_type): + self._type = custom_type super(ConstantFolding, self).__init__(expression) def result(self): diff --git a/slither/visitors/expression/export_values.py b/slither/visitors/expression/export_values.py index f478a83e1..246cc8384 100644 --- a/slither/visitors/expression/export_values.py +++ b/slither/visitors/expression/export_values.py @@ -1,9 +1,5 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignment_operation import AssignmentOperationType - -from slither.core.variables.variable import Variable - key = "ExportValues" diff --git a/slither/visitors/expression/expression.py b/slither/visitors/expression/expression.py index 01cdf81d9..ed84b6b33 100644 --- a/slither/visitors/expression/expression.py +++ b/slither/visitors/expression/expression.py @@ -5,7 +5,9 @@ from slither.core.expressions.assignment_operation import AssignmentOperation from slither.core.expressions.binary_operation import BinaryOperation from slither.core.expressions.call_expression import CallExpression from slither.core.expressions.conditional_expression import ConditionalExpression -from slither.core.expressions.elementary_type_name_expression import ElementaryTypeNameExpression +from slither.core.expressions.elementary_type_name_expression import ( + ElementaryTypeNameExpression, +) from slither.core.expressions.expression import Expression from slither.core.expressions.identifier import Identifier from slither.core.expressions.index_access import IndexAccess @@ -38,7 +40,7 @@ class ExpressionVisitor: # visit an expression # call pre_visit, visit_expression_name, post_visit - def _visit_expression(self, expression: Expression): + def _visit_expression(self, expression: Expression): # pylint: disable=too-many-branches self._pre_visit(expression) if isinstance(expression, AssignmentOperation): @@ -159,7 +161,7 @@ class ExpressionVisitor: # pre visit - def _pre_visit(self, expression): + def _pre_visit(self, expression): # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._pre_assignement_operation(expression) @@ -260,7 +262,7 @@ class ExpressionVisitor: # post visit - def _post_visit(self, expression): + def _post_visit(self, expression): # pylint: disable=too-many-branches if isinstance(expression, AssignmentOperation): self._post_assignement_operation(expression) diff --git a/slither/visitors/expression/find_push.py b/slither/visitors/expression/find_push.py index b3e79b0c4..cf2b07e60 100644 --- a/slither/visitors/expression/find_push.py +++ b/slither/visitors/expression/find_push.py @@ -1,6 +1,4 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.identifier import Identifier -from slither.core.expressions.index_access import IndexAccess from slither.visitors.expression.right_value import RightValue diff --git a/slither/visitors/expression/write_var.py b/slither/visitors/expression/write_var.py index 7267415a6..0509a2eb7 100644 --- a/slither/visitors/expression/write_var.py +++ b/slither/visitors/expression/write_var.py @@ -1,14 +1,5 @@ from slither.visitors.expression.expression import ExpressionVisitor -from slither.core.expressions.assignment_operation import AssignmentOperation - -from slither.core.variables.variable import Variable - -from slither.core.expressions.member_access import MemberAccess - -from slither.core.expressions.index_access import IndexAccess - - key = "WriteVar" diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index 066a9fab0..fb6a1d350 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -1,6 +1,10 @@ import logging -from slither.core.declarations import Function, SolidityVariable, SolidityVariableComposed +from slither.core.declarations import ( + Function, + SolidityVariable, + SolidityVariableComposed, +) from slither.core.expressions import ( AssignmentOperationType, UnaryOperationType, @@ -8,7 +12,10 @@ from slither.core.expressions import ( ) from slither.core.solidity_types import ArrayType, ElementaryType from slither.core.solidity_types.type import Type -from slither.core.variables.local_variable_init_from_tuple import LocalVariableInitFromTuple +from slither.core.variables.local_variable_init_from_tuple import ( + LocalVariableInitFromTuple, +) +from slither.slithir.exceptions import SlithIRError from slither.slithir.operations import ( Assignment, Binary, @@ -18,8 +25,6 @@ from slither.slithir.operations import ( InitArray, InternalCall, Member, - NewArray, - NewContract, TypeConversion, Unary, Unpack, @@ -30,11 +35,14 @@ from slither.slithir.tmp_operations.tmp_call import TmpCall from slither.slithir.tmp_operations.tmp_new_array import TmpNewArray from slither.slithir.tmp_operations.tmp_new_contract import TmpNewContract from slither.slithir.tmp_operations.tmp_new_elementary_type import TmpNewElementaryType -from slither.slithir.variables import Constant, ReferenceVariable, TemporaryVariable, TupleVariable +from slither.slithir.variables import ( + Constant, + ReferenceVariable, + TemporaryVariable, + TupleVariable, +) from slither.visitors.expression.expression import ExpressionVisitor -from slither.slithir.exceptions import SlithIRError - logger = logging.getLogger("VISTIOR:ExpressionToSlithIR") key = "expressionToSlithIR" @@ -85,33 +93,33 @@ _signed_to_unsigned = { def convert_assignment(left, right, t, return_type): if t == AssignmentOperationType.ASSIGN: return Assignment(left, right, return_type) - elif t == AssignmentOperationType.ASSIGN_OR: + if t == AssignmentOperationType.ASSIGN_OR: return Binary(left, left, right, BinaryType.OR) - elif t == AssignmentOperationType.ASSIGN_CARET: + if t == AssignmentOperationType.ASSIGN_CARET: return Binary(left, left, right, BinaryType.CARET) - elif t == AssignmentOperationType.ASSIGN_AND: + if t == AssignmentOperationType.ASSIGN_AND: return Binary(left, left, right, BinaryType.AND) - elif t == AssignmentOperationType.ASSIGN_LEFT_SHIFT: + if t == AssignmentOperationType.ASSIGN_LEFT_SHIFT: return Binary(left, left, right, BinaryType.LEFT_SHIFT) - elif t == AssignmentOperationType.ASSIGN_RIGHT_SHIFT: + if t == AssignmentOperationType.ASSIGN_RIGHT_SHIFT: return Binary(left, left, right, BinaryType.RIGHT_SHIFT) - elif t == AssignmentOperationType.ASSIGN_ADDITION: + if t == AssignmentOperationType.ASSIGN_ADDITION: return Binary(left, left, right, BinaryType.ADDITION) - elif t == AssignmentOperationType.ASSIGN_SUBTRACTION: + if t == AssignmentOperationType.ASSIGN_SUBTRACTION: return Binary(left, left, right, BinaryType.SUBTRACTION) - elif t == AssignmentOperationType.ASSIGN_MULTIPLICATION: + if t == AssignmentOperationType.ASSIGN_MULTIPLICATION: return Binary(left, left, right, BinaryType.MULTIPLICATION) - elif t == AssignmentOperationType.ASSIGN_DIVISION: + if t == AssignmentOperationType.ASSIGN_DIVISION: return Binary(left, left, right, BinaryType.DIVISION) - elif t == AssignmentOperationType.ASSIGN_MODULO: + if t == AssignmentOperationType.ASSIGN_MODULO: return Binary(left, left, right, BinaryType.MODULO) raise SlithIRError("Missing type during assignment conversion") class ExpressionToSlithIR(ExpressionVisitor): - def __init__(self, expression, node): - from slither.core.cfg.node import NodeType + def __init__(self, expression, node): # pylint: disable=super-init-not-called + from slither.core.cfg.node import NodeType # pylint: disable=import-outside-toplevel self._expression = expression self._node = node @@ -133,7 +141,7 @@ class ExpressionToSlithIR(ExpressionVisitor): if isinstance(left, list): # tuple expression: if isinstance(right, list): # unbox assigment assert len(left) == len(right) - for idx in range(len(left)): + for idx, _ in enumerate(left): if not left[idx] is None: operation = convert_assignment( left[idx], @@ -146,7 +154,7 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, None) else: assert isinstance(right, TupleVariable) - for idx in range(len(left)): + for idx, _ in enumerate(left): if not left[idx] is None: index = idx # The following test is probably always true? @@ -209,7 +217,9 @@ class ExpressionToSlithIR(ExpressionVisitor): new_right = right new_final = TemporaryVariable(self._node) - operation = Binary(new_final, new_left, new_right, _signed_to_unsigned[expression.type]) + operation = Binary( + new_final, new_left, new_right, _signed_to_unsigned[expression.type] + ) operation.set_expression(expression) self._result.append(operation) @@ -223,7 +233,7 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) - def _post_call_expression(self, expression): + def _post_call_expression(self, expression): # pylint: disable=too-many-branches,too-many-statements called = get(expression.called) args = [get(a) for a in expression.arguments if a] for arg in args: @@ -234,7 +244,10 @@ class ExpressionToSlithIR(ExpressionVisitor): # internal call # If tuple - if expression.type_call.startswith("tuple(") and expression.type_call != "tuple()": + if ( + expression.type_call.startswith("tuple(") + and expression.type_call != "tuple()" + ): val = TupleVariable(self._node) else: val = TemporaryVariable(self._node) @@ -261,7 +274,9 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) elif called.name == "selfbalance()": val = TemporaryVariable(self._node) - var = TypeConversion(val, SolidityVariable("this"), ElementaryType("address")) + var = TypeConversion( + val, SolidityVariable("this"), ElementaryType("address") + ) self._result.append(var) val1 = ReferenceVariable(self._node) @@ -270,7 +285,9 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val1) elif called.name == "address()": val = TemporaryVariable(self._node) - var = TypeConversion(val, SolidityVariable("this"), ElementaryType("address")) + var = TypeConversion( + val, SolidityVariable("this"), ElementaryType("address") + ) self._result.append(var) set_val(expression, val) elif called.name == "callvalue()": @@ -280,7 +297,10 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) else: # If tuple - if expression.type_call.startswith("tuple(") and expression.type_call != "tuple()": + if ( + expression.type_call.startswith("tuple(") + and expression.type_call != "tuple()" + ): val = TupleVariable(self._node) else: val = TemporaryVariable(self._node) @@ -302,7 +322,9 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, val) def _post_conditional_expression(self, expression): - raise Exception("Ternary operator are not convertible to SlithIR {}".format(expression)) + raise Exception( + "Ternary operator are not convertible to SlithIR {}".format(expression) + ) def _post_elementary_type_name_expression(self, expression): set_val(expression, expression.type) @@ -393,7 +415,7 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, val) - def _post_unary_operation(self, expression): + def _post_unary_operation(self, expression): # pylint: disable=too-many-branches,too-many-statements value = get(expression.expression) if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: lvalue = TemporaryVariable(self._node) @@ -407,12 +429,16 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) set_val(expression, value) elif expression.type in [UnaryOperationType.PLUSPLUS_PRE]: - operation = Binary(value, value, Constant("1", value.type), BinaryType.ADDITION) + operation = Binary( + value, value, Constant("1", value.type), BinaryType.ADDITION + ) operation.set_expression(expression) self._result.append(operation) set_val(expression, value) elif expression.type in [UnaryOperationType.MINUSMINUS_PRE]: - operation = Binary(value, value, Constant("1", value.type), BinaryType.SUBTRACTION) + operation = Binary( + value, value, Constant("1", value.type), BinaryType.SUBTRACTION + ) operation.set_expression(expression) self._result.append(operation) set_val(expression, value) @@ -421,7 +447,9 @@ class ExpressionToSlithIR(ExpressionVisitor): operation = Assignment(lvalue, value, value.type) operation.set_expression(expression) self._result.append(operation) - operation = Binary(value, value, Constant("1", value.type), BinaryType.ADDITION) + operation = Binary( + value, value, Constant("1", value.type), BinaryType.ADDITION + ) operation.set_expression(expression) self._result.append(operation) set_val(expression, lvalue) @@ -430,7 +458,9 @@ class ExpressionToSlithIR(ExpressionVisitor): operation = Assignment(lvalue, value, value.type) operation.set_expression(expression) self._result.append(operation) - operation = Binary(value, value, Constant("1", value.type), BinaryType.SUBTRACTION) + operation = Binary( + value, value, Constant("1", value.type), BinaryType.SUBTRACTION + ) operation.set_expression(expression) self._result.append(operation) set_val(expression, lvalue) @@ -438,9 +468,13 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, value) elif expression.type in [UnaryOperationType.MINUS_PRE]: lvalue = TemporaryVariable(self._node) - operation = Binary(lvalue, Constant("0", value.type), value, BinaryType.SUBTRACTION) + operation = Binary( + lvalue, Constant("0", value.type), value, BinaryType.SUBTRACTION + ) operation.set_expression(expression) self._result.append(operation) set_val(expression, lvalue) else: - raise SlithIRError("Unary operation to IR not supported {}".format(expression)) + raise SlithIRError( + "Unary operation to IR not supported {}".format(expression) + )