diff --git a/mythril/analysis/security.py b/mythril/analysis/security.py index bae08ff8..7a98e128 100644 --- a/mythril/analysis/security.py +++ b/mythril/analysis/security.py @@ -1,12 +1,35 @@ -from mythril.analysis.report import Report +from collections import defaultdict +from ethereum.opcodes import opcodes from mythril.analysis import modules import pkgutil import logging -def get_detection_modules(entrypoint, except_modules=None): - except_modules = [] if except_modules is None else except_modules - except_modules.append("base") # always exclude base class file +OPCODE_LIST = [c[0] for _, c in opcodes.items()] + + +def get_detection_module_hooks(): + hook_dict = defaultdict(list) + _modules = get_detection_modules(entrypoint="callback") + for module in _modules: + for op_code in map(lambda x: x.upper(), module.detector.hooks): + if op_code in OPCODE_LIST: + hook_dict[op_code].append(module.detector.execute) + elif op_code.endswith("*"): + to_register = filter(lambda x: x.startswith(op_code[:-1]), OPCODE_LIST) + for actual_hook in to_register: + hook_dict[actual_hook].append(module.detector.execute) + else: + logging.error( + "Encountered invalid hook opcode %s in module %s", + op_code, + module.detector.name, + ) + return dict(hook_dict) + + +def get_detection_modules(entrypoint, except_modules=()): + except_modules = list(except_modules) + ["base"] _modules = [] for loader, name, _ in pkgutil.walk_packages(modules.__path__): @@ -21,7 +44,7 @@ def get_detection_modules(entrypoint, except_modules=None): return _modules -def fire_lasers(statespace, module_names=None): +def fire_lasers(statespace, module_names=()): logging.info("Starting analysis") issues = [] diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index 4b01afc3..947443f4 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -1,3 +1,4 @@ +from mythril.analysis.security import get_detection_module_hooks from mythril.laser.ethereum import svm from mythril.laser.ethereum.state.account import Account from mythril.ether.soliditycontract import SolidityContract, ETHContract @@ -59,6 +60,7 @@ class SymExecWrapper: create_timeout=create_timeout, max_transaction_count=max_transaction_count, ) + self.laser.register_hooks(hook_type="post", hook_dict=get_detection_module_hooks()) if isinstance(contract, SolidityContract): self.laser.sym_exec( diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index e96ce4f7..5e5e2cf8 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -35,9 +35,6 @@ Main symbolic execution engine. """ -OPCODE_LIST = [c[0] for _, c in opcodes.items()] - - class LaserEVM: """ Laser EVM class @@ -79,29 +76,20 @@ class LaserEVM: self.pre_hooks = defaultdict(list) self.post_hooks = defaultdict(list) - self.register_detection_modules() - logging.info( "LASER EVM initialized with dynamic loader: " + str(dynamic_loader) ) - def register_detection_modules(self): - modules = get_detection_modules(entrypoint="callback") - for module in modules: - for hook in module.detector.hooks: - hook = hook.upper() - if hook in OPCODE_LIST: - self.post_hooks[hook].append(module.detector.execute) - elif hook.endswith("*"): - to_register = filter(lambda x: x.startswith(hook[:-1]), OPCODE_LIST) - for actual_hook in to_register: - self.post_hooks[actual_hook].append(module.detector.execute) - else: - logging.error( - "Encountered invalid hook opcode %s in module %s", - hook, - module.detector.name, - ) + def register_hooks(self, hook_type: str, hook_dict: Dict[str, List[Callable]]): + if hook_type == "pre": + entrypoint = self.pre_hooks + elif hook_type == "post": + entrypoint = self.post_hooks + else: + raise ValueError("Invalid hook type %s. Must be one of {pre, post}", hook_type) + + for op_code, funcs in hook_dict.items(): + entrypoint[op_code].extend(funcs) @property def accounts(self) -> Dict[str, Account]: diff --git a/tests/native_test.py b/tests/native_test.py index 74f7344b..4ceb0467 100644 --- a/tests/native_test.py +++ b/tests/native_test.py @@ -1,6 +1,4 @@ -import json from mythril.ether.soliditycontract import SolidityContract - from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.ethereum.state.global_state import GlobalState diff --git a/tests/svm_test.py b/tests/svm_test.py index d0509da2..19abfa05 100644 --- a/tests/svm_test.py +++ b/tests/svm_test.py @@ -1,4 +1,5 @@ import json +from mythril.analysis.security import get_detection_module_hooks from mythril.analysis.symbolic import SymExecWrapper from mythril.analysis.callgraph import generate_graph from mythril.ether.ethcontract import ETHContract @@ -85,6 +86,7 @@ class SVMTestCase(BaseTestCase): accounts = {account.address: account} laser = svm.LaserEVM(accounts, max_depth=22, max_transaction_count=1) + laser.register_hooks(hook_type="post", hook_dict=get_detection_module_hooks()) laser.sym_exec(account.address) laser_info = _all_info(laser)