diff --git a/slither/__main__.py b/slither/__main__.py index 38b5c25ef..7dbec0077 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -19,7 +19,7 @@ from slither.detectors.abstract_detector import (AbstractDetector, from slither.printers import all_printers from slither.printers.abstract_printer import AbstractPrinter from slither.slither import Slither -from slither.utils.output_redirect import StandardOutputRedirect +from slither.utils.output_capture import StandardOutputCapture from slither.utils.colors import red, yellow, set_colorization_enabled from slither.utils.command_line import (output_detectors, output_results_to_markdown, output_detectors_json, output_printers, @@ -109,8 +109,8 @@ def output_json(filename, error, results): "results": results } - # Determine if our filename is referring to stdout - if filename == "-": + # Determine if we should output to stdout + if filename is None: # Write json to console print(json.dumps(json_result)) else: @@ -511,12 +511,16 @@ def main_impl(all_detector_classes, all_printer_classes): # Set colorization option set_colorization_enabled(not args.disable_color) - # If we are outputting json to stdout, we'll want to define some variables and redirect stdout - output_error = None + # Define some variables for potential JSON output json_results = {} + output_error = None outputting_json = args.json is not None + outputting_json_stdout = args.json == '-' + + # If we are outputting JSON, capture all standard output. If we are outputting to stdout, we block typical stdout + # output. if outputting_json: - StandardOutputRedirect.enable() + StandardOutputCapture.enable(outputting_json_stdout) printer_classes = choose_printers(args, all_printer_classes) detector_classes = choose_detectors(args, all_detector_classes) @@ -602,10 +606,10 @@ def main_impl(all_detector_classes, all_printer_classes): # If we are outputting JSON, capture the redirected output and disable the redirect to output the final JSON. if outputting_json: - json_results['stdout'] = StandardOutputRedirect.get_stdout_output() - json_results['stderr'] = StandardOutputRedirect.get_stderr_output() - StandardOutputRedirect.disable() - output_json(args.json, output_error, json_results) + json_results['stdout'] = StandardOutputCapture.get_stdout_output() + json_results['stderr'] = StandardOutputCapture.get_stderr_output() + StandardOutputCapture.disable() + output_json(None if outputting_json_stdout else args.json, output_error, json_results) # Exit with the appropriate status code if output_error: diff --git a/slither/utils/output_capture.py b/slither/utils/output_capture.py new file mode 100644 index 000000000..d557ede1e --- /dev/null +++ b/slither/utils/output_capture.py @@ -0,0 +1,90 @@ +import io +import logging +import sys + + +class CapturingStringIO(io.StringIO): + """ + I/O implementation which captures output, and optionally mirrors it to the original I/O stream it replaces. + """ + def __init__(self, original_io=None): + super(CapturingStringIO, self).__init__() + self.original_io = original_io + + def write(self, s): + super().write(s) + if self.original_io: + self.original_io.write(s) + + +class StandardOutputCapture: + """ + Redirects and captures standard output/errors. + """ + original_stdout = None + original_stderr = None + original_logger_handlers = None + + @staticmethod + def enable(block_original=True): + """ + Redirects stdout and stderr to a capturable StringIO. + :param block_original: If True, blocks all output to the original stream. If False, duplicates output. + :return: None + """ + # Redirect stdout + if StandardOutputCapture.original_stdout is None: + StandardOutputCapture.original_stdout = sys.stdout + sys.stdout = CapturingStringIO(None if block_original else StandardOutputCapture.original_stdout) + + # Redirect stderr + if StandardOutputCapture.original_stderr is None: + StandardOutputCapture.original_stderr = sys.stderr + sys.stderr = CapturingStringIO(None if block_original else StandardOutputCapture.original_stderr) + + # Backup and swap root logger handlers + root_logger = logging.getLogger() + StandardOutputCapture.original_logger_handlers = root_logger.handlers + root_logger.handlers = [logging.StreamHandler(sys.stderr)] + + @staticmethod + def disable(): + """ + Disables redirection of stdout/stderr, if previously enabled. + :return: None + """ + # If we have a stdout backup, restore it. + if StandardOutputCapture.original_stdout is not None: + sys.stdout.close() + sys.stdout = StandardOutputCapture.original_stdout + StandardOutputCapture.original_stdout = None + + # If we have an stderr backup, restore it. + if StandardOutputCapture.original_stderr is not None: + sys.stderr.close() + sys.stderr = StandardOutputCapture.original_stderr + StandardOutputCapture.original_stderr = None + + # Restore our logging handlers + if StandardOutputCapture.original_logger_handlers is not None: + root_logger = logging.getLogger() + root_logger.handlers = StandardOutputCapture.original_logger_handlers + StandardOutputCapture.original_logger_handlers = None + + @staticmethod + def get_stdout_output(): + """ + Obtains the output from the currently set stdout + :return: Returns stdout output as a string + """ + sys.stdout.seek(0) + return sys.stdout.read() + + @staticmethod + def get_stderr_output(): + """ + Obtains the output from the currently set stderr + :return: Returns stderr output as a string + """ + sys.stderr.seek(0) + return sys.stderr.read() diff --git a/slither/utils/output_redirect.py b/slither/utils/output_redirect.py deleted file mode 100644 index 9811b21b2..000000000 --- a/slither/utils/output_redirect.py +++ /dev/null @@ -1,67 +0,0 @@ -import io -import logging -import sys - - -class StandardOutputRedirect: - """ - Redirects and captures standard output/errors. - """ - original_stdout = None - original_stderr = None - - @staticmethod - def enable(): - """ - Redirects stdout and/or stderr to a captureable StringIO. - :param redirect_stdout: True if redirection is desired for stdout. - :param redirect_stderr: True if redirection is desired for stderr. - :return: None - """ - # Redirect stdout - if StandardOutputRedirect.original_stdout is None: - StandardOutputRedirect.original_stdout = sys.stdout - sys.stdout = io.StringIO() - - # Redirect stderr - if StandardOutputRedirect.original_stderr is None: - StandardOutputRedirect.original_stderr = sys.stderr - sys.stderr = io.StringIO() - root_logger = logging.getLogger() - root_logger.handlers = [logging.StreamHandler(sys.stderr)] - - @staticmethod - def disable(): - """ - Disables redirection of stdout/stderr, if previously enabled. - :return: None - """ - # If we have a stdout backup, restore it. - if StandardOutputRedirect.original_stdout is not None: - sys.stdout.close() - sys.stdout = StandardOutputRedirect.original_stdout - StandardOutputRedirect.original_stdout = None - - # If we have an stderr backup, restore it. - if StandardOutputRedirect.original_stderr is not None: - sys.stderr.close() - sys.stderr = StandardOutputRedirect.original_stderr - StandardOutputRedirect.original_stderr = None - - @staticmethod - def get_stdout_output(): - """ - Obtains the output from stdout - :return: Returns stdout output as a string - """ - sys.stdout.seek(0) - return sys.stdout.read() - - @staticmethod - def get_stderr_output(): - """ - Obtains the output from stdout - :return: Returns stdout output as a string - """ - sys.stderr.seek(0) - return sys.stderr.read() \ No newline at end of file