diff --git a/mythril/analysis/security.py b/mythril/analysis/security.py index 5c5e520a..4d95497a 100644 --- a/mythril/analysis/security.py +++ b/mythril/analysis/security.py @@ -65,38 +65,36 @@ def get_detection_modules(entrypoint, include_modules=(), custom_modules_directo _modules = [] - for loader, module_name, _ in pkgutil.walk_packages(modules.__path__): + custom_modules_directory = os.path.abspath(custom_modules_directory) + + if custom_modules_directory and custom_modules_directory not in sys.path: + sys.path.append(custom_modules_directory) + + custom_packages = ( + list(pkgutil.walk_packages([custom_modules_directory])) + if custom_modules_directory + else [] + ) + packages = list(pkgutil.walk_packages(modules.__path__)) + custom_packages + + for loader, module_name, _ in packages: if include_modules and module_name not in include_modules: continue if module_name != "base": - module = importlib.import_module("mythril.analysis.modules." + module_name) + try: + module = importlib.import_module( + "mythril.analysis.modules." + module_name + ) + except ModuleNotFoundError: + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError: + raise ModuleNotFoundError module.log.setLevel(log.level) if module.detector.entrypoint == entrypoint: _modules.append(module) - if custom_modules_directory: - custom_modules = [os.path.abspath(custom_modules_directory)] - sys.path.append(custom_modules_directory) - for loader, module_name, _ in pkgutil.walk_packages(custom_modules): - if include_modules and module_name not in include_modules: - continue - - if module_name != "base": - module = importlib.import_module(module_name, custom_modules[0]) - module.log.setLevel(log.level) - if module.detector.entrypoint == entrypoint: - _modules.append(module) - - """ - for loader, module_name, _ in pkgutil.walk_packages([custom_modules_path]): - - custom_modules_path = os.path.abspath("custom/") - sys.path.append(custom_modules_path); - module = importlib.import_module( - module_name, custom_modules_path - ) - """ log.info("Found %s detection modules", len(_modules)) return _modules