diff --git a/slither/tools/upgradeability/__main__.py b/slither/tools/upgradeability/__main__.py index f6991923c..eb56bb0d8 100644 --- a/slither/tools/upgradeability/__main__.py +++ b/slither/tools/upgradeability/__main__.py @@ -2,6 +2,7 @@ import logging import argparse import sys import json +import os from slither import Slither from crytic_compile import cryticparser @@ -31,6 +32,11 @@ def parse_args(): parser.add_argument('--new-version', help='New implementation filename') parser.add_argument('--new-contract-name', help='New contract name (if changed)') + parser.add_argument('--json', + help='Export the results as a JSON file ("--json -" to export to stdout)', + action='store', + default=False) + cryticparser.init(parser) if len(sys.argv) == 1: @@ -39,6 +45,41 @@ def parse_args(): return parser.parse_args() +################################################################################### +################################################################################### +# region Output +################################################################################### +################################################################################### + + +def output_json(filename, error, results): + # Create our encapsulated JSON result. + json_result = { + "success": error is None, + "error": error, + "results": results + } + + # Determine if we should output to stdout + if filename is None: + # Write json to console + print(json.dumps(json_result)) + else: + # Write json to file + if os.path.isfile(filename): + logger.info(yellow(f'{filename} exists already, the overwrite is prevented')) + else: + with open(filename, 'w', encoding='utf8') as f: + json.dump(json_result, f, indent=2) + +# endregion +################################################################################### +################################################################################### +# region Main +################################################################################### +################################################################################### + + def main(): args = parse_args() @@ -50,20 +91,28 @@ def main(): v1 = Slither(v1_filename, **vars(args)) v1_name = args.ContractName - results = OrderedDict() - - results['check_initialization'] = check_initialization(v1) + # Define some variables for potential JSON output + json_results = {} + output_error = None + outputting_json = args.json is not None + outputting_json_stdout = args.json == '-' + + json_results['check-initialization'] = check_initialization(v1) if not args.new_version: - results['compare_function_ids'] = compare_function_ids(v1, v1_name, proxy, proxy_name) - results['compare_variables_order_proxy'] = compare_variables_order_proxy(v1, v1_name, proxy, proxy_name) + json_results['compare-function-ids'] = compare_function_ids(v1, v1_name, proxy, proxy_name) + json_results['compare-variables-order-proxy'] = compare_variables_order_proxy(v1, v1_name, proxy, proxy_name) else: v2 = Slither(args.new_version, **vars(args)) v2_name = v1_name if not args.new_contract_name else args.new_contract_name - results['check_initialization_v2'] = check_initialization(v2) - results['compare_function_ids'] = compare_function_ids(v2, v2_name, proxy, proxy_name) - results['compare_variables_order_proxy'] = compare_variables_order_proxy(v2, v2_name, proxy, proxy_name) - results['compare_variables_order_implementation'] = compare_variables_order_implementation(v1, v1_name, v2, v2_name) + json_results['check-initialization-v2'] = check_initialization(v2) + json_results['compare-function-ids'] = compare_function_ids(v2, v2_name, proxy, proxy_name) + json_results['compare-variables-order-proxy'] = compare_variables_order_proxy(v2, v2_name, proxy, proxy_name) + json_results['compare-variables-order-implementation'] = compare_variables_order_implementation(v1, v1_name, v2, v2_name) + + # If we are outputting JSON, capture the redirected output and disable the redirect to output the final JSON. + if outputting_json: + output_json(None if outputting_json_stdout else args.json, output_error, json_results) - with open('results.json', 'w') as fp: - json.dump(results, fp) + +# endregion