diff --git a/mythril/analysis/security.py b/mythril/analysis/security.py index 71f0a516..ed741ecb 100644 --- a/mythril/analysis/security.py +++ b/mythril/analysis/security.py @@ -6,22 +6,28 @@ from mythril.analysis import modules import pkgutil import importlib.util import logging +import os +import sys log = logging.getLogger(__name__) OPCODE_LIST = [c[0] for _, c in opcodes.items()] -def reset_callback_modules(): +def reset_callback_modules(module_names=(), custom_modules_directory=""): """Clean the issue records of every callback-based module.""" - modules = get_detection_modules("callback") + modules = get_detection_modules("callback", module_names, custom_modules_directory) for module in modules: module.detector.reset_module() -def get_detection_module_hooks(modules, hook_type="pre"): +def get_detection_module_hooks(modules, hook_type="pre", custom_modules_directory=""): hook_dict = defaultdict(list) - _modules = get_detection_modules(entrypoint="callback", include_modules=modules) + _modules = get_detection_modules( + entrypoint="callback", + include_modules=modules, + custom_modules_directory=custom_modules_directory, + ) for module in _modules: hooks = ( module.detector.pre_hooks @@ -45,14 +51,13 @@ def get_detection_module_hooks(modules, hook_type="pre"): return dict(hook_dict) -def get_detection_modules(entrypoint, include_modules=()): +def get_detection_modules(entrypoint, include_modules=(), custom_modules_directory=""): """ :param entrypoint: :param include_modules: :return: """ - module = importlib.import_module("mythril.analysis.modules.base") module.log.setLevel(log.level) @@ -60,27 +65,35 @@ def get_detection_modules(entrypoint, include_modules=()): _modules = [] - if not include_modules: - for loader, module_name, _ in pkgutil.walk_packages(modules.__path__): + for loader, module_name, _ in pkgutil.walk_packages(modules.__path__): + if include_modules and module_name not in include_modules: + continue + + if module_name != "base": + module = importlib.import_module("mythril.analysis.modules." + module_name) + module.log.setLevel(log.level) + if module.detector.entrypoint == entrypoint: + _modules.append(module) + if custom_modules_directory: + custom_modules_path = os.path.abspath(custom_modules_directory) + if custom_modules_path not in sys.path: + sys.path.append(custom_modules_path) + + for loader, module_name, _ in pkgutil.walk_packages([custom_modules_path]): + if include_modules and module_name not in include_modules: + continue + if module_name != "base": - module = importlib.import_module( - "mythril.analysis.modules." + module_name - ) + module = importlib.import_module(module_name, custom_modules_path) module.log.setLevel(log.level) if module.detector.entrypoint == entrypoint: _modules.append(module) - else: - for module_name in include_modules: - module = importlib.import_module("mythril.analysis.modules." + module_name) - if module.__name__ != "base" and module.detector.entrypoint == entrypoint: - module.log.setLevel(log.level) - _modules.append(module) log.info("Found %s detection modules", len(_modules)) return _modules -def fire_lasers(statespace, module_names=()): +def fire_lasers(statespace, module_names=(), custom_modules_directory=""): """ :param statespace: @@ -91,22 +104,28 @@ def fire_lasers(statespace, module_names=()): issues = [] for module in get_detection_modules( - entrypoint="post", include_modules=module_names + entrypoint="post", + include_modules=module_names, + custom_modules_directory=custom_modules_directory, ): log.info("Executing " + module.detector.name) issues += module.detector.execute(statespace) - issues += retrieve_callback_issues(module_names) + issues += retrieve_callback_issues(module_names, custom_modules_directory) return issues -def retrieve_callback_issues(module_names=()): +def retrieve_callback_issues(module_names=(), custom_modules_directory=""): issues = [] for module in get_detection_modules( - entrypoint="callback", include_modules=module_names + entrypoint="callback", + include_modules=module_names, + custom_modules_directory=custom_modules_directory, ): log.debug("Retrieving results for " + module.detector.name) issues += module.detector.issues - reset_callback_modules() + reset_callback_modules( + module_names=module_names, custom_modules_directory=custom_modules_directory + ) return issues diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index 4164758f..02b5c8ca 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -56,6 +56,7 @@ class SymExecWrapper: disable_dependency_pruning=False, run_analysis_modules=True, enable_coverage_strategy=False, + custom_modules_directory="", ): """ @@ -93,7 +94,8 @@ class SymExecWrapper: ) requires_statespace = ( - compulsory_statespace or len(get_detection_modules("post", modules)) > 0 + compulsory_statespace + or len(get_detection_modules("post", modules, custom_modules_directory)) > 0 ) if not contract.creation_code: self.accounts = {hex(ATTACKER_ADDRESS): attacker_account} @@ -135,11 +137,19 @@ class SymExecWrapper: if run_analysis_modules: self.laser.register_hooks( hook_type="pre", - hook_dict=get_detection_module_hooks(modules, hook_type="pre"), + hook_dict=get_detection_module_hooks( + modules, + hook_type="pre", + custom_modules_directory=custom_modules_directory, + ), ) self.laser.register_hooks( hook_type="post", - hook_dict=get_detection_module_hooks(modules, hook_type="post"), + hook_dict=get_detection_module_hooks( + modules, + hook_type="post", + custom_modules_directory=custom_modules_directory, + ), ) if isinstance(contract, SolidityContract): diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index 839d1fd5..8962fa16 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -396,6 +396,11 @@ def create_analyzer_parser(analyzer_parser: ArgumentParser): action="store_true", help="enable coverage based search strategy", ) + options.add_argument( + "--custom-modules-directory", + help="designates a separate directory to search for custom analysis modules", + metavar="CUSTOM_MODULES_DIRECTORY", + ) def validate_args(args: Namespace): @@ -567,6 +572,9 @@ def execute_command( solver_timeout=args.solver_timeout, requires_dynld=not args.no_onchain_storage_access, enable_coverage_strategy=args.enable_coverage_strategy, + custom_modules_directory=args.custom_modules_directory + if args.custom_modules_directory + else "", ) if not disassembler.contracts: diff --git a/mythril/mythril/mythril_analyzer.py b/mythril/mythril/mythril_analyzer.py index deb86281..d5e87369 100644 --- a/mythril/mythril/mythril_analyzer.py +++ b/mythril/mythril/mythril_analyzer.py @@ -42,6 +42,7 @@ class MythrilAnalyzer: disable_dependency_pruning: bool = False, solver_timeout: Optional[int] = None, enable_coverage_strategy: bool = False, + custom_modules_directory: str = "", ): """ @@ -63,6 +64,7 @@ class MythrilAnalyzer: self.enable_iprof = enable_iprof self.disable_dependency_pruning = disable_dependency_pruning self.enable_coverage_strategy = enable_coverage_strategy + self.custom_modules_directory = custom_modules_directory analysis_args.set_loop_bound(loop_bound) analysis_args.set_solver_timeout(solver_timeout) @@ -89,6 +91,7 @@ class MythrilAnalyzer: disable_dependency_pruning=self.disable_dependency_pruning, run_analysis_modules=False, enable_coverage_strategy=self.enable_coverage_strategy, + custom_modules_directory=self.custom_modules_directory, ) return get_serializable_statespace(sym) @@ -125,6 +128,7 @@ class MythrilAnalyzer: disable_dependency_pruning=self.disable_dependency_pruning, run_analysis_modules=False, enable_coverage_strategy=self.enable_coverage_strategy, + custom_modules_directory=self.custom_modules_directory, ) return generate_graph(sym, physics=enable_physics, phrackify=phrackify) @@ -163,18 +167,22 @@ class MythrilAnalyzer: enable_iprof=self.enable_iprof, disable_dependency_pruning=self.disable_dependency_pruning, enable_coverage_strategy=self.enable_coverage_strategy, + custom_modules_directory=self.custom_modules_directory, ) - - issues = fire_lasers(sym, modules) + issues = fire_lasers(sym, modules, self.custom_modules_directory) except KeyboardInterrupt: log.critical("Keyboard Interrupt") - issues = retrieve_callback_issues(modules) + issues = retrieve_callback_issues( + modules, self.custom_modules_directory + ) except Exception: log.critical( "Exception occurred, aborting analysis. Please report this issue to the Mythril GitHub page.\n" + traceback.format_exc() ) - issues = retrieve_callback_issues(modules) + issues = retrieve_callback_issues( + modules, self.custom_modules_directory + ) exceptions.append(traceback.format_exc()) for issue in issues: issue.add_code_info(contract)